systemprompt-api 0.14.6

Axum-based HTTP server and API gateway for systemprompt.io AI governance infrastructure. Exposes governed agents, MCP, A2A, and admin endpoints with rate limiting and RBAC.
Documentation
use axum::http::StatusCode;
use std::collections::BTreeMap;
use systemprompt_identifiers::{Actor, JwtToken, SessionId, TraceId, UserId};
use systemprompt_runtime::AppContext;
use systemprompt_users::{API_KEY_PREFIX, ApiKeyService};

use crate::services::middleware::JwtContextExtractor;

pub(super) enum AuthedPrincipal {
    Jwt(JwtPrincipal),
    ApiKey(ApiKeyPrincipal),
}

pub(super) struct JwtPrincipal {
    pub user_id: UserId,
    pub trace_id: TraceId,
    pub roles: Vec<String>,
    pub attributes: BTreeMap<String, serde_json::Value>,
    pub act_chain: Vec<Actor>,
    attested_session: SessionId,
}

pub(super) struct ApiKeyPrincipal {
    pub user_id: UserId,
    pub trace_id: TraceId,
}

impl AuthedPrincipal {
    pub(super) const fn user_id(&self) -> &UserId {
        match self {
            Self::Jwt(p) => &p.user_id,
            Self::ApiKey(p) => &p.user_id,
        }
    }

    pub(super) const fn trace_id(&self) -> &TraceId {
        match self {
            Self::Jwt(p) => &p.trace_id,
            Self::ApiKey(p) => &p.trace_id,
        }
    }

    pub(super) const fn attested_session(&self) -> Option<&SessionId> {
        match self {
            Self::Jwt(p) => Some(&p.attested_session),
            Self::ApiKey(_) => None,
        }
    }

    pub(super) fn authz_attributes(
        &self,
    ) -> (Vec<String>, BTreeMap<String, serde_json::Value>, Vec<Actor>) {
        match self {
            Self::Jwt(p) => (p.roles.clone(), p.attributes.clone(), p.act_chain.clone()),
            Self::ApiKey(_) => (Vec::new(), BTreeMap::new(), Vec::new()),
        }
    }

    pub(super) fn enforce_session_binding(
        &self,
        header: &SessionId,
    ) -> Result<(), (StatusCode, String)> {
        match self {
            Self::Jwt(p) => p.enforce_session_binding(header),
            Self::ApiKey(_) => Ok(()),
        }
    }
}

impl JwtPrincipal {
    fn enforce_session_binding(&self, header: &SessionId) -> Result<(), (StatusCode, String)> {
        if self.attested_session.as_str() == header.as_str() {
            return Ok(());
        }
        tracing::warn!(
            header_session = %header.as_str(),
            jwt_session = %self.attested_session.as_str(),
            user_id = %self.user_id,
            "X-Session-ID header does not match bearer JWT session_id; rejecting"
        );
        Err((
            StatusCode::UNAUTHORIZED,
            "X-Session-ID does not match authenticated session".to_owned(),
        ))
    }
}

pub(super) async fn authenticate(
    credential: &str,
    jwt_extractor: &JwtContextExtractor,
    ctx: &AppContext,
) -> Result<AuthedPrincipal, (StatusCode, String)> {
    if credential.starts_with(API_KEY_PREFIX) {
        return authenticate_api_key(credential, ctx).await;
    }
    authenticate_jwt(credential, jwt_extractor).await
}

async fn authenticate_api_key(
    credential: &str,
    ctx: &AppContext,
) -> Result<AuthedPrincipal, (StatusCode, String)> {
    let service = ApiKeyService::new(ctx.db_pool()).map_err(|e| {
        (
            StatusCode::INTERNAL_SERVER_ERROR,
            format!("API key service unavailable: {e}"),
        )
    })?;
    let record = service.verify(credential).await.map_err(|e| {
        (
            StatusCode::INTERNAL_SERVER_ERROR,
            format!("API key verification failed: {e}"),
        )
    })?;
    match record {
        Some(rec) => Ok(AuthedPrincipal::ApiKey(ApiKeyPrincipal {
            user_id: rec.user_id,
            trace_id: TraceId::generate(),
        })),
        None => Err((
            StatusCode::UNAUTHORIZED,
            "Invalid or revoked API key".to_owned(),
        )),
    }
}

async fn authenticate_jwt(
    credential: &str,
    jwt_extractor: &JwtContextExtractor,
) -> Result<AuthedPrincipal, (StatusCode, String)> {
    let jwt_token = JwtToken::new(credential);
    let (claims, user) = jwt_extractor
        .decode_for_gateway(&jwt_token)
        .await
        .map_err(|e| (StatusCode::UNAUTHORIZED, e.to_string()))?;

    Ok(AuthedPrincipal::Jwt(JwtPrincipal {
        user_id: claims.user_id,
        trace_id: TraceId::generate(),
        roles: user.roles,
        attributes: claims.attributes,
        act_chain: claims.act_chain,
        attested_session: claims.session_id,
    }))
}