use std::collections::HashMap;
use crate::BoxError;
use crate::config::{AppConfig, AppConfigDefinition};
use crate::pipeline::MiddlewareSlot;
use crate::BoxFuture;
use super::{
ActionArgs, ActionResult, BoxAction, InitContext, LayerContext, Plugin, PluginHandle,
PrepareContext, ReadyContext, RouteContext, ShutdownContext, TaggedLayer, TaggedRoute,
ordering::topological_sort,
};
pub const DEFAULT_REQUEST_BODY_LIMIT: usize = 8 * 1024 * 1024;
pub struct GasketApp {
plugins: Vec<PluginHandle>,
actions: HashMap<String, BoxAction>,
pub(crate) config: AppConfig,
pub(crate) extensions: http::Extensions,
pub(crate) request_body_limit: usize,
}
impl std::fmt::Debug for GasketApp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GasketApp")
.field("config", &self.config)
.field(
"plugins",
&self.plugins.iter().map(|p| p.name()).collect::<Vec<_>>(),
)
.field("actions", &self.actions.keys().collect::<Vec<_>>())
.finish_non_exhaustive()
}
}
impl GasketApp {
pub fn builder() -> GasketAppBuilder {
GasketAppBuilder {
plugins: Vec::new(),
config_def: None,
request_body_limit: DEFAULT_REQUEST_BODY_LIMIT,
}
}
#[must_use]
pub const fn config(&self) -> &AppConfig {
&self.config
}
#[must_use]
pub const fn extensions(&self) -> &http::Extensions {
&self.extensions
}
#[must_use]
pub fn plugins(&self) -> &[PluginHandle] {
&self.plugins
}
pub fn build_router(&self) -> axum::Router {
let tagged_routes = self.collect_routes();
let tagged_layers = self.collect_layers();
let mut bare_router = axum::Router::new();
let mut public_router = axum::Router::new();
let mut protected_router = axum::Router::new();
for tagged in tagged_routes {
match tagged.group {
super::RouteGroup::Bare => bare_router = bare_router.merge(tagged.router),
super::RouteGroup::Public => public_router = public_router.merge(tagged.router),
super::RouteGroup::Protected => {
protected_router = protected_router.merge(tagged.router)
}
}
}
let logged_public = public_router
.layer(axum::middleware::from_fn(
crate::observability::logging_middleware,
))
.layer(tower_http::limit::RequestBodyLimitLayer::new(
self.request_body_limit,
));
let mut transport_layers = Vec::new();
let mut protected_layers = Vec::new();
for tagged_layer in tagged_layers {
if tagged_layer.slot == MiddlewareSlot::TransportSecurity {
transport_layers.push(tagged_layer);
} else {
protected_layers.push(tagged_layer);
}
}
let mut protected_router = protected_router;
for tagged_layer in protected_layers.into_iter().rev() {
protected_router = tagged_layer.layer.apply(protected_router);
}
let protected_router = protected_router
.layer(axum::middleware::from_fn(
crate::observability::logging_middleware,
))
.layer(tower_http::limit::RequestBodyLimitLayer::new(
self.request_body_limit,
));
let mut instrumented_router = axum::Router::new()
.merge(logged_public)
.merge(protected_router);
for tagged_layer in transport_layers.into_iter().rev() {
instrumented_router = tagged_layer.layer.apply(instrumented_router);
}
axum::Router::new()
.merge(bare_router)
.merge(instrumented_router)
}
pub fn invoke_action(
&self,
name: &str,
args: ActionArgs,
) -> Result<BoxFuture<'static, ActionResult>, BoxError> {
let action = self
.actions
.get(name)
.ok_or_else(|| format!("Action '{name}' not found"))?;
Ok(action(args))
}
pub async fn invoke<T>(&self, name: &str, args: ActionArgs) -> Result<T, BoxError>
where
T: std::any::Any + Send + 'static,
{
let result = self.invoke_action(name, args)?.await?;
result.downcast::<T>().map(|boxed| *boxed).map_err(|_| {
format!(
"Action '{name}' returned a different type than {}",
std::any::type_name::<T>()
)
.into()
})
}
#[must_use]
pub fn collect_layers(&self) -> Vec<TaggedLayer> {
let ctx = LayerContext::new(self.config.clone(), self.extensions.clone());
let mut layers: Vec<TaggedLayer> = Vec::new();
for plugin in &self.plugins {
layers.extend(plugin.layers(&ctx));
}
layers.sort_by_key(|l| l.slot);
layers
}
#[must_use]
pub fn collect_routes(&self) -> Vec<TaggedRoute> {
let ctx = RouteContext::new(self.config.clone(), self.extensions.clone());
let mut routes = Vec::new();
for plugin in &self.plugins {
routes.extend(plugin.routes(&ctx));
}
routes
}
pub async fn ready(&self, local_addr: std::net::SocketAddr) -> Result<(), BoxError> {
let ctx = ReadyContext::new(self.config.clone(), self.extensions.clone(), local_addr);
for plugin in &self.plugins {
plugin.ready(&ctx).await?;
}
Ok(())
}
pub async fn shutdown(&self) {
let ctx = ShutdownContext::new(self.extensions.clone());
for plugin in self.plugins.iter().rev() {
if let Err(e) = plugin.shutdown(&ctx).await {
tracing::error!(plugin = plugin.name(), error = %e, "Plugin shutdown failed");
}
}
}
}
#[must_use = "GasketAppBuilder must be consumed by .build() to produce a GasketApp"]
pub struct GasketAppBuilder {
plugins: Vec<PluginHandle>,
config_def: Option<AppConfigDefinition>,
request_body_limit: usize,
}
impl std::fmt::Debug for GasketAppBuilder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GasketAppBuilder")
.field(
"plugins",
&self.plugins.iter().map(|p| p.name()).collect::<Vec<_>>(),
)
.finish_non_exhaustive()
}
}
impl GasketAppBuilder {
pub fn plugin(mut self, plugin: impl Plugin) -> Self {
self.plugins.push(PluginHandle::new(plugin));
self
}
pub fn plugin_handle(mut self, plugin: PluginHandle) -> Self {
self.plugins.push(plugin);
self
}
pub fn plugin_boxed(self, plugin: PluginHandle) -> Self {
self.plugin_handle(plugin)
}
pub fn preset(mut self, plugins: Vec<PluginHandle>) -> Self {
self.plugins.extend(plugins);
self
}
pub fn config(mut self, config_def: AppConfigDefinition) -> Self {
self.config_def = Some(config_def);
self
}
pub const fn request_body_limit(mut self, bytes: usize) -> Self {
self.request_body_limit = bytes;
self
}
pub async fn build(mut self) -> Result<GasketApp, BoxError> {
let plugin_names: Vec<&str> = self.plugins.iter().map(|p| p.name()).collect();
{
let mut seen = std::collections::HashSet::new();
for name in &plugin_names {
if !seen.insert(*name) {
return Err(format!("Duplicate plugin name: '{name}'").into());
}
}
}
for plugin in &self.plugins {
for dep in plugin.dependencies() {
if !plugin_names.contains(&dep) {
return Err(format!(
"Plugin '{}' requires missing dependency '{dep}'",
plugin.name()
)
.into());
}
}
}
let sorted_indices = topological_sort(&self.plugins)?;
let mut sorted_plugins: Vec<PluginHandle> = Vec::with_capacity(self.plugins.len());
let mut old_plugins: Vec<Option<PluginHandle>> =
self.plugins.into_iter().map(Some).collect();
for idx in sorted_indices {
if let Some(p) = old_plugins[idx].take() {
sorted_plugins.push(p);
}
}
self.plugins = sorted_plugins;
let mut init_ctx = InitContext::new();
for plugin in &self.plugins {
plugin.init(&mut init_ctx);
}
let config_def = self.config_def.unwrap_or_default();
let mut config = config_def.resolve()?;
for plugin in &self.plugins {
config = plugin.configure(config);
}
let mut prepare_ctx = PrepareContext::new(config.clone(), http::Extensions::new());
for (prepared_count, plugin) in self.plugins.iter().enumerate() {
if let Err(e) = plugin.prepare(&mut prepare_ctx).await {
tracing::error!(
plugin = plugin.name(),
error = %e,
"Plugin prepare failed, rolling back"
);
let shutdown_ctx = ShutdownContext::new(prepare_ctx.extensions.clone());
for prev_plugin in self.plugins[..prepared_count].iter().rev() {
if let Err(shutdown_err) = prev_plugin.shutdown(&shutdown_ctx).await {
tracing::warn!(
plugin = prev_plugin.name(),
error = %shutdown_err,
"Plugin shutdown failed during rollback"
);
}
}
return Err(e);
}
}
let app = GasketApp {
plugins: self.plugins,
actions: init_ctx.into_actions(),
config,
extensions: prepare_ctx.extensions,
request_body_limit: self.request_body_limit,
};
let plugin_names: Vec<&str> = app.plugins.iter().map(|p| p.name()).collect();
tracing::info!(?plugin_names, "GasketApp built successfully");
Ok(app)
}
}