1use crate::cache::Cache;
2use crate::config::{Config, ServerConfig};
3use crate::container::App;
4use crate::http::{HttpResponse, Request};
5use crate::inertia::InertiaContext;
6use crate::middleware::{Middleware, MiddlewareChain, MiddlewareRegistry};
7use crate::routing::Router;
8use bytes::Bytes;
9use http_body_util::Full;
10use hyper::server::conn::http1;
11use hyper::service::service_fn;
12use hyper_util::rt::TokioIo;
13use std::convert::Infallible;
14use std::net::SocketAddr;
15use std::sync::Arc;
16use tokio::net::TcpListener;
17
18pub struct Server {
19 router: Arc<Router>,
20 middleware: MiddlewareRegistry,
21 host: String,
22 port: u16,
23}
24
25impl Server {
26 pub fn new(router: impl Into<Router>) -> Self {
27 Self {
28 router: Arc::new(router.into()),
29 middleware: MiddlewareRegistry::new(),
30 host: "127.0.0.1".to_string(),
31 port: 8000,
32 }
33 }
34
35 pub fn from_config(router: impl Into<Router>) -> Self {
36 App::init();
38
39 App::boot_services();
41
42 let config = Config::get::<ServerConfig>().unwrap_or_else(ServerConfig::from_env);
43 Self {
44 router: Arc::new(router.into()),
45 middleware: MiddlewareRegistry::from_global(),
47 host: config.host,
48 port: config.port,
49 }
50 }
51
52 pub fn middleware<M: Middleware + 'static>(mut self, middleware: M) -> Self {
66 self.middleware = self.middleware.append(middleware);
67 self
68 }
69
70 pub fn host(mut self, host: &str) -> Self {
71 self.host = host.to_string();
72 self
73 }
74
75 pub fn port(mut self, port: u16) -> Self {
76 self.port = port;
77 self
78 }
79
80 fn get_addr(&self) -> SocketAddr {
81 SocketAddr::new(self.host.parse().unwrap(), self.port)
82 }
83
84 pub async fn run(self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
85 Cache::bootstrap().await;
87
88 let addr: SocketAddr = self.get_addr();
89 let listener = TcpListener::bind(addr).await?;
90
91 println!("Kit server running on http://{}", addr);
92
93 let router = self.router;
94 let middleware = Arc::new(self.middleware);
95
96 loop {
97 let (stream, _) = listener.accept().await?;
98 let io = TokioIo::new(stream);
99 let router = router.clone();
100 let middleware = middleware.clone();
101
102 tokio::spawn(async move {
103 let service = service_fn(move |req: hyper::Request<hyper::body::Incoming>| {
104 let router = router.clone();
105 let middleware = middleware.clone();
106 async move { Ok::<_, Infallible>(handle_request(router, middleware, req).await) }
107 });
108
109 if let Err(err) = http1::Builder::new().serve_connection(io, service).await {
110 eprintln!("Error serving connection: {:?}", err);
111 }
112 });
113 }
114 }
115}
116
117async fn handle_request(
118 router: Arc<Router>,
119 middleware_registry: Arc<MiddlewareRegistry>,
120 req: hyper::Request<hyper::body::Incoming>,
121) -> hyper::Response<Full<Bytes>> {
122 let method = req.method().clone();
123 let path = req.uri().path().to_string();
124 let query = req.uri().query().unwrap_or("");
125
126 if path == "/_kit/health" && method == hyper::Method::GET {
129 return health_response(query).await;
130 }
131
132 let is_inertia = req
134 .headers()
135 .get("X-Inertia")
136 .and_then(|v| v.to_str().ok())
137 .map(|v| v == "true")
138 .unwrap_or(false);
139
140 let inertia_version = req
141 .headers()
142 .get("X-Inertia-Version")
143 .and_then(|v| v.to_str().ok())
144 .map(|v| v.to_string());
145
146 InertiaContext::set(InertiaContext {
147 path: path.clone(),
148 is_inertia,
149 version: inertia_version,
150 });
151
152 let response = match router.match_route(&method, &path) {
153 Some((handler, params)) => {
154 let request = Request::new(req).with_params(params);
155
156 let mut chain = MiddlewareChain::new();
158
159 chain.extend(middleware_registry.global_middleware().iter().cloned());
161
162 let route_middleware = router.get_route_middleware(&path);
164 chain.extend(route_middleware);
165
166 let response = chain.execute(request, handler).await;
168
169 let http_response = response.unwrap_or_else(|e| e);
171 http_response.into_hyper()
172 }
173 None => {
174 if let Some((fallback_handler, fallback_middleware)) = router.get_fallback() {
176 let request = Request::new(req).with_params(std::collections::HashMap::new());
177
178 let mut chain = MiddlewareChain::new();
180
181 chain.extend(middleware_registry.global_middleware().iter().cloned());
183
184 chain.extend(fallback_middleware);
186
187 let response = chain.execute(request, fallback_handler).await;
189
190 let http_response = response.unwrap_or_else(|e| e);
192 http_response.into_hyper()
193 } else {
194 HttpResponse::text("404 Not Found").status(404).into_hyper()
196 }
197 }
198 };
199
200 InertiaContext::clear();
202
203 response
204}
205
206async fn health_response(query: &str) -> hyper::Response<Full<Bytes>> {
210 use chrono::Utc;
211 use serde_json::json;
212
213 let timestamp = Utc::now().to_rfc3339();
214 let check_db = query.contains("db=true");
215
216 let mut response = json!({
217 "status": "ok",
218 "timestamp": timestamp
219 });
220
221 if check_db {
222 match check_database_health().await {
224 Ok(_) => {
225 response["database"] = json!("connected");
226 }
227 Err(e) => {
228 response["database"] = json!("error");
229 response["database_error"] = json!(e);
230 }
231 }
232 }
233
234 let body = serde_json::to_string(&response).unwrap_or_else(|_| r#"{"status":"ok"}"#.to_string());
235
236 hyper::Response::builder()
237 .status(200)
238 .header("Content-Type", "application/json")
239 .body(Full::new(Bytes::from(body)))
240 .unwrap()
241}
242
243async fn check_database_health() -> Result<(), String> {
245 use crate::database::DB;
246 use sea_orm::ConnectionTrait;
247
248 if !DB::is_connected() {
249 return Err("Database not initialized".to_string());
250 }
251
252 let conn = DB::connection().map_err(|e| e.to_string())?;
253
254 conn.inner()
256 .execute_unprepared("SELECT 1")
257 .await
258 .map_err(|e| format!("Database query failed: {}", e))?;
259
260 Ok(())
261}