use std::{
any::Any,
collections::{BTreeMap, HashSet},
convert::Infallible,
};
use axum::{
body::Body,
extract::OriginalUri,
http::{
header::{self, HeaderValue},
StatusCode,
},
response::IntoResponse,
routing::{MethodRouter, Router},
BoxError,
};
use dyn_clone::clone_box;
use http::{Request, Response};
use okapi::{openapi3, schemars::gen::SchemaGenerator};
use thiserror::Error;
#[cfg(feature = "grpc")]
use tonic::{
body::Body as GrpcBody,
server::NamedService,
service::{Routes as GrpcRoutes, RoutesBuilder as GrpcRoutesBuilder},
};
#[cfg(feature = "grpc")]
use tower::Service;
use tower::{
builder::ServiceBuilder,
util::{BoxCloneSyncService, MapRequestLayer},
ServiceExt,
};
use tower_http::{
catch_panic::CatchPanicLayer,
request_id::MakeRequestUuid,
set_header::SetResponseHeaderLayer,
trace::{DefaultOnRequest, DefaultOnResponse, TraceLayer},
LatencyUnit, ServiceBuilderExt,
};
use tracing::{debug, debug_span, info, info_span, warn};
use crate::{
apidoc::{ApiDocBuilder, ApiDocError},
auth::{
AuthExtractor, AuthLayer, AuthProvider, AuthSetupError, BasicAuthExtractor,
ConfigAuthProvider, HeaderAuthExtractor, NoOpAuthExtractor, NoOpAuthProvider,
},
config::AppConfig,
errors,
http_client::{HttpClientConfig, HttpClientError},
layers::{
ext::HandlerName, rate::RateLimitError, request_id::RecordRequestIdLayer,
timeout::TimeoutError,
},
logging::span::CustomMakeSpan,
metrics::{MetricsBuilder, MetricsError, MetricsState},
state,
tracing::TracingError,
util::ResponseExtension,
};
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum AppBuilderError {
#[error(transparent)]
ApiDoc(#[from] ApiDocError),
#[error(transparent)]
Metrics(#[from] MetricsError),
#[error(transparent)]
Tracing(#[from] TracingError),
#[error("Duplicate handler name: {0}")]
DuplicateHandlerName(&'static str),
#[error("HTTP client error: {0}")]
HttpClient(#[from] HttpClientError),
#[error("HTTP client is absent from configuration: {0}")]
HttpClientAbsent(String),
#[error("Auth framework error: {0}")]
Auth(#[from] AuthSetupError),
}
#[derive(Debug)]
#[non_exhaustive]
pub struct AppBuilder {
auth_provider: Box<dyn AuthProvider>,
auth_extractor: Box<dyn AuthExtractor>,
config: AppConfig,
metrics: Option<MetricsState>,
#[cfg(feature = "grpc")]
grpc_services: GrpcRoutesBuilder,
}
impl TryFrom<AppConfig> for AppBuilder {
type Error = AppBuilderError;
fn try_from(mut value: AppConfig) -> Result<Self, Self::Error> {
let auth_provider = value.auth.make_provider()?;
let auth_extractor = value.auth.extractor.make_extractor()?;
Ok(Self {
auth_provider,
auth_extractor,
metrics: value.metrics_state.take(),
config: value,
#[cfg(feature = "grpc")]
grpc_services: GrpcRoutes::builder(),
})
}
}
impl Default for AppBuilder {
fn default() -> Self {
Self {
auth_provider: Box::new(NoOpAuthProvider),
auth_extractor: Box::new(NoOpAuthExtractor),
config: AppConfig::default(),
metrics: None,
#[cfg(feature = "grpc")]
grpc_services: GrpcRoutes::builder(),
}
}
}
impl AppBuilder {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn from_config(cfg: &AppConfig) -> Result<Self, AppBuilderError> {
cfg.clone().try_into()
}
#[must_use]
pub fn auth_layer<S>(&self, perms: &'static [&'static str]) -> AuthLayer<S> {
AuthLayer::new(
perms,
clone_box(self.auth_provider.as_ref()),
clone_box(self.auth_extractor.as_ref()),
)
}
pub fn with_auth_extractor(&mut self, extractor: impl AuthExtractor) -> &mut Self {
self.auth_extractor = Box::new(extractor);
self
}
pub fn with_auth_provider(&mut self, provider: impl AuthProvider) -> &mut Self {
self.auth_provider = Box::new(provider);
self
}
pub fn with_basic_auth(&mut self) -> &mut Self {
self.auth_provider = Box::new(ConfigAuthProvider::from(self.config.auth.clone()));
self.auth_extractor = Box::new(BasicAuthExtractor::new(None::<&str>));
self
}
pub fn with_header_auth(&mut self) -> &mut Self {
self.auth_provider = Box::new(ConfigAuthProvider::from(self.config.auth.clone()));
self.auth_extractor = Box::new(HeaderAuthExtractor::new(None::<&str>, None::<&str>));
self
}
pub fn with_api_doc(&mut self, api_doc: ApiDocBuilder) -> &mut Self {
self.config.api_doc = Some(api_doc);
self
}
pub fn with_state<S>(&mut self, state: S) -> &mut Self
where
S: Clone + Send + 'static,
{
state::put(state);
self
}
pub fn with_metrics_config(&mut self, metrics: MetricsBuilder) -> &mut Self {
self.config.metrics = Some(metrics);
self
}
#[must_use]
pub fn with_configured_metrics_config<F>(mut self, modifier: F) -> Self
where
F: FnOnce(Option<MetricsBuilder>) -> Option<MetricsBuilder>,
{
self.config.metrics = modifier(self.config.metrics);
self
}
pub fn configure_api_doc<F>(&mut self, modifier: F)
where
F: FnOnce(ApiDocBuilder) -> ApiDocBuilder,
{
self.config.api_doc = Some(modifier(self.config.api_doc.take().unwrap_or_default()));
}
pub fn metrics(&self) -> Option<&MetricsState> {
self.metrics.as_ref()
}
pub fn build(mut self) -> Result<Router, AppBuilderError> {
let _build_span = debug_span!("build_app").entered();
let mut rtr = Router::new();
let mut handler_names = HashSet::new();
let mut grouped: BTreeMap<&str, Vec<&dyn HandlerExt>> = BTreeMap::new();
for handler in inventory::iter::<&dyn HandlerExt> {
let name = handler.name();
let _record_span = debug_span!("iter_handler", name).entered();
if !handler_names.insert(name) {
return Err(AppBuilderError::DuplicateHandlerName(name));
}
grouped
.entry(handler.path())
.and_modify(|handlers| handlers.push(*handler))
.or_insert_with(|| vec![*handler]);
debug!("handler recorded");
}
for (path, handlers) in grouped {
if let Some(method_rtr) = self.register_path(path, handlers) {
rtr = rtr.route(path, method_rtr.handle_error(error_handler));
}
}
#[cfg(feature = "grpc")]
{
rtr = rtr.merge(self.grpc_services.clone().routes().into_axum_router());
}
rtr = rtr.fallback(fallback_handler);
let tracing_config = self.config.tracing.as_ref();
let include_headers = tracing_config.is_some_and(|t| t.include_headers());
let request_level =
tracing_config.map_or(tracing::Level::DEBUG, |t| t.request_level().into());
let response_level =
tracing_config.map_or(tracing::Level::INFO, |t| t.response_level().into());
rtr = rtr.layer(
TraceLayer::new_for_http()
.make_span_with(CustomMakeSpan::new().include_headers(include_headers))
.on_request(DefaultOnRequest::new().level(request_level))
.on_response(
DefaultOnResponse::new()
.level(response_level)
.include_headers(include_headers)
.latency_unit(LatencyUnit::Micros),
),
);
let metrics_state = self.metrics().cloned();
if let (Some(m_state), Some(m_cfg)) = (&metrics_state, &self.config.metrics) {
rtr = rtr.merge(m_cfg.build_router(m_state));
}
rtr = rtr.merge(self.config.probes.build_router(
clone_box(self.auth_provider.as_ref()),
clone_box(self.auth_extractor.as_ref()),
));
if let Some(ref mut api_doc) = self.config.api_doc {
let disabled = self
.config
.handlers
.iter()
.filter(|(_, v)| v.disabled)
.map(|(k, _)| k.clone());
api_doc.set_disabled_handlers(disabled);
api_doc.set_app_defaults(
self.config.app_name.as_deref(),
self.config.app_version.as_deref(),
);
let auth = self.auth_extractor.security_schemes();
rtr = rtr.merge(api_doc.build_router(auth)?);
}
let final_rtr = self.wrap_global_layers(rtr, metrics_state);
info!("finished building application");
Ok(final_rtr)
}
pub async fn http_client(
&self,
name: impl AsRef<str>,
) -> Result<reqwest_middleware::ClientWithMiddleware, AppBuilderError> {
let name = name.as_ref();
match self.config.http_clients.get(name).cloned() {
Some(mut cfg) => {
if let Some(app_name) = &self.config.app_name {
cfg.with_app_name(app_name);
}
if let Some(app_version) = &self.config.app_version {
cfg.with_app_version(app_version);
}
let metrics = self.metrics().map(|m| m.client_metrics(name));
cfg.to_client(metrics).await.map_err(Into::into)
}
None => Err(AppBuilderError::HttpClientAbsent(name.to_string())),
}
}
pub async fn http_client_or_default(
&self,
name: impl AsRef<str>,
) -> Result<reqwest_middleware::ClientWithMiddleware, AppBuilderError> {
let name = name.as_ref();
match self.http_client(name).await {
Ok(client) => Ok(client),
Err(AppBuilderError::HttpClientAbsent(_)) => {
let mut cfg = HttpClientConfig::default();
if let Some(app_name) = &self.config.app_name {
cfg.with_app_name(app_name);
}
if let Some(app_version) = &self.config.app_version {
cfg.with_app_version(app_version);
}
let metrics = self.metrics().map(|m| m.client_metrics(name));
cfg.to_client(metrics).await.map_err(Into::into)
}
Err(err) => Err(err),
}
}
fn wrap_global_layers(&self, rtr: Router, metrics: Option<MetricsState>) -> Router {
let mut sensitive_headers = vec![header::AUTHORIZATION];
sensitive_headers.append(&mut self.auth_extractor.sensitive_headers());
let global_layers = ServiceBuilder::new()
.set_x_request_id(MakeRequestUuid)
.layer(RecordRequestIdLayer::new())
.sensitive_headers(sensitive_headers)
.option_layer(metrics)
.option_layer(if self.config.tracing.is_some() {
Some(MapRequestLayer::new(crate::logging::span::register_request))
} else {
None
})
.propagate_x_request_id()
.layer(SetResponseHeaderLayer::if_not_present(
header::SERVER,
self.server_header(),
))
.layer(CatchPanicLayer::custom(panic_handler));
rtr.layer(global_layers)
}
#[must_use]
fn register_path(
&self,
path: &str,
handlers: Vec<&dyn HandlerExt>,
) -> Option<MethodRouter<(), BoxError>> {
let _register_span = info_span!("register_path", path).entered();
let mut path_has_handlers = false;
let mut method_rtr = MethodRouter::new();
for handler in handlers {
let name = handler.name();
let _span = info_span!("register_handler", name, method = ?handler.method()).entered();
if let Some(cfg) = self.config.handlers.get(name) {
if cfg.disabled {
info!("skipping disabled handler");
continue;
}
}
method_rtr = self.register_handler(method_rtr, handler);
path_has_handlers = true;
info!("handler registered");
}
path_has_handlers.then_some(method_rtr)
}
fn register_handler(
&self,
method_rtr: MethodRouter<(), BoxError>,
handler: &dyn HandlerExt,
) -> MethodRouter<(), BoxError> {
let service = self.handler_service(handler);
match handler.method() {
http::Method::GET => method_rtr.get_service(service),
http::Method::HEAD => method_rtr.head_service(service),
http::Method::POST => method_rtr.post_service(service),
http::Method::PUT => method_rtr.put_service(service),
http::Method::DELETE => method_rtr.delete_service(service),
http::Method::OPTIONS => method_rtr.options_service(service),
http::Method::TRACE => method_rtr.trace_service(service),
http::Method::PATCH => method_rtr.patch_service(service),
other => panic!("Unsupported HTTP method: {other}"),
}
}
#[must_use]
fn handler_service(
&self,
handler: &dyn HandlerExt,
) -> BoxCloneSyncService<Request<Body>, Response<Body>, BoxError> {
let name = handler.name();
let _span = info_span!("handler_service", name, method = ?handler.method()).entered();
let service_cfg = self.config.handlers.get(name);
let cors_layer =
service_cfg.and_then(|cfg| match cfg.cors.as_ref().map(|c| c.make_layer()) {
None => None,
Some(Ok(layer)) => Some(layer.allow_methods(handler.method())),
Some(Err(err)) => {
warn!(error = %err, "Unable to build CORS layer");
None
}
});
ServiceBuilder::new()
.layer(BoxCloneSyncService::layer())
.layer(ResponseExtension(HandlerName::new(name)))
.option_layer(match handler.no_auth() {
true => None,
false => Some(self.auth_layer(handler.permissions())),
})
.option_layer(
service_cfg.and_then(|cfg| cfg.buffer.as_ref())
.map(|lcfg| lcfg.make_layer()),
)
.option_layer(
service_cfg.and_then(|cfg| cfg.rate_limit.as_ref())
.map(|rcfg| rcfg.make_layer()),
)
.option_layer(cors_layer)
.option_layer(service_cfg.map(|cfg| cfg.timeout.clone()).unwrap_or_default().make_layer())
.service(handler.service().map_err(|err| err.into()))
}
#[must_use]
fn server_header(&self) -> Option<HeaderValue> {
const UXUM_PRODUCT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
if let Some(app_name) = &self.config.app_name {
let val = if let Some(app_version) = &self.config.app_version {
let app_product = [app_name.as_str(), app_version.as_str()].join("/");
[&app_product, UXUM_PRODUCT].join(" ")
} else {
[app_name, UXUM_PRODUCT].join(" ")
};
HeaderValue::from_str(&val).ok()
} else {
HeaderValue::from_str(UXUM_PRODUCT).ok()
}
}
#[cfg(feature = "grpc")]
pub fn with_grpc_service<S>(&mut self, svc: S) -> &mut Self
where
S: Service<Request<GrpcBody>, Error = Infallible>
+ NamedService
+ Clone
+ Send
+ Sync
+ 'static,
S::Response: IntoResponse,
S::Future: Send + 'static,
{
self.grpc_services.add_service(svc);
self
}
}
pub(crate) async fn error_handler(err: BoxError) -> Response<Body> {
if let Some(rate_err) = err.downcast_ref::<RateLimitError>().cloned() {
return rate_err.into_response();
}
if let Some(timeo_err) = err.downcast_ref::<TimeoutError>().cloned() {
return timeo_err.into_response();
}
problemdetails::new(StatusCode::INTERNAL_SERVER_ERROR)
.with_type(errors::TAG_UXUM_ERROR)
.with_title(err.to_string())
.into_response()
}
pub(crate) async fn fallback_handler(OriginalUri(uri): OriginalUri) -> Response<Body> {
problemdetails::new(StatusCode::NOT_FOUND)
.with_type(errors::TAG_UXUM_NOT_FOUND)
.with_title("Resource not found")
.with_value("uri", uri.to_string())
.into_response()
}
fn panic_handler(err: Box<dyn Any + Send + 'static>) -> Response<Body> {
let details = if let Some(s) = err.downcast_ref::<String>() {
s.clone()
} else if let Some(s) = err.downcast_ref::<&str>() {
s.to_string()
} else {
"Unknown panic format".to_string()
};
problemdetails::new(StatusCode::INTERNAL_SERVER_ERROR)
.with_type(errors::TAG_UXUM_PANIC)
.with_title("Encountered panic in handler")
.with_detail(details)
.into_response()
}
pub trait HandlerExt: Sync {
fn name(&self) -> &'static str;
fn path(&self) -> &'static str;
fn spec_path(&self) -> &'static str;
fn method(&self) -> http::Method;
fn permissions(&self) -> &'static [&'static str];
fn no_auth(&self) -> bool;
fn service(&self) -> BoxCloneSyncService<Request<Body>, Response<Body>, Infallible>;
fn openapi_spec(&self, gen: &mut SchemaGenerator) -> openapi3::Operation;
}
inventory::collect!(&'static dyn HandlerExt);