rest/
server.rs

1use crate::config::RestConf;
2use crate::http::HandlerFunc;
3use crate::middleware::Middleware;
4use crate::router::{Route, Router};
5use anyhow::Context;
6use core::service::Mode;
7use hyper::Server as HyperServer;
8use hyper::service::{make_service_fn, service_fn};
9use std::net::SocketAddr;
10use std::sync::Arc;
11use tokio::net::{TcpListener, TcpSocket};
12use tokio::sync::oneshot;
13use tokio::task::JoinHandle;
14
15/// Lightweight web server built on hyper.
16#[derive(Clone)]
17pub struct Server {
18    conf: RestConf,
19    routes: Vec<Route>,
20    // For debug printing (mode dev/test): (group_prefix, method, path)
21    debug_routes: Vec<(String, String, String)>,
22    debug_group: Option<String>,
23    middlewares: Vec<Middleware>,
24    not_found: Option<HandlerFunc>,
25    not_allowed: Option<HandlerFunc>,
26    prefix_chain: Vec<String>,
27}
28
29impl Server {
30    pub fn new(conf: RestConf) -> Self {
31        Self {
32            conf,
33            routes: Vec::new(),
34            debug_routes: Vec::new(),
35            debug_group: None,
36            middlewares: Vec::new(),
37            not_found: None,
38            not_allowed: None,
39            prefix_chain: Vec::new(),
40        }
41    }
42
43    /// Get server config (read-only).
44    pub fn conf(&self) -> &RestConf {
45        &self.conf
46    }
47
48    /// Debug print all registered routes (after prefixes applied) when mode is dev/test.
49    fn debug_print_routes(&self) {
50        let is_verbose = matches!(self.conf.service.mode, Mode::Dev | Mode::Test);
51        if !is_verbose {
52            return;
53        }
54        // Flatten to rows, compute column widths, print as table.
55        let mut rows: Vec<(String, String, String)> = self.debug_routes.clone();
56        // Sort by group, then path, then method for stable output.
57        rows.sort_by(|a, b| a.0.cmp(&b.0).then(a.2.cmp(&b.2)).then(a.1.cmp(&b.1)));
58        let mut rows_dedup = Vec::new();
59        let mut last: Option<(String, String, String)> = None;
60        for r in rows {
61            if last.as_ref() == Some(&r) {
62                continue;
63            }
64            rows_dedup.push(r.clone());
65            last = Some(r);
66        }
67        let mut group_w = "Group".len();
68        let mut method_w = "Method".len();
69        let mut path_w = "Path".len();
70        for (g, m, p) in &rows_dedup {
71            group_w = group_w.max(g.len());
72            method_w = method_w.max(m.len());
73            path_w = path_w.max(p.len());
74        }
75        let header = format!(
76            "{:<group_w$} | {:<method_w$} | {:<path_w$}",
77            "Group",
78            "Method",
79            "Path",
80            group_w = group_w,
81            method_w = method_w,
82            path_w = path_w
83        );
84        println!("Registered routes (mode {:?}):", self.conf.service.mode);
85        println!("{header}");
86        println!(
87            "{}-+-{}-+-{}",
88            "-".repeat(group_w),
89            "-".repeat(method_w),
90            "-".repeat(path_w)
91        );
92        for (g, m, p) in rows_dedup {
93            let label = if g == "-" { "-" } else { g.as_str() };
94            println!(
95                "{:<group_w$} | {:<method_w$} | {:<path_w$}",
96                label,
97                m,
98                p,
99                group_w = group_w,
100                method_w = method_w,
101                path_w = path_w
102            );
103        }
104    }
105
106    /// Register routes (in-place, chain-friendly).
107    pub fn add_routes<I>(&mut self, routes: I) -> &mut Self
108    where
109        I: IntoIterator<Item = Route>,
110    {
111        let routes_vec: Vec<Route> = routes.into_iter().collect();
112        let group_prefix = self.current_group_label();
113        let routes = self.apply_prefixes(routes_vec);
114        for r in &routes {
115            self.debug_routes
116                .push((group_prefix.clone(), r.method.to_string(), r.path.clone()));
117        }
118        self.routes.extend(routes);
119        self.reset_to_root();
120        self
121    }
122
123    /// Register a single route (in-place).
124    pub fn add_route(&mut self, route: Route) -> &mut Self {
125        let group_prefix = self.current_group_label();
126        let routes = self.apply_prefixes(vec![route]);
127        for r in &routes {
128            self.debug_routes
129                .push((group_prefix.clone(), r.method.to_string(), r.path.clone()));
130        }
131        self.routes.extend(routes);
132        self.reset_to_root();
133        self
134    }
135
136    fn current_group_label(&self) -> String {
137        if let Some(g) = &self.debug_group {
138            return g.clone();
139        }
140        if self.prefix_chain.is_empty() {
141            return "-".to_string();
142        }
143        let mut joined = self.prefix_chain.join("");
144        if !joined.starts_with('/') {
145            joined.insert(0, '/');
146        }
147        while joined.contains("//") {
148            joined = joined.replace("//", "/");
149        }
150        if joined == "/" {
151            "-".to_string()
152        } else {
153            joined
154        }
155    }
156
157    fn reset_to_root(&mut self) {
158        if self.prefix_chain.is_empty() {
159            return;
160        }
161        let root = self.prefix_chain.first().cloned();
162        self.prefix_chain.clear();
163        if let Some(r) = root {
164            self.prefix_chain.push(r);
165        }
166    }
167
168    pub fn set_debug_group(&mut self, name: impl Into<String>) {
169        self.debug_group = Some(name.into());
170    }
171
172    pub fn clear_debug_group(&mut self) {
173        self.debug_group = None;
174    }
175
176    /// Register global middleware applied to all routes.
177    pub fn use_middleware(&mut self, middleware: Middleware) -> &mut Self {
178        self.middlewares.push(middleware);
179        self
180    }
181
182    /// Chain-friendly append for global middlewares.
183    pub fn with_middlewares<I>(&mut self, middlewares: I) -> &mut Self
184    where
185        I: IntoIterator<Item = Middleware>,
186    {
187        self.middlewares.extend(middlewares);
188        self
189    }
190
191    /// Chain-friendly append for a single global middleware.
192    pub fn with_middleware(&mut self, middleware: Middleware) -> &mut Self {
193        self.with_middlewares(std::iter::once(middleware))
194    }
195
196    /// Set custom 404 handler.
197    pub fn set_not_found_handler(&mut self, handler: HandlerFunc) {
198        self.not_found = Some(handler);
199    }
200
201    /// Set custom 405 handler.
202    pub fn set_not_allowed_handler(&mut self, handler: HandlerFunc) {
203        self.not_allowed = Some(handler);
204    }
205
206    /// Set root prefix (replaces previous prefix chain).
207    pub fn with_root_prefix(&mut self, prefix: impl Into<String>) -> &mut Self {
208        self.prefix_chain.clear();
209        self.prefix_chain.push(prefix.into());
210        self
211    }
212
213    /// Append prefix (can be chained).
214    pub fn with_prefix(&mut self, prefix: impl Into<String>) -> &mut Self {
215        self.prefix_chain.push(prefix.into());
216        self
217    }
218
219    /// Deprecated aliases removed to keep pure Rust naming.
220    /// Start HTTP server and return controllable handle.
221    pub async fn start(self) -> anyhow::Result<ServerHandle> {
222        self.debug_print_routes();
223        let listen_addr: SocketAddr = self
224            .conf
225            .addr_string()
226            .parse()
227            .context("parse listen addr")?;
228        let router = Arc::new(self.build_router()?);
229
230        if self.conf.reuse_port {
231            start_with_reuse_port(listen_addr, router, &self.conf).await
232        } else {
233            let socket = TcpSocket::new_v4()?;
234            socket.set_reuseaddr(true)?;
235            if self.conf.tcp_keepalive_secs.is_some() {
236                socket.set_keepalive(true)?;
237            }
238            socket.bind(listen_addr)?;
239            let listener = socket.listen(1024)?;
240            start_single(listener, router, &self.conf).await
241        }
242    }
243
244    fn build_router(&self) -> anyhow::Result<Router> {
245        let mut router = Router::new();
246        if let Some(h) = &self.not_found {
247            router.set_not_found_handler(h.clone());
248        }
249        if let Some(h) = &self.not_allowed {
250            router.set_not_allowed_handler(h.clone());
251        }
252        // Auto middlewares: max_bytes -> rate/concurrency/timeout -> gzip (decode/encode) -> user
253        let mut auto_mws: Vec<Middleware> = Vec::new();
254        if self.conf.middlewares.max_bytes {
255            auto_mws.push(crate::middleware::max_bytes(self.conf.max_bytes as u64));
256        }
257        if let Some(rl) = &self.conf.rate_limit {
258            auto_mws.push(crate::middleware::rate_limit(
259                rl.permits_per_second,
260                rl.burst,
261            ));
262        }
263        if let Some(c) = self.conf.concurrency_limit {
264            auto_mws.push(crate::middleware::concurrency_limit(c));
265        }
266        if let Some(ms) = self.conf.timeout {
267            auto_mws.push(crate::middleware::timeout(
268                std::time::Duration::from_millis(ms),
269            ));
270        }
271        if self.conf.middlewares.gzip {
272            auto_mws.push(crate::middleware::gzip());
273        }
274
275        for route in &self.routes {
276            let mut mws = auto_mws.clone();
277            mws.extend(self.middlewares.clone());
278            let route = route.clone().with_middlewares(&mws);
279            router.add_route(route)?;
280        }
281        Ok(router)
282    }
283
284    fn apply_prefixes(&self, routes: Vec<Route>) -> Vec<Route> {
285        // Add shorter suffixes first, then prepend root prefix to ensure `/api` + `/v1` + `/hello`.
286        let mut acc = routes;
287        for p in self.prefix_chain.iter().rev() {
288            acc = crate::with_prefix(p, acc);
289        }
290        acc
291    }
292}
293
294/// Handle to control server lifecycle (graceful stop).
295pub struct ServerHandle {
296    addr: SocketAddr,
297    shutdowns: Vec<oneshot::Sender<()>>,
298    joins: Vec<JoinHandle<anyhow::Result<()>>>,
299}
300
301impl ServerHandle {
302    pub fn addr(&self) -> SocketAddr {
303        self.addr
304    }
305
306    /// Send shutdown signal and wait for graceful exit.
307    pub async fn stop(mut self) -> anyhow::Result<()> {
308        let shutdowns = std::mem::take(&mut self.shutdowns);
309        for tx in shutdowns {
310            let _ = tx.send(());
311        }
312        let joins = std::mem::take(&mut self.joins);
313        for j in joins {
314            j.await.context("join server task")?.context("server run")?;
315        }
316        Ok(())
317    }
318}
319
320async fn start_single(
321    listener: TcpListener,
322    router: Arc<Router>,
323    conf: &RestConf,
324) -> anyhow::Result<ServerHandle> {
325    let local_addr = listener.local_addr().context("get local addr")?;
326    let (shutdown_tx, shutdown) = oneshot::channel::<()>();
327    let mut builder = HyperServer::from_tcp(listener.into_std()?)?;
328    if conf.http2 {
329        builder = builder.http2_only(true);
330    } else {
331        builder = builder.http1_only(true);
332        builder = builder.http1_keepalive(conf.http1_keep_alive);
333        if let Some(sz) = conf.http1_max_buf_size {
334            builder = builder.http1_max_buf_size(sz);
335        }
336    }
337
338    let svc = make_service_fn(move |_conn| {
339        let router = router.clone();
340        async move {
341            Ok::<_, std::convert::Infallible>(service_fn(move |req| {
342                let router = router.clone();
343                async move { Ok::<_, std::convert::Infallible>(router.dispatch(req).await) }
344            }))
345        }
346    });
347
348    let server = builder.serve(svc).with_graceful_shutdown(async move {
349        let _ = shutdown.await;
350    });
351
352    let join: JoinHandle<anyhow::Result<()>> = tokio::spawn(async move {
353        server
354            .await
355            .map_err(|e| anyhow::anyhow!("hyper server error: {e}"))
356    });
357
358    Ok(ServerHandle {
359        addr: local_addr,
360        shutdowns: vec![shutdown_tx],
361        joins: vec![join],
362    })
363}
364
365async fn start_with_reuse_port(
366    addr: SocketAddr,
367    router: Arc<Router>,
368    conf: &RestConf,
369) -> anyhow::Result<ServerHandle> {
370    let workers = conf.workers.unwrap_or_else(|| {
371        std::thread::available_parallelism()
372            .map(|n| n.get())
373            .unwrap_or(1)
374    });
375    let mut joins = Vec::with_capacity(workers);
376    let mut shutdowns = Vec::with_capacity(workers);
377    let mut bound_addr = None;
378
379    for _ in 0..workers {
380        let socket = TcpSocket::new_v4()?;
381        socket.set_reuseaddr(true)?;
382        if conf.tcp_keepalive_secs.is_some() {
383            socket.set_keepalive(true)?;
384        }
385        #[cfg(any(
386            target_os = "linux",
387            target_os = "android",
388            target_os = "macos",
389            target_os = "ios",
390            target_os = "freebsd",
391            target_os = "dragonfly",
392            target_os = "netbsd",
393            target_os = "openbsd"
394        ))]
395        socket.set_reuseport(true)?;
396
397        socket.bind(addr)?;
398        let listener = socket.listen(1024)?;
399        let local = listener.local_addr().context("get local addr")?;
400        if bound_addr.is_none() {
401            bound_addr = Some(local);
402        }
403
404        let router_clone = router.clone();
405        let (tx, shutdown) = oneshot::channel::<()>();
406        let mut builder = HyperServer::from_tcp(listener.into_std()?)?;
407        if conf.http2 {
408            builder = builder.http2_only(true);
409        } else {
410            builder = builder.http1_only(true);
411            builder = builder.http1_keepalive(conf.http1_keep_alive);
412            if let Some(sz) = conf.http1_max_buf_size {
413                builder = builder.http1_max_buf_size(sz);
414            }
415        }
416        let server = builder
417            .serve(make_service_fn(move |_conn| {
418                let router = router_clone.clone();
419                async move {
420                    Ok::<_, std::convert::Infallible>(service_fn(move |req| {
421                        let router = router.clone();
422                        async move { Ok::<_, std::convert::Infallible>(router.dispatch(req).await) }
423                    }))
424                }
425            }))
426            .with_graceful_shutdown(async move {
427                let _ = shutdown.await;
428            });
429
430        let join: JoinHandle<anyhow::Result<()>> = tokio::spawn(async move {
431            server
432                .await
433                .map_err(|e| anyhow::anyhow!("hyper server error: {e}"))
434        });
435        joins.push(join);
436        shutdowns.push(tx);
437    }
438
439    Ok(ServerHandle {
440        addr: bound_addr.unwrap_or(addr),
441        shutdowns,
442        joins,
443    })
444}
445
446#[cfg(test)]
447mod tests {
448    use super::*;
449    use http::{Method, StatusCode};
450    use hyper::body::to_bytes;
451    use hyper::{Body, Client};
452    use tokio::runtime::Runtime;
453
454    fn runtime() -> Runtime {
455        Runtime::new().unwrap()
456    }
457
458    fn ok_route(path: &str) -> Route {
459        Route::new(Method::GET, path, |_: http::Request<Body>| async {
460            http::Response::builder()
461                .status(StatusCode::OK)
462                .body(Body::from("ok"))
463                .unwrap()
464        })
465    }
466
467    #[test]
468    fn add_routes_should_store() {
469        runtime().block_on(async {
470            let mut server = Server::new(RestConf::default());
471            server.add_route(ok_route("/hello"));
472            assert_eq!(server.routes.len(), 1);
473        });
474    }
475
476    #[test]
477    fn start_should_serve_requests() {
478        runtime().block_on(async {
479            let mut conf = RestConf::default();
480            // pick free port
481            let probe = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
482            conf.host = "127.0.0.1".to_string();
483            conf.port = probe.local_addr().unwrap().port();
484            drop(probe);
485
486            let mut server = Server::new(conf);
487            server.add_route(ok_route("/ping"));
488            let handle = server.start().await.unwrap();
489            let client = Client::new();
490            let uri = format!("http://{}{}", handle.addr(), "/ping")
491                .parse()
492                .unwrap();
493            let resp = client.get(uri).await.unwrap();
494            assert_eq!(resp.status(), StatusCode::OK);
495            let body = to_bytes(resp.into_body()).await.unwrap();
496            assert_eq!(&body[..], b"ok");
497
498            handle.stop().await.unwrap();
499        });
500    }
501
502    #[test]
503    fn demo_service_with_middleware_and_prefix() {
504        runtime().block_on(async {
505            let mut conf = RestConf::default();
506            let probe = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
507            conf.host = "127.0.0.1".to_string();
508            conf.port = probe.local_addr().unwrap().port();
509            drop(probe);
510
511            let routes = vec![Route::new(
512                Method::GET,
513                "/hello",
514                |_: http::Request<Body>| async {
515                    http::Response::builder()
516                        .status(StatusCode::OK)
517                        .body(Body::from("hi"))
518                        .unwrap()
519                },
520            )];
521
522            // Global middleware: add response header
523            let mw = crate::middleware(|req, next| async move {
524                let mut resp = next.call(req).await;
525                resp.headers_mut()
526                    .insert("X-Demo", http::HeaderValue::from_static("1"));
527                resp
528            });
529
530            let mut server = Server::new(conf);
531            server.with_root_prefix("/api").with_prefix("/session");
532            server.use_middleware(mw);
533            server.add_routes(routes);
534            let handle = server.start().await.unwrap();
535
536            let client = Client::new();
537            let uri = format!("http://{}{}", handle.addr(), "/api/session/hello")
538                .parse()
539                .unwrap();
540            let resp = client.get(uri).await.unwrap();
541            assert_eq!(resp.status(), StatusCode::OK);
542            assert_eq!(resp.headers().get("X-Demo").unwrap(), "1");
543
544            handle.stop().await.unwrap();
545        });
546    }
547}