systemprompt-api 0.15.1

Axum-based HTTP server and API gateway for systemprompt.io AI governance infrastructure. Exposes governed agents, MCP, A2A, and admin endpoints with rate limiting and RBAC.
Documentation
//! Gateway dispatch entry point: route resolution, policy and quota checks,
//! upstream send, and response finalization.
#![expect(
    clippy::clone_on_ref_ptr,
    reason = "Arc::clone usage is intentional and ergonomic in this gateway dispatch path"
)]

mod finalize;

use std::sync::Arc;

use anyhow::{Result, anyhow};
use axum::body::Body;
use axum::response::Response;
use bytes::Bytes;
use systemprompt_ai::SafetyConfig;
use systemprompt_database::DbPool;
use systemprompt_identifiers::AiRequestId;
use systemprompt_models::profile::{GatewayConfig, ProviderRegistry};

use self::finalize::{FinalizeCtx, attach_request_id, finalize, run_request_safety_scan};
use super::audit::{GatewayAudit, GatewayRequestContext};
use super::policy::PolicyResolver;
use super::protocol::canonical::CanonicalRequest;
use super::protocol::inbound::InboundAdapter;
use super::protocol::outbound::OutboundCtx;
use super::quota;
use super::registry::GatewayUpstreamRegistry;

pub const REQUEST_ID_HEADER: &str = "x-systemprompt-request-id";

#[derive(Debug, Clone, Copy)]
pub struct GatewayService;

#[derive(Debug)]
pub struct DispatchInputs {
    pub request: CanonicalRequest,
    pub raw_body: Bytes,
    pub ctx: GatewayRequestContext,
    pub inbound: Arc<dyn InboundAdapter>,
}

#[derive(Debug, thiserror::Error)]
pub enum DispatchError {
    #[error(transparent)]
    PreAudit(anyhow::Error),
    #[error(transparent)]
    Recorded(anyhow::Error),
}

#[derive(Debug, thiserror::Error)]
#[error("{0}")]
pub struct PolicyDenied(pub String);

#[derive(Debug, thiserror::Error)]
#[error("{message}")]
pub struct QuotaExceeded {
    pub message: String,
    pub retry_after_seconds: i32,
}

#[derive(Debug, thiserror::Error)]
#[error("{message}")]
pub struct SafetyBlocked {
    pub category: String,
    pub message: String,
}

impl GatewayService {
    pub async fn dispatch(
        config: &GatewayConfig,
        registry: &ProviderRegistry,
        db: &DbPool,
        inputs: DispatchInputs,
    ) -> Result<Response<Body>, DispatchError> {
        let DispatchInputs {
            request,
            raw_body,
            ctx,
            inbound,
        } = inputs;
        if ctx.session_id.is_none() {
            return Err(DispatchError::PreAudit(anyhow!(
                "gateway dispatch missing conversation binding (session_id)"
            )));
        }

        let ai_request_id = ctx.ai_request_id.clone();

        if !config.is_model_exposed(registry, &request.model) {
            tracing::warn!(
                ai_request_id = %ai_request_id,
                model = %request.model,
                "Gateway denied: model not exposed by gateway policy or registry"
            );
            return Err(DispatchError::PreAudit(
                PolicyDenied(format!(
                    "model '{}' is not permitted by gateway policy",
                    request.model
                ))
                .into(),
            ));
        }

        let route = config
            .resolve_route(registry, &request.model)
            .ok_or_else(|| {
                DispatchError::PreAudit(anyhow!(
                    "No gateway route matches model '{}'",
                    request.model
                ))
            })?;

        let provider = route.resolve(registry).ok_or_else(|| {
            DispatchError::PreAudit(anyhow!(
                "Gateway route '{}' provider '{}' is not declared in profile.providers",
                route.id.as_str(),
                route.provider.as_str()
            ))
        })?;

        let secrets = systemprompt_config::SecretsBootstrap::get()
            .map_err(|e| DispatchError::PreAudit(anyhow!("Secrets not available: {e}")))?;

        let upstream_api_key = secrets
            .get(provider.api_key_secret.as_str())
            .ok_or_else(|| {
                DispatchError::PreAudit(anyhow!(
                    "Gateway API key secret '{}' not configured",
                    provider.api_key_secret.as_str()
                ))
            })?;

        let upstream = GatewayUpstreamRegistry::global()
            .get(provider.wire.as_tag())
            .ok_or_else(|| {
                DispatchError::PreAudit(anyhow!(
                    "Gateway has no outbound adapter for wire protocol '{}'",
                    provider.wire.as_tag()
                ))
            })?;

        let is_streaming = request.stream;

        tracing::info!(
            ai_request_id = %ai_request_id,
            user_id = %ctx.user_id,
            model = %request.model,
            provider = %route.provider,
            upstream = %provider.endpoint,
            wire_protocol = %ctx.wire_protocol,
            streaming = is_streaming,
            "Gateway request dispatched"
        );

        let resolver = PolicyResolver::new(db).map_err(DispatchError::PreAudit)?;
        let policy = resolver.resolve().await;

        let audit = Arc::new(
            GatewayAudit::new(db, ctx.clone())
                .map_err(|e| DispatchError::PreAudit(anyhow!("audit init failed: {e}")))?,
        );

        if let Err(e) = audit.open(&request, &raw_body).await {
            tracing::error!(error = %e, "audit open failed — proceeding without audit row");
        }

        let reservation = quota::precheck_and_reserve(db, &ctx.user_id, &policy.quota_windows)
            .await
            .map_err(DispatchError::Recorded)?;
        if let Some(decision) = reservation {
            if !decision.allow {
                let msg = format!(
                    "quota exceeded for window {}s (used {}/{:?})",
                    decision.window_seconds, decision.state.requests, decision.limit_requests
                );
                if let Err(e) = audit.fail(&msg).await {
                    tracing::warn!(error = %e, "quota audit fail failed");
                }
                return Err(DispatchError::Recorded(
                    QuotaExceeded {
                        message: msg,
                        retry_after_seconds: decision.window_seconds,
                    }
                    .into(),
                ));
            }
        }

        enforce_request_safety(db, &ai_request_id, &request, &policy.safety, &audit).await?;

        let upstream_model = route.effective_upstream_model(&request.model).to_owned();
        let model_limits = provider.find_model(&upstream_model).map(|m| m.limits);
        let outbound_ctx = OutboundCtx {
            route: route.as_ref(),
            endpoint: &provider.endpoint,
            api_key: upstream_api_key,
            request: &request,
            upstream_model: &upstream_model,
            model_limits,
        };

        let outcome = match upstream.send(outbound_ctx).await {
            Ok(o) => o,
            Err(e) => {
                audit_upstream_failure(&audit, upstream.provider_tag(), &request.model, &e).await;
                return Err(DispatchError::Recorded(e));
            },
        };

        let response = finalize(
            outcome,
            FinalizeCtx {
                audit: Arc::clone(&audit),
                db: db.clone(),
                ai_request_id: ai_request_id.clone(),
                policy,
                inbound,
                request_model: request.model.clone(),
            },
        )
        .await;
        Ok(attach_request_id(response, &ai_request_id))
    }
}

async fn enforce_request_safety(
    db: &DbPool,
    ai_request_id: &AiRequestId,
    request: &CanonicalRequest,
    safety: &SafetyConfig,
    audit: &GatewayAudit,
) -> Result<(), DispatchError> {
    let findings = run_request_safety_scan(db, ai_request_id, request, safety).await;
    let Some(finding) = findings
        .iter()
        .find(|f| safety.block_categories.contains(&f.category))
    else {
        return Ok(());
    };
    let msg = format!(
        "request blocked by safety policy: category '{}'",
        finding.category
    );
    tracing::warn!(
        ai_request_id = %ai_request_id,
        category = %finding.category,
        scanner = %finding.scanner,
        "Gateway blocked request by safety policy"
    );
    if let Err(e) = audit.fail(&msg).await {
        tracing::warn!(error = %e, "safety-block audit fail failed");
    }
    Err(DispatchError::Recorded(
        SafetyBlocked {
            category: finding.category.clone(),
            message: msg,
        }
        .into(),
    ))
}

async fn audit_upstream_failure(
    audit: &GatewayAudit,
    provider: &str,
    model: &str,
    error: &anyhow::Error,
) {
    tracing::warn!(
        provider = %provider,
        model = %model,
        error = %error,
        "gateway upstream call failed"
    );
    if let Err(audit_err) = audit.fail(&error.to_string()).await {
        tracing::warn!(error = %audit_err, "upstream audit fail failed");
    }
}