use crate::dep::{AnyArc, DepEnv, DepFactory, DepResolver, TaskContext};
use crate::error::{Error, Result};
use crate::extract::{BodyLane, RequestCtx};
use crate::handler::BoxHandlerFn;
use crate::middleware::{Middleware, Next};
use crate::module::{FlatRoute, Module};
use crate::response::{IntoResponse, Response};
use crate::router::{Endpoint, MethodRouter, RouteMatch, Trie};
use crate::serve;
#[cfg(test)]
use crate::serve::is_transient_accept_error;
#[cfg(test)]
use bytes::Bytes;
use std::any::TypeId;
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
pub(crate) type BackgroundFactory = Box<
dyn FnOnce(
TaskContext,
tokio::sync::watch::Receiver<bool>,
) -> Pin<Box<dyn Future<Output = ()> + Send>>
+ Send,
>;
pub struct App {
routes: Vec<(String, MethodRouter)>,
mounts: Vec<(String, Module)>,
env: DepEnv,
middleware: Vec<Arc<dyn Middleware>>,
security_headers: bool,
cors: Option<std::sync::Arc<crate::cors::CorsConfig>>,
handler_timeout: std::time::Duration,
body_read_timeout: std::time::Duration,
write_stall_timeout: std::time::Duration,
background: Vec<(&'static str, BackgroundFactory)>,
}
impl Default for App {
fn default() -> Self {
let mut env = DepEnv::default();
env.insert_value(crate::clock::Clock::system());
Self {
routes: Vec::new(),
mounts: Vec::new(),
env,
middleware: Vec::new(),
security_headers: true,
cors: None,
handler_timeout: std::time::Duration::from_secs(30),
body_read_timeout: std::time::Duration::from_secs(30),
write_stall_timeout: std::time::Duration::from_secs(30),
background: Vec::new(),
}
}
}
pub trait Extension {
fn register(self, app: App) -> App;
}
impl App {
pub fn new() -> Self {
Self::default()
}
pub fn extend<E: Extension>(self, extension: E) -> App {
extension.register(self)
}
pub fn security_headers(mut self, on: bool) -> Self {
self.security_headers = on;
self
}
pub fn cors(mut self, config: crate::cors::CorsConfig) -> Self {
self.cors = Some(std::sync::Arc::new(config));
self
}
pub fn handler_timeout(mut self, budget: std::time::Duration) -> Self {
self.handler_timeout = budget;
self
}
pub fn body_read_timeout(mut self, budget: std::time::Duration) -> Self {
self.body_read_timeout = budget;
self
}
pub fn write_stall_timeout(mut self, budget: std::time::Duration) -> Self {
self.write_stall_timeout = budget;
self
}
pub fn route(mut self, path: &str, methods: MethodRouter) -> Self {
self.routes.push((path.to_string(), methods));
self
}
pub fn mount(mut self, prefix: &str, module: Module) -> Self {
self.mounts.push((prefix.to_string(), module));
self
}
pub fn provide<T: Send + Sync + 'static>(mut self, value: T) -> Self {
self.env.insert_value(value);
self
}
pub fn provide_dep<F, Args, T>(mut self, factory: F) -> Self
where
F: DepFactory<Args, T>,
T: Send + Sync + 'static,
{
self.env.insert_factory(factory);
self
}
pub fn middleware<M: Middleware>(mut self, mw: M) -> Self {
self.middleware.push(Arc::new(mw));
self
}
pub fn on_serve<F, Fut>(mut self, name: &'static str, f: F) -> App
where
F: FnOnce(TaskContext, tokio::sync::watch::Receiver<bool>) -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let factory: BackgroundFactory = Box::new(move |ctx, shutdown| Box::pin(f(ctx, shutdown)));
self.background.push((name, factory));
self
}
pub(crate) fn take_background(&mut self) -> Vec<(&'static str, BackgroundFactory)> {
std::mem::take(&mut self.background)
}
pub fn build(self) -> Result<BuiltApp> {
if let Some(c) = &self.cors {
c.validate()?;
}
let mut trie = Trie::default();
let app_env = Arc::new(self.env.clone());
let app_mw: Arc<[Arc<dyn Middleware>]> = Arc::from(self.middleware.clone());
for (path, methods) in self.routes {
let body_limit = methods.body_limit;
insert_flat(
&mut trie,
FlatRoute {
path,
methods,
env: app_env.clone(),
middleware: app_mw.clone(),
body_limit,
},
)?;
}
for (prefix, module) in self.mounts {
for flat in module.flatten(&prefix, &self.env, &self.middleware) {
insert_flat(&mut trie, flat)?;
}
}
Ok(BuiltApp {
trie,
app_env,
overrides: Arc::new(HashMap::new()),
security_headers: self.security_headers,
cors: self.cors.clone(),
handler_timeout: self.handler_timeout,
body_read_timeout: self.body_read_timeout,
write_stall_timeout: self.write_stall_timeout,
})
}
pub async fn serve(self) -> Result<()> {
let addr = std::env::var("JERRYCAN_ADDR").unwrap_or_else(|_| "127.0.0.1:8000".to_string());
let listener = tokio::net::TcpListener::bind(&addr)
.await
.map_err(|e| Error::internal(format!("failed to bind {addr}: {e}")))?;
self.serve_with_shutdown(listener, serve::shutdown_signal())
.await
}
pub async fn serve_with(self, listener: tokio::net::TcpListener) -> Result<()> {
self.serve_with_shutdown(listener, std::future::pending())
.await
}
pub async fn serve_with_shutdown(
self,
listener: tokio::net::TcpListener,
shutdown: impl std::future::Future<Output = ()> + Send,
) -> Result<()> {
serve::run_with_shutdown(self, listener, shutdown).await
}
}
fn insert_flat(trie: &mut Trie, flat: FlatRoute) -> Result<()> {
let stream_body = flat.methods.stream_body;
let mut methods = HashMap::new();
for (m, h) in flat.methods.handlers {
if methods.insert(m.clone(), h).is_some() {
return Err(Error::internal(format!(
"duplicate method {m} for `{}`",
flat.path
)));
}
}
trie.insert(
&flat.path,
Endpoint {
methods,
env: flat.env,
middleware: flat.middleware,
body_limit: flat.body_limit,
stream_body,
},
)
}
pub struct BuiltApp {
pub(crate) trie: Trie,
pub(crate) app_env: Arc<DepEnv>,
pub(crate) overrides: Arc<HashMap<TypeId, AnyArc>>,
pub(crate) security_headers: bool,
pub(crate) cors: Option<std::sync::Arc<crate::cors::CorsConfig>>,
pub(crate) handler_timeout: std::time::Duration,
pub(crate) body_read_timeout: std::time::Duration,
pub(crate) write_stall_timeout: std::time::Duration,
}
impl std::fmt::Debug for BuiltApp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BuiltApp").finish_non_exhaustive()
}
}
pub(crate) const BODY_LIMIT: usize = 1024 * 1024;
pub(crate) enum Policy {
Route { limit: usize, stream: bool },
Reject(Response),
}
pub(crate) fn apply_security_headers(res: &mut Response) {
const DEFAULTS: [(&str, &str); 5] = [
("x-content-type-options", "nosniff"),
("x-frame-options", "DENY"),
("referrer-policy", "no-referrer"),
("content-security-policy", "default-src 'none'"),
("cache-control", "no-store"),
];
for (name, value) in DEFAULTS {
let header_name = http::HeaderName::from_static(name);
if !res.headers().contains_key(&header_name) {
res.headers_mut()
.insert(header_name, http::HeaderValue::from_static(value));
}
}
}
impl BuiltApp {
pub fn task_context(&self) -> crate::dep::TaskContext {
crate::dep::TaskContext::new(DepResolver::new(
self.app_env.clone(),
self.overrides.clone(),
))
}
pub(crate) fn route_policy(&self, parts: &http::request::Parts) -> Policy {
let path = parts.uri.path();
if let Some(config) = &self.cors
&& crate::cors::is_preflight(parts)
{
let origin = parts
.headers
.get(http::header::ORIGIN)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if config.allows_origin(origin)
&& let Some(methods) = self.trie.methods_for(path)
{
let acrh = parts
.headers
.get(http::header::ACCESS_CONTROL_REQUEST_HEADERS)
.and_then(|v| v.to_str().ok());
return Policy::Reject(crate::cors::preflight_response(
config, origin, acrh, &methods,
));
}
let mut r = http::Response::new(crate::response::JcBody::empty());
*r.status_mut() = http::StatusCode::NO_CONTENT;
return Policy::Reject(r);
}
let reject = |response: Response| -> Policy {
let mut response = response;
if self.security_headers {
apply_security_headers(&mut response);
}
if let Some(config) = &self.cors {
crate::cors::apply_cors(
&mut response,
parts.headers.get(http::header::ORIGIN),
config,
);
}
Policy::Reject(response)
};
match self.trie.find(path, &parts.method) {
RouteMatch::Found { endpoint, .. } => Policy::Route {
limit: endpoint.body_limit.unwrap_or(BODY_LIMIT),
stream: endpoint.stream_body,
},
RouteMatch::NotFound => reject(Error::not_found().into_response()),
RouteMatch::MethodMissing => reject(Error::method_not_allowed().into_response()),
RouteMatch::Malformed => {
reject(Error::bad_request("malformed percent-encoding in path").into_response())
}
}
}
pub(crate) async fn dispatch(&self, parts: http::request::Parts, lane: BodyLane) -> Response {
let origin = parts.headers.get(http::header::ORIGIN).cloned();
let mut response = self.dispatch_inner(parts, lane).await;
if self.security_headers {
apply_security_headers(&mut response);
}
if let Some(config) = &self.cors {
crate::cors::apply_cors(&mut response, origin.as_ref(), config);
}
response
}
async fn dispatch_inner(&self, parts: http::request::Parts, lane: BodyLane) -> Response {
let method = parts.method.clone();
let path = parts.uri.path().to_string();
match self.trie.find(&path, &method) {
RouteMatch::NotFound => Error::not_found().into_response(),
RouteMatch::MethodMissing => Error::method_not_allowed().into_response(),
RouteMatch::Malformed => {
Error::bad_request("malformed percent-encoding in path").into_response()
}
RouteMatch::Found { endpoint, params } => {
let mut ctx = RequestCtx::with_lane(
parts,
lane,
DepResolver::new(endpoint.env.clone(), self.overrides.clone()),
);
ctx.params = params;
let handler: &BoxHandlerFn = endpoint
.methods
.get(&method)
.expect("find() checked the method");
let run = Next {
chain: &endpoint.middleware,
endpoint: handler,
}
.run(&mut ctx);
match tokio::time::timeout(self.handler_timeout, run).await {
Ok(response) => response,
Err(_) => Error::handler_timeout().into_response(),
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::response::Json;
use crate::router::get;
use crate::{Dep, Path};
use std::sync::Mutex;
#[derive(Default)]
struct Store {
items: Mutex<Vec<String>>,
}
async fn list(store: Dep<Store>) -> Json<Vec<String>> {
Json(store.items.lock().unwrap().clone())
}
async fn create(store: Dep<Store>, Json(item): Json<String>) -> crate::Result<Json<usize>> {
let mut items = store.items.lock().unwrap();
items.push(item);
Ok(Json(items.len()))
}
async fn show(store: Dep<Store>, Path(ix): Path<usize>) -> crate::Result<Json<String>> {
store
.items
.lock()
.unwrap()
.get(ix)
.cloned()
.map(Json)
.ok_or_else(Error::not_found)
}
fn crud_app() -> App {
App::new().provide(Store::default()).mount(
"/todos",
Module::new("todos")
.route("/", get(list).post(create))
.route("/{ix}", get(show)),
)
}
async fn dispatch(built: &BuiltApp, method: http::Method, path: &str, body: &str) -> Response {
let req = http::Request::builder()
.method(method)
.uri(path)
.body(())
.unwrap();
let (parts, ()) = req.into_parts();
built
.dispatch(parts, BodyLane::Buffered(Bytes::from(body.to_string())))
.await
}
#[tokio::test]
async fn crud_round_trip_in_process() {
let built = crud_app().build().unwrap();
let r = dispatch(&built, http::Method::POST, "/todos/", r#""write spike""#).await;
assert_eq!(r.status(), http::StatusCode::OK);
let r = dispatch(&built, http::Method::GET, "/todos/0", "").await;
assert_eq!(r.status(), http::StatusCode::OK);
let r = dispatch(&built, http::Method::GET, "/todos/9", "").await;
assert_eq!(r.status(), http::StatusCode::NOT_FOUND);
let r = dispatch(&built, http::Method::PATCH, "/todos/", "").await;
assert_eq!(r.status(), http::StatusCode::METHOD_NOT_ALLOWED);
let r = dispatch(&built, http::Method::GET, "/nope", "").await;
assert_eq!(r.status(), http::StatusCode::NOT_FOUND);
}
#[test]
fn conflicting_routes_fail_at_build_not_at_request_time() {
let app = App::new()
.route("/x", get(|| async { "a" }))
.route("/x", get(|| async { "b" }));
let err = app.build().unwrap_err();
assert!(err.message().contains("/x"));
}
#[test]
fn wildcard_origin_with_credentials_is_a_build_error() {
let err = App::new()
.cors(
crate::cors::CorsConfig::new(crate::cors::CorsOrigins::any())
.allow_credentials(true),
)
.build()
.unwrap_err();
assert!(
err.to_string().to_lowercase().contains("credential"),
"{err}"
);
}
#[test]
fn allowlist_origin_with_credentials_builds() {
assert!(
App::new()
.cors(
crate::cors::CorsConfig::new(crate::cors::CorsOrigins::list([
"https://app.example"
]))
.allow_credentials(true)
)
.build()
.is_ok()
);
}
#[tokio::test]
async fn extensions_register_through_extend() {
struct Greeting(&'static str);
struct GreetingExt;
impl Extension for GreetingExt {
fn register(self, app: App) -> App {
app.provide(Greeting("from-extension"))
}
}
async fn read(g: crate::Dep<Greeting>) -> String {
(*g).0.to_string()
}
let t = App::new()
.extend(GreetingExt)
.route("/", crate::router::get(read))
.into_test();
assert_eq!(t.get("/").await.text(), "from-extension");
}
#[test]
fn accept_error_classification_matches_unix_reality() {
use std::io::{Error as IoError, ErrorKind};
for transient in [
IoError::from(ErrorKind::ConnectionAborted),
IoError::from(ErrorKind::ConnectionReset),
IoError::from(ErrorKind::Interrupted),
IoError::from_raw_os_error(24), IoError::from_raw_os_error(23), ] {
assert!(is_transient_accept_error(&transient), "{transient:?}");
}
assert!(!is_transient_accept_error(&IoError::from(
ErrorKind::InvalidInput
)));
assert!(!is_transient_accept_error(&IoError::from(
ErrorKind::PermissionDenied
)));
}
}