systemprompt-api 0.11.0

Axum-based HTTP server and API gateway for systemprompt.io AI governance infrastructure. Exposes governed agents, MCP, A2A, and admin endpoints with rate limiting and RBAC.
Documentation
use super::generation::{
    ClientTokenOptions, TokenExchangeRequest, TokenGenerationParams, generate_client_tokens,
    generate_tokens_by_user_id, handle_token_exchange,
};
use super::validation::{
    AuthCodeValidationParams, extract_required_field, validate_authorization_code,
    validate_client_credentials,
};
use super::{TokenError, TokenRequest, TokenResponse};
use axum::extract::{Extension, State};
use axum::http::{HeaderMap, StatusCode};
use axum::response::{IntoResponse, Response};
use axum::{Form, Json};
use systemprompt_identifiers::{AuthorizationCode, ClientId, RefreshTokenId};
use systemprompt_models::RequestContext;
use systemprompt_oauth::repository::OAuthRepository;
use systemprompt_oauth::{GrantType, OAuthState};
use tracing::instrument;

use crate::routes::oauth::OAuthHttpError;
use crate::routes::oauth::extractors::OAuthRepo;

#[instrument(skip(state, _req_ctx, headers, request, repo), fields(grant_type = %request.grant_type))]
pub async fn handle_token(
    Extension(_req_ctx): Extension<RequestContext>,
    State(state): State<OAuthState>,
    OAuthRepo(repo): OAuthRepo,
    headers: HeaderMap,
    Form(request): Form<TokenRequest>,
) -> Result<Response, OAuthHttpError> {
    tracing::info!(grant_type = %request.grant_type, "Token request received");

    let parsed = request.grant_type.parse::<GrantType>().ok();
    let response = match parsed {
        Some(GrantType::AuthorizationCode) => {
            handle_authorization_code_grant(repo, request, &headers, &state).await?
        },
        Some(GrantType::RefreshToken) => {
            handle_refresh_token_grant(repo, request, &headers, &state).await?
        },
        Some(GrantType::ClientCredentials) => {
            handle_client_credentials_grant(repo, request, &headers, &state).await?
        },
        Some(GrantType::TokenExchange) => {
            handle_token_exchange_grant(repo, request, &headers, &state).await?
        },
        None => {
            return Err(TokenError::UnsupportedGrantType {
                grant_type: request.grant_type.clone(),
            }
            .into());
        },
    };
    Ok((StatusCode::OK, Json(response)).into_response())
}

async fn handle_authorization_code_grant(
    repo: OAuthRepository,
    request: TokenRequest,
    headers: &HeaderMap,
    state: &OAuthState,
) -> Result<TokenResponse, TokenError> {
    let code_str = extract_required_field(request.code.as_deref(), "code")?;
    let code = AuthorizationCode::new(code_str);

    let client_id = if let Some(id) = request.client_id.as_deref() {
        ClientId::new(id)
    } else {
        repo.get_client_id_from_auth_code(&code)
            .await
            .map_err(|e| TokenError::ServerError {
                message: format!("Failed to lookup authorization code: {e}"),
            })?
            .ok_or_else(|| TokenError::InvalidGrant {
                reason: "Invalid or expired authorization code".to_string(),
            })?
    };

    validate_client_credentials(&repo, &client_id, request.client_secret.as_deref())
        .await
        .map_err(|_| TokenError::InvalidClientSecret)?;

    let validation_result = validate_authorization_code(AuthCodeValidationParams {
        repo: &repo,
        code: &code,
        client_id: &client_id,
        redirect_uri: request.redirect_uri.as_deref(),
        code_verifier: request.code_verifier.as_deref(),
        request_resource: request.resource.as_deref(),
    })
    .await
    .map_err(|e: anyhow::Error| TokenError::InvalidGrant {
        reason: e.to_string(),
    })?;

    let generated = generate_tokens_by_user_id(
        &repo,
        TokenGenerationParams {
            client_id: &client_id,
            user_id: &validation_result.user_id,
            scope: Some(&validation_result.scope),
            headers,
            resource: validation_result.resource.as_deref(),
            family_id: None,
        },
        state,
    )
    .await
    .map_err(|e| TokenError::ServerError {
        message: e.to_string(),
    })?;

    if let Err(e) = repo
        .link_auth_code_to_refresh_token(&code, &generated.refresh_token_id)
        .await
    {
        tracing::warn!(error = %e, "Failed to link auth code to refresh token");
    }

    let token_response = generated.response;
    tracing::info!(
        grant_type = "authorization_code",
        client_id = %client_id,
        user_id = %validation_result.user_id,
        scope = %validation_result.scope,
        resource = ?validation_result.resource,
        token_type = %token_response.token_type,
        expires_in = token_response.expires_in,
        "Token issued"
    );

    Ok(token_response)
}

async fn handle_refresh_token_grant(
    repo: OAuthRepository,
    request: TokenRequest,
    headers: &HeaderMap,
    state: &OAuthState,
) -> Result<TokenResponse, TokenError> {
    let refresh_token_str =
        extract_required_field(request.refresh_token.as_deref(), "refresh_token")?;
    let refresh_token = RefreshTokenId::new(refresh_token_str);

    let client_id = if let Some(id) = request.client_id.as_deref() {
        ClientId::new(id)
    } else {
        repo.get_client_id_from_refresh_token(&refresh_token)
            .await
            .map_err(|e| TokenError::ServerError {
                message: format!("Failed to lookup refresh token: {e}"),
            })?
            .ok_or_else(|| TokenError::InvalidRefreshToken {
                reason: "Invalid refresh token".to_string(),
            })?
    };

    validate_client_credentials(&repo, &client_id, request.client_secret.as_deref())
        .await
        .map_err(|_| TokenError::InvalidClientSecret)?;

    let consumed = repo
        .consume_refresh_token(&refresh_token, &client_id)
        .await
        .map_err(|e| TokenError::InvalidRefreshToken {
            reason: e.to_string(),
        })?;
    let user_id = consumed.user_id;
    let original_scope = consumed.scope;
    let family_id = consumed.family_id;

    let effective_scope = if let Some(requested_scope) = request.scope.as_deref() {
        let original_scopes = OAuthRepository::parse_scopes(&original_scope);
        let requested_scopes = OAuthRepository::parse_scopes(requested_scope);

        for requested in &requested_scopes {
            if !original_scopes.contains(requested) {
                return Err(TokenError::InvalidRequest {
                    field: "scope".to_string(),
                    message: format!("Requested scope '{requested}' not in original scope"),
                });
            }
        }
        requested_scope
    } else {
        &original_scope
    };

    let generated = generate_tokens_by_user_id(
        &repo,
        TokenGenerationParams {
            client_id: &client_id,
            user_id: &user_id,
            scope: Some(effective_scope),
            headers,
            resource: request.resource.as_deref(),
            family_id: Some(family_id.as_str()),
        },
        state,
    )
    .await
    .map_err(|e| TokenError::ServerError {
        message: e.to_string(),
    })?;

    let token_response = generated.response;
    tracing::info!(
        grant_type = "refresh_token",
        client_id = %client_id,
        user_id = %user_id,
        scope = %effective_scope,
        token_type = %token_response.token_type,
        expires_in = token_response.expires_in,
        "Token issued"
    );

    Ok(token_response)
}

async fn handle_token_exchange_grant(
    repo: OAuthRepository,
    request: TokenRequest,
    headers: &HeaderMap,
    state: &OAuthState,
) -> Result<TokenResponse, TokenError> {
    let subject_token = extract_required_field(request.subject_token.as_deref(), "subject_token")?;
    let subject_token_type =
        extract_required_field(request.subject_token_type.as_deref(), "subject_token_type")?;

    let client_id_str = extract_required_field(request.client_id.as_deref(), "client_id")?;
    let client_id = ClientId::new(client_id_str);
    validate_client_credentials(&repo, &client_id, request.client_secret.as_deref())
        .await
        .map_err(|_| TokenError::InvalidClientSecret)?;

    let exchange = TokenExchangeRequest {
        subject_token,
        subject_token_type,
        actor_token: request.actor_token.as_deref(),
        actor_token_type: request.actor_token_type.as_deref(),
        requested_token_type: request.requested_token_type.as_deref(),
        scope: request.scope.as_deref(),
        audience: request.audience.as_deref(),
        resource: request.resource.as_deref(),
    };

    let response = handle_token_exchange(&repo, &client_id, exchange, headers, state)
        .await
        .map_err(|e| map_exchange_error(&e))?;

    tracing::info!(
        grant_type = "urn:ietf:params:oauth:grant-type:token-exchange",
        client_id = %client_id,
        scope = %response.scope.as_deref().unwrap_or(""),
        "Token exchanged"
    );

    Ok(response)
}

fn map_exchange_error(err: &anyhow::Error) -> TokenError {
    if let Some(token_err) = err.downcast_ref::<TokenError>() {
        return clone_token_error(token_err);
    }
    TokenError::ServerError {
        message: err.to_string(),
    }
}

fn clone_token_error(err: &TokenError) -> TokenError {
    match err {
        TokenError::InvalidRequest { field, message } => TokenError::InvalidRequest {
            field: field.clone(),
            message: message.clone(),
        },
        TokenError::UnsupportedGrantType { grant_type } => TokenError::UnsupportedGrantType {
            grant_type: grant_type.clone(),
        },
        TokenError::InvalidClient => TokenError::InvalidClient,
        TokenError::InvalidGrant { reason } => TokenError::InvalidGrant {
            reason: reason.clone(),
        },
        TokenError::InvalidRefreshToken { reason } => TokenError::InvalidRefreshToken {
            reason: reason.clone(),
        },
        TokenError::InvalidCredentials => TokenError::InvalidCredentials,
        TokenError::InvalidClientSecret => TokenError::InvalidClientSecret,
        TokenError::ExpiredCode => TokenError::ExpiredCode,
        TokenError::ServerError { message } => TokenError::ServerError {
            message: message.clone(),
        },
        TokenError::InvalidTarget { message } => TokenError::InvalidTarget {
            message: message.clone(),
        },
    }
}

async fn handle_client_credentials_grant(
    repo: OAuthRepository,
    request: TokenRequest,
    headers: &HeaderMap,
    state: &OAuthState,
) -> Result<TokenResponse, TokenError> {
    let client_id_str = extract_required_field(request.client_id.as_deref(), "client_id")?;
    let client_id = ClientId::new(client_id_str);

    validate_client_credentials(&repo, &client_id, request.client_secret.as_deref())
        .await
        .map_err(|_| TokenError::InvalidClientSecret)?;

    let options = ClientTokenOptions {
        scope: request.scope.as_deref(),
        plugin_id: request.plugin_id.as_deref(),
        audience: request.audience.as_deref(),
    };
    let token_response = generate_client_tokens(&repo, &client_id, headers, state, options)
        .await
        .map_err(|e| TokenError::ServerError {
            message: e.to_string(),
        })?;

    tracing::info!(
        grant_type = "client_credentials",
        client_id = %client_id,
        scope = %token_response.scope.as_deref().unwrap_or(""),
        token_type = %token_response.token_type,
        expires_in = token_response.expires_in,
        "Token issued"
    );

    Ok(token_response)
}