1use std::future::Future;
2use std::net::SocketAddr;
3use std::sync::Arc;
4use std::time::Duration;
5
6use axum::extract::Request;
7use axum::handler::Handler;
8use axum::middleware::Next;
9use axum::response::Response;
10use axum::Router;
11use tokio::net::TcpListener;
12use tower_http::catch_panic::CatchPanicLayer;
13use tower_http::cors::{Any, CorsLayer};
14use tracing::info;
15
16use crate::config::AppConfig;
17use crate::controller::Controller;
18use crate::logging;
19use crate::auth::{AuthConfig, AuthLayer};
20use crate::middleware::{self, InjectStateLayer, RequestTimeoutLayer};
21use crate::rate_limit::RateLimitLayer;
22use crate::router::{Method, OxideRouter};
23use crate::state::{AppState, TypeMap};
24
25type RouterTransform = Box<dyn FnOnce(Router) -> Router>;
26
27pub struct App {
56 config: AppConfig,
57 router: OxideRouter,
58 config_path: Option<String>,
59 type_map: TypeMap,
60 request_logging: bool,
61 rate_limit: Option<(u64, Duration)>,
62 cors: Option<CorsLayer>,
63 request_timeout: Option<Duration>,
64 controller_factories: Vec<Box<dyn FnOnce(AppState) -> OxideRouter>>,
65 user_layers: Vec<RouterTransform>,
69 auth: Option<AuthConfig>,
71}
72
73impl App {
74 pub fn new() -> Self {
78 logging::init();
79
80 Self {
81 config: AppConfig::default(),
82 router: OxideRouter::new(),
83 config_path: None,
84 type_map: TypeMap::default(),
85 request_logging: true,
86 rate_limit: None,
87 cors: None,
88 request_timeout: None,
89 controller_factories: Vec::new(),
90 user_layers: Vec::new(),
91 auth: None,
92 }
93 }
94
95 pub fn config(mut self, path: &str) -> Self {
100 self.config_path = Some(path.to_string());
101 self
102 }
103
104 pub fn state<T: Send + Sync + 'static>(mut self, value: T) -> Self {
108 self.type_map.insert(value);
109 self
110 }
111
112 pub fn route<H, T>(mut self, method: Method, path: &str, handler: H) -> Self
116 where
117 H: Handler<T, ()>,
118 T: 'static,
119 {
120 self.router = self.router.route(method, path, handler);
121 self
122 }
123
124 pub fn get<H, T>(mut self, path: &str, handler: H) -> Self
127 where
128 H: Handler<T, ()>,
129 T: 'static,
130 {
131 self.router = self.router.get(path, handler);
132 self
133 }
134
135 pub fn post<H, T>(mut self, path: &str, handler: H) -> Self
136 where
137 H: Handler<T, ()>,
138 T: 'static,
139 {
140 self.router = self.router.post(path, handler);
141 self
142 }
143
144 pub fn put<H, T>(mut self, path: &str, handler: H) -> Self
145 where
146 H: Handler<T, ()>,
147 T: 'static,
148 {
149 self.router = self.router.put(path, handler);
150 self
151 }
152
153 pub fn delete<H, T>(mut self, path: &str, handler: H) -> Self
154 where
155 H: Handler<T, ()>,
156 T: 'static,
157 {
158 self.router = self.router.delete(path, handler);
159 self
160 }
161
162 pub fn patch<H, T>(mut self, path: &str, handler: H) -> Self
163 where
164 H: Handler<T, ()>,
165 T: 'static,
166 {
167 self.router = self.router.patch(path, handler);
168 self
169 }
170
171 pub fn controller<C: Controller>(mut self) -> Self {
183 self.controller_factories.push(Box::new(|state: AppState| {
184 let instance = Arc::new(C::from_state(&state));
185 let routes = C::register(instance);
186 let inner = C::configure_router(routes.into_inner());
187 OxideRouter::from_router(inner).nest_self(C::PREFIX)
188 }));
189 self
190 }
191
192 pub fn routes(mut self, router: OxideRouter) -> Self {
196 self.router = self.router.merge(router);
197 self
198 }
199
200 pub fn nest(mut self, prefix: &str, router: OxideRouter) -> Self {
202 self.router = self.router.nest(prefix, router);
203 self
204 }
205
206 pub fn rate_limit(mut self, max_requests: u64, window_secs: u64) -> Self {
212 self.rate_limit = Some((max_requests, Duration::from_secs(window_secs)));
213 self
214 }
215
216 pub fn cors_permissive(mut self) -> Self {
218 self.cors = Some(
219 CorsLayer::new()
220 .allow_origin(Any)
221 .allow_methods(Any)
222 .allow_headers(Any),
223 );
224 self
225 }
226
227 pub fn cors_origins<I, S>(mut self, origins: I) -> Self
229 where
230 I: IntoIterator<Item = S>,
231 S: AsRef<str>,
232 {
233 let origins: Vec<_> = origins
234 .into_iter()
235 .filter_map(|o| o.as_ref().parse().ok())
236 .collect();
237
238 self.cors = Some(
239 CorsLayer::new()
240 .allow_origin(origins)
241 .allow_methods(Any)
242 .allow_headers(Any),
243 );
244 self
245 }
246
247 pub fn request_timeout(mut self, secs: u64) -> Self {
249 self.request_timeout = Some(Duration::from_secs(secs));
250 self
251 }
252
253 pub fn disable_request_logging(mut self) -> Self {
255 self.request_logging = false;
256 self
257 }
258
259 pub fn auth(mut self, config: AuthConfig) -> Self {
267 assert!(
268 !config.secret.is_empty(),
269 "AuthConfig.secret must not be empty"
270 );
271 self.auth = Some(config);
272 self
273 }
274
275 pub fn before<F, Fut>(mut self, f: F) -> Self
290 where
291 F: Fn(Request, Next) -> Fut + Clone + Send + Sync + 'static,
292 Fut: Future<Output = Response> + Send + 'static,
293 {
294 self.user_layers.push(Box::new(move |router: Router| {
295 router.layer(axum::middleware::from_fn(f))
296 }));
297 self
298 }
299
300 pub fn scoped_state<F, Fut, T>(mut self, factory: F) -> Self
306 where
307 F: Fn(&axum::http::request::Parts) -> Fut + Send + Sync + 'static,
308 Fut: Future<Output = T> + Send + 'static,
309 T: Clone + Send + Sync + 'static,
310 {
311 let factory = Arc::new(factory);
312 self.user_layers.push(Box::new(move |router: Router| {
313 let f = factory.clone();
314 router.layer(axum::middleware::from_fn(move |req: Request, next: Next| {
315 let f = f.clone();
316 async move {
317 let (mut parts, body) = req.into_parts();
318 let val = f(&parts).await;
319 parts.extensions.insert(val);
320 let req = axum::extract::Request::from_parts(parts, body);
321 next.run(req).await
322 }
323 }))
324 }));
325 self
326 }
327
328 pub fn after<F, Fut>(mut self, f: F) -> Self
337 where
338 F: Fn(Response) -> Fut + Clone + Send + Sync + 'static,
339 Fut: Future<Output = Response> + Send + 'static,
340 {
341 self.user_layers.push(Box::new(move |router: Router| {
342 router.layer(axum::middleware::map_response(f))
343 }));
344 self
345 }
346
347 pub fn layer<L>(mut self, layer: L) -> Self
352 where
353 L: tower::Layer<axum::routing::Route> + Clone + Send + Sync + 'static,
354 L::Service: tower::Service<Request, Response = Response, Error = std::convert::Infallible>
355 + Clone
356 + Send
357 + Sync
358 + 'static,
359 <L::Service as tower::Service<Request>>::Future: Send + 'static,
360 {
361 self.user_layers.push(Box::new(move |router: Router| {
362 router.layer(layer)
363 }));
364 self
365 }
366
367 fn build_router(self, config: AppConfig) -> (Router, AppState) {
370 let app_state = AppState::new(config, self.type_map);
371
372 let mut base = self.router;
373 for factory in self.controller_factories {
374 let ctrl_router = factory(app_state.clone());
375 base = base.merge(ctrl_router);
376 }
377 let mut router = base.into_inner();
378
379 for transform in self.user_layers {
387 router = transform(router);
388 }
389
390 if let Some(auth_cfg) = self.auth {
392 router = router.layer(AuthLayer::new(auth_cfg));
393 }
394
395 router = router.layer(InjectStateLayer::new(app_state.clone()));
397
398 router = router.layer(CatchPanicLayer::custom(middleware::panic_json_response));
400
401 if let Some((max, window)) = self.rate_limit {
403 router = router.layer(RateLimitLayer::new(max, window));
404 }
405
406 if let Some(timeout) = self.request_timeout {
408 router = router.layer(RequestTimeoutLayer::new(timeout));
409 }
410
411 if let Some(cors) = self.cors {
413 router = router.layer(cors);
414 }
415
416 if self.request_logging {
418 router = router.layer(axum::middleware::from_fn(middleware::request_logger));
419 }
420
421 (router, app_state)
422 }
423
424 pub fn run(self) {
428 let rt = tokio::runtime::Runtime::new().expect("failed to create tokio runtime");
429 rt.block_on(self.serve());
430 }
431
432 pub async fn serve(mut self) {
434 self.config = AppConfig::load(self.config_path.as_deref());
435
436 let addr = format!("{}:{}", self.config.host, self.config.port);
437 let app_name = if self.config.app_name.is_empty() {
438 "oxide-app".to_string()
439 } else {
440 self.config.app_name.clone()
441 };
442
443 let config = self.config.clone();
444 let (router, _state) = self.build_router(config);
445
446 let listener = TcpListener::bind(&addr)
447 .await
448 .unwrap_or_else(|e| panic!("failed to bind to {addr}: {e}"));
449
450 info!(
451 name = %app_name,
452 address = %addr,
453 "Oxide server started"
454 );
455
456 axum::serve(
457 listener,
458 router.into_make_service_with_connect_info::<SocketAddr>(),
459 )
460 .with_graceful_shutdown(shutdown_signal())
461 .await
462 .expect("server error");
463
464 info!("Oxide server shut down gracefully");
465 }
466
467 pub async fn into_test_server(self) -> TestServer {
474 let config = self.config.clone();
475 let (router, _state) = self.build_router(config);
476
477 let listener = TcpListener::bind("127.0.0.1:0")
478 .await
479 .expect("failed to bind test server");
480 let addr = listener.local_addr().unwrap();
481
482 let handle = tokio::spawn(async move {
483 axum::serve(
484 listener,
485 router.into_make_service_with_connect_info::<SocketAddr>(),
486 )
487 .await
488 .ok();
489 });
490
491 TestServer { addr, handle }
492 }
493}
494
495pub struct TestServer {
499 addr: SocketAddr,
500 handle: tokio::task::JoinHandle<()>,
501}
502
503impl TestServer {
504 pub fn addr(&self) -> SocketAddr {
505 self.addr
506 }
507
508 pub fn url(&self, path: &str) -> String {
510 format!("http://{}{}", self.addr, path)
511 }
512}
513
514impl Drop for TestServer {
515 fn drop(&mut self) {
516 self.handle.abort();
517 }
518}
519
520async fn shutdown_signal() {
521 let ctrl_c = tokio::signal::ctrl_c();
522
523 #[cfg(unix)]
524 {
525 let mut sigterm =
526 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
527 .expect("failed to install SIGTERM handler");
528 tokio::select! {
529 _ = ctrl_c => info!("received Ctrl+C, shutting down…"),
530 _ = sigterm.recv() => info!("received SIGTERM, shutting down…"),
531 }
532 }
533
534 #[cfg(not(unix))]
535 {
536 ctrl_c.await.expect("failed to listen for Ctrl+C");
537 info!("received Ctrl+C, shutting down…");
538 }
539}
540