1use crate::config::{Config, ServerConfig};
2use crate::http::{HttpResponse, Request};
3use crate::inertia::InertiaContext;
4use crate::middleware::{Middleware, MiddlewareChain, MiddlewareRegistry};
5use crate::routing::Router;
6use bytes::Bytes;
7use http_body_util::Full;
8use hyper::server::conn::http1;
9use hyper::service::service_fn;
10use hyper_util::rt::TokioIo;
11use std::convert::Infallible;
12use std::net::SocketAddr;
13use std::sync::Arc;
14use tokio::net::TcpListener;
15
16pub struct Server {
17 router: Arc<Router>,
18 middleware: MiddlewareRegistry,
19 host: String,
20 port: u16,
21}
22
23impl Server {
24 pub fn new(router: impl Into<Router>) -> Self {
25 Self {
26 router: Arc::new(router.into()),
27 middleware: MiddlewareRegistry::new(),
28 host: "127.0.0.1".to_string(),
29 port: 8000,
30 }
31 }
32
33 pub fn from_config(router: impl Into<Router>) -> Self {
34 let config = Config::get::<ServerConfig>().unwrap_or_else(ServerConfig::from_env);
35 Self {
36 router: Arc::new(router.into()),
37 middleware: MiddlewareRegistry::new(),
38 host: config.host,
39 port: config.port,
40 }
41 }
42
43 pub fn middleware<M: Middleware + 'static>(mut self, middleware: M) -> Self {
57 self.middleware = self.middleware.append(middleware);
58 self
59 }
60
61 pub fn host(mut self, host: &str) -> Self {
62 self.host = host.to_string();
63 self
64 }
65
66 pub fn port(mut self, port: u16) -> Self {
67 self.port = port;
68 self
69 }
70
71 fn get_addr(&self) -> SocketAddr {
72 SocketAddr::new(self.host.parse().unwrap(), self.port)
73 }
74
75 pub async fn run(self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
76 let addr: SocketAddr = self.get_addr();
77 let listener = TcpListener::bind(addr).await?;
78
79 println!("Kit server running on http://{}", addr);
80
81 let router = self.router;
82 let middleware = Arc::new(self.middleware);
83
84 loop {
85 let (stream, _) = listener.accept().await?;
86 let io = TokioIo::new(stream);
87 let router = router.clone();
88 let middleware = middleware.clone();
89
90 tokio::spawn(async move {
91 let service = service_fn(move |req: hyper::Request<hyper::body::Incoming>| {
92 let router = router.clone();
93 let middleware = middleware.clone();
94 async move { Ok::<_, Infallible>(handle_request(router, middleware, req).await) }
95 });
96
97 if let Err(err) = http1::Builder::new().serve_connection(io, service).await {
98 eprintln!("Error serving connection: {:?}", err);
99 }
100 });
101 }
102 }
103}
104
105async fn handle_request(
106 router: Arc<Router>,
107 middleware_registry: Arc<MiddlewareRegistry>,
108 req: hyper::Request<hyper::body::Incoming>,
109) -> hyper::Response<Full<Bytes>> {
110 let method = req.method().clone();
111 let path = req.uri().path().to_string();
112
113 let is_inertia = req
115 .headers()
116 .get("X-Inertia")
117 .and_then(|v| v.to_str().ok())
118 .map(|v| v == "true")
119 .unwrap_or(false);
120
121 let inertia_version = req
122 .headers()
123 .get("X-Inertia-Version")
124 .and_then(|v| v.to_str().ok())
125 .map(|v| v.to_string());
126
127 InertiaContext::set(InertiaContext {
128 path: path.clone(),
129 is_inertia,
130 version: inertia_version,
131 });
132
133 let response = match router.match_route(&method, &path) {
134 Some((handler, params)) => {
135 let request = Request::new(req).with_params(params);
136
137 let mut chain = MiddlewareChain::new();
139
140 chain.extend(middleware_registry.global_middleware().iter().cloned());
142
143 let route_middleware = router.get_route_middleware(&path);
145 chain.extend(route_middleware);
146
147 let response = chain.execute(request, handler).await;
149
150 let http_response = response.unwrap_or_else(|e| e);
152 http_response.into_hyper()
153 }
154 None => {
155 HttpResponse::text("404 Not Found")
156 .status(404)
157 .into_hyper()
158 }
159 };
160
161 InertiaContext::clear();
163
164 response
165}