Skip to main content

ferro_rs/
server.rs

1use crate::cache::Cache;
2use crate::config::{Config, ServerConfig};
3use crate::container::App;
4use crate::http::{HttpResponse, Request};
5use crate::middleware::{Middleware, MiddlewareChain, MiddlewareRegistry};
6use crate::routing::Router;
7use crate::websocket::handle_ws_upgrade;
8use bytes::Bytes;
9use http_body_util::Full;
10use hyper::server::conn::http1;
11use hyper::service::service_fn;
12use hyper_util::rt::TokioIo;
13use std::convert::Infallible;
14use std::net::SocketAddr;
15use std::sync::Arc;
16use tokio::net::TcpListener;
17
18/// Pre-routing WebSocket interceptor.
19///
20/// Called for every WS upgrade request before Ferro routing.
21/// Returns `Ok(Response)` to handle the request, `Err(Request)` to decline
22/// and pass to normal routing (including the built-in `/_ferro/ws` check).
23type WsInterceptor = Box<
24    dyn Fn(
25            hyper::Request<hyper::body::Incoming>,
26        ) -> Result<hyper::Response<Full<Bytes>>, hyper::Request<hyper::body::Incoming>>
27        + Send
28        + Sync,
29>;
30
31/// HTTP server that binds routes, middleware, and optional WebSocket handling.
32pub struct Server {
33    router: Arc<Router>,
34    middleware: MiddlewareRegistry,
35    host: String,
36    port: u16,
37    ws_interceptor: Option<Arc<WsInterceptor>>,
38}
39
40impl Server {
41    /// Create a server with default host/port and no global middleware.
42    pub fn new(router: impl Into<Router>) -> Self {
43        Self {
44            router: Arc::new(router.into()),
45            middleware: MiddlewareRegistry::new(),
46            host: "127.0.0.1".to_string(),
47            port: 8080,
48            ws_interceptor: None,
49        }
50    }
51
52    /// Create a server from environment configuration, booting all services.
53    pub fn from_config(router: impl Into<Router>) -> Self {
54        // Initialize the App container
55        App::init();
56
57        // Boot all auto-registered services from #[service(ConcreteType)]
58        App::boot_services();
59
60        let config = Config::get::<ServerConfig>().unwrap_or_else(ServerConfig::from_env);
61        Self {
62            router: Arc::new(router.into()),
63            // Pull global middleware registered via global_middleware! in bootstrap.rs
64            middleware: MiddlewareRegistry::from_global(),
65            host: config.host,
66            port: config.port,
67            ws_interceptor: None,
68        }
69    }
70
71    /// Set a WebSocket interceptor that runs before all routing.
72    ///
73    /// The interceptor receives every WS upgrade request first.
74    /// Return `Ok(response)` to handle the connection; return `Err(request)` to
75    /// decline and let normal routing (including `/_ferro/ws`) proceed.
76    ///
77    /// # Example
78    ///
79    /// ```rust,ignore
80    /// Server::from_config(router)
81    ///     .ws_interceptor(|req| {
82    ///         if req.uri().path().starts_with("/sessions/") {
83    ///             Ok(my_ws_handler(req))
84    ///         } else {
85    ///             Err(req) // pass to /_ferro/ws
86    ///         }
87    ///     })
88    ///     .run()
89    ///     .await;
90    /// ```
91    pub fn ws_interceptor<F>(mut self, handler: F) -> Self
92    where
93        F: Fn(
94                hyper::Request<hyper::body::Incoming>,
95            )
96                -> Result<hyper::Response<Full<Bytes>>, hyper::Request<hyper::body::Incoming>>
97            + Send
98            + Sync
99            + 'static,
100    {
101        self.ws_interceptor = Some(Arc::new(Box::new(handler)));
102        self
103    }
104
105    /// Add global middleware (runs on every request)
106    ///
107    /// For route-specific middleware, use `.middleware(M)` on the route itself.
108    ///
109    /// # Example
110    ///
111    /// ```rust,ignore
112    /// Server::from_config(router)
113    ///     .middleware(LoggingMiddleware)  // Global
114    ///     .middleware(CorsMiddleware)     // Global
115    ///     .run()
116    ///     .await;
117    /// ```
118    pub fn middleware<M: Middleware + 'static>(mut self, middleware: M) -> Self {
119        self.middleware = self.middleware.append(middleware);
120        self
121    }
122
123    /// Override the listen host address.
124    pub fn host(mut self, host: &str) -> Self {
125        self.host = host.to_string();
126        self
127    }
128
129    /// Override the listen port.
130    pub fn port(mut self, port: u16) -> Self {
131        self.port = port;
132        self
133    }
134
135    fn get_addr(&self) -> SocketAddr {
136        SocketAddr::new(self.host.parse().unwrap(), self.port)
137    }
138
139    /// Start listening and serving requests until the process is terminated.
140    pub async fn run(self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
141        // Bootstrap cache (Redis with in-memory fallback)
142        Cache::bootstrap().await;
143
144        let addr: SocketAddr = self.get_addr();
145        let listener = TcpListener::bind(addr).await?;
146
147        println!("Ferro server running on http://{addr}");
148
149        let router = self.router;
150        let middleware = Arc::new(self.middleware);
151        let ws_interceptor = self.ws_interceptor;
152
153        loop {
154            let (stream, _) = listener.accept().await?;
155            let io = TokioIo::new(stream);
156            let router = router.clone();
157            let middleware = middleware.clone();
158            let ws_interceptor = ws_interceptor.clone();
159
160            tokio::spawn(async move {
161                let service = service_fn(move |req: hyper::Request<hyper::body::Incoming>| {
162                    let router = router.clone();
163                    let middleware = middleware.clone();
164                    let ws_interceptor = ws_interceptor.clone();
165                    async move {
166                        Ok::<_, Infallible>(
167                            handle_request(router, middleware, ws_interceptor, req).await,
168                        )
169                    }
170                });
171
172                if let Err(err) = http1::Builder::new()
173                    .serve_connection(io, service)
174                    .with_upgrades()
175                    .await
176                {
177                    eprintln!("Error serving connection: {err:?}");
178                }
179            });
180        }
181    }
182}
183
184async fn handle_request(
185    router: Arc<Router>,
186    middleware_registry: Arc<MiddlewareRegistry>,
187    ws_interceptor: Option<Arc<WsInterceptor>>,
188    mut req: hyper::Request<hyper::body::Incoming>,
189) -> hyper::Response<Full<Bytes>> {
190    // Application WS interceptor runs before /_ferro/ws
191    if let Some(ref interceptor) = ws_interceptor {
192        if hyper_tungstenite::is_upgrade_request(&req) {
193            match interceptor(req) {
194                Ok(response) => return response,
195                Err(returned_req) => {
196                    // Interceptor declined — continue with returned request
197                    req = returned_req;
198                }
199            }
200        }
201    }
202
203    let method = req.method().clone();
204    let path = req.uri().path().to_string();
205    let query = req.uri().query().unwrap_or("");
206
207    // WebSocket upgrade at /_ferro/ws (must run before middleware/routing)
208    if path == "/_ferro/ws" && hyper_tungstenite::is_upgrade_request(&req) {
209        return handle_ws_upgrade(req);
210    }
211
212    // Built-in framework endpoints at /_ferro/*
213    // Uses framework prefix to avoid conflicts with user-defined routes
214    if path.starts_with("/_ferro/") && method == hyper::Method::GET {
215        return match path.as_str() {
216            "/_ferro/health" => health_response(query).await,
217            "/_ferro/routes" => crate::debug::handle_routes(),
218            "/_ferro/middleware" => crate::debug::handle_middleware(),
219            "/_ferro/services" => crate::debug::handle_services(),
220            "/_ferro/metrics" => crate::debug::handle_metrics(),
221            "/_ferro/queue/jobs" => crate::debug::handle_queue_jobs().await,
222            "/_ferro/queue/stats" => crate::debug::handle_queue_stats().await,
223            "/_ferro/ferro-base.css" => {
224                #[cfg(feature = "json-ui")]
225                {
226                    serve_ferro_base_css()
227                }
228                #[cfg(not(feature = "json-ui"))]
229                {
230                    HttpResponse::text("404 Not Found").status(404).into_hyper()
231                }
232            }
233            _ => HttpResponse::text("404 Not Found").status(404).into_hyper(),
234        };
235    }
236
237    // Note: Inertia context is now read directly from Request headers
238    // via req.is_inertia(), req.inertia_version(), etc.
239    // No thread-local storage needed - this is async-safe.
240
241    // Run pre-route middleware (path rewrites that affect route matching).
242    // PreRouteMiddleware runs before match_route so set_path() calls influence routing.
243    let mut ferro_request = Request::new(req);
244    for mw in &crate::middleware::get_pre_route_middleware() {
245        ferro_request = match mw.rewrite(ferro_request).await {
246            Ok(r) => r,
247            Err(response) => {
248                // Short-circuit: middleware rejected the request (e.g. unknown domain → 404).
249                return response.into_hyper();
250            }
251        };
252    }
253    // Use the (possibly rewritten) path for route matching and static file serving.
254    let routing_path = ferro_request.path().to_string();
255
256    // Extract host before request is consumed by routing.
257    let request_host = ferro_request
258        .header("host")
259        .unwrap_or_default()
260        .split(':')
261        .next()
262        .unwrap_or("")
263        .to_ascii_lowercase();
264
265    let response = match router.match_route(&method, &routing_path) {
266        Some((handler, params, route_pattern)) => {
267            let request = ferro_request
268                .with_params(params)
269                .with_route_pattern(route_pattern.clone());
270
271            // Build middleware chain
272            let mut chain = MiddlewareChain::new();
273
274            // 1. Add global middleware
275            chain.extend(middleware_registry.global_middleware().iter().cloned());
276
277            // 2. Add route-level middleware (already boxed)
278            let route_middleware = router.get_route_middleware(&route_pattern);
279            chain.extend(route_middleware);
280
281            // 3. Execute chain with handler inside request host context
282            let response = crate::http::request_context::REQUEST_HOST
283                .scope(request_host, chain.execute(request, handler))
284                .await;
285
286            // Unwrap the Result - both Ok and Err contain HttpResponse
287            let http_response = response.unwrap_or_else(|e| e);
288            http_response.into_hyper()
289        }
290        None => {
291            // Try static file serving before fallback (only GET/HEAD)
292            if method == hyper::Method::GET || method == hyper::Method::HEAD {
293                if let Some(response) =
294                    crate::static_files::try_serve_static_file(&routing_path).await
295                {
296                    return response;
297                }
298            }
299
300            // Check for fallback handler
301            if let Some((fallback_handler, fallback_middleware)) = router.get_fallback() {
302                let request = ferro_request.with_params(std::collections::HashMap::new());
303
304                // Build middleware chain for fallback
305                let mut chain = MiddlewareChain::new();
306
307                // 1. Add global middleware
308                chain.extend(middleware_registry.global_middleware().iter().cloned());
309
310                // 2. Add fallback-specific middleware
311                chain.extend(fallback_middleware);
312
313                // 3. Execute chain with fallback handler
314                let response = chain.execute(request, fallback_handler).await;
315
316                // Unwrap the Result - both Ok and Err contain HttpResponse
317                let http_response = response.unwrap_or_else(|e| e);
318                http_response.into_hyper()
319            } else {
320                // No fallback defined, return default 404
321                HttpResponse::text("404 Not Found").status(404).into_hyper()
322            }
323        }
324    };
325
326    response
327}
328
329/// Built-in health check endpoint at /_ferro/health
330/// Returns {"status": "ok", "timestamp": "..."} by default
331/// Add ?db=true to also check database connectivity (/_ferro/health?db=true)
332async fn health_response(query: &str) -> hyper::Response<Full<Bytes>> {
333    use chrono::Utc;
334    use serde_json::json;
335
336    let timestamp = Utc::now().to_rfc3339();
337    let check_db = query.contains("db=true");
338
339    let mut response = json!({
340        "status": "ok",
341        "timestamp": timestamp
342    });
343
344    if check_db {
345        // Try to check database connection
346        match check_database_health().await {
347            Ok(_) => {
348                response["database"] = json!("connected");
349            }
350            Err(e) => {
351                response["database"] = json!("error");
352                response["database_error"] = json!(e);
353            }
354        }
355    }
356
357    let body =
358        serde_json::to_string(&response).unwrap_or_else(|_| r#"{"status":"ok"}"#.to_string());
359
360    hyper::Response::builder()
361        .status(200)
362        .header("Content-Type", "application/json")
363        .body(Full::new(Bytes::from(body)))
364        .unwrap()
365}
366
367/// Serve the pre-built ferro-json-ui base CSS.
368///
369/// The bytes are embedded at compile time via ferro_json_ui::FERRO_BASE_CSS.
370/// Response: 200, text/css, 24h cache. No user input reaches this handler —
371/// the match arm is an exact string, and the body is static framework content.
372#[cfg(feature = "json-ui")]
373fn serve_ferro_base_css() -> hyper::Response<Full<Bytes>> {
374    let css = ferro_json_ui::FERRO_BASE_CSS;
375    hyper::Response::builder()
376        .status(200)
377        .header("Content-Type", "text/css; charset=utf-8")
378        .header("Content-Length", css.len().to_string())
379        .header("Cache-Control", "public, max-age=31536000, immutable")
380        .body(Full::new(Bytes::from_static(css.as_bytes())))
381        .unwrap()
382}
383
384/// Check database health by attempting a simple query
385async fn check_database_health() -> Result<(), String> {
386    use crate::database::DB;
387    use sea_orm::ConnectionTrait;
388
389    if !DB::is_connected() {
390        return Err("Database not initialized".to_string());
391    }
392
393    let conn = DB::connection().map_err(|e| e.to_string())?;
394
395    // Execute a simple query to verify connection is alive
396    conn.inner()
397        .execute_unprepared("SELECT 1")
398        .await
399        .map_err(|e| format!("Database query failed: {e}"))?;
400
401    Ok(())
402}
403
404#[cfg(all(test, feature = "json-ui"))]
405mod ferro_base_css_route_tests {
406    use super::*;
407    use http_body_util::BodyExt;
408
409    #[tokio::test]
410    async fn serve_ferro_base_css_returns_200_with_text_css_content_type() {
411        let response = serve_ferro_base_css();
412
413        assert_eq!(response.status(), 200, "expected 200 OK");
414
415        let ct = response
416            .headers()
417            .get("Content-Type")
418            .expect("Content-Type header missing")
419            .to_str()
420            .unwrap();
421        assert_eq!(ct, "text/css; charset=utf-8");
422
423        let cc = response
424            .headers()
425            .get("Cache-Control")
426            .expect("Cache-Control header missing")
427            .to_str()
428            .unwrap();
429        assert_eq!(cc, "public, max-age=31536000, immutable");
430
431        let cl = response
432            .headers()
433            .get("Content-Length")
434            .expect("Content-Length header missing")
435            .to_str()
436            .unwrap()
437            .parse::<usize>()
438            .expect("Content-Length must be an integer");
439        assert_eq!(cl, ferro_json_ui::FERRO_BASE_CSS.len());
440    }
441
442    #[tokio::test]
443    async fn serve_ferro_base_css_body_equals_embedded_constant() {
444        let response = serve_ferro_base_css();
445        let body_bytes = response
446            .into_body()
447            .collect()
448            .await
449            .expect("body collect")
450            .to_bytes();
451        assert_eq!(
452            body_bytes.as_ref(),
453            ferro_json_ui::FERRO_BASE_CSS.as_bytes()
454        );
455        assert!(!body_bytes.is_empty());
456    }
457}