kit_rs/
server.rs

1use crate::config::{Config, ServerConfig};
2use crate::http::{HttpResponse, Request};
3use crate::inertia::InertiaContext;
4use crate::middleware::{Middleware, MiddlewareChain, MiddlewareRegistry};
5use crate::routing::Router;
6use bytes::Bytes;
7use http_body_util::Full;
8use hyper::server::conn::http1;
9use hyper::service::service_fn;
10use hyper_util::rt::TokioIo;
11use std::convert::Infallible;
12use std::net::SocketAddr;
13use std::sync::Arc;
14use tokio::net::TcpListener;
15
16pub struct Server {
17    router: Arc<Router>,
18    middleware: MiddlewareRegistry,
19    host: String,
20    port: u16,
21}
22
23impl Server {
24    pub fn new(router: impl Into<Router>) -> Self {
25        Self {
26            router: Arc::new(router.into()),
27            middleware: MiddlewareRegistry::new(),
28            host: "127.0.0.1".to_string(),
29            port: 8000,
30        }
31    }
32
33    pub fn from_config(router: impl Into<Router>) -> Self {
34        let config = Config::get::<ServerConfig>().unwrap_or_else(ServerConfig::from_env);
35        Self {
36            router: Arc::new(router.into()),
37            middleware: MiddlewareRegistry::new(),
38            host: config.host,
39            port: config.port,
40        }
41    }
42
43    /// Add global middleware (runs on every request)
44    ///
45    /// For route-specific middleware, use `.middleware(M)` on the route itself.
46    ///
47    /// # Example
48    ///
49    /// ```rust,ignore
50    /// Server::from_config(router)
51    ///     .middleware(LoggingMiddleware)  // Global
52    ///     .middleware(CorsMiddleware)     // Global
53    ///     .run()
54    ///     .await;
55    /// ```
56    pub fn middleware<M: Middleware + 'static>(mut self, middleware: M) -> Self {
57        self.middleware = self.middleware.append(middleware);
58        self
59    }
60
61    pub fn host(mut self, host: &str) -> Self {
62        self.host = host.to_string();
63        self
64    }
65
66    pub fn port(mut self, port: u16) -> Self {
67        self.port = port;
68        self
69    }
70
71    fn get_addr(&self) -> SocketAddr {
72        SocketAddr::new(self.host.parse().unwrap(), self.port)
73    }
74
75    pub async fn run(self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
76        let addr: SocketAddr = self.get_addr();
77        let listener = TcpListener::bind(addr).await?;
78
79        println!("Kit server running on http://{}", addr);
80
81        let router = self.router;
82        let middleware = Arc::new(self.middleware);
83
84        loop {
85            let (stream, _) = listener.accept().await?;
86            let io = TokioIo::new(stream);
87            let router = router.clone();
88            let middleware = middleware.clone();
89
90            tokio::spawn(async move {
91                let service = service_fn(move |req: hyper::Request<hyper::body::Incoming>| {
92                    let router = router.clone();
93                    let middleware = middleware.clone();
94                    async move { Ok::<_, Infallible>(handle_request(router, middleware, req).await) }
95                });
96
97                if let Err(err) = http1::Builder::new().serve_connection(io, service).await {
98                    eprintln!("Error serving connection: {:?}", err);
99                }
100            });
101        }
102    }
103}
104
105async fn handle_request(
106    router: Arc<Router>,
107    middleware_registry: Arc<MiddlewareRegistry>,
108    req: hyper::Request<hyper::body::Incoming>,
109) -> hyper::Response<Full<Bytes>> {
110    let method = req.method().clone();
111    let path = req.uri().path().to_string();
112
113    // Set up Inertia context from request headers
114    let is_inertia = req
115        .headers()
116        .get("X-Inertia")
117        .and_then(|v| v.to_str().ok())
118        .map(|v| v == "true")
119        .unwrap_or(false);
120
121    let inertia_version = req
122        .headers()
123        .get("X-Inertia-Version")
124        .and_then(|v| v.to_str().ok())
125        .map(|v| v.to_string());
126
127    InertiaContext::set(InertiaContext {
128        path: path.clone(),
129        is_inertia,
130        version: inertia_version,
131    });
132
133    let response = match router.match_route(&method, &path) {
134        Some((handler, params)) => {
135            let request = Request::new(req).with_params(params);
136
137            // Build middleware chain
138            let mut chain = MiddlewareChain::new();
139
140            // 1. Add global middleware
141            chain.extend(middleware_registry.global_middleware().iter().cloned());
142
143            // 2. Add route-level middleware (already boxed)
144            let route_middleware = router.get_route_middleware(&path);
145            chain.extend(route_middleware);
146
147            // 3. Execute chain with handler
148            let response = chain.execute(request, handler).await;
149
150            // Unwrap the Result - both Ok and Err contain HttpResponse
151            let http_response = response.unwrap_or_else(|e| e);
152            http_response.into_hyper()
153        }
154        None => {
155            HttpResponse::text("404 Not Found")
156                .status(404)
157                .into_hyper()
158        }
159    };
160
161    // Clear context after request
162    InertiaContext::clear();
163
164    response
165}