use std::sync::Arc;
use crate::{
controller::Controller,
error::Result,
router::{ApiRoute, Router},
};
pub type RouterTransform = Box<dyn FnOnce(Router<()>) -> Router<()>>;
pub struct RouterPipeline(Result<Router<()>>);
impl RouterPipeline {
pub fn new() -> Self {
Self(Ok(crate::router::build()))
}
pub fn mount<C: Controller>(self, state: Arc<C::State>) -> Self {
Self(self.0.and_then(C::mount(state)))
}
pub fn map<F>(self, f: F) -> Self
where
F: FnOnce(Router<()>) -> Router<()>,
{
Self(self.0.map(f))
}
pub fn and_then<F>(self, f: F) -> Self
where
F: FnOnce(Router<()>) -> Result<Router<()>>,
{
Self(self.0.and_then(f))
}
pub fn route<H, T>(self, route_info: (&'static str, &'static str), handler: H) -> Self
where
H: axum::handler::Handler<T, ()>,
T: 'static,
{
self.map(|r| r.api_route(route_info, handler))
}
pub fn build(self) -> Result<Router<()>> {
self.0
}
pub fn mount_if<C: Controller>(self, condition: bool, state: Arc<C::State>) -> Self {
if condition {
self.mount::<C>(state)
} else {
self
}
}
pub fn mount_guarded<C: Controller, G>(self, state: Arc<C::State>, guard: G) -> Self
where
G: FnOnce() -> Result<()>,
{
Self(self.0.and_then(|router| {
guard()?;
C::mount(state)(router)
}))
}
pub fn fold<I, F>(self, steps: I) -> Self
where
I: IntoIterator<Item = F>,
F: FnOnce(Router<()>) -> Result<Router<()>>,
{
steps.into_iter().fold(self, |p, step| p.and_then(step))
}
pub fn layer_all(self, transforms: impl IntoIterator<Item = RouterTransform>) -> Self {
transforms.into_iter().fold(self, |p, f| p.map(f))
}
pub fn group<F>(self, prefix: &str, f: F) -> Self
where
F: FnOnce(RouterPipeline) -> RouterPipeline,
{
let prefix = prefix.to_owned();
self.and_then(move |outer| {
let inner = f(RouterPipeline::new()).build()?;
Ok(outer.merge(Router::new().nest(&prefix, inner)))
})
}
}
impl Default for RouterPipeline {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use axum::{body::Body, http::Request, routing::get};
use tower::ServiceExt;
use super::*;
use crate::{controller::Controller, error::Result, router::Router};
struct PingController;
impl Controller for PingController {
type State = ();
fn mount(state: Arc<Self::State>) -> impl FnOnce(Router<()>) -> Result<Router<()>> {
move |router| {
let scoped: Router<Arc<()>> =
Router::new().route("/ping", get(|| async { "pong" }));
Ok(router.merge(scoped.with_state(state)))
}
}
}
fn ping_state() -> Arc<()> {
Arc::new(())
}
async fn status(app: Router<()>, uri: &str) -> u16 {
app.oneshot(Request::builder().uri(uri).body(Body::empty()).unwrap())
.await
.unwrap()
.status()
.as_u16()
}
#[test]
fn mount_guarded_short_circuits_on_err_guard() {
let result = RouterPipeline::new()
.mount_guarded::<PingController, _>(ping_state(), || {
Err(crate::error::Error::other("guard failed"))
})
.build();
assert!(
result.is_err(),
"build() should return Err when guard fails"
);
}
#[tokio::test]
async fn mount_guarded_registers_route_on_ok_guard() {
let app = RouterPipeline::new()
.mount_guarded::<PingController, _>(ping_state(), || Ok(()))
.build()
.expect("build should succeed when guard passes");
assert_eq!(status(app, "/ping").await, 200);
}
#[tokio::test]
async fn mount_if_false_route_returns_404() {
let app = RouterPipeline::new()
.mount_if::<PingController>(false, ping_state())
.build()
.expect("build should succeed even when mount_if is false");
assert_eq!(status(app, "/ping").await, 404);
}
#[tokio::test]
async fn mount_if_true_route_returns_200() {
let app = RouterPipeline::new()
.mount_if::<PingController>(true, ping_state())
.build()
.expect("build should succeed when mount_if is true");
assert_eq!(status(app, "/ping").await, 200);
}
#[tokio::test]
async fn group_prefix_is_applied_to_routes() {
let app = RouterPipeline::new()
.group("/v1", |g| g.mount::<PingController>(ping_state()))
.build()
.expect("build should succeed");
assert_eq!(
status(app.clone(), "/v1/ping").await,
200,
"/v1/ping should be 200"
);
assert_eq!(
status(app, "/ping").await,
404,
"/ping without prefix should be 404"
);
}
#[test]
fn error_from_and_then_propagates_through_remaining_steps() {
let result = RouterPipeline::new()
.and_then(|_| Err(crate::error::Error::other("intentional failure")))
.mount::<PingController>(ping_state()) .build();
assert!(
result.is_err(),
"error should propagate through the rest of the pipeline"
);
}
}