kit_rs/
server.rs

1use crate::cache::Cache;
2use crate::config::{Config, ServerConfig};
3use crate::container::App;
4use crate::http::{HttpResponse, Request};
5use crate::inertia::InertiaContext;
6use crate::middleware::{Middleware, MiddlewareChain, MiddlewareRegistry};
7use crate::routing::Router;
8use bytes::Bytes;
9use http_body_util::Full;
10use hyper::server::conn::http1;
11use hyper::service::service_fn;
12use hyper_util::rt::TokioIo;
13use std::convert::Infallible;
14use std::net::SocketAddr;
15use std::sync::Arc;
16use tokio::net::TcpListener;
17
18pub struct Server {
19    router: Arc<Router>,
20    middleware: MiddlewareRegistry,
21    host: String,
22    port: u16,
23}
24
25impl Server {
26    pub fn new(router: impl Into<Router>) -> Self {
27        Self {
28            router: Arc::new(router.into()),
29            middleware: MiddlewareRegistry::new(),
30            host: "127.0.0.1".to_string(),
31            port: 8000,
32        }
33    }
34
35    pub fn from_config(router: impl Into<Router>) -> Self {
36        // Initialize the App container
37        App::init();
38
39        // Boot all auto-registered services from #[service(ConcreteType)]
40        App::boot_services();
41
42        let config = Config::get::<ServerConfig>().unwrap_or_else(ServerConfig::from_env);
43        Self {
44            router: Arc::new(router.into()),
45            // Pull global middleware registered via global_middleware! in bootstrap.rs
46            middleware: MiddlewareRegistry::from_global(),
47            host: config.host,
48            port: config.port,
49        }
50    }
51
52    /// Add global middleware (runs on every request)
53    ///
54    /// For route-specific middleware, use `.middleware(M)` on the route itself.
55    ///
56    /// # Example
57    ///
58    /// ```rust,ignore
59    /// Server::from_config(router)
60    ///     .middleware(LoggingMiddleware)  // Global
61    ///     .middleware(CorsMiddleware)     // Global
62    ///     .run()
63    ///     .await;
64    /// ```
65    pub fn middleware<M: Middleware + 'static>(mut self, middleware: M) -> Self {
66        self.middleware = self.middleware.append(middleware);
67        self
68    }
69
70    pub fn host(mut self, host: &str) -> Self {
71        self.host = host.to_string();
72        self
73    }
74
75    pub fn port(mut self, port: u16) -> Self {
76        self.port = port;
77        self
78    }
79
80    fn get_addr(&self) -> SocketAddr {
81        SocketAddr::new(self.host.parse().unwrap(), self.port)
82    }
83
84    pub async fn run(self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
85        // Bootstrap cache (Redis with in-memory fallback)
86        Cache::bootstrap().await;
87
88        let addr: SocketAddr = self.get_addr();
89        let listener = TcpListener::bind(addr).await?;
90
91        println!("Kit server running on http://{}", addr);
92
93        let router = self.router;
94        let middleware = Arc::new(self.middleware);
95
96        loop {
97            let (stream, _) = listener.accept().await?;
98            let io = TokioIo::new(stream);
99            let router = router.clone();
100            let middleware = middleware.clone();
101
102            tokio::spawn(async move {
103                let service = service_fn(move |req: hyper::Request<hyper::body::Incoming>| {
104                    let router = router.clone();
105                    let middleware = middleware.clone();
106                    async move { Ok::<_, Infallible>(handle_request(router, middleware, req).await) }
107                });
108
109                if let Err(err) = http1::Builder::new().serve_connection(io, service).await {
110                    eprintln!("Error serving connection: {:?}", err);
111                }
112            });
113        }
114    }
115}
116
117async fn handle_request(
118    router: Arc<Router>,
119    middleware_registry: Arc<MiddlewareRegistry>,
120    req: hyper::Request<hyper::body::Incoming>,
121) -> hyper::Response<Full<Bytes>> {
122    let method = req.method().clone();
123    let path = req.uri().path().to_string();
124    let query = req.uri().query().unwrap_or("");
125
126    // Built-in health check endpoint at /_kit/health
127    // Uses framework prefix to avoid conflicts with user-defined routes
128    if path == "/_kit/health" && method == hyper::Method::GET {
129        return health_response(query).await;
130    }
131
132    // Set up Inertia context from request headers
133    let is_inertia = req
134        .headers()
135        .get("X-Inertia")
136        .and_then(|v| v.to_str().ok())
137        .map(|v| v == "true")
138        .unwrap_or(false);
139
140    let inertia_version = req
141        .headers()
142        .get("X-Inertia-Version")
143        .and_then(|v| v.to_str().ok())
144        .map(|v| v.to_string());
145
146    InertiaContext::set(InertiaContext {
147        path: path.clone(),
148        is_inertia,
149        version: inertia_version,
150    });
151
152    let response = match router.match_route(&method, &path) {
153        Some((handler, params)) => {
154            let request = Request::new(req).with_params(params);
155
156            // Build middleware chain
157            let mut chain = MiddlewareChain::new();
158
159            // 1. Add global middleware
160            chain.extend(middleware_registry.global_middleware().iter().cloned());
161
162            // 2. Add route-level middleware (already boxed)
163            let route_middleware = router.get_route_middleware(&path);
164            chain.extend(route_middleware);
165
166            // 3. Execute chain with handler
167            let response = chain.execute(request, handler).await;
168
169            // Unwrap the Result - both Ok and Err contain HttpResponse
170            let http_response = response.unwrap_or_else(|e| e);
171            http_response.into_hyper()
172        }
173        None => {
174            // Check for fallback handler
175            if let Some((fallback_handler, fallback_middleware)) = router.get_fallback() {
176                let request = Request::new(req).with_params(std::collections::HashMap::new());
177
178                // Build middleware chain for fallback
179                let mut chain = MiddlewareChain::new();
180
181                // 1. Add global middleware
182                chain.extend(middleware_registry.global_middleware().iter().cloned());
183
184                // 2. Add fallback-specific middleware
185                chain.extend(fallback_middleware);
186
187                // 3. Execute chain with fallback handler
188                let response = chain.execute(request, fallback_handler).await;
189
190                // Unwrap the Result - both Ok and Err contain HttpResponse
191                let http_response = response.unwrap_or_else(|e| e);
192                http_response.into_hyper()
193            } else {
194                // No fallback defined, return default 404
195                HttpResponse::text("404 Not Found").status(404).into_hyper()
196            }
197        }
198    };
199
200    // Clear context after request
201    InertiaContext::clear();
202
203    response
204}
205
206/// Built-in health check endpoint at /_kit/health
207/// Returns {"status": "ok", "timestamp": "..."} by default
208/// Add ?db=true to also check database connectivity (/_kit/health?db=true)
209async fn health_response(query: &str) -> hyper::Response<Full<Bytes>> {
210    use chrono::Utc;
211    use serde_json::json;
212
213    let timestamp = Utc::now().to_rfc3339();
214    let check_db = query.contains("db=true");
215
216    let mut response = json!({
217        "status": "ok",
218        "timestamp": timestamp
219    });
220
221    if check_db {
222        // Try to check database connection
223        match check_database_health().await {
224            Ok(_) => {
225                response["database"] = json!("connected");
226            }
227            Err(e) => {
228                response["database"] = json!("error");
229                response["database_error"] = json!(e);
230            }
231        }
232    }
233
234    let body = serde_json::to_string(&response).unwrap_or_else(|_| r#"{"status":"ok"}"#.to_string());
235
236    hyper::Response::builder()
237        .status(200)
238        .header("Content-Type", "application/json")
239        .body(Full::new(Bytes::from(body)))
240        .unwrap()
241}
242
243/// Check database health by attempting a simple query
244async fn check_database_health() -> Result<(), String> {
245    use crate::database::DB;
246    use sea_orm::ConnectionTrait;
247
248    if !DB::is_connected() {
249        return Err("Database not initialized".to_string());
250    }
251
252    let conn = DB::connection().map_err(|e| e.to_string())?;
253
254    // Execute a simple query to verify connection is alive
255    conn.inner()
256        .execute_unprepared("SELECT 1")
257        .await
258        .map_err(|e| format!("Database query failed: {}", e))?;
259
260    Ok(())
261}