systemprompt-api 0.14.1

Axum-based HTTP server and API gateway for systemprompt.io AI governance infrastructure. Exposes governed agents, MCP, A2A, and admin endpoints with rate limiting and RBAC.
Documentation
//! Gateway dispatch entry point: route resolution, policy and quota checks,
//! upstream send, and response finalization.
#![expect(
    clippy::clone_on_ref_ptr,
    reason = "Arc::clone usage is intentional and ergonomic in this gateway dispatch path"
)]

mod finalize;

use std::sync::Arc;

use anyhow::{Result, anyhow};
use axum::body::Body;
use axum::response::Response;
use bytes::Bytes;
use systemprompt_database::DbPool;
use systemprompt_models::profile::{GatewayConfig, ProviderRegistry};

use self::finalize::{FinalizeCtx, attach_request_id, finalize, run_request_safety_scan};
use super::audit::{GatewayAudit, GatewayRequestContext};
use super::policy::PolicyResolver;
use super::protocol::canonical::CanonicalRequest;
use super::protocol::inbound::InboundAdapter;
use super::protocol::outbound::OutboundCtx;
use super::quota;
use super::registry::GatewayUpstreamRegistry;

pub const REQUEST_ID_HEADER: &str = "x-systemprompt-request-id";

#[derive(Debug, Clone, Copy)]
pub struct GatewayService;

#[derive(Debug)]
pub struct DispatchInputs {
    pub request: CanonicalRequest,
    pub raw_body: Bytes,
    pub ctx: GatewayRequestContext,
    pub inbound: Arc<dyn InboundAdapter>,
}

#[derive(Debug, thiserror::Error)]
#[error("{0}")]
pub struct PolicyDenied(pub String);

#[derive(Debug, thiserror::Error)]
#[error("{message}")]
pub struct QuotaExceeded {
    pub message: String,
    pub retry_after_seconds: i32,
}

#[derive(Debug, thiserror::Error)]
#[error("{message}")]
pub struct SafetyBlocked {
    pub category: String,
    pub message: String,
}

impl GatewayService {
    pub async fn dispatch(
        config: &GatewayConfig,
        registry: &ProviderRegistry,
        db: &DbPool,
        inputs: DispatchInputs,
    ) -> Result<Response<Body>> {
        let DispatchInputs {
            request,
            raw_body,
            ctx,
            inbound,
        } = inputs;
        if ctx.session_id.is_none() {
            return Err(anyhow!(
                "gateway dispatch missing conversation binding (session_id)"
            ));
        }

        let ai_request_id = ctx.ai_request_id.clone();

        if !config.is_model_exposed(registry, &request.model) {
            tracing::warn!(
                ai_request_id = %ai_request_id,
                model = %request.model,
                "Gateway denied: model not exposed by gateway policy or registry"
            );
            return Err(PolicyDenied(format!(
                "model '{}' is not permitted by gateway policy",
                request.model
            ))
            .into());
        }

        let route = config
            .resolve_route(registry, &request.model)
            .ok_or_else(|| anyhow!("No gateway route matches model '{}'", request.model))?;

        let provider = route.resolve(registry).ok_or_else(|| {
            anyhow!(
                "Gateway route '{}' provider '{}' is not declared in profile.providers",
                route.id.as_str(),
                route.provider.as_str()
            )
        })?;

        let secrets = systemprompt_config::SecretsBootstrap::get()
            .map_err(|e| anyhow!("Secrets not available: {e}"))?;

        let upstream_api_key = secrets
            .get(provider.api_key_secret.as_str())
            .ok_or_else(|| {
                anyhow!(
                    "Gateway API key secret '{}' not configured",
                    provider.api_key_secret.as_str()
                )
            })?;

        let upstream = GatewayUpstreamRegistry::global()
            .get(provider.protocol.as_tag())
            .ok_or_else(|| {
                anyhow!(
                    "Gateway has no outbound adapter for wire protocol '{}'",
                    provider.protocol.as_tag()
                )
            })?;

        let is_streaming = request.stream;

        tracing::info!(
            ai_request_id = %ai_request_id,
            user_id = %ctx.user_id,
            model = %request.model,
            provider = %route.provider,
            upstream = %provider.endpoint,
            wire_protocol = %ctx.wire_protocol,
            streaming = is_streaming,
            "Gateway request dispatched"
        );

        let resolver = PolicyResolver::new(db)?;
        let policy = resolver.resolve().await;

        let audit = Arc::new(
            GatewayAudit::new(db, ctx.clone()).map_err(|e| anyhow!("audit init failed: {e}"))?,
        );

        if let Err(e) = audit.open(&request, &raw_body).await {
            tracing::error!(error = %e, "audit open failed — proceeding without audit row");
        }

        if let Some(decision) =
            quota::precheck_and_reserve(db, &ctx.user_id, &policy.quota_windows).await?
        {
            if !decision.allow {
                let msg = format!(
                    "quota exceeded for window {}s (used {}/{:?})",
                    decision.window_seconds, decision.state.requests, decision.limit_requests
                );
                if let Err(e) = audit.fail(&msg).await {
                    tracing::warn!(error = %e, "quota audit fail failed");
                }
                return Err(QuotaExceeded {
                    message: msg,
                    retry_after_seconds: decision.window_seconds,
                }
                .into());
            }
        }

        let findings = run_request_safety_scan(db, &ai_request_id, &request, &policy.safety).await;
        if let Some(finding) = findings
            .iter()
            .find(|f| policy.safety.block_categories.contains(&f.category))
        {
            let msg = format!(
                "request blocked by safety policy: category '{}'",
                finding.category
            );
            tracing::warn!(
                ai_request_id = %ai_request_id,
                category = %finding.category,
                scanner = %finding.scanner,
                "Gateway blocked request by safety policy"
            );
            if let Err(e) = audit.fail(&msg).await {
                tracing::warn!(error = %e, "safety-block audit fail failed");
            }
            return Err(SafetyBlocked {
                category: finding.category.clone(),
                message: msg,
            }
            .into());
        }

        let upstream_model = route.effective_upstream_model(&request.model).to_owned();
        let model_limits = provider.find_model(&upstream_model).map(|m| m.limits);
        let outbound_ctx = OutboundCtx {
            route: route.as_ref(),
            endpoint: &provider.endpoint,
            api_key: upstream_api_key,
            request: &request,
            upstream_model: &upstream_model,
            model_limits,
        };

        let outcome = match upstream.send(outbound_ctx).await {
            Ok(o) => o,
            Err(e) => {
                if let Err(audit_err) = audit.fail(&e.to_string()).await {
                    tracing::warn!(error = %audit_err, "upstream audit fail failed");
                }
                return Err(e);
            },
        };

        let response = finalize(
            outcome,
            FinalizeCtx {
                audit: Arc::clone(&audit),
                db: db.clone(),
                ai_request_id: ai_request_id.clone(),
                policy,
                inbound,
                request_model: request.model.clone(),
            },
        )
        .await;
        Ok(attach_request_id(response, &ai_request_id))
    }
}