1use crate::config::RestConf;
2use crate::http::HandlerFunc;
3use crate::middleware::Middleware;
4use crate::router::{Route, Router};
5use anyhow::Context;
6use core::service::Mode;
7use hyper::Server as HyperServer;
8use hyper::service::{make_service_fn, service_fn};
9use std::net::SocketAddr;
10use std::sync::Arc;
11use tokio::net::{TcpListener, TcpSocket};
12use tokio::sync::oneshot;
13use tokio::task::JoinHandle;
14
15#[derive(Clone)]
17pub struct Server {
18 conf: RestConf,
19 routes: Vec<Route>,
20 debug_routes: Vec<(String, String, String)>,
22 debug_group: Option<String>,
23 middlewares: Vec<Middleware>,
24 not_found: Option<HandlerFunc>,
25 not_allowed: Option<HandlerFunc>,
26 prefix_chain: Vec<String>,
27}
28
29impl Server {
30 pub fn new(conf: RestConf) -> Self {
31 Self {
32 conf,
33 routes: Vec::new(),
34 debug_routes: Vec::new(),
35 debug_group: None,
36 middlewares: Vec::new(),
37 not_found: None,
38 not_allowed: None,
39 prefix_chain: Vec::new(),
40 }
41 }
42
43 pub fn conf(&self) -> &RestConf {
45 &self.conf
46 }
47
48 fn debug_print_routes(&self) {
50 let is_verbose = matches!(self.conf.service.mode, Mode::Dev | Mode::Test);
51 if !is_verbose {
52 return;
53 }
54 let mut rows: Vec<(String, String, String)> = self.debug_routes.clone();
56 rows.sort_by(|a, b| a.0.cmp(&b.0).then(a.2.cmp(&b.2)).then(a.1.cmp(&b.1)));
58 let mut rows_dedup = Vec::new();
59 let mut last: Option<(String, String, String)> = None;
60 for r in rows {
61 if last.as_ref() == Some(&r) {
62 continue;
63 }
64 rows_dedup.push(r.clone());
65 last = Some(r);
66 }
67 let mut group_w = "Group".len();
68 let mut method_w = "Method".len();
69 let mut path_w = "Path".len();
70 for (g, m, p) in &rows_dedup {
71 group_w = group_w.max(g.len());
72 method_w = method_w.max(m.len());
73 path_w = path_w.max(p.len());
74 }
75 let header = format!(
76 "{:<group_w$} | {:<method_w$} | {:<path_w$}",
77 "Group",
78 "Method",
79 "Path",
80 group_w = group_w,
81 method_w = method_w,
82 path_w = path_w
83 );
84 println!("Registered routes (mode {:?}):", self.conf.service.mode);
85 println!("{header}");
86 println!(
87 "{}-+-{}-+-{}",
88 "-".repeat(group_w),
89 "-".repeat(method_w),
90 "-".repeat(path_w)
91 );
92 for (g, m, p) in rows_dedup {
93 let label = if g == "-" { "-" } else { g.as_str() };
94 println!(
95 "{:<group_w$} | {:<method_w$} | {:<path_w$}",
96 label,
97 m,
98 p,
99 group_w = group_w,
100 method_w = method_w,
101 path_w = path_w
102 );
103 }
104 }
105
106 pub fn add_routes<I>(&mut self, routes: I) -> &mut Self
108 where
109 I: IntoIterator<Item = Route>,
110 {
111 let routes_vec: Vec<Route> = routes.into_iter().collect();
112 let group_prefix = self.current_group_label();
113 let routes = self.apply_prefixes(routes_vec);
114 for r in &routes {
115 self.debug_routes
116 .push((group_prefix.clone(), r.method.to_string(), r.path.clone()));
117 }
118 self.routes.extend(routes);
119 self.reset_to_root();
120 self
121 }
122
123 pub fn add_route(&mut self, route: Route) -> &mut Self {
125 let group_prefix = self.current_group_label();
126 let routes = self.apply_prefixes(vec![route]);
127 for r in &routes {
128 self.debug_routes
129 .push((group_prefix.clone(), r.method.to_string(), r.path.clone()));
130 }
131 self.routes.extend(routes);
132 self.reset_to_root();
133 self
134 }
135
136 fn current_group_label(&self) -> String {
137 if let Some(g) = &self.debug_group {
138 return g.clone();
139 }
140 if self.prefix_chain.is_empty() {
141 return "-".to_string();
142 }
143 let mut joined = self.prefix_chain.join("");
144 if !joined.starts_with('/') {
145 joined.insert(0, '/');
146 }
147 while joined.contains("//") {
148 joined = joined.replace("//", "/");
149 }
150 if joined == "/" {
151 "-".to_string()
152 } else {
153 joined
154 }
155 }
156
157 fn reset_to_root(&mut self) {
158 if self.prefix_chain.is_empty() {
159 return;
160 }
161 let root = self.prefix_chain.first().cloned();
162 self.prefix_chain.clear();
163 if let Some(r) = root {
164 self.prefix_chain.push(r);
165 }
166 }
167
168 pub fn set_debug_group(&mut self, name: impl Into<String>) {
169 self.debug_group = Some(name.into());
170 }
171
172 pub fn clear_debug_group(&mut self) {
173 self.debug_group = None;
174 }
175
176 pub fn use_middleware(&mut self, middleware: Middleware) -> &mut Self {
178 self.middlewares.push(middleware);
179 self
180 }
181
182 pub fn with_middlewares<I>(&mut self, middlewares: I) -> &mut Self
184 where
185 I: IntoIterator<Item = Middleware>,
186 {
187 self.middlewares.extend(middlewares);
188 self
189 }
190
191 pub fn with_middleware(&mut self, middleware: Middleware) -> &mut Self {
193 self.with_middlewares(std::iter::once(middleware))
194 }
195
196 pub fn set_not_found_handler(&mut self, handler: HandlerFunc) {
198 self.not_found = Some(handler);
199 }
200
201 pub fn set_not_allowed_handler(&mut self, handler: HandlerFunc) {
203 self.not_allowed = Some(handler);
204 }
205
206 pub fn with_root_prefix(&mut self, prefix: impl Into<String>) -> &mut Self {
208 self.prefix_chain.clear();
209 self.prefix_chain.push(prefix.into());
210 self
211 }
212
213 pub fn with_prefix(&mut self, prefix: impl Into<String>) -> &mut Self {
215 self.prefix_chain.push(prefix.into());
216 self
217 }
218
219 pub async fn start(self) -> anyhow::Result<ServerHandle> {
222 self.debug_print_routes();
223 let listen_addr: SocketAddr = self
224 .conf
225 .addr_string()
226 .parse()
227 .context("parse listen addr")?;
228 let router = Arc::new(self.build_router()?);
229
230 if self.conf.reuse_port {
231 start_with_reuse_port(listen_addr, router, &self.conf).await
232 } else {
233 let socket = TcpSocket::new_v4()?;
234 socket.set_reuseaddr(true)?;
235 if self.conf.tcp_keepalive_secs.is_some() {
236 socket.set_keepalive(true)?;
237 }
238 socket.bind(listen_addr)?;
239 let listener = socket.listen(1024)?;
240 start_single(listener, router, &self.conf).await
241 }
242 }
243
244 fn build_router(&self) -> anyhow::Result<Router> {
245 let mut router = Router::new();
246 if let Some(h) = &self.not_found {
247 router.set_not_found_handler(h.clone());
248 }
249 if let Some(h) = &self.not_allowed {
250 router.set_not_allowed_handler(h.clone());
251 }
252 let mut auto_mws: Vec<Middleware> = Vec::new();
254 if self.conf.middlewares.max_bytes {
255 auto_mws.push(crate::middleware::max_bytes(self.conf.max_bytes as u64));
256 }
257 if let Some(rl) = &self.conf.rate_limit {
258 auto_mws.push(crate::middleware::rate_limit(
259 rl.permits_per_second,
260 rl.burst,
261 ));
262 }
263 if let Some(c) = self.conf.concurrency_limit {
264 auto_mws.push(crate::middleware::concurrency_limit(c));
265 }
266 if let Some(ms) = self.conf.timeout {
267 auto_mws.push(crate::middleware::timeout(
268 std::time::Duration::from_millis(ms),
269 ));
270 }
271 if self.conf.middlewares.gzip {
272 auto_mws.push(crate::middleware::gzip());
273 }
274
275 for route in &self.routes {
276 let mut mws = auto_mws.clone();
277 mws.extend(self.middlewares.clone());
278 let route = route.clone().with_middlewares(&mws);
279 router.add_route(route)?;
280 }
281 Ok(router)
282 }
283
284 fn apply_prefixes(&self, routes: Vec<Route>) -> Vec<Route> {
285 let mut acc = routes;
287 for p in self.prefix_chain.iter().rev() {
288 acc = crate::with_prefix(p, acc);
289 }
290 acc
291 }
292}
293
294pub struct ServerHandle {
296 addr: SocketAddr,
297 shutdowns: Vec<oneshot::Sender<()>>,
298 joins: Vec<JoinHandle<anyhow::Result<()>>>,
299}
300
301impl ServerHandle {
302 pub fn addr(&self) -> SocketAddr {
303 self.addr
304 }
305
306 pub async fn stop(mut self) -> anyhow::Result<()> {
308 let shutdowns = std::mem::take(&mut self.shutdowns);
309 for tx in shutdowns {
310 let _ = tx.send(());
311 }
312 let joins = std::mem::take(&mut self.joins);
313 for j in joins {
314 j.await.context("join server task")?.context("server run")?;
315 }
316 Ok(())
317 }
318}
319
320async fn start_single(
321 listener: TcpListener,
322 router: Arc<Router>,
323 conf: &RestConf,
324) -> anyhow::Result<ServerHandle> {
325 let local_addr = listener.local_addr().context("get local addr")?;
326 let (shutdown_tx, shutdown) = oneshot::channel::<()>();
327 let mut builder = HyperServer::from_tcp(listener.into_std()?)?;
328 if conf.http2 {
329 builder = builder.http2_only(true);
330 } else {
331 builder = builder.http1_only(true);
332 builder = builder.http1_keepalive(conf.http1_keep_alive);
333 if let Some(sz) = conf.http1_max_buf_size {
334 builder = builder.http1_max_buf_size(sz);
335 }
336 }
337
338 let svc = make_service_fn(move |_conn| {
339 let router = router.clone();
340 async move {
341 Ok::<_, std::convert::Infallible>(service_fn(move |req| {
342 let router = router.clone();
343 async move { Ok::<_, std::convert::Infallible>(router.dispatch(req).await) }
344 }))
345 }
346 });
347
348 let server = builder.serve(svc).with_graceful_shutdown(async move {
349 let _ = shutdown.await;
350 });
351
352 let join: JoinHandle<anyhow::Result<()>> = tokio::spawn(async move {
353 server
354 .await
355 .map_err(|e| anyhow::anyhow!("hyper server error: {e}"))
356 });
357
358 Ok(ServerHandle {
359 addr: local_addr,
360 shutdowns: vec![shutdown_tx],
361 joins: vec![join],
362 })
363}
364
365async fn start_with_reuse_port(
366 addr: SocketAddr,
367 router: Arc<Router>,
368 conf: &RestConf,
369) -> anyhow::Result<ServerHandle> {
370 let workers = conf.workers.unwrap_or_else(|| {
371 std::thread::available_parallelism()
372 .map(|n| n.get())
373 .unwrap_or(1)
374 });
375 let mut joins = Vec::with_capacity(workers);
376 let mut shutdowns = Vec::with_capacity(workers);
377 let mut bound_addr = None;
378
379 for _ in 0..workers {
380 let socket = TcpSocket::new_v4()?;
381 socket.set_reuseaddr(true)?;
382 if conf.tcp_keepalive_secs.is_some() {
383 socket.set_keepalive(true)?;
384 }
385 #[cfg(any(
386 target_os = "linux",
387 target_os = "android",
388 target_os = "macos",
389 target_os = "ios",
390 target_os = "freebsd",
391 target_os = "dragonfly",
392 target_os = "netbsd",
393 target_os = "openbsd"
394 ))]
395 socket.set_reuseport(true)?;
396
397 socket.bind(addr)?;
398 let listener = socket.listen(1024)?;
399 let local = listener.local_addr().context("get local addr")?;
400 if bound_addr.is_none() {
401 bound_addr = Some(local);
402 }
403
404 let router_clone = router.clone();
405 let (tx, shutdown) = oneshot::channel::<()>();
406 let mut builder = HyperServer::from_tcp(listener.into_std()?)?;
407 if conf.http2 {
408 builder = builder.http2_only(true);
409 } else {
410 builder = builder.http1_only(true);
411 builder = builder.http1_keepalive(conf.http1_keep_alive);
412 if let Some(sz) = conf.http1_max_buf_size {
413 builder = builder.http1_max_buf_size(sz);
414 }
415 }
416 let server = builder
417 .serve(make_service_fn(move |_conn| {
418 let router = router_clone.clone();
419 async move {
420 Ok::<_, std::convert::Infallible>(service_fn(move |req| {
421 let router = router.clone();
422 async move { Ok::<_, std::convert::Infallible>(router.dispatch(req).await) }
423 }))
424 }
425 }))
426 .with_graceful_shutdown(async move {
427 let _ = shutdown.await;
428 });
429
430 let join: JoinHandle<anyhow::Result<()>> = tokio::spawn(async move {
431 server
432 .await
433 .map_err(|e| anyhow::anyhow!("hyper server error: {e}"))
434 });
435 joins.push(join);
436 shutdowns.push(tx);
437 }
438
439 Ok(ServerHandle {
440 addr: bound_addr.unwrap_or(addr),
441 shutdowns,
442 joins,
443 })
444}
445
446#[cfg(test)]
447mod tests {
448 use super::*;
449 use http::{Method, StatusCode};
450 use hyper::body::to_bytes;
451 use hyper::{Body, Client};
452 use tokio::runtime::Runtime;
453
454 fn runtime() -> Runtime {
455 Runtime::new().unwrap()
456 }
457
458 fn ok_route(path: &str) -> Route {
459 Route::new(Method::GET, path, |_: http::Request<Body>| async {
460 http::Response::builder()
461 .status(StatusCode::OK)
462 .body(Body::from("ok"))
463 .unwrap()
464 })
465 }
466
467 #[test]
468 fn add_routes_should_store() {
469 runtime().block_on(async {
470 let mut server = Server::new(RestConf::default());
471 server.add_route(ok_route("/hello"));
472 assert_eq!(server.routes.len(), 1);
473 });
474 }
475
476 #[test]
477 fn start_should_serve_requests() {
478 runtime().block_on(async {
479 let mut conf = RestConf::default();
480 let probe = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
482 conf.host = "127.0.0.1".to_string();
483 conf.port = probe.local_addr().unwrap().port();
484 drop(probe);
485
486 let mut server = Server::new(conf);
487 server.add_route(ok_route("/ping"));
488 let handle = server.start().await.unwrap();
489 let client = Client::new();
490 let uri = format!("http://{}{}", handle.addr(), "/ping")
491 .parse()
492 .unwrap();
493 let resp = client.get(uri).await.unwrap();
494 assert_eq!(resp.status(), StatusCode::OK);
495 let body = to_bytes(resp.into_body()).await.unwrap();
496 assert_eq!(&body[..], b"ok");
497
498 handle.stop().await.unwrap();
499 });
500 }
501
502 #[test]
503 fn demo_service_with_middleware_and_prefix() {
504 runtime().block_on(async {
505 let mut conf = RestConf::default();
506 let probe = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
507 conf.host = "127.0.0.1".to_string();
508 conf.port = probe.local_addr().unwrap().port();
509 drop(probe);
510
511 let routes = vec![Route::new(
512 Method::GET,
513 "/hello",
514 |_: http::Request<Body>| async {
515 http::Response::builder()
516 .status(StatusCode::OK)
517 .body(Body::from("hi"))
518 .unwrap()
519 },
520 )];
521
522 let mw = crate::middleware(|req, next| async move {
524 let mut resp = next.call(req).await;
525 resp.headers_mut()
526 .insert("X-Demo", http::HeaderValue::from_static("1"));
527 resp
528 });
529
530 let mut server = Server::new(conf);
531 server.with_root_prefix("/api").with_prefix("/session");
532 server.use_middleware(mw);
533 server.add_routes(routes);
534 let handle = server.start().await.unwrap();
535
536 let client = Client::new();
537 let uri = format!("http://{}{}", handle.addr(), "/api/session/hello")
538 .parse()
539 .unwrap();
540 let resp = client.get(uri).await.unwrap();
541 assert_eq!(resp.status(), StatusCode::OK);
542 assert_eq!(resp.headers().get("X-Demo").unwrap(), "1");
543
544 handle.stop().await.unwrap();
545 });
546 }
547}