#[ctor::ctor]
fn install_crypto_provider() {
rustls::crypto::aws_lc_rs::default_provider().install_default().ok();
}
pub mod api;
pub mod auth;
pub mod config;
mod config_watcher;
pub mod connections;
mod crypto;
pub mod db;
mod email;
pub mod encryption;
mod error_enrichment;
pub mod errors;
mod leader_election;
pub mod limits;
mod metrics;
mod notifications;
mod openapi;
mod payment_providers;
mod probes;
mod request_logging;
pub mod sample_files;
mod static_assets;
mod sync;
pub mod tasks;
pub mod telemetry;
pub mod tool_executor;
pub mod tool_injection;
mod types;
pub mod webhooks;
#[cfg(test)]
mod test;
use crate::{
api::models::{
deployments::{DeployedModelCreate, StandardModelCreate},
users::Role,
},
auth::password,
config::CorsOrigin,
db::handlers::{Deployments, Groups, Repository, Users},
db::models::{deployments::DeploymentCreateDBRequest, users::UserCreateDBRequest},
metrics::GenAiMetrics,
openapi::{AdminApiDoc, AiApiDoc},
request_logging::serializers::{parse_ai_request, parse_ai_response},
};
use sqlx_pool_router::{DbPools, PoolProvider};
use anyhow::Context;
use auth::middleware::admin_ai_proxy_middleware;
use axum::extract::DefaultBodyLimit;
use axum::http::HeaderValue;
use axum::{
Router, ServiceExt, http, middleware,
routing::{delete, get, patch, post},
};
use axum_prometheus::PrometheusMetricLayerBuilder;
use bon::Builder;
pub use config::Config;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use opentelemetry::trace::TraceContextExt;
use outlet::{MultiHandler, RequestLoggerConfig, RequestLoggerLayer};
use outlet_postgres::PostgresHandler;
use request_logging::{AiResponse, ParsedAIRequest};
use sqlx::{ConnectOptions, Executor, PgPool, postgres::PgConnectOptions};
use std::path::PathBuf;
use std::str::FromStr;
use std::sync::{Arc, OnceLock};
use tokio::net::TcpListener;
use tower::Layer;
use tower_http::{cors::CorsLayer, trace::TraceLayer};
use tracing::{debug, info, instrument};
use tracing_opentelemetry::OpenTelemetrySpanExt;
use utoipa::OpenApi;
use utoipa_scalar::{Scalar, Servable};
use uuid::Uuid;
pub use types::{ApiKeyId, DeploymentId, GroupId, InferenceEndpointId, UserId};
#[derive(Clone)]
pub struct SharedConfig(Arc<arc_swap::ArcSwap<Config>>);
impl SharedConfig {
pub fn new(config: Config) -> Self {
Self(Arc::new(arc_swap::ArcSwap::from_pointee(config)))
}
pub fn snapshot(&self) -> Arc<Config> {
self.0.load_full()
}
pub fn store(&self, config: Config) {
self.0.store(Arc::new(config));
}
}
impl From<Config> for SharedConfig {
fn from(config: Config) -> Self {
Self::new(config)
}
}
#[derive(Clone, Builder)]
pub struct AppState<P = DbPools>
where
P: PoolProvider + Clone,
{
pub db: P,
pub config: SharedConfig,
pub outlet_db: Option<DbPools>,
pub metrics_recorder: Option<GenAiMetrics>,
#[builder(default = false)]
pub is_leader: bool,
pub request_manager: Arc<fusillade::PostgresRequestManager<P, fusillade::ReqwestHttpClient>>,
pub task_runner: Arc<tasks::TaskRunner<P>>,
pub limiters: limits::Limiters,
pub connections_encryption_key: Option<Vec<u8>>,
}
impl<P> AppState<P>
where
P: PoolProvider + Clone,
{
pub fn current_config(&self) -> Arc<Config> {
self.config.snapshot()
}
}
pub fn migrator() -> sqlx::migrate::Migrator {
sqlx::migrate!("./migrations")
}
static PROMETHEUS_HANDLE: OnceLock<PrometheusHandle> = OnceLock::new();
static AXUM_PROMETHEUS_PREFIX_SET: OnceLock<()> = OnceLock::new();
fn get_or_install_prometheus_handle() -> PrometheusHandle {
PROMETHEUS_HANDLE
.get_or_init(|| {
const ANALYTICS_LAG_BUCKETS: &[f64] = &[0.1, 0.5, 1.0, 5.0, 10.0, 30.0, 60.0, 120.0, 300.0, 600.0];
const CACHE_SYNC_LAG_BUCKETS: &[f64] = &[0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0];
const RETRY_ATTEMPTS_BUCKETS: &[f64] = &[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
PrometheusBuilder::new()
.set_buckets_for_metric(Matcher::Full("dwctl_analytics_lag_seconds".to_string()), ANALYTICS_LAG_BUCKETS)
.expect("Failed to set custom buckets for dwctl_analytics_lag_seconds")
.set_buckets_for_metric(Matcher::Full("dwctl_cache_sync_lag_seconds".to_string()), CACHE_SYNC_LAG_BUCKETS)
.expect("Failed to set custom buckets for dwctl_cache_sync_lag_seconds")
.set_buckets_for_metric(
Matcher::Full("fusillade_retry_attempts_on_success".to_string()),
RETRY_ATTEMPTS_BUCKETS,
)
.expect("Failed to set custom buckets for fusillade_retry_attempts_on_success")
.install_recorder()
.expect("Failed to install Prometheus recorder")
})
.clone()
}
#[instrument(skip_all)]
pub async fn create_initial_admin_user(
email: &str,
password: Option<&str>,
argon2_params: password::Argon2Params,
db: &PgPool,
) -> Result<UserId, sqlx::Error> {
let password_hash = if let Some(pwd) = password {
Some(
password::hash_string_with_params(pwd, Some(argon2_params))
.map_err(|e| sqlx::Error::Encode(format!("Failed to hash admin password: {e}").into()))?,
)
} else {
None
};
let mut tx = db.begin().await?;
let mut user_repo = Users::new(&mut tx);
if let Some(existing_user) = user_repo
.get_user_by_email(email)
.await
.map_err(|e| sqlx::Error::Protocol(format!("Failed to check existing user: {e}")))?
{
if let Some(password_hash) = password_hash {
sqlx::query!("UPDATE users SET password_hash = $1 WHERE id = $2", password_hash, existing_user.id)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
return Ok(existing_user.id);
}
let user_create = UserCreateDBRequest {
username: email.to_string(),
email: email.to_string(),
display_name: None,
avatar_url: None,
is_admin: true,
roles: vec![Role::PlatformManager],
auth_source: "system".to_string(),
password_hash,
external_user_id: None,
};
let created_user = user_repo
.create(&user_create)
.await
.map_err(|e| sqlx::Error::Protocol(format!("Failed to create admin user: {e}")))?;
tx.commit().await?;
Ok(created_user.id)
}
#[instrument(skip_all)]
pub async fn seed_database(sources: &[config::ModelSource], db: &PgPool) -> Result<(), anyhow::Error> {
let mut tx = db.begin().await?;
let seeded = sqlx::query_scalar!("SELECT value FROM system_config WHERE key = 'endpoints_seeded'")
.fetch_optional(&mut *tx)
.await?;
if let Some(true) = seeded {
info!("Database already seeded, skipping seeding operations");
tx.commit().await?;
return Ok(());
}
info!("Seeding database with initial configuration");
let system_user_id = Uuid::nil();
for source in sources {
if let Some(endpoint_id) = sqlx::query_scalar!(
"INSERT INTO inference_endpoints (name, description, url, created_by)
VALUES ($1, $2, $3, $4)
ON CONFLICT (name) DO NOTHING
RETURNING id",
source.name,
None::<String>, source.url.as_str(),
system_user_id,
)
.fetch_optional(&mut *tx)
.await?
{
for model in source.default_models.as_deref().unwrap_or(&[]) {
let mut model_repo = Deployments::new(&mut tx);
if let Ok(row) = model_repo
.create(&DeploymentCreateDBRequest::from_api_create(
Uuid::nil(),
DeployedModelCreate::Standard(StandardModelCreate {
model_name: model.name.clone(),
alias: Some(model.name.clone()),
display_name: None,
hosted_on: endpoint_id,
description: None,
model_type: None,
capabilities: None,
requests_per_second: None,
burst_size: None,
capacity: None,
batch_capacity: None,
throughput: None,
tariffs: None,
provider_pricing: None,
sanitize_responses: None,
trusted: None,
open_responses_adapter: None,
traffic_routing_rules: None,
allowed_batch_completion_windows: None,
metadata: None,
}),
))
.await
&& model.add_to_everyone_group
{
let mut groups_repo = Groups::new(&mut tx);
if let Err(e) = groups_repo.add_deployment_to_group(row.id, Uuid::nil(), Uuid::nil()).await {
debug!(
"Failed to add deployed model {} to 'everyone' group during seeding: {}",
model.name, e
);
}
}
}
}
}
let system_api_key_id = Uuid::nil();
let new_secret = crypto::generate_api_key();
sqlx::query!(
"UPDATE api_keys SET secret = $1, purpose = 'platform' WHERE id = $2",
new_secret,
system_api_key_id
)
.execute(&mut *tx)
.await?;
sqlx::query!(
"UPDATE system_config SET value = true, updated_at = NOW()
WHERE key = 'endpoints_seeded'"
)
.execute(&mut *tx)
.await?;
tx.commit().await?;
debug!("Database seeded successfully");
Ok(())
}
async fn setup_database(
config: &Config,
pool: Option<PgPool>,
) -> anyhow::Result<(Option<db::embedded::EmbeddedDatabase>, DbPools, DbPools, Option<DbPools>)> {
let slow_threshold = std::time::Duration::from_millis(config.slow_statement_threshold_ms);
let (embedded_db, pool, test_replica_pool) = if let Some(existing_pool) = pool {
info!("Using provided database pool with TestDbPools for read/write separation");
let test_pools = sqlx_pool_router::TestDbPools::new(existing_pool.clone())
.await
.expect("Failed to create TestDbPools");
let replica_pool = test_pools.read().clone();
(None, existing_pool, Some(replica_pool))
} else {
let (_embedded_db, database_url) = match &config.database {
config::DatabaseConfig::Embedded { .. } => {
let persistent = config.database.embedded_persistent();
info!("Starting with embedded database (persistent: {})", persistent);
if !persistent {
info!("persistent=false: database will be ephemeral and data will be lost on shutdown");
}
#[cfg(feature = "embedded-db")]
{
let data_dir = config.database.embedded_data_dir();
let embedded_db = db::embedded::EmbeddedDatabase::start(data_dir, persistent).await?;
let url = embedded_db.connection_string().to_string();
(Some(embedded_db), url)
}
#[cfg(not(feature = "embedded-db"))]
{
anyhow::bail!(
"Embedded database is configured but the feature is not enabled. \
Rebuild with --features embedded-db to use embedded database."
);
}
}
config::DatabaseConfig::External { url, .. } => {
info!("Using external database");
(None::<db::embedded::EmbeddedDatabase>, url.clone())
}
};
let main_settings = config.database.main_pool_settings();
let connect_opts = PgConnectOptions::from_str(&database_url)?.log_slow_statements(log::LevelFilter::Warn, slow_threshold);
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(main_settings.max_connections)
.min_connections(main_settings.min_connections)
.acquire_timeout(std::time::Duration::from_secs(main_settings.acquire_timeout_secs))
.idle_timeout(if main_settings.idle_timeout_secs > 0 {
Some(std::time::Duration::from_secs(main_settings.idle_timeout_secs))
} else {
None
})
.max_lifetime(if main_settings.max_lifetime_secs > 0 {
Some(std::time::Duration::from_secs(main_settings.max_lifetime_secs))
} else {
None
})
.connect_with(connect_opts)
.await?;
(_embedded_db, pool, None)
};
migrator().run(&pool).await?;
let db_pools = if let Some(test_replica) = test_replica_pool {
info!("Using test replica pool with read-only enforcement");
DbPools::with_replica(pool, test_replica)
} else if let Some(replica_url) = config.database.external_replica_url() {
info!("Setting up read replica pool");
let replica_settings = config.database.main_replica_pool_settings();
let replica_opts = PgConnectOptions::from_str(replica_url)?.log_slow_statements(log::LevelFilter::Warn, slow_threshold);
let replica_pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(replica_settings.max_connections)
.min_connections(replica_settings.min_connections)
.acquire_timeout(std::time::Duration::from_secs(replica_settings.acquire_timeout_secs))
.idle_timeout(if replica_settings.idle_timeout_secs > 0 {
Some(std::time::Duration::from_secs(replica_settings.idle_timeout_secs))
} else {
None
})
.max_lifetime(if replica_settings.max_lifetime_secs > 0 {
Some(std::time::Duration::from_secs(replica_settings.max_lifetime_secs))
} else {
None
})
.connect_with(replica_opts)
.await?;
DbPools::with_replica(pool, replica_pool)
} else {
DbPools::new(pool)
};
let main_connect_opts = db_pools.connect_options().as_ref().clone();
async fn create_schema_pool(
schema: String,
opts: sqlx::postgres::PgConnectOptions,
settings: &config::PoolSettings,
) -> Result<sqlx::PgPool, sqlx::Error> {
let search_path_key = "search_path".to_string();
let search_path_value = schema.clone();
info!("Setting search_path={} via connection options for schema pool", schema);
let opts_with_schema = opts.options([(search_path_key, search_path_value)]);
sqlx::postgres::PgPoolOptions::new()
.max_connections(settings.max_connections)
.min_connections(settings.min_connections)
.acquire_timeout(std::time::Duration::from_secs(settings.acquire_timeout_secs))
.idle_timeout(if settings.idle_timeout_secs > 0 {
Some(std::time::Duration::from_secs(settings.idle_timeout_secs))
} else {
None
})
.max_lifetime(if settings.max_lifetime_secs > 0 {
Some(std::time::Duration::from_secs(settings.max_lifetime_secs))
} else {
None
})
.connect_with(opts_with_schema)
.await
}
info!("Setting up fusillade batch processing pool");
let fusillade_pools = match config.database.fusillade() {
config::ComponentDb::Schema {
name, pool: pool_settings, ..
} => {
let primary = create_schema_pool(name.clone(), main_connect_opts.clone(), pool_settings).await?;
primary.execute(&*format!("CREATE SCHEMA IF NOT EXISTS {name}")).await?;
if db_pools.has_replica() {
info!("Setting up fusillade read replica (schema mode)");
let replica_opts = db_pools.read().connect_options().as_ref().clone();
let replica_pool_settings = config.database.fusillade().replica_pool_settings();
let replica = create_schema_pool(name.clone(), replica_opts, replica_pool_settings).await?;
DbPools::with_replica(primary, replica)
} else {
DbPools::new(primary)
}
}
config::ComponentDb::Dedicated {
url,
replica_url,
pool: pool_settings,
..
} => {
info!("Using dedicated database for fusillade");
let connect_opts = PgConnectOptions::from_str(url)?.log_slow_statements(log::LevelFilter::Warn, slow_threshold);
let primary = sqlx::postgres::PgPoolOptions::new()
.max_connections(pool_settings.max_connections)
.min_connections(pool_settings.min_connections)
.acquire_timeout(std::time::Duration::from_secs(pool_settings.acquire_timeout_secs))
.idle_timeout(if pool_settings.idle_timeout_secs > 0 {
Some(std::time::Duration::from_secs(pool_settings.idle_timeout_secs))
} else {
None
})
.max_lifetime(if pool_settings.max_lifetime_secs > 0 {
Some(std::time::Duration::from_secs(pool_settings.max_lifetime_secs))
} else {
None
})
.connect_with(connect_opts)
.await?;
if let Some(replica_url) = replica_url {
info!("Setting up fusillade read replica");
let replica_pool_settings = config.database.fusillade().replica_pool_settings();
let replica_opts = PgConnectOptions::from_str(replica_url)?.log_slow_statements(log::LevelFilter::Warn, slow_threshold);
let replica = sqlx::postgres::PgPoolOptions::new()
.max_connections(replica_pool_settings.max_connections)
.min_connections(replica_pool_settings.min_connections)
.acquire_timeout(std::time::Duration::from_secs(replica_pool_settings.acquire_timeout_secs))
.idle_timeout(if replica_pool_settings.idle_timeout_secs > 0 {
Some(std::time::Duration::from_secs(replica_pool_settings.idle_timeout_secs))
} else {
None
})
.max_lifetime(if replica_pool_settings.max_lifetime_secs > 0 {
Some(std::time::Duration::from_secs(replica_pool_settings.max_lifetime_secs))
} else {
None
})
.connect_with(replica_opts)
.await?;
DbPools::with_replica(primary, replica)
} else {
DbPools::new(primary)
}
}
};
fusillade::migrator().run(&*fusillade_pools).await?;
underway::run_migrations(&*db_pools).await?;
let outlet_pools = if config.enable_request_logging {
info!("Setting up outlet request logging pool (logging enabled)");
let pools = match config.database.outlet() {
config::ComponentDb::Schema {
name, pool: pool_settings, ..
} => {
let primary = create_schema_pool(name.clone(), main_connect_opts.clone(), pool_settings).await?;
primary.execute(&*format!("CREATE SCHEMA IF NOT EXISTS {name}")).await?;
if db_pools.has_replica() {
info!("Setting up outlet read replica (schema mode)");
let replica_opts = db_pools.read().connect_options().as_ref().clone();
let replica_pool_settings = config.database.outlet().replica_pool_settings();
let replica = create_schema_pool(name.clone(), replica_opts, replica_pool_settings).await?;
DbPools::with_replica(primary, replica)
} else {
DbPools::new(primary)
}
}
config::ComponentDb::Dedicated {
url,
replica_url,
pool: pool_settings,
..
} => {
info!("Using dedicated database for outlet");
let connect_opts = PgConnectOptions::from_str(url)?.log_slow_statements(log::LevelFilter::Warn, slow_threshold);
let primary = sqlx::postgres::PgPoolOptions::new()
.max_connections(pool_settings.max_connections)
.min_connections(pool_settings.min_connections)
.acquire_timeout(std::time::Duration::from_secs(pool_settings.acquire_timeout_secs))
.idle_timeout(if pool_settings.idle_timeout_secs > 0 {
Some(std::time::Duration::from_secs(pool_settings.idle_timeout_secs))
} else {
None
})
.max_lifetime(if pool_settings.max_lifetime_secs > 0 {
Some(std::time::Duration::from_secs(pool_settings.max_lifetime_secs))
} else {
None
})
.connect_with(connect_opts)
.await?;
if let Some(replica_url) = replica_url {
info!("Setting up outlet read replica");
let replica_pool_settings = config.database.outlet().replica_pool_settings();
let replica_opts = PgConnectOptions::from_str(replica_url)?.log_slow_statements(log::LevelFilter::Warn, slow_threshold);
let replica = sqlx::postgres::PgPoolOptions::new()
.max_connections(replica_pool_settings.max_connections)
.min_connections(replica_pool_settings.min_connections)
.acquire_timeout(std::time::Duration::from_secs(replica_pool_settings.acquire_timeout_secs))
.idle_timeout(if replica_pool_settings.idle_timeout_secs > 0 {
Some(std::time::Duration::from_secs(replica_pool_settings.idle_timeout_secs))
} else {
None
})
.max_lifetime(if replica_pool_settings.max_lifetime_secs > 0 {
Some(std::time::Duration::from_secs(replica_pool_settings.max_lifetime_secs))
} else {
None
})
.connect_with(replica_opts)
.await?;
DbPools::with_replica(primary, replica)
} else {
DbPools::new(primary)
}
}
};
outlet_postgres::migrator().run(&*pools).await?;
Some(pools)
} else {
info!("Skipping outlet pool setup (logging disabled)");
None
};
let argon2_params = password::Argon2Params {
memory_kib: config.auth.native.password.argon2_memory_kib,
iterations: config.auth.native.password.argon2_iterations,
parallelism: config.auth.native.password.argon2_parallelism,
};
create_initial_admin_user(&config.admin_email, config.admin_password.as_deref(), argon2_params, &db_pools)
.await
.map_err(|e| anyhow::anyhow!("Failed to create initial admin user: {}", e))?;
seed_database(&config.model_sources, &db_pools).await?;
Ok((embedded_db, db_pools, fusillade_pools, outlet_pools))
}
fn create_cors_layer(config: &Config) -> anyhow::Result<CorsLayer> {
let mut origins = Vec::new();
for origin in &config.auth.security.cors.allowed_origins {
let header_value = match origin {
CorsOrigin::Wildcard => "*".parse::<HeaderValue>()?,
CorsOrigin::Url(url) => {
let url_str = url.as_str().trim_end_matches('/');
url_str.parse::<HeaderValue>()?
}
};
origins.push(header_value);
}
info!("Configuring CORS with allowed origins: {:?}", origins);
let exposed: Vec<http::HeaderName> = config
.auth
.security
.cors
.exposed_headers
.iter()
.filter_map(|h| h.parse().ok())
.collect();
let mut cors = CorsLayer::new()
.allow_origin(origins)
.allow_methods([
http::Method::GET,
http::Method::POST,
http::Method::PUT,
http::Method::DELETE,
http::Method::PATCH,
http::Method::OPTIONS,
])
.allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION, http::header::ACCEPT])
.allow_credentials(config.auth.security.cors.allow_credentials)
.expose_headers(exposed);
if let Some(max_age) = config.auth.security.cors.max_age {
cors = cors.max_age(std::time::Duration::from_secs(max_age));
}
Ok(cors)
}
#[instrument(skip_all)]
pub async fn build_router(
state: &mut AppState,
onwards_router: Router,
analytics_sender: Option<request_logging::batcher::AnalyticsSender>,
metrics_recorder: Option<GenAiMetrics>,
strict_mode: bool,
) -> anyhow::Result<Router> {
let config = state.current_config();
let request_logging_enabled = state.outlet_db.is_some() && config.enable_request_logging;
let analytics_enabled = config.enable_analytics;
let outlet_layer = if request_logging_enabled || analytics_enabled {
state.metrics_recorder = metrics_recorder;
let mut multi_handler = MultiHandler::new();
if request_logging_enabled {
let outlet_pool = state.outlet_db.as_ref().expect("outlet_db checked above");
let postgres_handler = PostgresHandler::<DbPools, ParsedAIRequest, AiResponse>::from_pool_provider(outlet_pool.clone())
.await
.expect("Failed to create PostgresHandler for request logging")
.with_request_serializer(parse_ai_request)
.with_response_serializer(parse_ai_response);
multi_handler = multi_handler.with(postgres_handler);
}
if let Some(sender) = analytics_sender {
let analytics_handler = request_logging::AnalyticsHandler::new(sender, uuid::Uuid::new_v4(), config.as_ref().clone());
multi_handler = multi_handler.with(analytics_handler);
}
if multi_handler.is_empty() {
None
} else {
let outlet_config = RequestLoggerConfig {
capture_request_body: true,
capture_response_body: true,
path_filter: None, ..Default::default()
};
Some(RequestLoggerLayer::new(outlet_config, multi_handler))
}
} else {
None
};
let auth_routes = Router::new()
.route(
"/authentication/register",
get(api::handlers::auth::get_registration_info).post(api::handlers::auth::register),
)
.route(
"/authentication/login",
get(api::handlers::auth::get_login_info).post(api::handlers::auth::login),
)
.route("/authentication/logout", post(api::handlers::auth::logout))
.route("/authentication/password-resets", post(api::handlers::auth::request_password_reset))
.route(
"/authentication/password-resets/{token_id}/confirm",
post(api::handlers::auth::confirm_password_reset),
)
.route("/authentication/password-change", post(api::handlers::auth::change_password))
.with_state(state.clone());
let api_routes = Router::new()
.route("/config", get(api::handlers::config::get_config))
.route("/auth/cli-callback", get(api::handlers::auth::cli_callback))
.route("/users", get(api::handlers::users::list_users))
.route("/users", post(api::handlers::users::create_user))
.route("/users/{id}", get(api::handlers::users::get_user))
.route("/users/{id}", patch(api::handlers::users::update_user))
.route("/users/{id}", delete(api::handlers::users::delete_user))
.route("/users/{user_id}/api-keys", get(api::handlers::api_keys::list_user_api_keys))
.route("/users/{user_id}/api-keys", post(api::handlers::api_keys::create_user_api_key))
.route("/users/{user_id}/api-keys/{id}", get(api::handlers::api_keys::get_user_api_key))
.route(
"/users/{user_id}/api-keys/{id}",
delete(api::handlers::api_keys::delete_user_api_key),
)
.route("/users/{user_id}/webhooks", get(api::handlers::webhooks::list_webhooks))
.route("/users/{user_id}/webhooks", post(api::handlers::webhooks::create_webhook))
.route("/users/{user_id}/webhooks/{webhook_id}", get(api::handlers::webhooks::get_webhook))
.route(
"/users/{user_id}/webhooks/{webhook_id}",
patch(api::handlers::webhooks::update_webhook),
)
.route(
"/users/{user_id}/webhooks/{webhook_id}",
delete(api::handlers::webhooks::delete_webhook),
)
.route(
"/users/{user_id}/webhooks/{webhook_id}/rotate-secret",
post(api::handlers::webhooks::rotate_secret),
)
.route("/users/{user_id}/groups", get(api::handlers::groups::get_user_groups))
.route("/users/{user_id}/groups/{group_id}", post(api::handlers::groups::add_group_to_user))
.route(
"/users/{user_id}/groups/{group_id}",
delete(api::handlers::groups::remove_group_from_user),
)
.route("/transactions", post(api::handlers::transactions::create_transaction))
.route("/transactions/{transaction_id}", get(api::handlers::transactions::get_transaction))
.route("/transactions", get(api::handlers::transactions::list_transactions))
.route("/payments", post(api::handlers::payments::create_payment))
.route("/payments/{id}", patch(api::handlers::payments::process_payment))
.route("/billing-portal", post(api::handlers::payments::create_billing_portal_session))
.route("/auto-topup/enable", post(api::handlers::payments::enable_auto_topup))
.route("/auto-topup/disable", post(api::handlers::payments::disable_auto_topup))
.route("/endpoints", get(api::handlers::inference_endpoints::list_inference_endpoints))
.route("/endpoints", post(api::handlers::inference_endpoints::create_inference_endpoint))
.route(
"/endpoints/validate",
post(api::handlers::inference_endpoints::validate_inference_endpoint),
)
.route("/endpoints/{id}", get(api::handlers::inference_endpoints::get_inference_endpoint))
.route(
"/endpoints/{id}",
patch(api::handlers::inference_endpoints::update_inference_endpoint),
)
.route(
"/endpoints/{id}",
delete(api::handlers::inference_endpoints::delete_inference_endpoint),
)
.route(
"/endpoints/{id}/synchronize",
post(api::handlers::inference_endpoints::synchronize_endpoint),
)
.route("/models", get(api::handlers::deployments::list_deployed_models))
.route("/models", post(api::handlers::deployments::create_deployed_model))
.route("/models/{id}", get(api::handlers::deployments::get_deployed_model))
.route("/models/{id}", patch(api::handlers::deployments::update_deployed_model))
.route("/models/{id}", delete(api::handlers::deployments::delete_deployed_model))
.route(
"/provider-display-configs",
get(api::handlers::provider_display_configs::list_provider_display_configs),
)
.route(
"/provider-display-configs",
post(api::handlers::provider_display_configs::create_provider_display_config),
)
.route(
"/provider-display-configs/{provider_key}",
get(api::handlers::provider_display_configs::get_provider_display_config),
)
.route(
"/provider-display-configs/{provider_key}",
patch(api::handlers::provider_display_configs::update_provider_display_config),
)
.route(
"/provider-display-configs/{provider_key}",
delete(api::handlers::provider_display_configs::delete_provider_display_config),
)
.route("/models/{id}/components", get(api::handlers::deployments::get_model_components))
.route(
"/models/{id}/components/{component_id}",
post(api::handlers::deployments::add_model_component),
)
.route(
"/models/{id}/components/{component_id}",
patch(api::handlers::deployments::update_model_component),
)
.route(
"/models/{id}/components/{component_id}",
delete(api::handlers::deployments::remove_model_component),
)
.route("/groups", get(api::handlers::groups::list_groups))
.route("/groups", post(api::handlers::groups::create_group))
.route("/groups/{id}", get(api::handlers::groups::get_group))
.route("/groups/{id}", patch(api::handlers::groups::update_group))
.route("/groups/{id}", delete(api::handlers::groups::delete_group))
.route("/groups/{group_id}/users", get(api::handlers::groups::get_group_users))
.route("/groups/{group_id}/users/{user_id}", post(api::handlers::groups::add_user_to_group))
.route(
"/groups/{group_id}/users/{user_id}",
delete(api::handlers::groups::remove_user_from_group),
)
.route("/groups/{group_id}/models", get(api::handlers::groups::get_group_deployments))
.route(
"/groups/{group_id}/models/{deployment_id}",
post(api::handlers::groups::add_deployment_to_group),
)
.route(
"/groups/{group_id}/models/{deployment_id}",
delete(api::handlers::groups::remove_deployment_from_group),
)
.route("/models/{deployment_id}/groups", get(api::handlers::groups::get_deployment_groups))
.route("/organizations", get(api::handlers::organizations::list_organizations))
.route("/organizations", post(api::handlers::organizations::create_organization))
.route("/organizations/{id}", get(api::handlers::organizations::get_organization))
.route("/organizations/{id}", patch(api::handlers::organizations::update_organization))
.route("/organizations/{id}", delete(api::handlers::organizations::delete_organization))
.route("/organizations/{id}/members", get(api::handlers::organizations::list_members))
.route("/organizations/{id}/members", post(api::handlers::organizations::add_member))
.route(
"/organizations/{id}/members/{user_id}",
patch(api::handlers::organizations::update_member_role),
)
.route(
"/organizations/{id}/members/{user_id}",
delete(api::handlers::organizations::remove_member),
)
.route("/organizations/{id}/leave", post(api::handlers::organizations::leave_organization))
.route("/organizations/{id}/invites", post(api::handlers::organizations::invite_member))
.route(
"/organizations/{id}/invites/{invite_id}",
delete(api::handlers::organizations::cancel_invite),
)
.route(
"/organizations/invites/{token}",
get(api::handlers::organizations::get_invite_details),
)
.route(
"/organizations/invites/{token}/accept",
post(api::handlers::organizations::accept_invite),
)
.route(
"/organizations/invites/{token}/decline",
post(api::handlers::organizations::decline_invite),
)
.route(
"/users/{user_id}/organizations",
get(api::handlers::organizations::list_user_organizations),
)
.route("/session/organization", post(api::handlers::organizations::set_active_organization))
.route("/support/requests", post(api::handlers::support::submit_support_request))
.route("/batches/requests", get(api::handlers::batch_requests::list_batch_requests))
.route(
"/batches/requests/{request_id}",
get(api::handlers::batch_requests::get_batch_request),
)
.route("/requests", get(api::handlers::requests::list_requests))
.route("/requests/aggregate", get(api::handlers::requests::aggregate_requests))
.route("/requests/aggregate-by-user", get(api::handlers::requests::aggregate_by_user))
.route("/usage", get(api::handlers::requests::get_usage))
.route("/probes", get(api::handlers::probes::list_probes))
.route("/probes", post(api::handlers::probes::create_probe))
.route("/probes/test/{deployment_id}", post(api::handlers::probes::test_probe))
.route("/probes/{id}", get(api::handlers::probes::get_probe))
.route("/probes/{id}", patch(api::handlers::probes::update_probe))
.route("/probes/{id}", delete(api::handlers::probes::delete_probe))
.route("/probes/{id}/activate", patch(api::handlers::probes::activate_probe))
.route("/probes/{id}/deactivate", patch(api::handlers::probes::deactivate_probe))
.route("/probes/{id}/execute", post(api::handlers::probes::execute_probe))
.route("/probes/{id}/results", get(api::handlers::probes::get_probe_results))
.route("/probes/{id}/statistics", get(api::handlers::probes::get_statistics))
.route(
"/monitoring/pending-request-counts",
get(api::handlers::queue::get_pending_request_counts),
)
.route("/tool-sources", get(api::handlers::tool_sources::list_tool_sources))
.route("/tool-sources", post(api::handlers::tool_sources::create_tool_source))
.route("/tool-sources/{id}", get(api::handlers::tool_sources::get_tool_source))
.route("/tool-sources/{id}", patch(api::handlers::tool_sources::update_tool_source))
.route("/tool-sources/{id}", delete(api::handlers::tool_sources::delete_tool_source))
.route(
"/deployments/{id}/tool-sources",
get(api::handlers::tool_sources::list_deployment_tool_sources),
)
.route(
"/deployments/{id}/tool-sources/{source_id}",
axum::routing::put(api::handlers::tool_sources::attach_tool_source_to_deployment),
)
.route(
"/deployments/{id}/tool-sources/{source_id}",
delete(api::handlers::tool_sources::detach_tool_source_from_deployment),
)
.route(
"/groups/{id}/tool-sources",
get(api::handlers::tool_sources::list_group_tool_sources),
)
.route(
"/groups/{id}/tool-sources/{source_id}",
axum::routing::put(api::handlers::tool_sources::attach_tool_source_to_group),
)
.route(
"/groups/{id}/tool-sources/{source_id}",
delete(api::handlers::tool_sources::detach_tool_source_from_group),
)
.route("/connections", post(api::handlers::connections::create_connection))
.route("/connections", get(api::handlers::connections::list_connections))
.route("/connections/{connection_id}", get(api::handlers::connections::get_connection))
.route(
"/connections/{connection_id}",
delete(api::handlers::connections::delete_connection),
)
.route(
"/connections/{connection_id}/test",
post(api::handlers::connections::test_connection),
)
.route(
"/connections/{connection_id}/files",
get(api::handlers::connections::list_connection_files),
)
.route(
"/connections/{connection_id}/synced-keys",
get(api::handlers::connections::list_synced_keys),
)
.route("/connections/{connection_id}/sync", post(api::handlers::connections::trigger_sync))
.route("/connections/{connection_id}/syncs", get(api::handlers::connections::list_syncs))
.route(
"/connections/{connection_id}/syncs/{sync_id}",
get(api::handlers::connections::get_sync),
)
.route(
"/connections/{connection_id}/syncs/{sync_id}/entries",
get(api::handlers::connections::list_sync_entries),
);
let api_routes_with_state = api_routes.with_state(state.clone());
let batches_routes = if config.batches.enabled {
let file_upload_limit = config.limits.files.max_file_size;
let body_limit_layer = if file_upload_limit == 0 {
DefaultBodyLimit::disable()
} else {
let body_limit_u64 = file_upload_limit.saturating_add(limits::MULTIPART_OVERHEAD);
let body_limit = usize::try_from(body_limit_u64).unwrap_or(usize::MAX);
DefaultBodyLimit::max(body_limit)
};
let file_router = Router::new().route("/files", post(api::handlers::files::upload_file).layer(body_limit_layer));
Some(
Router::new()
.merge(file_router)
.route("/files", get(api::handlers::files::list_files))
.route("/files/{file_id}", get(api::handlers::files::get_file))
.route("/files/{file_id}", delete(api::handlers::files::delete_file))
.route("/files/{file_id}/content", get(api::handlers::files::get_file_content))
.route("/files/{file_id}/cost-estimate", get(api::handlers::files::get_file_cost_estimate))
.route("/batches", post(api::handlers::batches::create_batch))
.route("/batches", get(api::handlers::batches::list_batches))
.route("/batches/{batch_id}", get(api::handlers::batches::get_batch))
.route("/batches/{batch_id}", delete(api::handlers::batches::delete_batch))
.route("/batches/{batch_id}/analytics", get(api::handlers::batches::get_batch_analytics))
.route("/batches/{batch_id}/results", get(api::handlers::batches::get_batch_results))
.route("/batches/{batch_id}/cancel", post(api::handlers::batches::cancel_batch))
.route(
"/batches/{batch_id}/retry",
post(api::handlers::batches::retry_failed_batch_requests),
)
.route(
"/batches/{batch_id}/retry-requests",
post(api::handlers::batches::retry_specific_requests),
)
.route("/daemons", get(api::handlers::daemons::list_daemons))
.with_state(state.clone()),
)
} else {
None
};
let fallback = get(api::handlers::static_assets::serve_embedded_asset).fallback(get(api::handlers::static_assets::spa_fallback));
let tool_injection_state = crate::tool_injection::ToolInjectionState {
db: state.db.write().clone(),
};
let onwards_router = onwards_router.layer(middleware::from_fn_with_state(
tool_injection_state,
crate::tool_injection::tool_injection_middleware,
));
let onwards_router = onwards_router.layer(middleware::from_fn_with_state(
state.db.write().clone(),
error_enrichment::error_enrichment_middleware,
));
let onwards_router = if let Some(outlet_layer) = outlet_layer.clone() {
onwards_router.layer(outlet_layer)
} else {
onwards_router
};
let mut router = Router::new()
.route("/healthz", get(|| async { "OK" }))
.route("/webhooks/payments", post(api::handlers::payments::webhook_handler))
.with_state(state.clone())
.merge(auth_routes);
if strict_mode {
router = router.nest("/ai/v1", onwards_router);
if let Some(batches) = batches_routes {
let batches_with_fallback = batches.fallback(|| async {
(
axum::http::StatusCode::NOT_FOUND,
axum::Json(serde_json::json!({
"error": {
"message": "Unknown endpoint",
"type": "invalid_request_error",
"code": "not_found"
}
})),
)
});
router = router.nest("/ai/v1", batches_with_fallback);
}
} else {
let ai_router = if let Some(batches) = batches_routes {
batches.merge(onwards_router)
} else {
onwards_router
};
router = router.nest("/ai/v1", ai_router);
}
let router = router
.nest("/admin/api/v1", api_routes_with_state)
.route("/admin/openapi.json", get(|| async { axum::Json(AdminApiDoc::openapi()) }))
.route("/ai/openapi.json", get(|| async { axum::Json(AiApiDoc::openapi()) }))
.merge(Scalar::with_url("/admin/docs", AdminApiDoc::openapi()))
.merge(Scalar::with_url("/ai/docs", AiApiDoc::openapi()))
.fallback_service(fallback.with_state(state.clone()));
let cors_layer = create_cors_layer(&config)?;
let mut router = router.layer(cors_layer);
if config.enable_metrics {
let metric_handle = get_or_install_prometheus_handle();
let prometheus_layer = if AXUM_PROMETHEUS_PREFIX_SET.set(()).is_ok() {
PrometheusMetricLayerBuilder::new()
.with_prefix("dwctl")
.with_metrics_from_fn(move || metric_handle.clone())
.build_pair()
.0
} else {
PrometheusMetricLayerBuilder::new()
.with_metrics_from_fn(move || metric_handle.clone())
.build_pair()
.0
};
let gen_ai_registry = if let Some(ref recorder) = state.metrics_recorder {
recorder.registry().clone()
} else {
prometheus::Registry::new()
};
let endpoint_handle = get_or_install_prometheus_handle();
router = router
.route(
"/internal/metrics",
get(|| async move {
use prometheus::{Encoder, TextEncoder};
let mut axum_metrics = endpoint_handle.render();
let encoder = TextEncoder::new();
let gen_ai_families = gen_ai_registry.gather();
let mut gen_ai_buffer = vec![];
encoder.encode(&gen_ai_families, &mut gen_ai_buffer).unwrap();
axum_metrics.push_str(&String::from_utf8_lossy(&gen_ai_buffer));
axum_metrics
}),
)
.layer(prometheus_layer);
}
let router = router.layer(middleware::from_fn(inject_trace_id)).layer(
TraceLayer::new_for_http()
.make_span_with(|request: &http::Request<_>| {
let path = request.uri().path();
let route = request
.extensions()
.get::<axum::extract::MatchedPath>()
.map(|mp| mp.as_str().to_owned());
let span_name = if let Some(ref route) = route {
format!("{} {}", request.method(), route)
} else {
format!("{} {}", request.method(), path)
};
let api_type = if path.starts_with("/ai/") {
"ai_proxy"
} else if path.starts_with("/admin/") {
"admin"
} else {
"other"
};
let span = tracing::info_span!(
"request",
trace_id = tracing::field::Empty,
otel.name = %span_name,
);
if let Some(traceparent) = request.headers().get("traceparent")
&& let Ok(tp) = traceparent.to_str()
{
let parts: Vec<&str> = tp.split('-').collect();
if parts.len() == 4
&& let (Ok(trace_id), Ok(span_id)) = (
opentelemetry::trace::TraceId::from_hex(parts[1]),
opentelemetry::trace::SpanId::from_hex(parts[2]),
)
{
let flags = u8::from_str_radix(parts[3], 16).unwrap_or(1);
let parent_ctx = opentelemetry::trace::SpanContext::new(
trace_id,
span_id,
opentelemetry::trace::TraceFlags::new(flags),
true, opentelemetry::trace::TraceState::default(),
);
let parent = opentelemetry::Context::new().with_remote_span_context(parent_ctx);
let _ = span.set_parent(parent);
}
}
span.set_attribute("otel.kind", "Server");
span.set_attribute("api.type", api_type.to_string());
span.set_attribute("http.request.method", request.method().to_string());
span.set_attribute("http.route", route.unwrap_or_default());
span.set_attribute("url.path", path.to_string());
span.set_attribute("url.query", request.uri().query().unwrap_or("").to_string());
span
})
.on_request(tower_http::trace::DefaultOnRequest::new().level(tracing::Level::TRACE))
.on_response(|response: &http::Response<_>, latency: std::time::Duration, span: &tracing::Span| {
let status = response.status().as_u16();
span.set_attribute("http.response.status_code", i64::from(status));
if status >= 500 {
span.set_attribute("otel.status_code", "ERROR");
span.set_attribute("error.type", status.to_string());
} else if status >= 400 {
span.set_attribute("error.type", status.to_string());
}
tracing::info!(
http.response.status_code = status,
latency_ms = latency.as_millis() as u64,
"finished processing request"
);
})
.on_failure(
|error: tower_http::classify::ServerErrorsFailureClass, latency: std::time::Duration, span: &tracing::Span| {
span.set_attribute("otel.status_code", "ERROR");
span.set_attribute("error.type", error.to_string());
tracing::error!(
error = %error,
latency_ms = latency.as_millis() as u64,
"request failed"
);
},
),
);
Ok(router)
}
async fn inject_trace_id(request: axum::extract::Request, next: middleware::Next) -> axum::response::Response {
let span = tracing::Span::current();
let sc = span.context().span().span_context().clone();
if sc.is_valid() {
span.record("trace_id", tracing::field::display(sc.trace_id()));
}
next.run(request).await
}
pub struct BackgroundServices {
request_manager: Arc<fusillade::PostgresRequestManager<DbPools, fusillade::ReqwestHttpClient>>,
task_runner: Arc<tasks::TaskRunner>,
is_leader: bool,
onwards_targets: onwards::target::Targets,
#[cfg_attr(not(test), allow(dead_code))]
onwards_sender: Option<tokio::sync::watch::Sender<onwards::target::Targets>>,
#[allow(dead_code)] strict_mode: bool,
analytics_sender: Option<request_logging::batcher::AnalyticsSender>,
background_tasks: tokio::task::JoinSet<anyhow::Result<()>>,
task_names: std::collections::HashMap<tokio::task::Id, &'static str>,
shutdown_token: tokio_util::sync::CancellationToken,
pub drop_guard: Option<tokio_util::sync::DropGuard>,
connections_encryption_key: Option<Vec<u8>>,
}
impl BackgroundServices {
fn spawn<F>(&mut self, name: &'static str, future: F)
where
F: std::future::Future<Output = anyhow::Result<()>> + Send + 'static,
{
let abort_handle = self.background_tasks.spawn(future);
self.task_names.insert(abort_handle.id(), name);
}
pub async fn wait_for_failure(&mut self) -> anyhow::Result<std::convert::Infallible> {
loop {
match self.background_tasks.join_next_with_id().await {
None => {
futures::future::pending::<()>().await;
unreachable!()
}
Some(Ok((task_id, Ok(())))) if self.shutdown_token.is_cancelled() => {
let task_name = self.task_names.get(&task_id).copied().unwrap_or("unknown");
tracing::debug!(task = task_name, "Background task completed during shutdown");
}
Some(Ok((task_id, Ok(())))) => {
let task_name = self.task_names.get(&task_id).copied().unwrap_or("unknown");
tracing::warn!(task = task_name, "Background task completed unexpectedly");
anyhow::bail!("Background task '{}' completed early", task_name)
}
Some(Ok((task_id, Err(e)))) if self.shutdown_token.is_cancelled() => {
let task_name = self.task_names.get(&task_id).copied().unwrap_or("unknown");
tracing::debug!(task = task_name, error = %e, "Background task exited with error during shutdown");
}
Some(Ok((task_id, Err(e)))) => {
let task_name = self.task_names.get(&task_id).copied().unwrap_or("unknown");
tracing::error!(task = task_name, error = %e, "Background task failed");
anyhow::bail!("Background task '{}' failed: {}", task_name, e)
}
Some(Err(e)) if self.shutdown_token.is_cancelled() => {
let task_id = e.id();
let task_name = self.task_names.get(&task_id).copied().unwrap_or("unknown");
tracing::debug!(task = task_name, error = %e, "Background task panicked during shutdown");
}
Some(Err(e)) => {
let task_id = e.id();
let task_name = self.task_names.get(&task_id).copied().unwrap_or("unknown");
tracing::error!(task = task_name, error = %e, "Background task panicked");
anyhow::bail!("Background task '{}' panicked: {}", task_name, e)
}
}
}
}
pub fn shutdown_token(&self) -> tokio_util::sync::CancellationToken {
self.shutdown_token.clone()
}
pub async fn shutdown(mut self) {
self.shutdown_token.cancel();
while let Some(result) = self.background_tasks.join_next_with_id().await {
match result {
Ok((task_id, Ok(()))) => {
let task_name = self.task_names.get(&task_id).copied().unwrap_or("unknown");
tracing::debug!(task = task_name, "Background task completed successfully");
}
Ok((task_id, Err(e))) => {
let task_name = self.task_names.get(&task_id).copied().unwrap_or("unknown");
tracing::error!(task = task_name, error = %e, "Background task failed");
}
Err(e) => {
let task_id = e.id();
let task_name = self.task_names.get(&task_id).copied().unwrap_or("unknown");
tracing::error!(task = task_name, error = %e, "Background task panicked");
}
}
}
}
#[cfg(test)]
pub async fn sync_onwards_config(&self, pool: &sqlx::PgPool) -> anyhow::Result<()> {
let sender = self
.onwards_sender
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Onwards sync not enabled"))?;
let new_targets = crate::sync::onwards_config::load_targets_from_db(pool, &[], self.strict_mode).await?;
sender
.send(new_targets)
.map_err(|_| anyhow::anyhow!("Failed to send targets update"))?;
Ok(())
}
}
struct BackgroundTaskBuilder {
tasks: tokio::task::JoinSet<anyhow::Result<()>>,
names: std::collections::HashMap<tokio::task::Id, &'static str>,
}
impl BackgroundTaskBuilder {
fn new() -> Self {
Self {
tasks: tokio::task::JoinSet::new(),
names: std::collections::HashMap::new(),
}
}
fn spawn<F>(&mut self, name: &'static str, future: F)
where
F: std::future::Future<Output = anyhow::Result<()>> + Send + 'static,
{
let abort_handle = self.tasks.spawn(future);
self.names.insert(abort_handle.id(), name);
}
fn into_parts(
self,
) -> (
tokio::task::JoinSet<anyhow::Result<()>>,
std::collections::HashMap<tokio::task::Id, &'static str>,
) {
(self.tasks, self.names)
}
}
async fn setup_background_services(
pool: PgPool,
fusillade_pools: DbPools,
outlet_pool: Option<PgPool>,
config: Config,
shared_config: SharedConfig,
shutdown_token: tokio_util::sync::CancellationToken,
metrics_recorder: Option<GenAiMetrics>,
) -> anyhow::Result<BackgroundServices> {
use fusillade::manager::postgres::BatchInsertStrategy;
let drop_guard = shutdown_token.clone().drop_guard();
let mut background_tasks = BackgroundTaskBuilder::new();
let model_capacity_limits = Arc::new(dashmap::DashMap::new());
#[cfg_attr(not(test), allow(unused_variables))]
let (initial_targets, onwards_sender) = if config.background_services.onwards_sync.enabled {
let escalation_models: Vec<String> = config
.background_services
.batch_daemon
.model_escalations
.values()
.map(|e| e.escalation_model.clone())
.collect();
let (onwards_config_sync, initial_targets, onwards_stream) = sync::onwards_config::OnwardsConfigSync::new_with_daemon_limits(
pool.clone(),
Some(model_capacity_limits.clone()),
config.background_services.batch_daemon.default_model_concurrency,
escalation_models,
config.onwards.strict_mode,
)
.await?;
let sender = onwards_config_sync.sender();
initial_targets
.receive_updates(onwards_stream)
.await
.map_err(anyhow::Error::from)
.context("Onwards target updates failed")?;
let onwards_shutdown = shutdown_token.clone();
let fallback_interval = config.background_services.onwards_sync.fallback_interval_milliseconds;
background_tasks.spawn("onwards-config-sync", async move {
info!(
"Starting onwards configuration listener (fallback sync every {}ms)",
fallback_interval
);
let sync_config = sync::onwards_config::SyncConfig {
status_tx: None,
fallback_interval_milliseconds: fallback_interval,
};
onwards_config_sync
.start(sync_config, onwards_shutdown)
.await
.context("Onwards configuration listener failed")
});
(initial_targets, Some(sender))
} else {
info!("Onwards config sync disabled - AI proxy will not receive config updates");
let empty_config = onwards::target::ConfigFile {
targets: std::collections::HashMap::new(),
auth: None,
strict_mode: false,
http_pool: None,
};
(onwards::target::Targets::from_config(empty_config)?, None)
};
const LEADER_LOCK_ID: i64 = 0x4457_4354_5052_4F42_i64;
let probe_scheduler = probes::ProbeScheduler::new(pool.clone(), config.clone());
let fusillade_pool_for_metrics = fusillade_pools.write().clone();
let request_manager = Arc::new(
fusillade::PostgresRequestManager::new(
fusillade_pools,
config
.background_services
.batch_daemon
.to_fusillade_config_with_limits(Some(model_capacity_limits.clone())),
)
.with_download_buffer_size(config.batches.files.download_buffer_size)
.with_batch_insert_strategy(BatchInsertStrategy::Batched {
batch_size: config.batches.files.batch_insert_size,
}),
);
let is_leader: bool;
if !config.background_services.leader_election.enabled {
info!("Launching without leader election: running as leader");
is_leader = true;
if config.background_services.probe_scheduler.enabled {
probe_scheduler.initialize(shutdown_token.clone()).await?;
let daemon_scheduler = probe_scheduler.clone();
let daemon_shutdown = shutdown_token.clone();
background_tasks.spawn("probe-scheduler", async move {
let use_listen_notify = !cfg!(test);
daemon_scheduler.run_daemon(daemon_shutdown, use_listen_notify, 300).await;
Ok(())
});
} else {
info!("Probe scheduler disabled by configuration");
}
use crate::config::DaemonEnabled;
use fusillade::DaemonExecutor;
match config.background_services.batch_daemon.enabled {
DaemonEnabled::Always | DaemonEnabled::Leader => {
let daemon_handle = request_manager.clone().run(shutdown_token.clone())?;
background_tasks.spawn("fusillade-daemon", async move {
match daemon_handle.await {
Ok(Ok(())) => {
tracing::info!("Fusillade daemon exited normally");
}
Ok(Err(e)) => {
tracing::error!(error = %e, "Fusillade daemon failed");
anyhow::bail!("Fusillade daemon error: {}", e);
}
Err(e) => {
tracing::error!(error = %e, "Fusillade daemon task panicked");
anyhow::bail!("Fusillade daemon panic: {}", e);
}
}
Ok(())
});
info!("Skipping leader election - running as leader with probe scheduler and fusillade daemon");
}
DaemonEnabled::Never => {
info!("Skipping leader election - running as leader with probe scheduler (fusillade daemon disabled)");
}
}
{
let daemon_config = config.clone();
let daemon_request_manager = request_manager.clone();
let daemon_pool = pool.clone();
let daemon_shutdown = shutdown_token.clone();
background_tasks.spawn("batch-completion", async move {
notifications::run_notification_poller(
daemon_config.background_services.notifications.clone(),
daemon_config,
daemon_request_manager,
daemon_pool,
daemon_shutdown,
)
.await;
Ok(())
});
}
} else {
is_leader = false;
info!("Starting leader election - will attempt to acquire leadership");
use crate::config::DaemonEnabled;
if config.background_services.batch_daemon.enabled == DaemonEnabled::Always {
use fusillade::DaemonExecutor;
let daemon_handle = request_manager.clone().run(shutdown_token.clone())?;
background_tasks.spawn("fusillade-daemon", async move {
match daemon_handle.await {
Ok(Ok(())) => {
tracing::info!("Fusillade daemon exited normally");
}
Ok(Err(e)) => {
tracing::error!(error = %e, "Fusillade daemon failed");
anyhow::bail!("Fusillade daemon error: {}", e);
}
Err(e) => {
tracing::error!(error = %e, "Fusillade daemon task panicked");
anyhow::bail!("Fusillade daemon panic: {}", e);
}
}
Ok(())
});
info!("Fusillade batch daemon started (configured to always run)");
}
let is_leader_flag = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
let leader_election_pool = pool.clone();
let leader_election_scheduler_gain = probe_scheduler.clone();
let leader_election_scheduler_lose = probe_scheduler.clone();
let leader_election_request_manager_gain = request_manager.clone();
let leader_election_config = config.clone();
let leader_election_flag = is_leader_flag.clone();
let daemon_handle: Arc<tokio::sync::Mutex<Option<tokio::task::JoinHandle<fusillade::Result<()>>>>> =
Arc::new(tokio::sync::Mutex::new(None));
let daemon_handle_gain = daemon_handle.clone();
let daemon_handle_lose = daemon_handle.clone();
let leadership_shutdown: Arc<tokio::sync::Mutex<Option<tokio_util::sync::CancellationToken>>> =
Arc::new(tokio::sync::Mutex::new(None));
let leadership_shutdown_gain = leadership_shutdown.clone();
let leadership_shutdown_lose = leadership_shutdown.clone();
let leader_election_shutdown = shutdown_token.clone();
background_tasks.spawn("leader-election", async move {
leader_election::leader_election_task(
leader_election_pool,
leader_election_config,
leader_election_flag,
LEADER_LOCK_ID,
leader_election_shutdown,
move |pool, config| {
let scheduler = leader_election_scheduler_gain.clone();
let request_manager = leader_election_request_manager_gain.clone();
let daemon_handle = daemon_handle_gain.clone();
let leadership_shutdown = leadership_shutdown_gain.clone();
async move {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
let session_token = tokio_util::sync::CancellationToken::new();
*leadership_shutdown.lock().await = Some(session_token.clone());
if config.background_services.probe_scheduler.enabled {
scheduler
.initialize(session_token.clone())
.await
.map_err(|e| anyhow::anyhow!("Failed to initialize probe scheduler: {}", e))?;
let daemon_scheduler = scheduler.clone();
let daemon_session_token = session_token.clone();
tokio::spawn(async move {
let use_listen_notify = !cfg!(test);
daemon_scheduler.run_daemon(daemon_session_token, use_listen_notify, 300).await;
});
} else {
tracing::info!("Probe scheduler disabled by configuration");
}
let notification_request_manager = request_manager.clone();
use crate::config::DaemonEnabled;
use fusillade::DaemonExecutor;
match config.background_services.batch_daemon.enabled {
DaemonEnabled::Leader => {
let handle = request_manager
.run(session_token.clone())
.map_err(|e| anyhow::anyhow!("Failed to start fusillade daemon: {}", e))?;
*daemon_handle.lock().await = Some(handle);
tracing::info!("Fusillade batch daemon started on elected leader");
}
DaemonEnabled::Always => {
}
DaemonEnabled::Never => {
tracing::info!("Fusillade batch daemon disabled by configuration");
}
}
{
let daemon_config = config.clone();
let daemon_session_token = session_token.clone();
tokio::spawn(async move {
notifications::run_notification_poller(
daemon_config.background_services.notifications.clone(),
daemon_config,
notification_request_manager,
pool,
daemon_session_token,
)
.await;
});
tracing::info!("Batch completion poller started on elected leader");
}
Ok(())
}
},
move |_pool, config| {
let scheduler = leader_election_scheduler_lose.clone();
let daemon_handle = daemon_handle_lose.clone();
let leadership_shutdown = leadership_shutdown_lose.clone();
async move {
if let Some(token) = leadership_shutdown.lock().await.take() {
token.cancel();
}
if config.background_services.probe_scheduler.enabled {
scheduler
.stop_all()
.await
.map_err(|e| anyhow::anyhow!("Failed to stop probe scheduler: {}", e))?;
}
if let Some(handle) = daemon_handle.lock().await.take() {
handle.abort();
tracing::info!("Fusillade batch daemon stopped (lost leadership)");
}
Ok(())
}
},
)
.await;
Ok(())
});
}
let uw = config.database.underway_pool_settings();
let underway_pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(uw.max_connections)
.min_connections(uw.min_connections)
.acquire_timeout(std::time::Duration::from_secs(uw.acquire_timeout_secs))
.idle_timeout(if uw.idle_timeout_secs > 0 {
Some(std::time::Duration::from_secs(uw.idle_timeout_secs))
} else {
None
})
.max_lifetime(if uw.max_lifetime_secs > 0 {
Some(std::time::Duration::from_secs(uw.max_lifetime_secs))
} else {
None
})
.connect_with(pool.connect_options().as_ref().clone())
.await?;
if config.enable_metrics {
let mut pools = vec![
db::LabeledPool {
name: "main",
pool: pool.clone(),
},
db::LabeledPool {
name: "fusillade",
pool: fusillade_pool_for_metrics,
},
db::LabeledPool {
name: "main_underway",
pool: underway_pool.clone(),
},
];
if let Some(outlet) = outlet_pool {
pools.push(db::LabeledPool {
name: "outlet",
pool: outlet,
});
}
let metrics_shutdown = shutdown_token.clone();
let metrics_config = db::PoolMetricsConfig {
sample_interval: config.background_services.pool_metrics.sample_interval,
};
background_tasks.spawn("pool-metrics-sampler", async move {
db::run_pool_metrics_sampler(pools, metrics_config, metrics_shutdown).await
});
}
let analytics_sender = if config.enable_analytics {
let (batcher, sender) = request_logging::AnalyticsBatcher::new(pool.clone(), config.clone(), metrics_recorder);
let batcher_shutdown = shutdown_token.clone();
background_tasks.spawn("analytics-batcher", async move {
batcher.run(batcher_shutdown).await;
Ok(())
});
Some(sender)
} else {
None
};
let encryption_key = match config.connections.encryption_key.as_deref().or(config.secret_key.as_deref()) {
Some(secret) if !secret.trim().is_empty() => Some(encryption::derive_encryption_key(secret.trim())),
Some(_) => {
tracing::warn!("Encryption key is empty/whitespace — connection features will be unavailable");
None
}
None => {
tracing::info!("No encryption key configured for connections (set secret_key or connections.encryption_key)");
None
}
};
let task_state = tasks::TaskState {
request_manager: request_manager.clone(),
dwctl_pool: pool.clone(),
config: shared_config.clone(),
encryption_key: encryption_key.clone(),
ingest_file_job: Arc::new(std::sync::OnceLock::new()),
activate_batch_job: Arc::new(std::sync::OnceLock::new()),
create_batch_job: Arc::new(std::sync::OnceLock::new()),
};
let task_runner = Arc::new(tasks::TaskRunner::new(underway_pool, task_state).await?);
for (name, handle) in task_runner.start(shutdown_token.clone(), &config.background_services.sync_workers) {
background_tasks.spawn(name, async move { handle.await.map_err(|e| anyhow::anyhow!("{}", e)) });
}
let (background_tasks, task_names) = background_tasks.into_parts();
Ok(BackgroundServices {
request_manager,
task_runner,
is_leader,
onwards_targets: initial_targets,
onwards_sender,
strict_mode: config.onwards.strict_mode,
analytics_sender,
background_tasks,
task_names,
shutdown_token,
drop_guard: Some(drop_guard),
connections_encryption_key: encryption_key.clone(),
})
}
pub struct Application {
router: Router,
app_state: AppState,
config: Config,
db_pools: DbPools,
_fusillade_pools: DbPools,
_outlet_pools: Option<DbPools>,
_embedded_db: Option<db::embedded::EmbeddedDatabase>,
_tracer_provider: Option<telemetry::SdkTracerProvider>,
bg_services: BackgroundServices,
}
impl Application {
pub async fn new(config: Config, tracer_provider: Option<telemetry::SdkTracerProvider>) -> anyhow::Result<Self> {
Self::new_with_pool_and_config_path(config, None, None, tracer_provider).await
}
pub async fn new_with_config_path(
config: Config,
config_path: Option<PathBuf>,
tracer_provider: Option<telemetry::SdkTracerProvider>,
) -> anyhow::Result<Self> {
Self::new_with_pool_and_config_path(config, config_path, None, tracer_provider).await
}
pub async fn new_with_pool(
config: Config,
pool: Option<PgPool>,
tracer_provider: Option<telemetry::SdkTracerProvider>,
) -> anyhow::Result<Self> {
Self::new_with_pool_and_config_path(config, None, pool, tracer_provider).await
}
pub async fn new_with_pool_and_config_path(
config: Config,
config_path: Option<PathBuf>,
pool: Option<PgPool>,
tracer_provider: Option<telemetry::SdkTracerProvider>,
) -> anyhow::Result<Self> {
debug!("Starting control layer with configuration: {:#?}", config);
let (_embedded_db, db_pools, fusillade_pools, outlet_pools) = setup_database(&config, pool).await?;
if config.enable_metrics {
get_or_install_prometheus_handle();
}
let shutdown_token = tokio_util::sync::CancellationToken::new();
let metrics_recorder = if config.enable_metrics && config.enable_analytics {
let gen_ai_registry = prometheus::Registry::new();
Some(GenAiMetrics::new(&gen_ai_registry).map_err(|e| anyhow::anyhow!("Failed to create GenAI metrics: {}", e))?)
} else {
None
};
let shared_config = SharedConfig::new(config.clone());
let mut bg_services = setup_background_services(
(*db_pools).clone(),
fusillade_pools.clone(),
outlet_pools.as_ref().map(|p| (**p).clone()),
config.clone(),
shared_config.clone(),
shutdown_token.clone(),
metrics_recorder.clone(),
)
.await?;
let body_transform: onwards::BodyTransformFn = Arc::new(request_logging::stream_usage::stream_usage_transform);
let reqwest_client = reqwest::Client::new();
let tool_executor = crate::tool_executor::HttpToolExecutor::new(reqwest_client, Some(Arc::new(db_pools.write().clone())));
let onwards_app_state = onwards::AppState::with_transform(bg_services.onwards_targets.clone(), body_transform)
.with_response_transform(onwards::create_openai_sanitizer())
.with_streaming_header("x-fusillade-stream")
.with_tool_executor(Arc::new(tool_executor));
let onwards_router = if bg_services.onwards_targets.strict_mode {
tracing::info!("Strict mode enabled - using typed request validation");
onwards::strict::build_strict_router(onwards_app_state)
} else {
onwards::build_router(onwards_app_state)
};
let limiters = limits::Limiters::new(&config.limits);
let mut app_state = AppState::builder()
.db(db_pools.clone())
.config(shared_config.clone())
.is_leader(bg_services.is_leader)
.request_manager(bg_services.request_manager.clone())
.task_runner(bg_services.task_runner.clone())
.maybe_outlet_db(outlet_pools.clone())
.limiters(limiters)
.maybe_connections_encryption_key(bg_services.connections_encryption_key.clone())
.build();
if let Some(config_path) = config_path {
bg_services.spawn(
"config-watcher",
config_watcher::watch_config_file(config_path, shared_config, bg_services.shutdown_token()),
);
}
let router = build_router(
&mut app_state,
onwards_router,
bg_services.analytics_sender.clone(),
metrics_recorder,
bg_services.onwards_targets.strict_mode,
)
.await?;
Ok(Self {
router,
app_state,
config,
db_pools,
_fusillade_pools: fusillade_pools,
_outlet_pools: outlet_pools,
_embedded_db,
_tracer_provider: tracer_provider,
bg_services,
})
}
#[cfg(test)]
pub fn into_test_server(self) -> (axum_test::TestServer, BackgroundServices) {
let middleware = middleware::from_fn_with_state(self.app_state, admin_ai_proxy_middleware);
let service = middleware.layer(self.router).into_make_service();
let server = axum_test::TestServer::new(service).expect("Failed to create test server");
(server, self.bg_services)
}
pub async fn serve<F>(mut self, shutdown: F) -> anyhow::Result<()>
where
F: std::future::Future<Output = ()> + Send + 'static,
{
let bind_addr = self.config.bind_address();
let listener = TcpListener::bind(&bind_addr).await?;
info!(
"Control layer listening on http://{}, available at http://localhost:{}",
bind_addr, self.config.port
);
let middleware = middleware::from_fn_with_state(self.app_state, admin_ai_proxy_middleware);
let service = middleware.layer(self.router);
let shutdown_token = self.bg_services.shutdown_token();
let shutdown = async move {
shutdown.await;
shutdown_token.cancel();
};
let server_error: Option<anyhow::Error> = tokio::select! {
result = axum::serve(listener, service.into_make_service()).with_graceful_shutdown(shutdown) => {
result.err().map(Into::into) }
result = self.bg_services.wait_for_failure() => {
match result {
Ok(_infallible) => unreachable!("wait_for_failure never returns Ok"),
Err(e) => Some(e),
}
}
};
info!("Shutting down background services...");
self.bg_services.shutdown().await;
info!("Closing database connections...");
self.db_pools.close().await;
if let Some(ref provider) = self._tracer_provider {
info!("Flushing telemetry...");
if let Err(e) = provider.force_flush() {
tracing::error!("Failed to flush tracer provider: {}", e);
}
}
if let Some(embedded_db) = self._embedded_db {
info!("Shutting down embedded database...");
embedded_db.stop().await?;
}
if let Some(e) = server_error {
return Err(e);
}
Ok(())
}
}