1use std::collections::HashMap;
5use std::net::SocketAddr;
6use std::sync::Arc;
7
8use bytes::Bytes;
9use http_body_util::{BodyExt, Full};
10use hyper::body::Incoming;
11use hyper::service::service_fn;
12use hyper::StatusCode;
13use hyper_util::rt::TokioIo;
14use tokio::net::TcpListener;
15
16use crate::error::Result;
17use crate::http::{Request, Response};
18use crate::router::Router;
19
20pub struct Server {
22 router: Arc<Router>,
23 addr: SocketAddr,
24}
25
26impl Server {
27 pub fn new(router: Router, addr: SocketAddr) -> Self {
29 Self {
30 router: Arc::new(router),
31 addr,
32 }
33 }
34
35 pub async fn run(self) -> Result<()> {
39 let listener = TcpListener::bind(self.addr).await?;
40 let shutdown = shutdown_signal();
47 tokio::pin!(shutdown);
48
49 loop {
50 tokio::select! {
51 accept = listener.accept() => {
52 let (stream, peer) = accept?;
53 let io = TokioIo::new(stream);
54 let router = self.router.clone();
55 tokio::spawn(async move {
56 let svc = service_fn(move |req: hyper::Request<Incoming>| {
57 let router = router.clone();
58 async move { handle(router, req, peer).await }
59 });
60 let conn = hyper::server::conn::http1::Builder::new()
61 .keep_alive(true)
62 .serve_connection(io, svc);
63 if let Err(e) = conn.await {
64 log::debug!("connection error: {e}");
67 }
68 });
69 }
70 _ = &mut shutdown => {
71 log::info!("shutdown signal received, stopping accept loop");
72 break;
73 }
74 }
75 }
76
77 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
78 Ok(())
79 }
80}
81
82async fn handle(
83 router: Arc<Router>,
84 hyper_req: hyper::Request<Incoming>,
85 _peer: SocketAddr,
86) -> std::result::Result<hyper::Response<Full<Bytes>>, hyper::Error> {
87 let method = hyper_req.method().clone();
88 let uri = hyper_req.uri().clone();
89 let path = uri.path().to_string();
90 let query = uri.query().unwrap_or("").to_string();
91
92 let mut headers = HashMap::new();
93 for (name, value) in hyper_req.headers() {
94 if let Ok(v) = value.to_str() {
95 headers.insert(name.as_str().to_ascii_lowercase(), v.to_string());
96 }
97 }
98
99 let body = match hyper_req.into_body().collect().await {
100 Ok(b) => b.to_bytes(),
101 Err(_) => {
102 return Ok(simple_response(
103 StatusCode::BAD_REQUEST,
104 "could not read body",
105 ));
106 }
107 };
108
109 let our_req = Request::new(method, path, query, headers, body);
110 let our_resp = router.dispatch(our_req).await;
111 Ok(to_hyper(our_resp))
112}
113
114fn to_hyper(resp: Response) -> hyper::Response<Full<Bytes>> {
115 let mut builder = hyper::Response::builder().status(resp.status);
116 for (name, value) in resp.headers {
117 builder = builder.header(name, value);
118 }
119 builder.body(Full::new(resp.body)).unwrap_or_else(|_| {
120 hyper::Response::builder()
121 .status(StatusCode::INTERNAL_SERVER_ERROR)
122 .body(Full::new(Bytes::from("internal error")))
123 .unwrap()
124 })
125}
126
127fn simple_response(status: StatusCode, body: &str) -> hyper::Response<Full<Bytes>> {
128 hyper::Response::builder()
129 .status(status)
130 .header("content-type", "text/plain; charset=utf-8")
131 .body(Full::new(Bytes::from(body.to_string())))
132 .unwrap()
133}
134
135async fn shutdown_signal() {
136 let ctrl_c = async {
137 tokio::signal::ctrl_c().await.ok();
138 };
139
140 #[cfg(unix)]
141 let terminate = async {
142 if let Ok(mut sig) =
143 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
144 {
145 sig.recv().await;
146 }
147 };
148
149 #[cfg(not(unix))]
150 let terminate = std::future::pending::<()>();
151
152 tokio::select! {
153 _ = ctrl_c => {}
154 _ = terminate => {}
155 }
156}
157
158pub async fn serve_static(root: std::path::PathBuf, name: &str) -> Result<Response> {
162 let safe: String = name
163 .chars()
164 .filter(|c| *c != '/' && *c != '\\' && *c != '\0')
165 .collect();
166 if safe.contains("..") {
167 return Err(crate::error::Error::BadRequest("invalid path".into()));
168 }
169 let path = root.join(&safe);
170 if !path.is_file() {
171 return Err(crate::error::Error::NotFound(safe));
172 }
173 let bytes = tokio::fs::read(&path).await?;
174 Ok(Response::new(StatusCode::OK, Bytes::from(bytes))
175 .with_header("content-type", guess_content_type(&safe)))
176}
177
178fn guess_content_type(name: &str) -> &'static str {
179 match name.rsplit_once('.').map(|(_, ext)| ext) {
180 Some("css") => "text/css; charset=utf-8",
181 Some("js") => "application/javascript; charset=utf-8",
182 Some("png") => "image/png",
183 Some("jpg" | "jpeg") => "image/jpeg",
184 Some("svg") => "image/svg+xml",
185 Some("ico") => "image/x-icon",
186 Some("html") => "text/html; charset=utf-8",
187 Some("woff2") => "font/woff2",
188 Some("json") => "application/json; charset=utf-8",
189 _ => "application/octet-stream",
190 }
191}