Skip to main content

nano_web/
server.rs

1use anyhow::Result;
2use bytes::Bytes;
3use http_body_util::Full;
4use hyper::header::{self, HeaderName, HeaderValue};
5use hyper::server::conn::http1;
6use hyper::service::service_fn;
7use hyper::{Method, Request, Response, StatusCode};
8use hyper_util::rt::TokioIo;
9use socket2::{Domain, Protocol, Socket, Type};
10use std::net::SocketAddr;
11use std::path::PathBuf;
12use std::sync::Arc;
13use tokio::net::TcpListener;
14use tracing::{debug, info, warn};
15
16use crate::routes::NanoWeb;
17
18#[derive(Clone)]
19pub struct ServeConfig {
20    pub public_dir: PathBuf,
21    pub port: u16,
22    pub dev: bool,
23    pub spa_mode: bool,
24    pub config_prefix: String,
25    pub log_requests: bool,
26}
27
28struct AppState {
29    server: Arc<NanoWeb>,
30    config: ServeConfig,
31}
32
33/// Create a TCP listener with `SO_REUSEPORT` for better multi-core scaling
34fn create_reuse_port_listener(addr: SocketAddr) -> Result<std::net::TcpListener> {
35    let socket = Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP))?;
36    socket.set_reuse_address(true)?;
37    #[cfg(unix)]
38    socket.set_reuse_port(true)?;
39    socket.set_nonblocking(true)?;
40    socket.bind(&addr.into())?;
41    socket.listen(8192)?; // Large backlog for high concurrency
42    Ok(socket.into())
43}
44
45pub async fn start_server(config: ServeConfig) -> Result<()> {
46    let server = Arc::new(NanoWeb::new());
47    server.populate_routes(&config.public_dir, &config.config_prefix)?;
48
49    let state = Arc::new(AppState {
50        server,
51        config: config.clone(),
52    });
53
54    info!("Routes loaded: {}", state.server.route_count());
55
56    let addr: SocketAddr = ([0, 0, 0, 0], config.port).into();
57    let std_listener = create_reuse_port_listener(addr)?;
58    let listener = TcpListener::from_std(std_listener)?;
59
60    info!("Starting server on http://{}", addr);
61    info!("Serving directory: {:?}", config.public_dir);
62
63    let shutdown = shutdown_signal();
64    tokio::pin!(shutdown);
65
66    loop {
67        tokio::select! {
68            result = listener.accept() => {
69                let (stream, _) = result?;
70                let io = TokioIo::new(stream);
71                let state = state.clone();
72
73                tokio::spawn(async move {
74                    let service = service_fn(move |req| {
75                        let state = state.clone();
76                        async move { handle_request(req, state) }
77                    });
78
79                    if let Err(e) = http1::Builder::new()
80                        .keep_alive(true)
81                        .pipeline_flush(true)
82                        .serve_connection(io, service)
83                        .await
84                    {
85                        debug!("Connection error: {}", e);
86                    }
87                });
88            }
89            () = &mut shutdown => {
90                info!("Shutdown signal received, stopping server");
91                break;
92            }
93        }
94    }
95
96    Ok(())
97}
98
99async fn shutdown_signal() {
100    let ctrl_c = async {
101        tokio::signal::ctrl_c()
102            .await
103            .expect("failed to install Ctrl+C handler");
104    };
105
106    #[cfg(unix)]
107    let terminate = async {
108        tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
109            .expect("failed to install SIGTERM handler")
110            .recv()
111            .await;
112    };
113
114    #[cfg(not(unix))]
115    let terminate = std::future::pending::<()>();
116
117    tokio::select! {
118        () = ctrl_c => {},
119        () = terminate => {},
120    }
121}
122
123type HyperResponse = Response<Full<Bytes>>;
124
125#[allow(clippy::needless_pass_by_value, clippy::unnecessary_wraps)]
126fn handle_request(
127    req: Request<hyper::body::Incoming>,
128    state: Arc<AppState>,
129) -> Result<HyperResponse, std::convert::Infallible> {
130    let is_head = req.method() == Method::HEAD;
131
132    if req.method() != Method::GET && !is_head {
133        return Ok(response(
134            StatusCode::METHOD_NOT_ALLOWED,
135            "Method Not Allowed",
136        ));
137    }
138
139    let raw_path = req.uri().path();
140
141    // Health check
142    if raw_path == "/_health" {
143        let body = format!(
144            r#"{{"status":"ok","timestamp":"{}"}}"#,
145            httpdate::fmt_http_date(std::time::SystemTime::now())
146        );
147        return Ok(Response::builder()
148            .status(StatusCode::OK)
149            .header("content-type", "application/json")
150            .body(Full::new(Bytes::from(body)))
151            .expect("health check response"));
152    }
153
154    // Path validation — use the sanitized path for all lookups
155    let path = match crate::path::validate_request_path(raw_path) {
156        Ok(sanitized) => sanitized,
157        Err(e) => {
158            warn!("Path validation failed for '{}': {}", raw_path, e);
159            return Ok(response(StatusCode::BAD_REQUEST, "Bad Request"));
160        }
161    };
162
163    // Dev mode: refresh if modified
164    if state.config.dev {
165        let _ = state.server.refresh_if_modified(
166            &path,
167            &state.config.public_dir,
168            &state.config.config_prefix,
169        );
170    }
171
172    let accept_encoding = req
173        .headers()
174        .get("accept-encoding")
175        .and_then(|h| h.to_str().ok())
176        .unwrap_or("");
177
178    let if_none_match = req
179        .headers()
180        .get("if-none-match")
181        .and_then(|h| h.to_str().ok());
182
183    let mut buf = state.server.get_response(&path, accept_encoding);
184
185    // Try with trailing slash
186    if buf.is_none() && !path.ends_with('/') {
187        let with_slash = format!("{path}/");
188        buf = state.server.get_response(&with_slash, accept_encoding);
189    }
190
191    // SPA fallback
192    if buf.is_none() && state.config.spa_mode {
193        debug!("SPA fallback for: {}", path);
194        buf = state.server.get_response("/", accept_encoding);
195    }
196
197    let resp = if let Some(ref b) = buf {
198        // ETag conditional: return 304 if client already has this version
199        if let Some(etag) = if_none_match {
200            if etag == b.etag.as_ref() {
201                return Ok(Response::builder()
202                    .status(StatusCode::NOT_MODIFIED)
203                    .header("etag", b.etag.as_ref())
204                    .header("cache-control", b.cache_control.as_ref())
205                    .body(Full::new(Bytes::new()))
206                    .expect("304 response"));
207            }
208        }
209        build_response(b, is_head)
210    } else {
211        debug!("Route not found: {path}");
212        response(StatusCode::NOT_FOUND, "Not Found")
213    };
214
215    if state.config.log_requests {
216        info!(
217            method = %req.method(),
218            path = %path,
219            status = resp.status().as_u16(),
220            "request"
221        );
222    }
223
224    Ok(resp)
225}
226
227fn response(status: StatusCode, body: &'static str) -> HyperResponse {
228    Response::builder()
229        .status(status)
230        .body(Full::new(Bytes::from_static(body.as_bytes())))
231        .expect("error response")
232}
233
234fn build_response(buf: &crate::response_buffer::ResponseBuffer, head_only: bool) -> HyperResponse {
235    let mut builder = Response::builder()
236        .status(StatusCode::OK)
237        .header(header::CONTENT_TYPE, buf.content_type.as_ref())
238        .header(header::ETAG, buf.etag.as_ref())
239        .header(header::LAST_MODIFIED, buf.last_modified.as_ref())
240        .header(header::CACHE_CONTROL, buf.cache_control.as_ref())
241        // Pre-computed at route creation to avoid per-request integer→string alloc
242        .header(header::CONTENT_LENGTH, buf.content_length.as_ref())
243        .header(
244            header::X_CONTENT_TYPE_OPTIONS,
245            HeaderValue::from_static("nosniff"),
246        )
247        .header(
248            header::X_FRAME_OPTIONS,
249            HeaderValue::from_static("SAMEORIGIN"),
250        )
251        .header(
252            header::REFERRER_POLICY,
253            HeaderValue::from_static("strict-origin-when-cross-origin"),
254        )
255        .header(
256            header::STRICT_TRANSPORT_SECURITY,
257            HeaderValue::from_static("max-age=63072000; includeSubDomains"),
258        )
259        .header(
260            HeaderName::from_static("permissions-policy"),
261            HeaderValue::from_static("camera=(), microphone=(), geolocation=()"),
262        )
263        .header(
264            HeaderName::from_static("x-dns-prefetch-control"),
265            HeaderValue::from_static("off"),
266        );
267
268    if let Some(encoding) = buf.content_encoding {
269        builder = builder.header(header::CONTENT_ENCODING, HeaderValue::from_static(encoding));
270    }
271
272    // Vary: Accept-Encoding for all compressible content, not just compressed responses
273    if buf.vary_encoding {
274        builder = builder.header(header::VARY, HeaderValue::from_static("Accept-Encoding"));
275    }
276
277    let body = if head_only {
278        Bytes::new()
279    } else {
280        buf.body.clone()
281    };
282
283    builder.body(Full::new(body)).expect("response body")
284}