use axum::{
Router,
http::HeaderMap,
routing::MethodRouter,
};
use bytes::Bytes;
use ::core::time::Duration;
use tower_http::{
LatencyUnit,
classify::ServerErrorsFailureClass,
trace::{DefaultOnRequest, DefaultOnResponse, TraceLayer},
};
use tracing::{Level, Span, debug, error};
#[cfg(feature = "utoipa")]
use ::{
utoipa::openapi::OpenApi,
utoipa_rapidoc::RapiDoc,
utoipa_redoc::{Redoc, Servable as _},
utoipa_swagger_ui::SwaggerUi,
};
pub trait RouterExt<S>
where
S: Clone + Send + Sync + 'static,
{
#[must_use]
fn add_http_logging(self) -> Self;
#[cfg(feature = "utoipa")]
#[must_use]
fn add_openapi<P: AsRef<str>>(self, prefix: P, openapi: OpenApi) -> Self;
#[must_use]
fn public_routes(self, routes: Vec<(&str, MethodRouter<S>)>) -> Self;
}
#[expect(clippy::similar_names, reason = "Not too similar")]
impl<S> RouterExt<S> for Router<S>
where
S: Clone + Send + Sync + 'static,
{
fn add_http_logging(self) -> Self {
self.layer(TraceLayer::new_for_http()
.on_request(
DefaultOnRequest::new()
.level(Level::INFO)
)
.on_response(
DefaultOnResponse::new()
.level(Level::INFO)
.latency_unit(LatencyUnit::Micros)
)
.on_body_chunk(|chunk: &Bytes, _latency: Duration, _span: &Span| {
debug!("Sending {} bytes", chunk.len());
})
.on_eos(|_trailers: Option<&HeaderMap>, stream_duration: Duration, _span: &Span| {
debug!("Stream closed after {:?}", stream_duration);
})
.on_failure(|_error: ServerErrorsFailureClass, _latency: Duration, _span: &Span| {
error!("Something went wrong");
})
)
}
#[cfg(feature = "utoipa")]
fn add_openapi<P: AsRef<str>>(self, prefix: P, openapi: OpenApi) -> Self {
self
.merge(RapiDoc::new(format!("{}/openapi.json", prefix.as_ref()))
.path(format!("{}/rapidoc", prefix.as_ref()))
)
.merge(Redoc::with_url(format!("{}/redoc", prefix.as_ref()), openapi.clone()))
.merge(SwaggerUi::new(format!("{}/swagger", prefix.as_ref()))
.url(format!("{}/openapi.json", prefix.as_ref()), openapi)
)
}
fn public_routes(self, routes: Vec<(&str, MethodRouter<S>)>) -> Self {
let mut router = self;
for (path, method_router) in routes {
router = router.route(path, method_router);
}
router
}
}