Skip to main content

oxide_framework_core/
app.rs

1use std::future::Future;
2use std::net::SocketAddr;
3use std::sync::Arc;
4use std::time::Duration;
5
6use axum::extract::Request;
7use axum::handler::Handler;
8use axum::middleware::Next;
9use axum::response::Response;
10use axum::Router;
11use tokio::net::TcpListener;
12use tower_http::catch_panic::CatchPanicLayer;
13use tower_http::cors::{Any, CorsLayer};
14use tracing::info;
15
16use crate::config::AppConfig;
17use crate::controller::Controller;
18use crate::logging;
19use crate::auth::{AuthConfig, AuthLayer};
20use crate::middleware::{self, InjectStateLayer, RequestTimeoutLayer};
21use crate::rate_limit::RateLimitLayer;
22use crate::router::{Method, OxideRouter};
23use crate::state::{AppState, TypeMap};
24
25type RouterTransform = Box<dyn FnOnce(Router) -> Router>;
26
27/// Primary entry point for building an Oxide application.
28///
29/// Uses a builder pattern to configure routes, state, middleware, and then
30/// start the server.
31///
32/// # Example
33///
34/// ```rust,no_run
35/// use oxide_framework_core::{App, ApiResponse, Config};
36/// use serde::Serialize;
37///
38/// #[derive(Serialize)]
39/// struct Msg { text: String }
40///
41/// async fn index(Config(cfg): Config) -> ApiResponse<Msg> {
42///     ApiResponse::ok(Msg { text: format!("Hello from {}!", cfg.app_name) })
43/// }
44///
45/// fn main() {
46///     App::new()
47///         .config("app.yaml")
48///         .rate_limit(100, 60)
49///         .cors_permissive()
50///         .request_timeout(30)
51///         .get("/", index)
52///         .run();
53/// }
54/// ```
55pub struct App {
56    config: AppConfig,
57    router: OxideRouter,
58    config_path: Option<String>,
59    type_map: TypeMap,
60    request_logging: bool,
61    rate_limit: Option<(u64, Duration)>,
62    cors: Option<CorsLayer>,
63    request_timeout: Option<Duration>,
64    controller_factories: Vec<Box<dyn FnOnce(AppState) -> OxideRouter>>,
65    /// User-registered middleware (before/after hooks, custom layers).
66    /// Applied between State injection and CatchPanic (can access state,
67    /// panics are still caught).
68    user_layers: Vec<RouterTransform>,
69    /// Optional JWT / session-cookie auth (runs after user hooks, before state injection).
70    auth: Option<AuthConfig>,
71}
72
73impl App {
74    /// Create a new `App` with default configuration.
75    ///
76    /// Initialises structured logging on first call.
77    pub fn new() -> Self {
78        logging::init();
79
80        Self {
81            config: AppConfig::default(),
82            router: OxideRouter::new(),
83            config_path: None,
84            type_map: TypeMap::default(),
85            request_logging: true,
86            rate_limit: None,
87            cors: None,
88            request_timeout: None,
89            controller_factories: Vec::new(),
90            user_layers: Vec::new(),
91            auth: None,
92        }
93    }
94
95    // -- Configuration --------------------------------------------------------
96
97    /// Point the application at a YAML config file.
98    /// Config is loaded (and merged with env vars) when `.run()` is called.
99    pub fn config(mut self, path: &str) -> Self {
100        self.config_path = Some(path.to_string());
101        self
102    }
103
104    // -- State ----------------------------------------------------------------
105
106    /// Register a shared value accessible in handlers via the [`Data<T>`](crate::Data) extractor.
107    pub fn state<T: Send + Sync + 'static>(mut self, value: T) -> Self {
108        self.type_map.insert(value);
109        self
110    }
111
112    // -- Generic route registration -------------------------------------------
113
114    /// Register a route for the given method and path.
115    pub fn route<H, T>(mut self, method: Method, path: &str, handler: H) -> Self
116    where
117        H: Handler<T, ()>,
118        T: 'static,
119    {
120        self.router = self.router.route(method, path, handler);
121        self
122    }
123
124    // -- Convenience methods --------------------------------------------------
125
126    pub fn get<H, T>(mut self, path: &str, handler: H) -> Self
127    where
128        H: Handler<T, ()>,
129        T: 'static,
130    {
131        self.router = self.router.get(path, handler);
132        self
133    }
134
135    pub fn post<H, T>(mut self, path: &str, handler: H) -> Self
136    where
137        H: Handler<T, ()>,
138        T: 'static,
139    {
140        self.router = self.router.post(path, handler);
141        self
142    }
143
144    pub fn put<H, T>(mut self, path: &str, handler: H) -> Self
145    where
146        H: Handler<T, ()>,
147        T: 'static,
148    {
149        self.router = self.router.put(path, handler);
150        self
151    }
152
153    pub fn delete<H, T>(mut self, path: &str, handler: H) -> Self
154    where
155        H: Handler<T, ()>,
156        T: 'static,
157    {
158        self.router = self.router.delete(path, handler);
159        self
160    }
161
162    pub fn patch<H, T>(mut self, path: &str, handler: H) -> Self
163    where
164        H: Handler<T, ()>,
165        T: 'static,
166    {
167        self.router = self.router.patch(path, handler);
168        self
169    }
170
171    // -- Controller registration ----------------------------------------------
172
173    /// Register a `#[controller]`-annotated struct.
174    ///
175    /// At startup the framework will:
176    /// 1. Construct the controller via `C::from_state(&app_state)`.
177    /// 2. Call `C::register(Arc::new(instance))` to build its routes.
178    /// 3. Nest those routes under `C::PREFIX`.
179    ///
180    /// Dependencies are resolved eagerly — a missing `Data<T>` will panic at
181    /// startup, not at request time (fail-fast).
182    pub fn controller<C: Controller>(mut self) -> Self {
183        self.controller_factories.push(Box::new(|state: AppState| {
184            let instance = Arc::new(C::from_state(&state));
185            let routes = C::register(instance);
186            let inner = C::configure_router(routes.into_inner());
187            OxideRouter::from_router(inner).nest_self(C::PREFIX)
188        }));
189        self
190    }
191
192    // -- Router composition ---------------------------------------------------
193
194    /// Merge a pre-built `OxideRouter` into the application (flat, no prefix).
195    pub fn routes(mut self, router: OxideRouter) -> Self {
196        self.router = self.router.merge(router);
197        self
198    }
199
200    /// Nest a pre-built `OxideRouter` under the given path prefix.
201    pub fn nest(mut self, prefix: &str, router: OxideRouter) -> Self {
202        self.router = self.router.nest(prefix, router);
203        self
204    }
205
206    // -- Scalability & production middleware -----------------------------------
207
208    /// Enable per-IP rate limiting.
209    ///
210    /// Returns HTTP 429 with `Retry-After` header when the limit is exceeded.
211    pub fn rate_limit(mut self, max_requests: u64, window_secs: u64) -> Self {
212        self.rate_limit = Some((max_requests, Duration::from_secs(window_secs)));
213        self
214    }
215
216    /// Enable permissive CORS (allow any origin, method, and header).
217    pub fn cors_permissive(mut self) -> Self {
218        self.cors = Some(
219            CorsLayer::new()
220                .allow_origin(Any)
221                .allow_methods(Any)
222                .allow_headers(Any),
223        );
224        self
225    }
226
227    /// Enable CORS with a specific set of allowed origins.
228    pub fn cors_origins<I, S>(mut self, origins: I) -> Self
229    where
230        I: IntoIterator<Item = S>,
231        S: AsRef<str>,
232    {
233        let origins: Vec<_> = origins
234            .into_iter()
235            .filter_map(|o| o.as_ref().parse().ok())
236            .collect();
237
238        self.cors = Some(
239            CorsLayer::new()
240                .allow_origin(origins)
241                .allow_methods(Any)
242                .allow_headers(Any),
243        );
244        self
245    }
246
247    /// Set a maximum duration for request processing.
248    pub fn request_timeout(mut self, secs: u64) -> Self {
249        self.request_timeout = Some(Duration::from_secs(secs));
250        self
251    }
252
253    /// Disable the built-in per-request logging middleware.
254    pub fn disable_request_logging(mut self) -> Self {
255        self.request_logging = false;
256        self
257    }
258
259    /// Enable JWT authentication from `Authorization: Bearer` and/or a session cookie.
260    ///
261    /// Inserts [`crate::auth::AuthClaims`] into request extensions when the token is valid.
262    /// Invalid or expired tokens return **401** before your handler runs.
263    ///
264    /// Relative to state and hooks: application state is injected first, then JWT is validated, then
265    /// [`App::before`](Self::before) / [`App::layer`](Self::layer) hooks, then the route handler.
266    pub fn auth(mut self, config: AuthConfig) -> Self {
267        assert!(
268            !config.secret.is_empty(),
269            "AuthConfig.secret must not be empty"
270        );
271        self.auth = Some(config);
272        self
273    }
274
275    // -- Lifecycle hooks & custom middleware -----------------------------------
276
277    /// Register a "before" hook that runs on every request.
278    ///
279    /// The hook receives the request and a [`Next`] handle, and must produce a
280    /// response. Use it for auth checks, request mutation, short-circuit
281    /// responses, etc.
282    ///
283    /// ```rust,ignore
284    /// app.before(|req: Request, next: Next| async move {
285    ///     println!("incoming: {} {}", req.method(), req.uri());
286    ///     next.run(req).await
287    /// })
288    /// ```
289    pub fn before<F, Fut>(mut self, f: F) -> Self
290    where
291        F: Fn(Request, Next) -> Fut + Clone + Send + Sync + 'static,
292        Fut: Future<Output = Response> + Send + 'static,
293    {
294        self.user_layers.push(Box::new(move |router: Router| {
295            router.layer(axum::middleware::from_fn(f))
296        }));
297        self
298    }
299
300    /// Register a request-scoped dependency factory.
301    /// 
302    /// The factory closure is called on *every* incoming request, and its output
303    /// is automatically injected into the request extensions, making it available
304    /// to handlers via the `Scoped<T>` extractor.
305    pub fn scoped_state<F, Fut, T>(mut self, factory: F) -> Self
306    where
307        F: Fn(&axum::http::request::Parts) -> Fut + Send + Sync + 'static,
308        Fut: Future<Output = T> + Send + 'static,
309        T: Clone + Send + Sync + 'static,
310    {
311        let factory = Arc::new(factory);
312        self.user_layers.push(Box::new(move |router: Router| {
313            let f = factory.clone();
314            router.layer(axum::middleware::from_fn(move |req: Request, next: Next| {
315                let f = f.clone();
316                async move {
317                    let (mut parts, body) = req.into_parts();
318                    let val = f(&parts).await;
319                    parts.extensions.insert(val);
320                    let req = axum::extract::Request::from_parts(parts, body);
321                    next.run(req).await
322                }
323            }))
324        }));
325        self
326    }
327
328    /// Register an "after" hook that can transform every outgoing response.
329    ///
330    /// ```rust,ignore
331    /// app.after(|mut res: Response| async move {
332    ///     res.headers_mut().insert("X-Powered-By", "Oxide".parse().unwrap());
333    ///     res
334    /// })
335    /// ```
336    pub fn after<F, Fut>(mut self, f: F) -> Self
337    where
338        F: Fn(Response) -> Fut + Clone + Send + Sync + 'static,
339        Fut: Future<Output = Response> + Send + 'static,
340    {
341        self.user_layers.push(Box::new(move |router: Router| {
342            router.layer(axum::middleware::map_response(f))
343        }));
344        self
345    }
346
347    /// Add an arbitrary Tower `Layer` to the middleware stack.
348    ///
349    /// The layer is positioned between state injection and the panic catcher,
350    /// so it has access to `AppState` and any panics it causes are caught.
351    pub fn layer<L>(mut self, layer: L) -> Self
352    where
353        L: tower::Layer<axum::routing::Route> + Clone + Send + Sync + 'static,
354        L::Service: tower::Service<Request, Response = Response, Error = std::convert::Infallible>
355            + Clone
356            + Send
357            + Sync
358            + 'static,
359        <L::Service as tower::Service<Request>>::Future: Send + 'static,
360    {
361        self.user_layers.push(Box::new(move |router: Router| {
362            router.layer(layer)
363        }));
364        self
365    }
366
367    // -- Internal: build the layered router -----------------------------------
368
369    fn build_router(self, config: AppConfig) -> (Router, AppState) {
370        let app_state = AppState::new(config, self.type_map);
371
372        let mut base = self.router;
373        for factory in self.controller_factories {
374            let ctrl_router = factory(app_state.clone());
375            base = base.merge(ctrl_router);
376        }
377        let mut router = base.into_inner();
378
379        // Layer application order: first applied = innermost = closest to the route handler.
380        // Request flow (outer → inner): Logger → CORS → Timeout → RateLimit → CatchPanic →
381        // InjectState → JwtAuth → UserHooks → Route handler
382        //
383        // User hooks run after JWT validation so `OptionalAuth` / [`AuthClaims`] are visible in `before` / custom layers.
384
385        // 1. User-registered hooks / layers (innermost)
386        for transform in self.user_layers {
387            router = transform(router);
388        }
389
390        // 2. JWT / session cookie auth
391        if let Some(auth_cfg) = self.auth {
392            router = router.layer(AuthLayer::new(auth_cfg));
393        }
394
395        // 3. State injection
396        router = router.layer(InjectStateLayer::new(app_state.clone()));
397
398        // 4. Panic recovery — catches panics in hooks AND handlers
399        router = router.layer(CatchPanicLayer::custom(middleware::panic_json_response));
400
401        // 5. Rate limiting
402        if let Some((max, window)) = self.rate_limit {
403            router = router.layer(RateLimitLayer::new(max, window));
404        }
405
406        // 6. Request timeout
407        if let Some(timeout) = self.request_timeout {
408            router = router.layer(RequestTimeoutLayer::new(timeout));
409        }
410
411        // 7. CORS (wraps everything — headers on ALL responses including 429/408/500)
412        if let Some(cors) = self.cors {
413            router = router.layer(cors);
414        }
415
416        // 8. Request logging (outermost)
417        if self.request_logging {
418            router = router.layer(axum::middleware::from_fn(middleware::request_logger));
419        }
420
421        (router, app_state)
422    }
423
424    // -- Server lifecycle -----------------------------------------------------
425
426    /// Build and start the HTTP server. Blocks the current thread, creating a new Tokio runtime.
427    pub fn run(self) {
428        let rt = tokio::runtime::Runtime::new().expect("failed to create tokio runtime");
429        rt.block_on(self.serve());
430    }
431
432    /// Build and start the HTTP server using the current Tokio runtime.
433    pub async fn serve(mut self) {
434        self.config = AppConfig::load(self.config_path.as_deref());
435
436        let addr = format!("{}:{}", self.config.host, self.config.port);
437        let app_name = if self.config.app_name.is_empty() {
438            "oxide-app".to_string()
439        } else {
440            self.config.app_name.clone()
441        };
442
443        let config = self.config.clone();
444        let (router, _state) = self.build_router(config);
445
446        let listener = TcpListener::bind(&addr)
447            .await
448            .unwrap_or_else(|e| panic!("failed to bind to {addr}: {e}"));
449
450        info!(
451            name = %app_name,
452            address = %addr,
453            "Oxide server started"
454        );
455
456        axum::serve(
457            listener,
458            router.into_make_service_with_connect_info::<SocketAddr>(),
459        )
460        .with_graceful_shutdown(shutdown_signal())
461        .await
462        .expect("server error");
463
464        info!("Oxide server shut down gracefully");
465    }
466
467    // -- Testing --------------------------------------------------------------
468
469    /// Start the server on a random port for integration testing.
470    ///
471    /// Returns a [`TestServer`] with the bound address. The server runs in a
472    /// background tokio task and is stopped when the `TestServer` is dropped.
473    pub async fn into_test_server(self) -> TestServer {
474        let config = self.config.clone();
475        let (router, _state) = self.build_router(config);
476
477        let listener = TcpListener::bind("127.0.0.1:0")
478            .await
479            .expect("failed to bind test server");
480        let addr = listener.local_addr().unwrap();
481
482        let handle = tokio::spawn(async move {
483            axum::serve(
484                listener,
485                router.into_make_service_with_connect_info::<SocketAddr>(),
486            )
487            .await
488            .ok();
489        });
490
491        TestServer { addr, handle }
492    }
493}
494
495/// A running test server bound to a random port.
496///
497/// The server is automatically stopped when this value is dropped.
498pub struct TestServer {
499    addr: SocketAddr,
500    handle: tokio::task::JoinHandle<()>,
501}
502
503impl TestServer {
504    pub fn addr(&self) -> SocketAddr {
505        self.addr
506    }
507
508    /// Build a full URL for the given path.
509    pub fn url(&self, path: &str) -> String {
510        format!("http://{}{}", self.addr, path)
511    }
512}
513
514impl Drop for TestServer {
515    fn drop(&mut self) {
516        self.handle.abort();
517    }
518}
519
520async fn shutdown_signal() {
521    let ctrl_c = tokio::signal::ctrl_c();
522
523    #[cfg(unix)]
524    {
525        let mut sigterm =
526            tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
527                .expect("failed to install SIGTERM handler");
528        tokio::select! {
529            _ = ctrl_c => info!("received Ctrl+C, shutting down…"),
530            _ = sigterm.recv() => info!("received SIGTERM, shutting down…"),
531        }
532    }
533
534    #[cfg(not(unix))]
535    {
536        ctrl_c.await.expect("failed to listen for Ctrl+C");
537        info!("received Ctrl+C, shutting down…");
538    }
539}
540