kit_rs/
server.rs

1use crate::config::{Config, ServerConfig};
2use crate::container::App;
3use crate::http::{HttpResponse, Request};
4use crate::inertia::InertiaContext;
5use crate::middleware::{Middleware, MiddlewareChain, MiddlewareRegistry};
6use crate::routing::Router;
7use bytes::Bytes;
8use http_body_util::Full;
9use hyper::server::conn::http1;
10use hyper::service::service_fn;
11use hyper_util::rt::TokioIo;
12use std::convert::Infallible;
13use std::net::SocketAddr;
14use std::sync::Arc;
15use tokio::net::TcpListener;
16
17pub struct Server {
18    router: Arc<Router>,
19    middleware: MiddlewareRegistry,
20    host: String,
21    port: u16,
22}
23
24impl Server {
25    pub fn new(router: impl Into<Router>) -> Self {
26        Self {
27            router: Arc::new(router.into()),
28            middleware: MiddlewareRegistry::new(),
29            host: "127.0.0.1".to_string(),
30            port: 8000,
31        }
32    }
33
34    pub fn from_config(router: impl Into<Router>) -> Self {
35        // Initialize the App container
36        App::init();
37
38        // Boot all auto-registered services from #[service(ConcreteType)]
39        App::boot_services();
40
41        let config = Config::get::<ServerConfig>().unwrap_or_else(ServerConfig::from_env);
42        Self {
43            router: Arc::new(router.into()),
44            // Pull global middleware registered via global_middleware! in bootstrap.rs
45            middleware: MiddlewareRegistry::from_global(),
46            host: config.host,
47            port: config.port,
48        }
49    }
50
51    /// Add global middleware (runs on every request)
52    ///
53    /// For route-specific middleware, use `.middleware(M)` on the route itself.
54    ///
55    /// # Example
56    ///
57    /// ```rust,ignore
58    /// Server::from_config(router)
59    ///     .middleware(LoggingMiddleware)  // Global
60    ///     .middleware(CorsMiddleware)     // Global
61    ///     .run()
62    ///     .await;
63    /// ```
64    pub fn middleware<M: Middleware + 'static>(mut self, middleware: M) -> Self {
65        self.middleware = self.middleware.append(middleware);
66        self
67    }
68
69    pub fn host(mut self, host: &str) -> Self {
70        self.host = host.to_string();
71        self
72    }
73
74    pub fn port(mut self, port: u16) -> Self {
75        self.port = port;
76        self
77    }
78
79    fn get_addr(&self) -> SocketAddr {
80        SocketAddr::new(self.host.parse().unwrap(), self.port)
81    }
82
83    pub async fn run(self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
84        let addr: SocketAddr = self.get_addr();
85        let listener = TcpListener::bind(addr).await?;
86
87        println!("Kit server running on http://{}", addr);
88
89        let router = self.router;
90        let middleware = Arc::new(self.middleware);
91
92        loop {
93            let (stream, _) = listener.accept().await?;
94            let io = TokioIo::new(stream);
95            let router = router.clone();
96            let middleware = middleware.clone();
97
98            tokio::spawn(async move {
99                let service = service_fn(move |req: hyper::Request<hyper::body::Incoming>| {
100                    let router = router.clone();
101                    let middleware = middleware.clone();
102                    async move { Ok::<_, Infallible>(handle_request(router, middleware, req).await) }
103                });
104
105                if let Err(err) = http1::Builder::new().serve_connection(io, service).await {
106                    eprintln!("Error serving connection: {:?}", err);
107                }
108            });
109        }
110    }
111}
112
113async fn handle_request(
114    router: Arc<Router>,
115    middleware_registry: Arc<MiddlewareRegistry>,
116    req: hyper::Request<hyper::body::Incoming>,
117) -> hyper::Response<Full<Bytes>> {
118    let method = req.method().clone();
119    let path = req.uri().path().to_string();
120
121    // Set up Inertia context from request headers
122    let is_inertia = req
123        .headers()
124        .get("X-Inertia")
125        .and_then(|v| v.to_str().ok())
126        .map(|v| v == "true")
127        .unwrap_or(false);
128
129    let inertia_version = req
130        .headers()
131        .get("X-Inertia-Version")
132        .and_then(|v| v.to_str().ok())
133        .map(|v| v.to_string());
134
135    InertiaContext::set(InertiaContext {
136        path: path.clone(),
137        is_inertia,
138        version: inertia_version,
139    });
140
141    let response = match router.match_route(&method, &path) {
142        Some((handler, params)) => {
143            let request = Request::new(req).with_params(params);
144
145            // Build middleware chain
146            let mut chain = MiddlewareChain::new();
147
148            // 1. Add global middleware
149            chain.extend(middleware_registry.global_middleware().iter().cloned());
150
151            // 2. Add route-level middleware (already boxed)
152            let route_middleware = router.get_route_middleware(&path);
153            chain.extend(route_middleware);
154
155            // 3. Execute chain with handler
156            let response = chain.execute(request, handler).await;
157
158            // Unwrap the Result - both Ok and Err contain HttpResponse
159            let http_response = response.unwrap_or_else(|e| e);
160            http_response.into_hyper()
161        }
162        None => {
163            HttpResponse::text("404 Not Found")
164                .status(404)
165                .into_hyper()
166        }
167    };
168
169    // Clear context after request
170    InertiaContext::clear();
171
172    response
173}