#[cfg(feature = "open-api")]
use crate::api::http::default_api_routes;
#[cfg(not(feature = "open-api"))]
use crate::api::http::default_routes;
use crate::app::context::AppContext;
use crate::error::RoadsterResult;
use crate::service::ServiceBuilder;
use crate::service::http::initializer::Initializer;
use crate::service::http::initializer::default::default_initializers;
use crate::service::http::middleware::Middleware;
use crate::service::http::middleware::default::default_middleware;
use crate::service::http::service::{HttpService, NAME, enabled};
#[cfg(feature = "open-api")]
use aide::axum::ApiRouter;
#[cfg(feature = "open-api")]
use aide::openapi::OpenApi;
#[cfg(feature = "open-api")]
use aide::transform::TransformOpenApi;
use async_trait::async_trait;
#[cfg(feature = "open-api")]
use axum::Extension;
use axum::Router;
use axum_core::extract::FromRef;
use itertools::Itertools;
use std::collections::BTreeMap;
#[cfg(feature = "open-api")]
use std::sync::Arc;
use tracing::info;
#[cfg(feature = "open-api")]
type ApiDocs = Box<dyn Fn(TransformOpenApi) -> TransformOpenApi + Send>;
pub struct HttpServiceBuilder<S>
where
S: Clone + Send + Sync + 'static,
AppContext: FromRef<S>,
{
state: S,
router: Router<S>,
#[cfg(feature = "open-api")]
api_router: ApiRouter<S>,
#[cfg(feature = "open-api")]
api_docs: ApiDocs,
middleware: BTreeMap<String, Box<dyn Middleware<S>>>,
initializers: BTreeMap<String, Box<dyn Initializer<S>>>,
}
impl<S> HttpServiceBuilder<S>
where
S: Clone + Send + Sync + 'static,
AppContext: FromRef<S>,
{
pub fn new(path_root: Option<&str>, state: &S) -> Self {
#[cfg(not(feature = "open-api"))]
let router = default_routes(path_root.unwrap_or_default(), state);
#[cfg(feature = "open-api")]
let router = Router::<S>::new();
#[cfg(feature = "open-api")]
let context = AppContext::from_ref(state);
Self {
state: state.clone(),
router,
#[cfg(feature = "open-api")]
api_router: default_api_routes(path_root.unwrap_or_default(), state),
#[cfg(feature = "open-api")]
api_docs: Box::new(move |api| {
let api = api
.title(&context.config().app.name)
.description(&format!("# {}", context.config().app.name));
if let Some(version) = context.metadata().version.as_ref() {
api.version(version)
} else {
api
}
}),
middleware: default_middleware(state),
initializers: default_initializers(state),
}
}
#[cfg(test)]
fn empty(state: &S) -> Self {
Self {
state: state.clone(),
router: Router::<S>::new(),
#[cfg(feature = "open-api")]
api_router: ApiRouter::<S>::new(),
#[cfg(feature = "open-api")]
api_docs: Box::new(|op| op),
middleware: Default::default(),
initializers: Default::default(),
}
}
pub fn router(mut self, router: Router<S>) -> Self {
self.router = self.router.merge(router);
self
}
#[cfg(feature = "open-api")]
pub fn api_router(mut self, router: ApiRouter<S>) -> Self {
self.api_router = self.api_router.merge(router);
self
}
#[cfg(feature = "open-api")]
pub fn api_docs(mut self, api_docs: ApiDocs) -> Self {
self.api_docs = api_docs;
self
}
pub fn initializer<T>(mut self, initializer: T) -> RoadsterResult<Self>
where
T: Initializer<S> + 'static,
{
if !initializer.enabled(&self.state) {
return Ok(self);
}
let name = initializer.name();
if self
.initializers
.insert(name.clone(), Box::new(initializer))
.is_some()
{
return Err(crate::error::other::OtherError::Message(format!(
"Initializer `{name}` was already registered"
))
.into());
}
Ok(self)
}
pub fn middleware<T>(mut self, middleware: T) -> RoadsterResult<Self>
where
T: Middleware<S> + 'static,
{
if !middleware.enabled(&self.state) {
return Ok(self);
}
let name = middleware.name();
if self
.middleware
.insert(name.clone(), Box::new(middleware))
.is_some()
{
return Err(crate::error::other::OtherError::Message(format!(
"Middleware `{name}` was already registered"
))
.into());
}
Ok(self)
}
}
#[async_trait]
impl<S> ServiceBuilder<S, HttpService> for HttpServiceBuilder<S>
where
S: Clone + Send + Sync + 'static,
AppContext: FromRef<S>,
{
fn name(&self) -> String {
NAME.to_string()
}
fn enabled(&self, state: &S) -> bool {
enabled(&AppContext::from_ref(state))
}
async fn build(self, state: &S) -> RoadsterResult<HttpService> {
let router = self.router;
#[cfg(feature = "open-api")]
let (router, api) = {
let mut api = OpenApi::default();
let api_docs = self.api_docs;
let api_router = self.api_router.finish_api_with(&mut api, api_docs);
let router = router.merge(api_router);
let api = Arc::new(api);
let router = router.layer(Extension(api.clone()));
(router, api)
};
let router = router.with_state::<()>(state.clone());
let initializers = self
.initializers
.values()
.filter(|initializer| initializer.enabled(state))
.sorted_by(|a, b| Ord::cmp(&a.priority(state), &b.priority(state)))
.collect_vec();
let router = initializers
.iter()
.try_fold(router, |router, initializer| {
info!(http_initializer.name=%initializer.name(), "Running Initializer::after_router");
initializer.after_router(router, state)
})?;
let router = initializers
.iter()
.try_fold(router, |router, initializer| {
info!(http_initializer.name=%initializer.name(), "Running Initializer::before_middleware");
initializer.before_middleware(router, state)
})?;
info!(
"Installing middleware. Note: the order of installation is the inverse of the order middleware will run when handling a request."
);
let router = self
.middleware
.values()
.filter(|middleware| middleware.enabled(state))
.sorted_by(|a, b| Ord::cmp(&a.priority(state), &b.priority(state)))
.rev()
.try_fold(router, |router, middleware| {
info!(http_middleware.name=%middleware.name(), "Installing middleware");
middleware.install(router, state)
})?;
let router = initializers
.iter()
.try_fold(router, |router, initializer| {
info!(http_initializer.name=%initializer.name(), "Running Initializer::after_middleware");
initializer.after_middleware(router, state)
})?;
let router = initializers
.iter()
.try_fold(router, |router, initializer| {
info!(http_initializer.name=%initializer.name(), "Running Initializer::before_serve");
initializer.before_serve(router, state)
})?;
let service = HttpService {
router,
#[cfg(feature = "open-api")]
api,
};
Ok(service)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::app::context::AppContext;
use crate::service::http::initializer::MockInitializer;
use crate::service::http::middleware::MockMiddleware;
#[test]
#[cfg_attr(coverage_nightly, coverage(off))]
fn middleware() {
let context = AppContext::test(None, None, None).unwrap();
let builder = HttpServiceBuilder::<AppContext>::empty(&context);
let mut middleware = MockMiddleware::default();
middleware.expect_enabled().returning(|_| true);
middleware.expect_name().returning(|| "test".to_string());
let builder = builder.middleware(middleware).unwrap();
assert_eq!(builder.middleware.len(), 1);
assert!(builder.middleware.contains_key("test"));
}
#[test]
#[cfg_attr(coverage_nightly, coverage(off))]
fn middleware_not_enabled() {
let context = AppContext::test(None, None, None).unwrap();
let builder = HttpServiceBuilder::<AppContext>::empty(&context);
let mut middleware = MockMiddleware::default();
middleware.expect_enabled().returning(|_| false);
let builder = builder.middleware(middleware).unwrap();
assert!(builder.middleware.is_empty());
}
#[test]
#[cfg_attr(coverage_nightly, coverage(off))]
#[should_panic]
fn middleware_already_registered() {
let context = AppContext::test(None, None, None).unwrap();
let builder = HttpServiceBuilder::<AppContext>::empty(&context);
let mut middleware = MockMiddleware::default();
middleware.expect_name().returning(|| "test".to_string());
let builder = builder.middleware(middleware).unwrap();
let mut middleware = MockMiddleware::default();
middleware.expect_name().returning(|| "test".to_string());
builder.middleware(middleware).unwrap();
}
#[test]
#[cfg_attr(coverage_nightly, coverage(off))]
fn initializer() {
let context = AppContext::test(None, None, None).unwrap();
let builder = HttpServiceBuilder::<AppContext>::empty(&context);
let mut initializer = MockInitializer::default();
initializer.expect_enabled().returning(|_| true);
initializer.expect_name().returning(|| "test".to_string());
let builder = builder.initializer(initializer).unwrap();
assert_eq!(builder.initializers.len(), 1);
assert!(builder.initializers.contains_key("test"));
}
#[test]
#[cfg_attr(coverage_nightly, coverage(off))]
fn initializer_not_enabled() {
let context = AppContext::test(None, None, None).unwrap();
let builder = HttpServiceBuilder::<AppContext>::empty(&context);
let mut initializer = MockInitializer::default();
initializer.expect_enabled().returning(|_| false);
let builder = builder.initializer(initializer).unwrap();
assert!(builder.initializers.is_empty());
}
#[test]
#[cfg_attr(coverage_nightly, coverage(off))]
#[should_panic]
fn initializer_already_registered() {
let context = AppContext::test(None, None, None).unwrap();
let builder = HttpServiceBuilder::<AppContext>::empty(&context);
let mut initializer = MockInitializer::default();
initializer.expect_name().returning(|| "test".to_string());
let builder = builder.initializer(initializer).unwrap();
let mut initializer = MockInitializer::default();
initializer.expect_name().returning(|| "test".to_string());
builder.initializer(initializer).unwrap();
}
}