systemprompt-oauth 0.10.2

OAuth 2.0 / OIDC with PKCE, token introspection, and audience/issuer validation for systemprompt.io AI governance infrastructure. WebAuthn and JWT auth for the MCP governance pipeline.
Documentation
//! Aggregated OAuth repository orchestrating client, code, and token
//! operations.

mod auth_code;
mod cleanup;
mod refresh_token;
mod scopes;
mod user;

pub use auth_code::{AuthCodeParams, AuthCodeValidationResult};
pub use refresh_token::RefreshTokenParams;

use super::{ClientRepository, CreateClientParams, UpdateClientParams};
use crate::error::{OauthError, OauthResult};
use crate::models::OAuthClient;
use crate::services::generate_client_secret;
use sqlx::PgPool;
use std::sync::Arc;
use std::time::Instant;
use systemprompt_database::DbPool;
use systemprompt_identifiers::ClientId;
use tracing::instrument;
use uuid::Uuid;

#[derive(Clone, Debug)]
pub struct OAuthRepository {
    pool: Arc<PgPool>,
    write_pool: Arc<PgPool>,
    client_repo: ClientRepository,
}

impl OAuthRepository {
    pub fn new(db: &DbPool) -> OauthResult<Self> {
        let pool = db.pool_arc()?;
        let write_pool = db.write_pool_arc()?;
        let client_repo = ClientRepository::new(db)?;
        Ok(Self {
            pool,
            write_pool,
            client_repo,
        })
    }

    pub fn pool_ref(&self) -> &PgPool {
        &self.pool
    }

    pub fn write_pool_ref(&self) -> &PgPool {
        &self.write_pool
    }

    #[instrument(skip(self, params), fields(client_id = %params.client_id, client_name = %params.client_name))]
    pub async fn create_client(&self, params: CreateClientParams) -> OauthResult<OAuthClient> {
        let start_time = Instant::now();

        tracing::info!("Creating OAuth client");

        let client_repo = &self.client_repo;
        let client_id = params.client_id.clone();
        let client_name = params.client_name.clone();
        let scopes = params.scopes.clone();
        let redirect_uris = params.redirect_uris.clone();

        match client_repo.create(params).await {
            Ok(client) => {
                let duration = start_time.elapsed();

                tracing::info!(
                    client_id = %client_id,
                    client_name = %client_name,
                    scopes = ?scopes,
                    redirect_uris = ?redirect_uris,
                    created_in_ms = duration.as_millis(),
                    "OAuth client created"
                );

                if duration.as_millis() > 500 {
                    tracing::warn!(
                        client_id = %client_id,
                        duration_ms = duration.as_millis(),
                        "Slow OAuth client creation"
                    );
                }

                Ok(client)
            },
            Err(e) => {
                let duration = start_time.elapsed();
                tracing::error!(
                    error = %e,
                    client_id = %client_id,
                    duration_ms = duration.as_millis(),
                    "OAuth client creation failed"
                );
                Err(e)
            },
        }
    }

    pub async fn list_clients(&self) -> OauthResult<Vec<OAuthClient>> {
        let client_repo = &self.client_repo;
        client_repo.list().await
    }

    pub async fn list_clients_paginated(
        &self,
        limit: i32,
        offset: i32,
    ) -> OauthResult<Vec<OAuthClient>> {
        let client_repo = &self.client_repo;
        client_repo.list_paginated(limit, offset).await
    }

    pub async fn count_clients(&self) -> OauthResult<i64> {
        let client_repo = &self.client_repo;
        client_repo.count().await
    }

    pub async fn find_client_by_id(
        &self,
        client_id: &ClientId,
    ) -> OauthResult<Option<OAuthClient>> {
        let client_repo = &self.client_repo;
        client_repo.get_by_client_id(client_id).await
    }

    pub async fn find_client_by_redirect_uri(
        &self,
        redirect_uri: &str,
    ) -> OauthResult<Option<OAuthClient>> {
        let client_repo = &self.client_repo;
        client_repo.find_by_redirect_uri(redirect_uri).await
    }

    pub async fn find_client_by_redirect_uri_with_scope(
        &self,
        redirect_uri: &str,
        required_scopes: &[&str],
    ) -> OauthResult<Option<OAuthClient>> {
        let client_repo = &self.client_repo;
        client_repo
            .find_by_redirect_uri_with_scope(redirect_uri, required_scopes)
            .await
    }

    pub async fn update_client(
        &self,
        client_id: &ClientId,
        name: Option<&str>,
        redirect_uris: Option<&[String]>,
        scopes: Option<&[String]>,
    ) -> OauthResult<OAuthClient> {
        let updated_name = match name {
            Some(n) if !n.is_empty() => n.to_string(),
            _ => {
                return Err(OauthError::Validation(
                    "Client name is required for update".to_string(),
                ));
            },
        };

        let updated_redirect_uris =
            redirect_uris
                .filter(|uris| !uris.is_empty())
                .ok_or_else(|| {
                    OauthError::Validation("At least one redirect URI required".to_string())
                })?;

        let updated_scopes = scopes
            .filter(|s| !s.is_empty())
            .ok_or_else(|| OauthError::Validation("At least one scope required".to_string()))?;

        let client = self
            .find_client_by_id(client_id)
            .await?
            .ok_or_else(|| OauthError::Validation("Client not found".to_string()))?;

        let client_repo = &self.client_repo;
        let params = UpdateClientParams {
            client_id: client_id.clone(),
            client_name: updated_name,
            redirect_uris: updated_redirect_uris.to_vec(),
            grant_types: Some(client.grant_types.clone()),
            response_types: Some(client.response_types.clone()),
            scopes: updated_scopes.to_vec(),
            token_endpoint_auth_method: Some(client.token_endpoint_auth_method.clone()),
            client_uri: client.client_uri.clone(),
            logo_uri: client.logo_uri.clone(),
            contacts: client.contacts.clone(),
        };
        let updated = client_repo.update(params).await?;

        updated.ok_or_else(|| OauthError::Validation("Client not found".to_string()))
    }

    pub async fn update_client_secret(
        &self,
        client_id: &ClientId,
        client_secret_hash: &str,
    ) -> OauthResult<Option<OAuthClient>> {
        self.client_repo
            .update_secret(client_id, client_secret_hash)
            .await
    }

    pub async fn update_client_full(&self, client: &OAuthClient) -> OauthResult<OAuthClient> {
        let client_repo = &self.client_repo;
        let params = UpdateClientParams {
            client_id: client.client_id.clone(),
            client_name: client.client_name.clone(),
            redirect_uris: client.redirect_uris.clone(),
            grant_types: Some(client.grant_types.clone()),
            response_types: Some(client.response_types.clone()),
            scopes: client.scopes.clone(),
            token_endpoint_auth_method: Some(client.token_endpoint_auth_method.clone()),
            client_uri: client.client_uri.clone(),
            logo_uri: client.logo_uri.clone(),
            contacts: client.contacts.clone(),
        };
        let updated = client_repo.update(params).await?;

        updated.ok_or_else(|| OauthError::Validation("Client not found".to_string()))
    }

    pub async fn delete_client(&self, client_id: &ClientId) -> OauthResult<bool> {
        let client_repo = &self.client_repo;
        let rows_affected = client_repo.delete(client_id).await?;
        Ok(rows_affected > 0)
    }

    #[must_use]
    pub fn generate_client_secret() -> String {
        generate_client_secret()
    }

    #[must_use]
    pub fn generate_client_id() -> String {
        format!("client_{}", Uuid::new_v4().simple())
    }
}