1use crate::dep::{AnyArc, DepEnv, DepFactory, DepResolver, TaskContext};
5use crate::error::{Error, Result};
6use crate::extract::{BodyLane, RequestCtx};
7use crate::handler::BoxHandlerFn;
8use crate::middleware::{Middleware, Next};
9use crate::module::{FlatRoute, Module};
10use crate::response::{IntoResponse, Response};
11use crate::router::{Endpoint, MethodRouter, RouteMatch, Trie};
12use crate::serve;
13#[cfg(test)]
14use crate::serve::is_transient_accept_error;
15#[cfg(test)]
16use bytes::Bytes;
17use std::any::TypeId;
18use std::collections::HashMap;
19use std::future::Future;
20use std::pin::Pin;
21use std::sync::Arc;
22
23pub(crate) type BackgroundFactory = Box<
27 dyn FnOnce(
28 TaskContext,
29 tokio::sync::watch::Receiver<bool>,
30 ) -> Pin<Box<dyn Future<Output = ()> + Send>>
31 + Send,
32>;
33
34pub struct App {
37 routes: Vec<(String, MethodRouter)>,
38 mounts: Vec<(String, Module)>,
39 env: DepEnv,
40 middleware: Vec<Arc<dyn Middleware>>,
41 security_headers: bool,
42 cors: Option<std::sync::Arc<crate::cors::CorsConfig>>,
43 handler_timeout: std::time::Duration,
44 body_read_timeout: std::time::Duration,
45 write_stall_timeout: std::time::Duration,
46 background: Vec<(&'static str, BackgroundFactory)>,
49}
50
51impl Default for App {
52 fn default() -> Self {
53 let mut env = DepEnv::default();
58 env.insert_value(crate::clock::Clock::system());
59 Self {
60 routes: Vec::new(),
61 mounts: Vec::new(),
62 env,
63 middleware: Vec::new(),
64 security_headers: true,
65 cors: None,
66 handler_timeout: std::time::Duration::from_secs(30),
67 body_read_timeout: std::time::Duration::from_secs(30),
68 write_stall_timeout: std::time::Duration::from_secs(30),
69 background: Vec::new(),
70 }
71 }
72}
73
74pub trait Extension {
77 fn register(self, app: App) -> App;
78}
79
80impl App {
81 pub fn new() -> Self {
82 Self::default()
83 }
84
85 pub fn extend<E: Extension>(self, extension: E) -> App {
87 extension.register(self)
88 }
89
90 pub fn security_headers(mut self, on: bool) -> Self {
93 self.security_headers = on;
94 self
95 }
96
97 pub fn cors(mut self, config: crate::cors::CorsConfig) -> Self {
102 self.cors = Some(std::sync::Arc::new(config));
103 self
104 }
105
106 pub fn handler_timeout(mut self, budget: std::time::Duration) -> Self {
109 self.handler_timeout = budget;
110 self
111 }
112
113 pub fn body_read_timeout(mut self, budget: std::time::Duration) -> Self {
115 self.body_read_timeout = budget;
116 self
117 }
118
119 pub fn write_stall_timeout(mut self, budget: std::time::Duration) -> Self {
123 self.write_stall_timeout = budget;
124 self
125 }
126
127 pub fn route(mut self, path: &str, methods: MethodRouter) -> Self {
129 self.routes.push((path.to_string(), methods));
130 self
131 }
132
133 pub fn mount(mut self, prefix: &str, module: Module) -> Self {
135 self.mounts.push((prefix.to_string(), module));
136 self
137 }
138
139 pub fn provide<T: Send + Sync + 'static>(mut self, value: T) -> Self {
141 self.env.insert_value(value);
142 self
143 }
144
145 pub fn provide_dep<F, Args, T>(mut self, factory: F) -> Self
147 where
148 F: DepFactory<Args, T>,
149 T: Send + Sync + 'static,
150 {
151 self.env.insert_factory(factory);
152 self
153 }
154
155 pub fn middleware<M: Middleware>(mut self, mw: M) -> Self {
157 self.middleware.push(Arc::new(mw));
158 self
159 }
160
161 pub fn on_serve<F, Fut>(mut self, name: &'static str, f: F) -> App
174 where
175 F: FnOnce(TaskContext, tokio::sync::watch::Receiver<bool>) -> Fut + Send + 'static,
176 Fut: Future<Output = ()> + Send + 'static,
177 {
178 let factory: BackgroundFactory = Box::new(move |ctx, shutdown| Box::pin(f(ctx, shutdown)));
179 self.background.push((name, factory));
180 self
181 }
182
183 pub(crate) fn take_background(&mut self) -> Vec<(&'static str, BackgroundFactory)> {
187 std::mem::take(&mut self.background)
188 }
189
190 pub fn build(self) -> Result<BuiltApp> {
193 if let Some(c) = &self.cors {
194 c.validate()?;
195 }
196 let mut trie = Trie::default();
197 let app_env = Arc::new(self.env.clone());
198 let app_mw: Arc<[Arc<dyn Middleware>]> = Arc::from(self.middleware.clone());
199
200 for (path, methods) in self.routes {
201 let body_limit = methods.body_limit;
202 insert_flat(
203 &mut trie,
204 FlatRoute {
205 path,
206 methods,
207 env: app_env.clone(),
208 middleware: app_mw.clone(),
209 body_limit,
210 },
211 )?;
212 }
213 for (prefix, module) in self.mounts {
214 for flat in module.flatten(&prefix, &self.env, &self.middleware) {
215 insert_flat(&mut trie, flat)?;
216 }
217 }
218 Ok(BuiltApp {
219 trie,
220 app_env,
221 overrides: Arc::new(HashMap::new()),
222 security_headers: self.security_headers,
223 cors: self.cors.clone(),
224 handler_timeout: self.handler_timeout,
225 body_read_timeout: self.body_read_timeout,
226 write_stall_timeout: self.write_stall_timeout,
227 })
228 }
229
230 pub async fn serve(self) -> Result<()> {
234 let addr = std::env::var("JERRYCAN_ADDR").unwrap_or_else(|_| "127.0.0.1:8000".to_string());
235 let listener = tokio::net::TcpListener::bind(&addr)
236 .await
237 .map_err(|e| Error::internal(format!("failed to bind {addr}: {e}")))?;
238 self.serve_with_shutdown(listener, serve::shutdown_signal())
239 .await
240 }
241
242 pub async fn serve_with(self, listener: tokio::net::TcpListener) -> Result<()> {
244 self.serve_with_shutdown(listener, std::future::pending())
245 .await
246 }
247
248 pub async fn serve_with_shutdown(
251 self,
252 listener: tokio::net::TcpListener,
253 shutdown: impl std::future::Future<Output = ()> + Send,
254 ) -> Result<()> {
255 serve::run_with_shutdown(self, listener, shutdown).await
256 }
257}
258
259fn insert_flat(trie: &mut Trie, flat: FlatRoute) -> Result<()> {
260 let stream_body = flat.methods.stream_body;
261 let mut methods = HashMap::new();
262 for (m, h) in flat.methods.handlers {
263 if methods.insert(m.clone(), h).is_some() {
264 return Err(Error::internal(format!(
265 "duplicate method {m} for `{}`",
266 flat.path
267 )));
268 }
269 }
270 trie.insert(
271 &flat.path,
272 Endpoint {
273 methods,
274 env: flat.env,
275 middleware: flat.middleware,
276 body_limit: flat.body_limit,
277 stream_body,
278 },
279 )
280}
281
282pub struct BuiltApp {
284 pub(crate) trie: Trie,
285 pub(crate) app_env: Arc<DepEnv>,
289 pub(crate) overrides: Arc<HashMap<TypeId, AnyArc>>,
290 pub(crate) security_headers: bool,
291 pub(crate) cors: Option<std::sync::Arc<crate::cors::CorsConfig>>,
294 pub(crate) handler_timeout: std::time::Duration,
295 pub(crate) body_read_timeout: std::time::Duration,
296 pub(crate) write_stall_timeout: std::time::Duration,
297}
298
299impl std::fmt::Debug for BuiltApp {
302 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
303 f.debug_struct("BuiltApp").finish_non_exhaustive()
304 }
305}
306
307pub(crate) const BODY_LIMIT: usize = 1024 * 1024; pub(crate) enum Policy {
316 Route { limit: usize, stream: bool },
320 Reject(Response),
323}
324
325pub(crate) fn apply_security_headers(res: &mut Response) {
327 const DEFAULTS: [(&str, &str); 5] = [
328 ("x-content-type-options", "nosniff"),
329 ("x-frame-options", "DENY"),
330 ("referrer-policy", "no-referrer"),
331 ("content-security-policy", "default-src 'none'"),
332 ("cache-control", "no-store"),
333 ];
334 for (name, value) in DEFAULTS {
335 let header_name = http::HeaderName::from_static(name);
336 if !res.headers().contains_key(&header_name) {
337 res.headers_mut()
338 .insert(header_name, http::HeaderValue::from_static(value));
339 }
340 }
341}
342
343impl BuiltApp {
344 pub fn task_context(&self) -> crate::dep::TaskContext {
352 crate::dep::TaskContext::new(DepResolver::new(
353 self.app_env.clone(),
354 self.overrides.clone(),
355 ))
356 }
357
358 pub(crate) fn route_policy(&self, parts: &http::request::Parts) -> Policy {
370 let path = parts.uri.path();
371 if let Some(config) = &self.cors
376 && crate::cors::is_preflight(parts)
377 {
378 let origin = parts
379 .headers
380 .get(http::header::ORIGIN)
381 .and_then(|v| v.to_str().ok())
382 .unwrap_or("");
383 if config.allows_origin(origin)
384 && let Some(methods) = self.trie.methods_for(path)
385 {
386 let acrh = parts
387 .headers
388 .get(http::header::ACCESS_CONTROL_REQUEST_HEADERS)
389 .and_then(|v| v.to_str().ok());
390 return Policy::Reject(crate::cors::preflight_response(
391 config, origin, acrh, &methods,
392 ));
393 }
394 let mut r = http::Response::new(crate::response::JcBody::empty());
397 *r.status_mut() = http::StatusCode::NO_CONTENT;
398 return Policy::Reject(r);
399 }
400 let reject = |response: Response| -> Policy {
401 let mut response = response;
402 if self.security_headers {
403 apply_security_headers(&mut response);
404 }
405 if let Some(config) = &self.cors {
410 crate::cors::apply_cors(
411 &mut response,
412 parts.headers.get(http::header::ORIGIN),
413 config,
414 );
415 }
416 Policy::Reject(response)
417 };
418 match self.trie.find(path, &parts.method) {
419 RouteMatch::Found { endpoint, .. } => Policy::Route {
420 limit: endpoint.body_limit.unwrap_or(BODY_LIMIT),
421 stream: endpoint.stream_body,
422 },
423 RouteMatch::NotFound => reject(Error::not_found().into_response()),
424 RouteMatch::MethodMissing => reject(Error::method_not_allowed().into_response()),
425 RouteMatch::Malformed => {
426 reject(Error::bad_request("malformed percent-encoding in path").into_response())
427 }
428 }
429 }
430
431 pub(crate) async fn dispatch(&self, parts: http::request::Parts, lane: BodyLane) -> Response {
436 let origin = parts.headers.get(http::header::ORIGIN).cloned();
440 let mut response = self.dispatch_inner(parts, lane).await;
441 if self.security_headers {
442 apply_security_headers(&mut response);
443 }
444 if let Some(config) = &self.cors {
445 crate::cors::apply_cors(&mut response, origin.as_ref(), config);
446 }
447 response
448 }
449
450 async fn dispatch_inner(&self, parts: http::request::Parts, lane: BodyLane) -> Response {
451 let method = parts.method.clone();
452 let path = parts.uri.path().to_string();
453 match self.trie.find(&path, &method) {
454 RouteMatch::NotFound => Error::not_found().into_response(),
455 RouteMatch::MethodMissing => Error::method_not_allowed().into_response(),
456 RouteMatch::Malformed => {
457 Error::bad_request("malformed percent-encoding in path").into_response()
458 }
459 RouteMatch::Found { endpoint, params } => {
460 let mut ctx = RequestCtx::with_lane(
461 parts,
462 lane,
463 DepResolver::new(endpoint.env.clone(), self.overrides.clone()),
464 );
465 ctx.params = params;
466 let handler: &BoxHandlerFn = endpoint
467 .methods
468 .get(&method)
469 .expect("find() checked the method");
470 let run = Next {
471 chain: &endpoint.middleware,
472 endpoint: handler,
473 }
474 .run(&mut ctx);
475 match tokio::time::timeout(self.handler_timeout, run).await {
476 Ok(response) => response,
477 Err(_) => Error::handler_timeout().into_response(),
478 }
479 }
480 }
481 }
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487 use crate::response::Json;
488 use crate::router::get;
489 use crate::{Dep, Path};
490 use std::sync::Mutex;
491
492 #[derive(Default)]
493 struct Store {
494 items: Mutex<Vec<String>>,
495 }
496
497 async fn list(store: Dep<Store>) -> Json<Vec<String>> {
498 Json(store.items.lock().unwrap().clone())
499 }
500
501 async fn create(store: Dep<Store>, Json(item): Json<String>) -> crate::Result<Json<usize>> {
502 let mut items = store.items.lock().unwrap();
503 items.push(item);
504 Ok(Json(items.len()))
505 }
506
507 async fn show(store: Dep<Store>, Path(ix): Path<usize>) -> crate::Result<Json<String>> {
508 store
509 .items
510 .lock()
511 .unwrap()
512 .get(ix)
513 .cloned()
514 .map(Json)
515 .ok_or_else(Error::not_found)
516 }
517
518 fn crud_app() -> App {
519 App::new().provide(Store::default()).mount(
520 "/todos",
521 Module::new("todos")
522 .route("/", get(list).post(create))
523 .route("/{ix}", get(show)),
524 )
525 }
526
527 async fn dispatch(built: &BuiltApp, method: http::Method, path: &str, body: &str) -> Response {
528 let req = http::Request::builder()
529 .method(method)
530 .uri(path)
531 .body(())
532 .unwrap();
533 let (parts, ()) = req.into_parts();
534 built
535 .dispatch(parts, BodyLane::Buffered(Bytes::from(body.to_string())))
536 .await
537 }
538
539 #[tokio::test]
540 async fn crud_round_trip_in_process() {
541 let built = crud_app().build().unwrap();
542 let r = dispatch(&built, http::Method::POST, "/todos/", r#""write spike""#).await;
543 assert_eq!(r.status(), http::StatusCode::OK);
544 let r = dispatch(&built, http::Method::GET, "/todos/0", "").await;
545 assert_eq!(r.status(), http::StatusCode::OK);
546 let r = dispatch(&built, http::Method::GET, "/todos/9", "").await;
547 assert_eq!(r.status(), http::StatusCode::NOT_FOUND);
548 let r = dispatch(&built, http::Method::PATCH, "/todos/", "").await;
549 assert_eq!(r.status(), http::StatusCode::METHOD_NOT_ALLOWED);
550 let r = dispatch(&built, http::Method::GET, "/nope", "").await;
551 assert_eq!(r.status(), http::StatusCode::NOT_FOUND);
552 }
553
554 #[test]
555 fn conflicting_routes_fail_at_build_not_at_request_time() {
556 let app = App::new()
557 .route("/x", get(|| async { "a" }))
558 .route("/x", get(|| async { "b" }));
559 let err = app.build().unwrap_err();
560 assert!(err.message().contains("/x"));
561 }
562
563 #[test]
564 fn wildcard_origin_with_credentials_is_a_build_error() {
565 let err = App::new()
566 .cors(
567 crate::cors::CorsConfig::new(crate::cors::CorsOrigins::any())
568 .allow_credentials(true),
569 )
570 .build()
571 .unwrap_err();
572 assert!(
573 err.to_string().to_lowercase().contains("credential"),
574 "{err}"
575 );
576 }
577
578 #[test]
579 fn allowlist_origin_with_credentials_builds() {
580 assert!(
581 App::new()
582 .cors(
583 crate::cors::CorsConfig::new(crate::cors::CorsOrigins::list([
584 "https://app.example"
585 ]))
586 .allow_credentials(true)
587 )
588 .build()
589 .is_ok()
590 );
591 }
592
593 #[tokio::test]
594 async fn extensions_register_through_extend() {
595 struct Greeting(&'static str);
596 struct GreetingExt;
597 impl Extension for GreetingExt {
598 fn register(self, app: App) -> App {
599 app.provide(Greeting("from-extension"))
600 }
601 }
602 async fn read(g: crate::Dep<Greeting>) -> String {
603 (*g).0.to_string()
606 }
607 let t = App::new()
608 .extend(GreetingExt)
609 .route("/", crate::router::get(read))
610 .into_test();
611 assert_eq!(t.get("/").await.text(), "from-extension");
612 }
613
614 #[test]
615 fn accept_error_classification_matches_unix_reality() {
616 use std::io::{Error as IoError, ErrorKind};
617 for transient in [
618 IoError::from(ErrorKind::ConnectionAborted),
619 IoError::from(ErrorKind::ConnectionReset),
620 IoError::from(ErrorKind::Interrupted),
621 IoError::from_raw_os_error(24), IoError::from_raw_os_error(23), ] {
624 assert!(is_transient_accept_error(&transient), "{transient:?}");
625 }
626 assert!(!is_transient_accept_error(&IoError::from(
627 ErrorKind::InvalidInput
628 )));
629 assert!(!is_transient_accept_error(&IoError::from(
630 ErrorKind::PermissionDenied
631 )));
632 }
633}