use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use crate::context::RequestContext;
use crate::middleware::{BoxFuture, Handler, Middleware, MiddlewareStack};
use crate::request::{Method, Request};
use crate::response::{Response, StatusCode};
use crate::shutdown::ShutdownController;
use fastapi_router::{Route, RouteLookup, Router};
pub enum StartupHook {
Sync(Box<dyn FnOnce() -> Result<(), StartupHookError> + Send>),
AsyncFactory(
Box<
dyn FnOnce() -> Pin<Box<dyn Future<Output = Result<(), StartupHookError>> + Send>>
+ Send,
>,
),
}
impl StartupHook {
pub fn sync<F>(f: F) -> Self
where
F: FnOnce() -> Result<(), StartupHookError> + Send + 'static,
{
Self::Sync(Box::new(f))
}
pub fn async_fn<F, Fut>(f: F) -> Self
where
F: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = Result<(), StartupHookError>> + Send + 'static,
{
Self::AsyncFactory(Box::new(move || Box::pin(f())))
}
pub fn run(
self,
) -> Result<
Option<Pin<Box<dyn Future<Output = Result<(), StartupHookError>> + Send>>>,
StartupHookError,
> {
match self {
Self::Sync(f) => f().map(|()| None),
Self::AsyncFactory(f) => Ok(Some(f())),
}
}
}
#[derive(Debug)]
pub struct StartupHookError {
pub hook_name: Option<String>,
pub message: String,
pub abort: bool,
}
impl StartupHookError {
pub fn new(message: impl Into<String>) -> Self {
Self {
hook_name: None,
message: message.into(),
abort: true,
}
}
#[must_use]
pub fn with_hook_name(mut self, name: impl Into<String>) -> Self {
self.hook_name = Some(name.into());
self
}
#[must_use]
pub fn with_abort(mut self, abort: bool) -> Self {
self.abort = abort;
self
}
pub fn non_fatal(message: impl Into<String>) -> Self {
Self {
hook_name: None,
message: message.into(),
abort: false,
}
}
}
impl std::fmt::Display for StartupHookError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(name) = &self.hook_name {
write!(f, "Startup hook '{}' failed: {}", name, self.message)
} else {
write!(f, "Startup hook failed: {}", self.message)
}
}
}
impl std::error::Error for StartupHookError {}
#[derive(Debug)]
pub enum StartupOutcome {
Success,
PartialSuccess {
warnings: usize,
},
Aborted(StartupHookError),
}
impl StartupOutcome {
#[must_use]
pub fn can_proceed(&self) -> bool {
!matches!(self, Self::Aborted(_))
}
pub fn into_error(self) -> Option<StartupHookError> {
match self {
Self::Aborted(e) => Some(e),
_ => None,
}
}
}
pub type BoxHandler = Box<
dyn Fn(
&RequestContext,
&mut Request,
) -> std::pin::Pin<Box<dyn Future<Output = Response> + Send>>
+ Send
+ Sync,
>;
pub type BoxWebSocketHandler = Box<
dyn Fn(
&RequestContext,
&mut Request,
crate::websocket::WebSocket,
) -> std::pin::Pin<
Box<dyn Future<Output = Result<(), crate::websocket::WebSocketError>> + Send>,
> + Send
+ Sync,
>;
#[derive(Clone)]
pub struct RouteEntry {
pub method: Method,
pub path: String,
meta: Option<fastapi_router::Route>,
handler: Arc<BoxHandler>,
}
impl RouteEntry {
pub fn new<H, Fut>(method: Method, path: impl Into<String>, handler: H) -> Self
where
H: Fn(&RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Response> + Send + 'static,
{
let handler: BoxHandler = Box::new(move |ctx, req| {
let fut = handler(ctx, req);
Box::pin(fut)
});
Self {
method,
path: path.into(),
meta: None,
handler: Arc::new(handler),
}
}
pub fn from_route<H, Fut>(route: fastapi_router::Route, handler: H) -> Self
where
H: Fn(&RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Response> + Send + 'static,
{
let method = route.method;
let path = route.path.clone();
let mut entry = Self::new(method, path, handler);
entry.meta = Some(route);
entry
}
pub fn route_meta(&self) -> Option<&fastapi_router::Route> {
self.meta.as_ref()
}
pub async fn call(&self, ctx: &RequestContext, req: &mut Request) -> Response {
(self.handler)(ctx, req).await
}
}
impl std::fmt::Debug for RouteEntry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RouteEntry")
.field("method", &self.method)
.field("path", &self.path)
.field("meta", &self.meta.as_ref().map(|r| r.operation_id.as_str()))
.finish_non_exhaustive()
}
}
#[derive(Clone)]
pub struct WebSocketRouteEntry {
pub path: String,
handler: Arc<BoxWebSocketHandler>,
}
impl WebSocketRouteEntry {
pub fn new<H, Fut>(path: impl Into<String>, handler: H) -> Self
where
H: Fn(&RequestContext, &mut Request, crate::websocket::WebSocket) -> Fut
+ Send
+ Sync
+ 'static,
Fut: Future<Output = Result<(), crate::websocket::WebSocketError>> + Send + 'static,
{
let handler: BoxWebSocketHandler = Box::new(move |ctx, req, ws| {
let fut = handler(ctx, req, ws);
Box::pin(fut)
});
Self {
path: path.into(),
handler: Arc::new(handler),
}
}
pub async fn call(
&self,
ctx: &RequestContext,
req: &mut Request,
ws: crate::websocket::WebSocket,
) -> Result<(), crate::websocket::WebSocketError> {
(self.handler)(ctx, req, ws).await
}
}
impl std::fmt::Debug for WebSocketRouteEntry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WebSocketRouteEntry")
.field("path", &self.path)
.finish_non_exhaustive()
}
}
#[derive(Default)]
pub struct StateContainer {
state: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
}
impl StateContainer {
#[must_use]
pub fn new() -> Self {
Self {
state: HashMap::new(),
}
}
pub fn insert<T: Send + Sync + 'static>(&mut self, value: T) {
self.state.insert(TypeId::of::<T>(), Arc::new(value));
}
pub fn get<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
self.state
.get(&TypeId::of::<T>())
.and_then(|v| Arc::clone(v).downcast::<T>().ok())
}
pub fn contains<T: 'static>(&self) -> bool {
self.state.contains_key(&TypeId::of::<T>())
}
pub fn len(&self) -> usize {
self.state.len()
}
pub fn is_empty(&self) -> bool {
self.state.is_empty()
}
}
impl std::fmt::Debug for StateContainer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StateContainer")
.field("count", &self.state.len())
.finish()
}
}
pub type BoxExceptionHandler = Box<
dyn Fn(&RequestContext, Box<dyn std::error::Error + Send + Sync>) -> Response + Send + Sync,
>;
#[derive(Default)]
pub struct ExceptionHandlers {
handlers: HashMap<TypeId, BoxExceptionHandler>,
}
impl ExceptionHandlers {
#[must_use]
pub fn new() -> Self {
Self {
handlers: HashMap::new(),
}
}
#[must_use]
pub fn with_defaults() -> Self {
let mut handlers = Self::new();
handlers.register::<crate::HttpError>(|_ctx, err| {
use crate::IntoResponse;
err.into_response()
});
handlers.register::<crate::ValidationErrors>(|_ctx, err| {
use crate::IntoResponse;
err.into_response()
});
handlers
}
pub fn register<E>(
&mut self,
handler: impl Fn(&RequestContext, E) -> Response + Send + Sync + 'static,
) where
E: std::error::Error + Send + Sync + 'static,
{
let boxed_handler: BoxExceptionHandler = Box::new(move |ctx, err| {
match err.downcast::<E>() {
Ok(typed_err) => handler(ctx, *typed_err),
Err(_) => {
Response::with_status(StatusCode::INTERNAL_SERVER_ERROR)
}
}
});
self.handlers.insert(TypeId::of::<E>(), boxed_handler);
}
#[must_use]
pub fn handler<E>(
mut self,
handler: impl Fn(&RequestContext, E) -> Response + Send + Sync + 'static,
) -> Self
where
E: std::error::Error + Send + Sync + 'static,
{
self.register::<E>(handler);
self
}
pub fn handle<E>(&self, ctx: &RequestContext, err: E) -> Option<Response>
where
E: std::error::Error + Send + Sync + 'static,
{
let type_id = TypeId::of::<E>();
self.handlers
.get(&type_id)
.map(|handler| handler(ctx, Box::new(err)))
}
pub fn handle_or_default<E>(&self, ctx: &RequestContext, err: E) -> Response
where
E: std::error::Error + Send + Sync + 'static,
{
self.handle(ctx, err)
.unwrap_or_else(|| Response::with_status(StatusCode::INTERNAL_SERVER_ERROR))
}
pub fn has_handler<E: 'static>(&self) -> bool {
self.handlers.contains_key(&TypeId::of::<E>())
}
pub fn len(&self) -> usize {
self.handlers.len()
}
pub fn is_empty(&self) -> bool {
self.handlers.is_empty()
}
pub fn merge(&mut self, other: ExceptionHandlers) {
self.handlers.extend(other.handlers);
}
}
impl std::fmt::Debug for ExceptionHandlers {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ExceptionHandlers")
.field("count", &self.handlers.len())
.finish()
}
}
#[derive(Debug, Clone)]
pub struct AppConfig {
pub name: String,
pub version: String,
pub debug: bool,
pub root_path: String,
pub root_path_in_servers: bool,
pub trailing_slash_mode: crate::routing::TrailingSlashMode,
pub debug_config: crate::error::DebugConfig,
pub max_body_size: usize,
pub request_timeout_ms: u64,
}
impl Default for AppConfig {
fn default() -> Self {
Self {
name: String::from("fastapi_rust"),
version: String::from("0.1.0"),
debug: false,
root_path: String::new(),
root_path_in_servers: false,
trailing_slash_mode: crate::routing::TrailingSlashMode::Strict,
debug_config: crate::error::DebugConfig::default(),
max_body_size: 1024 * 1024, request_timeout_ms: 30_000, }
}
}
impl AppConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn name(mut self, name: impl Into<String>) -> Self {
self.name = name.into();
self
}
#[must_use]
pub fn version(mut self, version: impl Into<String>) -> Self {
self.version = version.into();
self
}
#[must_use]
pub fn debug(mut self, debug: bool) -> Self {
self.debug = debug;
self
}
#[must_use]
pub fn root_path(mut self, root_path: impl Into<String>) -> Self {
let mut rp = root_path.into();
while rp.ends_with('/') {
rp.pop();
}
self.root_path = rp;
self
}
#[must_use]
pub fn root_path_in_servers(mut self, enabled: bool) -> Self {
self.root_path_in_servers = enabled;
self
}
#[must_use]
pub fn trailing_slash_mode(mut self, mode: crate::routing::TrailingSlashMode) -> Self {
self.trailing_slash_mode = mode;
self
}
#[must_use]
pub fn debug_config(mut self, config: crate::error::DebugConfig) -> Self {
self.debug_config = config;
self
}
#[must_use]
pub fn max_body_size(mut self, size: usize) -> Self {
self.max_body_size = size;
self
}
#[must_use]
pub fn request_timeout_ms(mut self, timeout: u64) -> Self {
self.request_timeout_ms = timeout;
self
}
}
#[derive(Debug, Clone)]
pub struct OpenApiConfig {
pub enabled: bool,
pub title: String,
pub version: String,
pub description: Option<String>,
pub openapi_path: String,
pub servers: Vec<(String, Option<String>)>,
pub tags: Vec<(String, Option<String>)>,
}
impl Default for OpenApiConfig {
fn default() -> Self {
Self {
enabled: true,
title: "FastAPI Rust".to_string(),
version: "0.1.0".to_string(),
description: None,
openapi_path: "/openapi.json".to_string(),
servers: Vec::new(),
tags: Vec::new(),
}
}
}
impl OpenApiConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn title(mut self, title: impl Into<String>) -> Self {
self.title = title.into();
self
}
#[must_use]
pub fn version(mut self, version: impl Into<String>) -> Self {
self.version = version.into();
self
}
#[must_use]
pub fn description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
#[must_use]
pub fn path(mut self, path: impl Into<String>) -> Self {
self.openapi_path = path.into();
self
}
#[must_use]
pub fn server(mut self, url: impl Into<String>, description: Option<String>) -> Self {
self.servers.push((url.into(), description));
self
}
#[must_use]
pub fn tag(mut self, name: impl Into<String>, description: Option<String>) -> Self {
self.tags.push((name.into(), description));
self
}
#[must_use]
pub fn disable(mut self) -> Self {
self.enabled = false;
self
}
}
pub struct AppBuilder {
config: AppConfig,
routes: Vec<RouteEntry>,
ws_routes: Vec<WebSocketRouteEntry>,
middleware: Vec<Arc<dyn Middleware>>,
state: StateContainer,
exception_handlers: ExceptionHandlers,
startup_hooks: Vec<StartupHook>,
shutdown_hooks: Vec<Box<dyn FnOnce() + Send>>,
async_shutdown_hooks: Vec<Box<dyn FnOnce() -> Pin<Box<dyn Future<Output = ()> + Send>> + Send>>,
openapi_config: Option<OpenApiConfig>,
docs_config: Option<crate::docs::DocsConfig>,
}
impl Default for AppBuilder {
fn default() -> Self {
Self {
config: AppConfig::default(),
routes: Vec::new(),
ws_routes: Vec::new(),
middleware: Vec::new(),
state: StateContainer::default(),
exception_handlers: ExceptionHandlers::default(),
startup_hooks: Vec::new(),
shutdown_hooks: Vec::new(),
async_shutdown_hooks: Vec::new(),
openapi_config: None,
docs_config: None,
}
}
}
impl AppBuilder {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn config(mut self, config: AppConfig) -> Self {
self.config = config;
self
}
#[must_use]
pub fn openapi(mut self, config: OpenApiConfig) -> Self {
self.openapi_config = Some(config);
self
}
#[must_use]
pub fn enable_docs(mut self, mut config: crate::docs::DocsConfig) -> Self {
if config.title == crate::docs::DocsConfig::default().title {
config.title.clone_from(&self.config.name);
}
match self.openapi_config.take() {
Some(mut openapi) => {
openapi.openapi_path.clone_from(&config.openapi_path);
self.openapi_config = Some(openapi);
}
None => {
self.openapi_config = Some(
OpenApiConfig::new()
.title(self.config.name.clone())
.version(self.config.version.clone())
.path(config.openapi_path.clone()),
);
}
}
self.docs_config = Some(config);
self
}
#[must_use]
pub fn route<H, Fut>(mut self, path: impl Into<String>, method: Method, handler: H) -> Self
where
H: Fn(&RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Response> + Send + 'static,
{
self.routes.push(RouteEntry::new(method, path, handler));
self
}
#[must_use]
pub fn route_entry(mut self, entry: RouteEntry) -> Self {
self.routes.push(entry);
self
}
#[must_use]
pub fn websocket<H, Fut>(mut self, path: impl Into<String>, handler: H) -> Self
where
H: Fn(&RequestContext, &mut Request, crate::websocket::WebSocket) -> Fut
+ Send
+ Sync
+ 'static,
Fut: Future<Output = Result<(), crate::websocket::WebSocketError>> + Send + 'static,
{
self.ws_routes.push(WebSocketRouteEntry::new(path, handler));
self
}
#[must_use]
pub fn get<H, Fut>(self, path: impl Into<String>, handler: H) -> Self
where
H: Fn(&RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Response> + Send + 'static,
{
self.route(path, Method::Get, handler)
}
#[must_use]
pub fn post<H, Fut>(self, path: impl Into<String>, handler: H) -> Self
where
H: Fn(&RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Response> + Send + 'static,
{
self.route(path, Method::Post, handler)
}
#[must_use]
pub fn put<H, Fut>(self, path: impl Into<String>, handler: H) -> Self
where
H: Fn(&RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Response> + Send + 'static,
{
self.route(path, Method::Put, handler)
}
#[must_use]
pub fn delete<H, Fut>(self, path: impl Into<String>, handler: H) -> Self
where
H: Fn(&RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Response> + Send + 'static,
{
self.route(path, Method::Delete, handler)
}
#[must_use]
pub fn patch<H, Fut>(self, path: impl Into<String>, handler: H) -> Self
where
H: Fn(&RequestContext, &mut Request) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Response> + Send + 'static,
{
self.route(path, Method::Patch, handler)
}
#[must_use]
pub fn middleware<M: Middleware + 'static>(mut self, middleware: M) -> Self {
self.middleware.push(Arc::new(middleware));
self
}
#[must_use]
pub fn state<T: Send + Sync + 'static>(mut self, state: T) -> Self {
self.state.insert(state);
self
}
#[must_use]
pub fn exception_handler<E, H>(mut self, handler: H) -> Self
where
E: std::error::Error + Send + Sync + 'static,
H: Fn(&RequestContext, E) -> Response + Send + Sync + 'static,
{
self.exception_handlers.register::<E>(handler);
self
}
#[must_use]
pub fn exception_handlers(mut self, handlers: ExceptionHandlers) -> Self {
self.exception_handlers = handlers;
self
}
#[must_use]
pub fn with_default_exception_handlers(mut self) -> Self {
self.exception_handlers = ExceptionHandlers::with_defaults();
self
}
#[must_use]
pub fn on_startup<F>(mut self, hook: F) -> Self
where
F: FnOnce() -> Result<(), StartupHookError> + Send + 'static,
{
self.startup_hooks.push(StartupHook::Sync(Box::new(hook)));
self
}
#[must_use]
pub fn on_startup_async<F, Fut>(mut self, hook: F) -> Self
where
F: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = Result<(), StartupHookError>> + Send + 'static,
{
self.startup_hooks.push(StartupHook::AsyncFactory(Box::new(
move || Box::pin(hook()),
)));
self
}
#[must_use]
pub fn on_shutdown<F>(mut self, hook: F) -> Self
where
F: FnOnce() + Send + 'static,
{
self.shutdown_hooks.push(Box::new(hook));
self
}
#[must_use]
pub fn on_shutdown_async<F, Fut>(mut self, hook: F) -> Self
where
F: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
self.async_shutdown_hooks
.push(Box::new(move || Box::pin(hook())));
self
}
#[must_use]
pub fn startup_hook_count(&self) -> usize {
self.startup_hooks.len()
}
#[must_use]
pub fn shutdown_hook_count(&self) -> usize {
self.shutdown_hooks.len() + self.async_shutdown_hooks.len()
}
#[must_use]
#[allow(clippy::too_many_lines)]
pub fn build(mut self) -> App {
let (openapi_spec, openapi_path) = if let Some(ref openapi_config) = self.openapi_config {
if openapi_config.enabled {
let spec = self.generate_openapi_spec(openapi_config);
let spec_json =
serde_json::to_string_pretty(&spec).unwrap_or_else(|_| "{}".to_string());
(
Some(Arc::new(spec_json)),
Some(openapi_config.openapi_path.clone()),
)
} else {
(None, None)
}
} else {
(None, None)
};
if let (Some(spec), Some(path)) = (&openapi_spec, &openapi_path) {
let spec_clone = Arc::clone(spec);
self.routes.push(RouteEntry::new(
Method::Get,
path.clone(),
move |_ctx: &RequestContext, _req: &mut Request| {
let spec = Arc::clone(&spec_clone);
async move {
Response::ok()
.header("content-type", b"application/json".to_vec())
.body(crate::response::ResponseBody::Bytes(
spec.as_bytes().to_vec(),
))
}
},
));
}
if let (Some(openapi_url), Some(docs_config)) = (openapi_path.clone(), self.docs_config) {
let docs_config = Arc::new(docs_config);
let openapi_url = Arc::new(openapi_url);
if let Some(docs_path) = docs_config.docs_path.clone() {
let cfg = Arc::clone(&docs_config);
let url = Arc::clone(&openapi_url);
self.routes.push(RouteEntry::new(
Method::Get,
docs_path.clone(),
move |_ctx: &RequestContext, _req: &mut Request| {
let cfg = Arc::clone(&cfg);
let url = Arc::clone(&url);
async move { crate::docs::swagger_ui_response(&cfg, &url) }
},
));
let docs_prefix = docs_path.trim_end_matches('/');
let oauth2_redirect_path = if docs_prefix.is_empty() {
"/oauth2-redirect".to_string()
} else {
format!("{docs_prefix}/oauth2-redirect")
};
self.routes.push(RouteEntry::new(
Method::Get,
oauth2_redirect_path,
|_ctx: &RequestContext, _req: &mut Request| async move {
crate::docs::oauth2_redirect_response()
},
));
}
if let Some(redoc_path) = docs_config.redoc_path.clone() {
let cfg = Arc::clone(&docs_config);
let url = Arc::clone(&openapi_url);
self.routes.push(RouteEntry::new(
Method::Get,
redoc_path,
move |_ctx: &RequestContext, _req: &mut Request| {
let cfg = Arc::clone(&cfg);
let url = Arc::clone(&url);
async move { crate::docs::redoc_response(&cfg, &url) }
},
));
}
}
let mut middleware_stack = MiddlewareStack::with_capacity(self.middleware.len());
for mw in self.middleware {
middleware_stack.push_arc(mw);
}
let mut router = Router::new();
for entry in &self.routes {
let route = entry
.route_meta()
.cloned()
.unwrap_or_else(|| Route::new(entry.method, &entry.path));
router
.add(route)
.expect("route conflict during App::build()");
}
let mut ws_router = Router::new();
for entry in &self.ws_routes {
ws_router
.add(Route::new(Method::Get, &entry.path))
.expect("websocket route conflict during App::build()");
}
App {
config: self.config,
routes: self.routes,
ws_routes: self.ws_routes,
router,
ws_router,
middleware: middleware_stack,
state: Arc::new(self.state),
exception_handlers: Arc::new(self.exception_handlers),
dependency_overrides: Arc::new(crate::dependency::DependencyOverrides::new()),
startup_hooks: parking_lot::Mutex::new(self.startup_hooks),
shutdown_hooks: parking_lot::Mutex::new(self.shutdown_hooks),
async_shutdown_hooks: parking_lot::Mutex::new(self.async_shutdown_hooks),
openapi_spec,
}
}
fn generate_openapi_spec(&self, config: &OpenApiConfig) -> fastapi_openapi::OpenApi {
use fastapi_openapi::{OpenApiBuilder, Operation, Response as OAResponse};
use std::collections::HashMap;
let mut builder = OpenApiBuilder::new(&config.title, &config.version);
if let Some(ref desc) = config.description {
builder = builder.description(desc);
}
for (url, desc) in &config.servers {
builder = builder.server(url, desc.clone());
}
for (name, desc) in &config.tags {
builder = builder.tag(name, desc.clone());
}
for entry in &self.routes {
if let Some(route) = entry.route_meta() {
builder.add_route(route);
continue;
}
let mut responses = HashMap::new();
responses.insert(
"200".to_string(),
OAResponse {
description: "Successful response".to_string(),
content: HashMap::new(),
},
);
let operation = Operation {
operation_id: Some(format!(
"{}_{}",
entry.method.as_str().to_lowercase(),
entry
.path
.replace('/', "_")
.replace(['{', '}'], "")
.trim_matches('_')
)),
summary: None,
description: None,
tags: Vec::new(),
parameters: Vec::new(),
request_body: None,
responses,
deprecated: false,
};
builder = builder.operation(entry.method.as_str(), &entry.path, operation);
}
builder.build()
}
}
impl std::fmt::Debug for AppBuilder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AppBuilder")
.field("config", &self.config)
.field("routes", &self.routes.len())
.field("middleware", &self.middleware.len())
.field("state", &self.state)
.field("exception_handlers", &self.exception_handlers)
.field("startup_hooks", &self.startup_hooks.len())
.field("shutdown_hooks", &self.shutdown_hook_count())
.finish()
}
}
pub struct App {
config: AppConfig,
routes: Vec<RouteEntry>,
ws_routes: Vec<WebSocketRouteEntry>,
router: Router,
ws_router: Router,
middleware: MiddlewareStack,
state: Arc<StateContainer>,
exception_handlers: Arc<ExceptionHandlers>,
dependency_overrides: Arc<crate::dependency::DependencyOverrides>,
startup_hooks: parking_lot::Mutex<Vec<StartupHook>>,
shutdown_hooks: parking_lot::Mutex<Vec<Box<dyn FnOnce() + Send>>>,
async_shutdown_hooks: parking_lot::Mutex<
Vec<Box<dyn FnOnce() -> Pin<Box<dyn Future<Output = ()> + Send>> + Send>>,
>,
openapi_spec: Option<Arc<String>>,
}
impl App {
#[must_use]
pub fn builder() -> AppBuilder {
AppBuilder::new()
}
#[must_use]
pub fn test_client(self: Arc<Self>) -> crate::testing::TestClient<Arc<Self>> {
crate::testing::TestClient::new(self)
}
#[must_use]
pub fn test_client_with_seed(
self: Arc<Self>,
seed: u64,
) -> crate::testing::TestClient<Arc<Self>> {
crate::testing::TestClient::with_seed(self, seed)
}
#[must_use]
pub fn config(&self) -> &AppConfig {
&self.config
}
#[must_use]
pub fn route_count(&self) -> usize {
self.routes.len()
}
#[must_use]
pub fn websocket_route_count(&self) -> usize {
self.ws_routes.len()
}
#[must_use]
pub fn has_websocket_route(&self, path: &str) -> bool {
matches!(
self.ws_router.lookup(path, Method::Get),
RouteLookup::Match(_)
)
}
pub fn routes(&self) -> impl Iterator<Item = (Method, &str)> {
self.routes.iter().map(|r| (r.method, r.path.as_str()))
}
#[must_use]
pub fn openapi_spec(&self) -> Option<&str> {
self.openapi_spec.as_ref().map(|s| s.as_str())
}
#[must_use]
pub fn state(&self) -> &Arc<StateContainer> {
&self.state
}
pub fn get_state<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
self.state.get::<T>()
}
#[must_use]
pub fn exception_handlers(&self) -> &Arc<ExceptionHandlers> {
&self.exception_handlers
}
pub fn override_dependency_value<T>(&self, value: T)
where
T: crate::dependency::FromDependency,
{
self.dependency_overrides.insert_value(value);
}
pub fn clear_dependency_overrides(&self) {
self.dependency_overrides.clear();
}
#[must_use]
pub fn dependency_overrides(&self) -> Arc<crate::dependency::DependencyOverrides> {
Arc::clone(&self.dependency_overrides)
}
pub fn take_background_tasks(req: &mut Request) -> Option<crate::request::BackgroundTasks> {
req.take_extension::<crate::request::BackgroundTasks>()
}
pub fn handle_error<E>(&self, ctx: &RequestContext, err: E) -> Option<Response>
where
E: std::error::Error + Send + Sync + 'static,
{
self.exception_handlers.handle(ctx, err)
}
pub fn handle_error_or_default<E>(&self, ctx: &RequestContext, err: E) -> Response
where
E: std::error::Error + Send + Sync + 'static,
{
self.exception_handlers.handle_or_default(ctx, err)
}
pub async fn handle(&self, ctx: &RequestContext, req: &mut Request) -> Response {
match self.router.lookup(req.path(), req.method()) {
RouteLookup::Match(route_match) => {
let entry = self.routes.iter().find(|e| {
e.method == route_match.route.method && e.path == route_match.route.path
});
let Some(entry) = entry else {
return Response::with_status(StatusCode::INTERNAL_SERVER_ERROR);
};
if !route_match.params.is_empty() {
let path_params = crate::extract::PathParams::from_pairs(
route_match
.params
.iter()
.map(|(k, v)| ((*k).to_string(), (*v).to_string()))
.collect(),
);
req.insert_extension(path_params);
}
let handler = RouteHandler { entry };
self.middleware.execute(&handler, ctx, req).await
}
RouteLookup::MethodNotAllowed { allowed } => {
if req.method() == Method::Options {
let mut methods = allowed.methods().to_vec();
if !methods.contains(&Method::Options) {
methods.push(Method::Options);
}
let allow = fastapi_router::AllowedMethods::new(methods);
Response::with_status(StatusCode::NO_CONTENT)
.header("allow", allow.header_value().as_bytes().to_vec())
} else {
Response::with_status(StatusCode::METHOD_NOT_ALLOWED)
.header("allow", allowed.header_value().as_bytes().to_vec())
}
}
RouteLookup::NotFound => Response::with_status(StatusCode::NOT_FOUND),
}
}
pub async fn handle_websocket(
&self,
ctx: &RequestContext,
req: &mut Request,
ws: crate::websocket::WebSocket,
) -> Result<(), crate::websocket::WebSocketError> {
match self.ws_router.lookup(req.path(), Method::Get) {
RouteLookup::Match(route_match) => {
let entry = self
.ws_routes
.iter()
.find(|e| e.path == route_match.route.path);
let Some(entry) = entry else {
return Err(crate::websocket::WebSocketError::Protocol(
"websocket route missing handler",
));
};
if !route_match.params.is_empty() {
let path_params = crate::extract::PathParams::from_pairs(
route_match
.params
.iter()
.map(|(k, v)| ((*k).to_string(), (*v).to_string()))
.collect(),
);
req.insert_extension(path_params);
}
entry.call(ctx, req, ws).await
}
_ => Err(crate::websocket::WebSocketError::Protocol(
"no websocket route matched",
)),
}
}
pub async fn run_startup_hooks(&self) -> StartupOutcome {
let hooks: Vec<StartupHook> = std::mem::take(&mut *self.startup_hooks.lock());
let mut warnings = 0;
for hook in hooks {
match hook.run() {
Ok(None) => {
}
Ok(Some(fut)) => {
match fut.await {
Ok(()) => {}
Err(e) if e.abort => {
return StartupOutcome::Aborted(e);
}
Err(_) => {
warnings += 1;
}
}
}
Err(e) if e.abort => {
return StartupOutcome::Aborted(e);
}
Err(_) => {
warnings += 1;
}
}
}
if warnings > 0 {
StartupOutcome::PartialSuccess { warnings }
} else {
StartupOutcome::Success
}
}
pub async fn run_shutdown_hooks(&self) {
let async_hooks: Vec<_> = std::mem::take(&mut *self.async_shutdown_hooks.lock());
for hook in async_hooks.into_iter().rev() {
let fut = hook();
fut.await;
}
let sync_hooks: Vec<_> = std::mem::take(&mut *self.shutdown_hooks.lock());
for hook in sync_hooks.into_iter().rev() {
hook();
}
}
pub fn transfer_shutdown_hooks(&self, controller: &ShutdownController) {
let sync_hooks: Vec<_> = std::mem::take(&mut *self.shutdown_hooks.lock());
for hook in sync_hooks {
controller.register_hook(hook);
}
let async_hooks: Vec<_> = std::mem::take(&mut *self.async_shutdown_hooks.lock());
for hook in async_hooks {
controller.register_async_hook(move || hook());
}
}
#[must_use]
pub fn pending_startup_hooks(&self) -> usize {
self.startup_hooks.lock().len()
}
#[must_use]
pub fn pending_shutdown_hooks(&self) -> usize {
self.shutdown_hooks.lock().len() + self.async_shutdown_hooks.lock().len()
}
}
impl std::fmt::Debug for App {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("App")
.field("config", &self.config)
.field("routes", &self.routes.len())
.field("middleware", &self.middleware.len())
.field("state", &self.state)
.field("exception_handlers", &self.exception_handlers)
.field("startup_hooks", &self.startup_hooks.lock().len())
.field("shutdown_hooks", &self.pending_shutdown_hooks())
.finish()
}
}
impl Handler for App {
fn call<'a>(
&'a self,
ctx: &'a RequestContext,
req: &'a mut Request,
) -> BoxFuture<'a, Response> {
Box::pin(async move { self.handle(ctx, req).await })
}
fn dependency_overrides(&self) -> Option<Arc<crate::dependency::DependencyOverrides>> {
Some(Arc::clone(&self.dependency_overrides))
}
}
struct RouteHandler<'a> {
entry: &'a RouteEntry,
}
impl<'a> Handler for RouteHandler<'a> {
fn call<'b>(
&'b self,
ctx: &'b RequestContext,
req: &'b mut Request,
) -> BoxFuture<'b, Response> {
let handler = self.entry.handler.clone();
Box::pin(async move { handler(ctx, req).await })
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::response::ResponseBody;
fn test_handler(_ctx: &RequestContext, _req: &mut Request) -> std::future::Ready<Response> {
std::future::ready(Response::ok().body(ResponseBody::Bytes(b"Hello, World!".to_vec())))
}
fn health_handler(_ctx: &RequestContext, _req: &mut Request) -> std::future::Ready<Response> {
std::future::ready(Response::ok().body(ResponseBody::Bytes(b"OK".to_vec())))
}
fn test_context() -> RequestContext {
let cx = asupersync::Cx::for_testing();
RequestContext::new(cx, 1)
}
#[test]
fn app_builder_creates_app() {
let app = App::builder()
.config(AppConfig::new().name("Test App"))
.get("/", test_handler)
.get("/health", health_handler)
.build();
assert_eq!(app.route_count(), 2);
assert_eq!(app.config().name, "Test App");
}
#[test]
fn app_config_builder() {
let config = AppConfig::new()
.name("My API")
.version("1.0.0")
.debug(true)
.max_body_size(2 * 1024 * 1024)
.request_timeout_ms(60_000);
assert_eq!(config.name, "My API");
assert_eq!(config.version, "1.0.0");
assert!(config.debug);
assert_eq!(config.max_body_size, 2 * 1024 * 1024);
assert_eq!(config.request_timeout_ms, 60_000);
}
#[test]
fn state_container_insert_and_get() {
#[derive(Debug, PartialEq)]
struct MyState {
value: i32,
}
let mut container = StateContainer::new();
container.insert(MyState { value: 42 });
let state = container.get::<MyState>();
assert!(state.is_some());
assert_eq!(state.unwrap().value, 42);
}
#[test]
fn state_container_multiple_types() {
struct TypeA(i32);
struct TypeB(String);
let mut container = StateContainer::new();
container.insert(TypeA(1));
container.insert(TypeB("hello".to_string()));
assert!(container.contains::<TypeA>());
assert!(container.contains::<TypeB>());
assert!(!container.contains::<i64>());
assert_eq!(container.get::<TypeA>().unwrap().0, 1);
assert_eq!(container.get::<TypeB>().unwrap().0, "hello");
}
#[test]
fn app_builder_with_state() {
struct DbPool {
connection_count: usize,
}
let app = App::builder()
.state(DbPool {
connection_count: 10,
})
.get("/", test_handler)
.build();
let pool = app.get_state::<DbPool>();
assert!(pool.is_some());
assert_eq!(pool.unwrap().connection_count, 10);
}
#[test]
fn app_handles_get_request() {
let app = App::builder().get("/", test_handler).build();
let ctx = test_context();
let mut req = Request::new(Method::Get, "/");
let response = futures_executor::block_on(app.handle(&ctx, &mut req));
assert_eq!(response.status().as_u16(), 200);
}
#[test]
fn app_returns_404_for_unknown_path() {
let app = App::builder().get("/", test_handler).build();
let ctx = test_context();
let mut req = Request::new(Method::Get, "/unknown");
let response = futures_executor::block_on(app.handle(&ctx, &mut req));
assert_eq!(response.status().as_u16(), 404);
}
#[test]
fn app_returns_405_for_wrong_method() {
let app = App::builder().get("/", test_handler).build();
let ctx = test_context();
let mut req = Request::new(Method::Post, "/");
let response = futures_executor::block_on(app.handle(&ctx, &mut req));
assert_eq!(response.status().as_u16(), 405);
}
#[test]
fn app_builder_all_methods() {
let app = App::builder()
.get("/get", test_handler)
.post("/post", test_handler)
.put("/put", test_handler)
.delete("/delete", test_handler)
.patch("/patch", test_handler)
.build();
assert_eq!(app.route_count(), 5);
}
#[test]
fn route_entry_debug() {
let entry = RouteEntry::new(Method::Get, "/test", test_handler);
let debug = format!("{:?}", entry);
assert!(debug.contains("RouteEntry"));
assert!(debug.contains("Get"));
assert!(debug.contains("/test"));
}
#[test]
fn app_with_middleware() {
use crate::middleware::NoopMiddleware;
let app = App::builder()
.middleware(NoopMiddleware)
.middleware(NoopMiddleware)
.get("/", test_handler)
.build();
let ctx = test_context();
let mut req = Request::new(Method::Get, "/");
let response = futures_executor::block_on(app.handle(&ctx, &mut req));
assert_eq!(response.status().as_u16(), 200);
}
#[derive(Debug)]
struct TestError {
message: String,
code: u32,
}
impl std::fmt::Display for TestError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "TestError({}): {}", self.code, self.message)
}
}
impl std::error::Error for TestError {}
#[derive(Debug)]
struct AnotherError(String);
impl std::fmt::Display for AnotherError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "AnotherError: {}", self.0)
}
}
impl std::error::Error for AnotherError {}
#[test]
fn exception_handlers_new_is_empty() {
let handlers = ExceptionHandlers::new();
assert!(handlers.is_empty());
assert_eq!(handlers.len(), 0);
}
#[test]
fn exception_handlers_register_single() {
let mut handlers = ExceptionHandlers::new();
handlers.register::<TestError>(|_ctx, err| {
Response::with_status(StatusCode::BAD_REQUEST)
.body(ResponseBody::Bytes(err.message.as_bytes().to_vec()))
});
assert!(handlers.has_handler::<TestError>());
assert!(!handlers.has_handler::<AnotherError>());
assert_eq!(handlers.len(), 1);
}
#[test]
fn exception_handlers_register_multiple() {
let mut handlers = ExceptionHandlers::new();
handlers.register::<TestError>(|_ctx, _err| Response::with_status(StatusCode::BAD_REQUEST));
handlers.register::<AnotherError>(|_ctx, _err| {
Response::with_status(StatusCode::INTERNAL_SERVER_ERROR)
});
assert!(handlers.has_handler::<TestError>());
assert!(handlers.has_handler::<AnotherError>());
assert_eq!(handlers.len(), 2);
}
#[test]
fn exception_handlers_builder_pattern() {
let handlers = ExceptionHandlers::new()
.handler::<TestError>(|_ctx, _err| Response::with_status(StatusCode::BAD_REQUEST))
.handler::<AnotherError>(|_ctx, _err| {
Response::with_status(StatusCode::INTERNAL_SERVER_ERROR)
});
assert!(handlers.has_handler::<TestError>());
assert!(handlers.has_handler::<AnotherError>());
assert_eq!(handlers.len(), 2);
}
#[test]
fn exception_handlers_with_defaults() {
let handlers = ExceptionHandlers::with_defaults();
assert!(handlers.has_handler::<crate::HttpError>());
assert!(handlers.has_handler::<crate::ValidationErrors>());
assert_eq!(handlers.len(), 2);
}
#[test]
fn exception_handlers_merge() {
let mut handlers1 = ExceptionHandlers::new()
.handler::<TestError>(|_ctx, _err| Response::with_status(StatusCode::BAD_REQUEST));
let handlers2 = ExceptionHandlers::new().handler::<AnotherError>(|_ctx, _err| {
Response::with_status(StatusCode::INTERNAL_SERVER_ERROR)
});
handlers1.merge(handlers2);
assert!(handlers1.has_handler::<TestError>());
assert!(handlers1.has_handler::<AnotherError>());
assert_eq!(handlers1.len(), 2);
}
#[test]
fn exception_handlers_handle_registered_error() {
let handlers = ExceptionHandlers::new().handler::<TestError>(|_ctx, err| {
Response::with_status(StatusCode::BAD_REQUEST)
.body(ResponseBody::Bytes(err.message.as_bytes().to_vec()))
});
let ctx = test_context();
let err = TestError {
message: "test error".into(),
code: 42,
};
let response = handlers.handle(&ctx, err);
assert!(response.is_some());
let response = response.unwrap();
assert_eq!(response.status().as_u16(), 400);
}
#[test]
fn exception_handlers_handle_unregistered_error() {
let handlers = ExceptionHandlers::new()
.handler::<TestError>(|_ctx, _err| Response::with_status(StatusCode::BAD_REQUEST));
let ctx = test_context();
let err = AnotherError("unhandled".into());
let response = handlers.handle(&ctx, err);
assert!(response.is_none());
}
#[test]
fn exception_handlers_handle_or_default_registered() {
let handlers = ExceptionHandlers::new()
.handler::<TestError>(|_ctx, _err| Response::with_status(StatusCode::BAD_REQUEST));
let ctx = test_context();
let err = TestError {
message: "test".into(),
code: 1,
};
let response = handlers.handle_or_default(&ctx, err);
assert_eq!(response.status().as_u16(), 400);
}
#[test]
fn exception_handlers_handle_or_default_unregistered() {
let handlers = ExceptionHandlers::new();
let ctx = test_context();
let err = TestError {
message: "test".into(),
code: 1,
};
let response = handlers.handle_or_default(&ctx, err);
assert_eq!(response.status().as_u16(), 500);
}
#[test]
fn exception_handlers_error_values_passed_to_handler() {
use std::sync::atomic::{AtomicU32, Ordering};
let captured_code = Arc::new(AtomicU32::new(0));
let captured_code_clone = captured_code.clone();
let handlers = ExceptionHandlers::new().handler::<TestError>(move |_ctx, err| {
captured_code_clone.store(err.code, Ordering::SeqCst);
Response::with_status(StatusCode::BAD_REQUEST)
});
let ctx = test_context();
let err = TestError {
message: "test".into(),
code: 12345,
};
let _ = handlers.handle(&ctx, err);
assert_eq!(captured_code.load(Ordering::SeqCst), 12345);
}
#[test]
fn app_builder_exception_handler_single() {
let app = App::builder()
.exception_handler::<TestError, _>(|_ctx, err| {
Response::with_status(StatusCode::BAD_REQUEST)
.body(ResponseBody::Bytes(err.message.as_bytes().to_vec()))
})
.get("/", test_handler)
.build();
assert!(app.exception_handlers().has_handler::<TestError>());
}
#[test]
fn app_builder_exception_handler_multiple() {
let app = App::builder()
.exception_handler::<TestError, _>(|_ctx, _err| {
Response::with_status(StatusCode::BAD_REQUEST)
})
.exception_handler::<AnotherError, _>(|_ctx, _err| {
Response::with_status(StatusCode::INTERNAL_SERVER_ERROR)
})
.get("/", test_handler)
.build();
assert!(app.exception_handlers().has_handler::<TestError>());
assert!(app.exception_handlers().has_handler::<AnotherError>());
}
#[test]
fn app_builder_with_default_exception_handlers() {
let app = App::builder()
.with_default_exception_handlers()
.get("/", test_handler)
.build();
assert!(app.exception_handlers().has_handler::<crate::HttpError>());
assert!(
app.exception_handlers()
.has_handler::<crate::ValidationErrors>()
);
}
#[test]
fn app_builder_exception_handlers_registry() {
let handlers = ExceptionHandlers::new()
.handler::<TestError>(|_ctx, _err| Response::with_status(StatusCode::BAD_REQUEST))
.handler::<AnotherError>(|_ctx, _err| {
Response::with_status(StatusCode::INTERNAL_SERVER_ERROR)
});
let app = App::builder()
.exception_handlers(handlers)
.get("/", test_handler)
.build();
assert!(app.exception_handlers().has_handler::<TestError>());
assert!(app.exception_handlers().has_handler::<AnotherError>());
}
#[test]
fn app_handle_error_registered() {
let app = App::builder()
.exception_handler::<TestError, _>(|_ctx, _err| {
Response::with_status(StatusCode::BAD_REQUEST)
})
.get("/", test_handler)
.build();
let ctx = test_context();
let err = TestError {
message: "test".into(),
code: 1,
};
let response = app.handle_error(&ctx, err);
assert!(response.is_some());
assert_eq!(response.unwrap().status().as_u16(), 400);
}
#[test]
fn app_handle_error_unregistered() {
let app = App::builder().get("/", test_handler).build();
let ctx = test_context();
let err = TestError {
message: "test".into(),
code: 1,
};
let response = app.handle_error(&ctx, err);
assert!(response.is_none());
}
#[test]
fn app_handle_error_or_default() {
let app = App::builder().get("/", test_handler).build();
let ctx = test_context();
let err = TestError {
message: "test".into(),
code: 1,
};
let response = app.handle_error_or_default(&ctx, err);
assert_eq!(response.status().as_u16(), 500);
}
#[test]
fn exception_handlers_override_on_register() {
let handlers = ExceptionHandlers::new()
.handler::<TestError>(|_ctx, _err| Response::with_status(StatusCode::BAD_REQUEST))
.handler::<TestError>(|_ctx, _err| {
Response::with_status(StatusCode::UNPROCESSABLE_ENTITY)
});
assert_eq!(handlers.len(), 1);
let ctx = test_context();
let err = TestError {
message: "test".into(),
code: 1,
};
let response = handlers.handle(&ctx, err);
assert!(response.is_some());
assert_eq!(response.unwrap().status().as_u16(), 422);
}
#[test]
fn exception_handlers_merge_overrides() {
let mut handlers1 = ExceptionHandlers::new()
.handler::<TestError>(|_ctx, _err| Response::with_status(StatusCode::BAD_REQUEST));
let handlers2 = ExceptionHandlers::new().handler::<TestError>(|_ctx, _err| {
Response::with_status(StatusCode::UNPROCESSABLE_ENTITY)
});
handlers1.merge(handlers2);
assert_eq!(handlers1.len(), 1);
let ctx = test_context();
let err = TestError {
message: "test".into(),
code: 1,
};
let response = handlers1.handle(&ctx, err);
assert!(response.is_some());
assert_eq!(response.unwrap().status().as_u16(), 422);
}
#[test]
fn exception_handlers_override_default_http_error() {
let mut handlers = ExceptionHandlers::with_defaults();
handlers.register::<crate::HttpError>(|_ctx, err| {
let detail = err.detail.as_deref().unwrap_or("Unknown error");
Response::with_status(err.status)
.header("x-custom-error", b"true".to_vec())
.body(ResponseBody::Bytes(detail.as_bytes().to_vec()))
});
assert_eq!(handlers.len(), 2);
let ctx = test_context();
let err = crate::HttpError::bad_request().with_detail("test error");
let response = handlers.handle(&ctx, err);
assert!(response.is_some());
let response = response.unwrap();
assert_eq!(response.status().as_u16(), 400);
let custom_header = response
.headers()
.iter()
.find(|(name, _)| name.eq_ignore_ascii_case("x-custom-error"))
.map(|(_, v)| v.as_slice());
assert_eq!(custom_header, Some(b"true".as_slice()));
}
#[test]
fn exception_handlers_override_default_validation_errors() {
let mut handlers = ExceptionHandlers::with_defaults();
handlers.register::<crate::ValidationErrors>(|_ctx, errs| {
Response::with_status(StatusCode::BAD_REQUEST)
.header("x-error-count", errs.len().to_string().as_bytes().to_vec())
});
let ctx = test_context();
let mut errs = crate::ValidationErrors::new();
errs.push(crate::ValidationError::missing(
crate::error::loc::body_field("name"),
));
errs.push(crate::ValidationError::missing(
crate::error::loc::body_field("email"),
));
let response = handlers.handle(&ctx, errs);
assert!(response.is_some());
let response = response.unwrap();
assert_eq!(response.status().as_u16(), 400);
let count_header = response
.headers()
.iter()
.find(|(name, _)| name.eq_ignore_ascii_case("x-error-count"))
.map(|(_, v)| v.as_slice());
assert_eq!(count_header, Some(b"2".as_slice()));
}
#[test]
fn exception_handlers_debug_format() {
let handlers = ExceptionHandlers::new()
.handler::<TestError>(|_ctx, _err| Response::with_status(StatusCode::BAD_REQUEST));
let debug = format!("{:?}", handlers);
assert!(debug.contains("ExceptionHandlers"));
assert!(debug.contains("count"));
assert!(debug.contains("1"));
}
#[test]
fn app_debug_includes_exception_handlers() {
let app = App::builder()
.exception_handler::<TestError, _>(|_ctx, _err| {
Response::with_status(StatusCode::BAD_REQUEST)
})
.get("/", test_handler)
.build();
let debug = format!("{:?}", app);
assert!(debug.contains("exception_handlers"));
}
#[test]
fn app_builder_debug_includes_exception_handlers() {
let builder = App::builder().exception_handler::<TestError, _>(|_ctx, _err| {
Response::with_status(StatusCode::BAD_REQUEST)
});
let debug = format!("{:?}", builder);
assert!(debug.contains("exception_handlers"));
}
#[test]
fn app_builder_startup_hook_registration() {
let builder = App::builder().on_startup(|| Ok(())).on_startup(|| Ok(()));
assert_eq!(builder.startup_hook_count(), 2);
}
#[test]
fn app_builder_shutdown_hook_registration() {
let builder = App::builder().on_shutdown(|| {}).on_shutdown(|| {});
assert_eq!(builder.shutdown_hook_count(), 2);
}
#[test]
fn app_builder_mixed_hooks() {
let builder = App::builder()
.on_startup(|| Ok(()))
.on_shutdown(|| {})
.on_startup(|| Ok(()))
.on_shutdown(|| {});
assert_eq!(builder.startup_hook_count(), 2);
assert_eq!(builder.shutdown_hook_count(), 2);
}
#[test]
fn app_pending_hooks_count() {
let app = App::builder()
.on_startup(|| Ok(()))
.on_startup(|| Ok(()))
.on_shutdown(|| {})
.get("/", test_handler)
.build();
assert_eq!(app.pending_startup_hooks(), 2);
assert_eq!(app.pending_shutdown_hooks(), 1);
}
#[test]
fn startup_hooks_run_in_fifo_order() {
let order = Arc::new(parking_lot::Mutex::new(Vec::new()));
let order1 = Arc::clone(&order);
let order2 = Arc::clone(&order);
let order3 = Arc::clone(&order);
let app = App::builder()
.on_startup(move || {
order1.lock().push(1);
Ok(())
})
.on_startup(move || {
order2.lock().push(2);
Ok(())
})
.on_startup(move || {
order3.lock().push(3);
Ok(())
})
.get("/", test_handler)
.build();
let outcome = futures_executor::block_on(app.run_startup_hooks());
assert!(outcome.can_proceed());
assert_eq!(*order.lock(), vec![1, 2, 3]);
assert_eq!(app.pending_startup_hooks(), 0);
}
#[test]
fn shutdown_hooks_run_in_lifo_order() {
let order = Arc::new(parking_lot::Mutex::new(Vec::new()));
let order1 = Arc::clone(&order);
let order2 = Arc::clone(&order);
let order3 = Arc::clone(&order);
let app = App::builder()
.on_shutdown(move || {
order1.lock().push(1);
})
.on_shutdown(move || {
order2.lock().push(2);
})
.on_shutdown(move || {
order3.lock().push(3);
})
.get("/", test_handler)
.build();
futures_executor::block_on(app.run_shutdown_hooks());
assert_eq!(*order.lock(), vec![3, 2, 1]);
assert_eq!(app.pending_shutdown_hooks(), 0);
}
#[test]
fn startup_hooks_success_outcome() {
let app = App::builder()
.on_startup(|| Ok(()))
.on_startup(|| Ok(()))
.get("/", test_handler)
.build();
let outcome = futures_executor::block_on(app.run_startup_hooks());
assert!(matches!(outcome, StartupOutcome::Success));
assert!(outcome.can_proceed());
}
#[test]
fn startup_hooks_fatal_error_aborts() {
let app = App::builder()
.on_startup(|| Ok(()))
.on_startup(|| Err(StartupHookError::new("database connection failed")))
.on_startup(|| Ok(())) .get("/", test_handler)
.build();
let outcome = futures_executor::block_on(app.run_startup_hooks());
assert!(!outcome.can_proceed());
if let StartupOutcome::Aborted(err) = outcome {
assert!(err.message.contains("database connection failed"));
assert!(err.abort);
} else {
panic!("Expected Aborted outcome");
}
}
#[test]
fn startup_hooks_non_fatal_error_continues() {
let app = App::builder()
.on_startup(|| Ok(()))
.on_startup(|| Err(StartupHookError::non_fatal("optional feature unavailable")))
.on_startup(|| Ok(())) .get("/", test_handler)
.build();
let outcome = futures_executor::block_on(app.run_startup_hooks());
assert!(outcome.can_proceed());
if let StartupOutcome::PartialSuccess { warnings } = outcome {
assert_eq!(warnings, 1);
} else {
panic!("Expected PartialSuccess outcome");
}
}
#[test]
fn startup_hook_error_builder() {
let err = StartupHookError::new("test error")
.with_hook_name("database_init")
.with_abort(false);
assert_eq!(err.hook_name.as_deref(), Some("database_init"));
assert_eq!(err.message, "test error");
assert!(!err.abort);
}
#[test]
fn startup_hook_error_display() {
let err = StartupHookError::new("connection failed").with_hook_name("redis_init");
let display = format!("{}", err);
assert!(display.contains("redis_init"));
assert!(display.contains("connection failed"));
}
#[test]
fn startup_hook_error_non_fatal() {
let err = StartupHookError::non_fatal("optional feature");
assert!(!err.abort);
}
#[test]
fn transfer_shutdown_hooks_to_controller() {
let order = Arc::new(parking_lot::Mutex::new(Vec::new()));
let order1 = Arc::clone(&order);
let order2 = Arc::clone(&order);
let app = App::builder()
.on_shutdown(move || {
order1.lock().push(1);
})
.on_shutdown(move || {
order2.lock().push(2);
})
.get("/", test_handler)
.build();
let controller = ShutdownController::new();
app.transfer_shutdown_hooks(&controller);
assert_eq!(app.pending_shutdown_hooks(), 0);
assert_eq!(controller.hook_count(), 2);
while let Some(hook) = controller.pop_hook() {
hook.run();
}
assert_eq!(*order.lock(), vec![2, 1]);
}
#[test]
fn app_debug_includes_hooks() {
let app = App::builder()
.on_startup(|| Ok(()))
.on_shutdown(|| {})
.get("/", test_handler)
.build();
let debug = format!("{:?}", app);
assert!(debug.contains("startup_hooks"));
assert!(debug.contains("shutdown_hooks"));
}
#[test]
fn app_builder_debug_includes_hooks() {
let builder = App::builder().on_startup(|| Ok(())).on_shutdown(|| {});
let debug = format!("{:?}", builder);
assert!(debug.contains("startup_hooks"));
assert!(debug.contains("shutdown_hooks"));
}
#[test]
fn startup_outcome_success() {
let outcome = StartupOutcome::Success;
assert!(outcome.can_proceed());
assert!(outcome.into_error().is_none());
}
#[test]
fn startup_outcome_partial_success() {
let outcome = StartupOutcome::PartialSuccess { warnings: 2 };
assert!(outcome.can_proceed());
assert!(outcome.into_error().is_none());
}
#[test]
fn startup_outcome_aborted() {
let err = StartupHookError::new("fatal");
let outcome = StartupOutcome::Aborted(err);
assert!(!outcome.can_proceed());
let err = outcome.into_error();
assert!(err.is_some());
assert_eq!(err.unwrap().message, "fatal");
}
#[test]
fn startup_hooks_multiple_non_fatal_errors() {
let app = App::builder()
.on_startup(|| Err(StartupHookError::non_fatal("warning 1")))
.on_startup(|| Ok(()))
.on_startup(|| Err(StartupHookError::non_fatal("warning 2")))
.on_startup(|| Err(StartupHookError::non_fatal("warning 3")))
.get("/", test_handler)
.build();
let outcome = futures_executor::block_on(app.run_startup_hooks());
assert!(outcome.can_proceed());
if let StartupOutcome::PartialSuccess { warnings } = outcome {
assert_eq!(warnings, 3);
} else {
panic!("Expected PartialSuccess");
}
}
#[test]
fn empty_startup_hooks() {
let app = App::builder().get("/", test_handler).build();
let outcome = futures_executor::block_on(app.run_startup_hooks());
assert!(matches!(outcome, StartupOutcome::Success));
}
#[test]
fn empty_shutdown_hooks() {
let app = App::builder().get("/", test_handler).build();
futures_executor::block_on(app.run_shutdown_hooks());
}
#[test]
fn startup_hooks_consumed_after_run() {
let counter = Arc::new(std::sync::atomic::AtomicU32::new(0));
let counter_clone = Arc::clone(&counter);
let app = App::builder()
.on_startup(move || {
counter_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Ok(())
})
.get("/", test_handler)
.build();
futures_executor::block_on(app.run_startup_hooks());
assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 1);
futures_executor::block_on(app.run_startup_hooks());
assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 1);
}
#[test]
fn shutdown_hooks_consumed_after_run() {
let counter = Arc::new(std::sync::atomic::AtomicU32::new(0));
let counter_clone = Arc::clone(&counter);
let app = App::builder()
.on_shutdown(move || {
counter_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
})
.get("/", test_handler)
.build();
futures_executor::block_on(app.run_shutdown_hooks());
assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 1);
futures_executor::block_on(app.run_shutdown_hooks());
assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 1);
}
#[test]
fn root_path_strips_trailing_slashes() {
let config = AppConfig::new().root_path("/api/");
assert_eq!(config.root_path, "/api");
let config = AppConfig::new().root_path("/api///");
assert_eq!(config.root_path, "/api");
let config = AppConfig::new().root_path("/api");
assert_eq!(config.root_path, "/api");
let config = AppConfig::new().root_path("");
assert_eq!(config.root_path, "");
}
}