Skip to main content

ferro_rs/
server.rs

1use crate::cache::Cache;
2use crate::config::{Config, ServerConfig};
3use crate::container::App;
4use crate::http::FerroBody;
5use crate::http::{HttpResponse, Request};
6use crate::middleware::{Middleware, MiddlewareChain, MiddlewareRegistry};
7use crate::routing::Router;
8use crate::websocket::handle_ws_upgrade;
9use bytes::Bytes;
10use http_body_util::Full;
11use hyper::server::conn::http1;
12use hyper::service::service_fn;
13use hyper_util::rt::TokioIo;
14use std::convert::Infallible;
15use std::net::SocketAddr;
16use std::sync::Arc;
17use tokio::net::TcpListener;
18
19/// Pre-routing WebSocket interceptor.
20///
21/// Called for every WS upgrade request before Ferro routing.
22/// Returns `Ok(Response)` to handle the request, `Err(Request)` to decline
23/// and pass to normal routing (including the built-in `/_ferro/ws` check).
24type WsInterceptor = Box<
25    dyn Fn(
26            hyper::Request<hyper::body::Incoming>,
27        ) -> Result<hyper::Response<FerroBody>, hyper::Request<hyper::body::Incoming>>
28        + Send
29        + Sync,
30>;
31
32/// HTTP server that binds routes, middleware, and optional WebSocket handling.
33pub struct Server {
34    router: Arc<Router>,
35    middleware: MiddlewareRegistry,
36    host: String,
37    port: u16,
38    ws_interceptor: Option<Arc<WsInterceptor>>,
39}
40
41impl Server {
42    /// Create a server with default host/port and no global middleware.
43    pub fn new(router: impl Into<Router>) -> Self {
44        Self {
45            router: Arc::new(router.into()),
46            middleware: MiddlewareRegistry::new(),
47            host: "127.0.0.1".to_string(),
48            port: 8080,
49            ws_interceptor: None,
50        }
51    }
52
53    /// Create a server from environment configuration, booting all services.
54    pub fn from_config(router: impl Into<Router>) -> Self {
55        // Initialize the App container
56        App::init();
57
58        // Boot all auto-registered services from #[service(ConcreteType)]
59        App::boot_services();
60
61        let config = Config::get::<ServerConfig>().unwrap_or_else(ServerConfig::from_env);
62        Self {
63            router: Arc::new(router.into()),
64            // Pull global middleware registered via global_middleware! in bootstrap.rs
65            middleware: MiddlewareRegistry::from_global(),
66            host: config.host,
67            port: config.port,
68            ws_interceptor: None,
69        }
70    }
71
72    /// Set a WebSocket interceptor that runs before all routing.
73    ///
74    /// The interceptor receives every WS upgrade request first.
75    /// Return `Ok(response)` to handle the connection; return `Err(request)` to
76    /// decline and let normal routing (including `/_ferro/ws`) proceed.
77    ///
78    /// # Example
79    ///
80    /// ```rust,ignore
81    /// Server::from_config(router)
82    ///     .ws_interceptor(|req| {
83    ///         if req.uri().path().starts_with("/sessions/") {
84    ///             Ok(my_ws_handler(req))
85    ///         } else {
86    ///             Err(req) // pass to /_ferro/ws
87    ///         }
88    ///     })
89    ///     .run()
90    ///     .await;
91    /// ```
92    pub fn ws_interceptor<F>(mut self, handler: F) -> Self
93    where
94        F: Fn(
95                hyper::Request<hyper::body::Incoming>,
96            )
97                -> Result<hyper::Response<FerroBody>, hyper::Request<hyper::body::Incoming>>
98            + Send
99            + Sync
100            + 'static,
101    {
102        self.ws_interceptor = Some(Arc::new(Box::new(handler)));
103        self
104    }
105
106    /// Add global middleware (runs on every request)
107    ///
108    /// For route-specific middleware, use `.middleware(M)` on the route itself.
109    ///
110    /// # Example
111    ///
112    /// ```rust,ignore
113    /// Server::from_config(router)
114    ///     .middleware(LoggingMiddleware)  // Global
115    ///     .middleware(CorsMiddleware)     // Global
116    ///     .run()
117    ///     .await;
118    /// ```
119    pub fn middleware<M: Middleware + 'static>(mut self, middleware: M) -> Self {
120        self.middleware = self.middleware.append(middleware);
121        self
122    }
123
124    /// Override the listen host address.
125    pub fn host(mut self, host: &str) -> Self {
126        self.host = host.to_string();
127        self
128    }
129
130    /// Override the listen port.
131    pub fn port(mut self, port: u16) -> Self {
132        self.port = port;
133        self
134    }
135
136    fn get_addr(&self) -> SocketAddr {
137        SocketAddr::new(self.host.parse().unwrap(), self.port)
138    }
139
140    /// Start listening and serving requests until the process is terminated.
141    pub async fn run(self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
142        // Bootstrap cache (Redis with in-memory fallback)
143        Cache::bootstrap().await;
144
145        let addr: SocketAddr = self.get_addr();
146        let listener = TcpListener::bind(addr).await?;
147
148        println!("Ferro server running on http://{addr}");
149
150        let router = self.router;
151        let middleware = Arc::new(self.middleware);
152        let ws_interceptor = self.ws_interceptor;
153
154        loop {
155            let (stream, _) = listener.accept().await?;
156            let io = TokioIo::new(stream);
157            let router = router.clone();
158            let middleware = middleware.clone();
159            let ws_interceptor = ws_interceptor.clone();
160
161            tokio::spawn(async move {
162                let service = service_fn(move |req: hyper::Request<hyper::body::Incoming>| {
163                    let router = router.clone();
164                    let middleware = middleware.clone();
165                    let ws_interceptor = ws_interceptor.clone();
166                    async move {
167                        Ok::<_, Infallible>(
168                            handle_request(router, middleware, ws_interceptor, req).await,
169                        )
170                    }
171                });
172
173                if let Err(err) = http1::Builder::new()
174                    .serve_connection(io, service)
175                    .with_upgrades()
176                    .await
177                {
178                    eprintln!("Error serving connection: {err:?}");
179                }
180            });
181        }
182    }
183}
184
185async fn handle_request(
186    router: Arc<Router>,
187    middleware_registry: Arc<MiddlewareRegistry>,
188    ws_interceptor: Option<Arc<WsInterceptor>>,
189    mut req: hyper::Request<hyper::body::Incoming>,
190) -> hyper::Response<FerroBody> {
191    // Application WS interceptor runs before /_ferro/ws
192    if let Some(ref interceptor) = ws_interceptor {
193        if hyper_tungstenite::is_upgrade_request(&req) {
194            match interceptor(req) {
195                Ok(response) => return response,
196                Err(returned_req) => {
197                    // Interceptor declined — continue with returned request
198                    req = returned_req;
199                }
200            }
201        }
202    }
203
204    let method = req.method().clone();
205    let path = req.uri().path().to_string();
206    let query = req.uri().query().unwrap_or("");
207
208    // WebSocket upgrade at /_ferro/ws (must run before middleware/routing)
209    if path == "/_ferro/ws" && hyper_tungstenite::is_upgrade_request(&req) {
210        return handle_ws_upgrade(req);
211    }
212
213    // Built-in framework endpoints at /_ferro/*
214    // Uses framework prefix to avoid conflicts with user-defined routes
215    if path.starts_with("/_ferro/") && method == hyper::Method::GET {
216        return match path.as_str() {
217            "/_ferro/health" => health_response(query).await,
218            "/_ferro/routes" => crate::debug::handle_routes(),
219            "/_ferro/middleware" => crate::debug::handle_middleware(),
220            "/_ferro/services" => crate::debug::handle_services(),
221            "/_ferro/metrics" => crate::debug::handle_metrics(),
222            "/_ferro/queue/jobs" => crate::debug::handle_queue_jobs().await,
223            "/_ferro/queue/stats" => crate::debug::handle_queue_stats().await,
224            "/_ferro/ferro-base.css" => {
225                #[cfg(feature = "json-ui")]
226                {
227                    serve_ferro_base_css()
228                }
229                #[cfg(not(feature = "json-ui"))]
230                {
231                    HttpResponse::text("404 Not Found").status(404).into_hyper()
232                }
233            }
234            _ => HttpResponse::text("404 Not Found").status(404).into_hyper(),
235        };
236    }
237
238    // Run pre-route middleware (e.g. custom domain path rewriting) before routing.
239    let pre_route = crate::middleware::get_pre_route_middleware();
240    for hook in &pre_route {
241        match hook.handle(req).await {
242            Ok(rewritten) => req = rewritten,
243            Err(response) => return response,
244        }
245    }
246
247    let method = req.method().clone();
248    let path = req.uri().path().to_string();
249
250    let ferro_request = Request::new(req);
251    let routing_path = path.clone();
252
253    // Extract host before request is consumed by routing.
254    let request_host = ferro_request
255        .header("host")
256        .unwrap_or_default()
257        .split(':')
258        .next()
259        .unwrap_or("")
260        .to_ascii_lowercase();
261
262    let response = match router.match_route(&method, &routing_path) {
263        Some((handler, params, route_pattern)) => {
264            let request = ferro_request
265                .with_params(params)
266                .with_route_pattern(route_pattern.clone());
267
268            // Build middleware chain
269            let mut chain = MiddlewareChain::new();
270
271            // 1. Add global middleware
272            chain.extend(middleware_registry.global_middleware().iter().cloned());
273
274            // 2. Add route-level middleware (already boxed)
275            let route_middleware = router.get_route_middleware(&route_pattern);
276            chain.extend(route_middleware);
277
278            // 3. Execute chain with handler inside request host context
279            let response = crate::http::request_context::REQUEST_HOST
280                .scope(request_host, chain.execute(request, handler))
281                .await;
282
283            // Unwrap the Result - both Ok and Err contain HttpResponse
284            let http_response = response.unwrap_or_else(|e| e);
285            http_response.into_hyper()
286        }
287        None => {
288            // Try static file serving before fallback (only GET/HEAD)
289            if method == hyper::Method::GET || method == hyper::Method::HEAD {
290                if let Some(response) =
291                    crate::static_files::try_serve_static_file(&routing_path).await
292                {
293                    return response;
294                }
295            }
296
297            // Check for fallback handler
298            if let Some((fallback_handler, fallback_middleware)) = router.get_fallback() {
299                let request = ferro_request.with_params(std::collections::HashMap::new());
300
301                // Build middleware chain for fallback
302                let mut chain = MiddlewareChain::new();
303
304                // 1. Add global middleware
305                chain.extend(middleware_registry.global_middleware().iter().cloned());
306
307                // 2. Add fallback-specific middleware
308                chain.extend(fallback_middleware);
309
310                // 3. Execute chain with fallback handler
311                let response = chain.execute(request, fallback_handler).await;
312
313                // Unwrap the Result - both Ok and Err contain HttpResponse
314                let http_response = response.unwrap_or_else(|e| e);
315                http_response.into_hyper()
316            } else {
317                // No fallback defined, return default 404
318                HttpResponse::text("404 Not Found").status(404).into_hyper()
319            }
320        }
321    };
322
323    response
324}
325
326/// Built-in health check endpoint at /_ferro/health
327/// Returns {"status": "ok", "timestamp": "..."} by default
328/// Add ?db=true to also check database connectivity (/_ferro/health?db=true)
329async fn health_response(query: &str) -> hyper::Response<FerroBody> {
330    use chrono::Utc;
331    use serde_json::json;
332
333    let timestamp = Utc::now().to_rfc3339();
334    let check_db = query.contains("db=true");
335
336    let mut response = json!({
337        "status": "ok",
338        "timestamp": timestamp
339    });
340
341    if check_db {
342        // Try to check database connection
343        match check_database_health().await {
344            Ok(_) => {
345                response["database"] = json!("connected");
346            }
347            Err(e) => {
348                response["database"] = json!("error");
349                response["database_error"] = json!(e);
350            }
351        }
352    }
353
354    let body =
355        serde_json::to_string(&response).unwrap_or_else(|_| r#"{"status":"ok"}"#.to_string());
356
357    hyper::Response::builder()
358        .status(200)
359        .header("Content-Type", "application/json")
360        .body(FerroBody::Full(Full::new(Bytes::from(body))))
361        .unwrap()
362}
363
364/// Serve the pre-built ferro-json-ui base CSS.
365///
366/// The bytes are embedded at compile time via ferro_json_ui::FERRO_BASE_CSS.
367/// Response: 200, text/css, 24h cache. No user input reaches this handler —
368/// the match arm is an exact string, and the body is static framework content.
369#[cfg(feature = "json-ui")]
370fn serve_ferro_base_css() -> hyper::Response<FerroBody> {
371    let css = ferro_json_ui::FERRO_BASE_CSS;
372    hyper::Response::builder()
373        .status(200)
374        .header("Content-Type", "text/css; charset=utf-8")
375        .header("Content-Length", css.len().to_string())
376        .header("Cache-Control", "public, max-age=31536000, immutable")
377        .body(FerroBody::Full(Full::new(Bytes::from_static(
378            css.as_bytes(),
379        ))))
380        .unwrap()
381}
382
383/// Check database health by attempting a simple query
384async fn check_database_health() -> Result<(), String> {
385    use crate::database::DB;
386    use sea_orm::ConnectionTrait;
387
388    if !DB::is_connected() {
389        return Err("Database not initialized".to_string());
390    }
391
392    let conn = DB::connection().map_err(|e| e.to_string())?;
393
394    // Execute a simple query to verify connection is alive
395    conn.inner()
396        .execute_unprepared("SELECT 1")
397        .await
398        .map_err(|e| format!("Database query failed: {e}"))?;
399
400    Ok(())
401}
402
403#[cfg(all(test, feature = "json-ui"))]
404mod ferro_base_css_route_tests {
405    use super::*;
406    use http_body_util::BodyExt;
407
408    #[tokio::test]
409    async fn serve_ferro_base_css_returns_200_with_text_css_content_type() {
410        let response = serve_ferro_base_css();
411
412        assert_eq!(response.status(), 200, "expected 200 OK");
413
414        let ct = response
415            .headers()
416            .get("Content-Type")
417            .expect("Content-Type header missing")
418            .to_str()
419            .unwrap();
420        assert_eq!(ct, "text/css; charset=utf-8");
421
422        let cc = response
423            .headers()
424            .get("Cache-Control")
425            .expect("Cache-Control header missing")
426            .to_str()
427            .unwrap();
428        assert_eq!(cc, "public, max-age=31536000, immutable");
429
430        let cl = response
431            .headers()
432            .get("Content-Length")
433            .expect("Content-Length header missing")
434            .to_str()
435            .unwrap()
436            .parse::<usize>()
437            .expect("Content-Length must be an integer");
438        assert_eq!(cl, ferro_json_ui::FERRO_BASE_CSS.len());
439    }
440
441    #[tokio::test]
442    async fn serve_ferro_base_css_body_equals_embedded_constant() {
443        let response = serve_ferro_base_css();
444        let body_bytes = response
445            .into_body()
446            .collect()
447            .await
448            .expect("body collect")
449            .to_bytes();
450        assert_eq!(
451            body_bytes.as_ref(),
452            ferro_json_ui::FERRO_BASE_CSS.as_bytes()
453        );
454        assert!(!body_bytes.is_empty());
455    }
456}