systemprompt-oauth 0.1.22

OAuth 2.0 authentication and authorization module for systemprompt.io OS
Documentation
use async_trait::async_trait;
use std::sync::Arc;
use systemprompt_models::auth::JwtAudience;
use systemprompt_traits::{
    AuthAction, AuthPermission, AuthProvider, AuthProviderError, AuthResult, AuthorizationProvider,
    TokenClaims, TokenPair,
};

use crate::models::JwtClaims as OAuthJwtClaims;
use crate::services::validation::jwt as jwt_validation;

#[derive(Debug, Clone)]
pub struct JwtAuthProvider {
    secret: String,
    issuer: String,
    audiences: Vec<JwtAudience>,
}

impl JwtAuthProvider {
    pub const fn new(secret: String, issuer: String, audiences: Vec<JwtAudience>) -> Self {
        Self {
            secret,
            issuer,
            audiences,
        }
    }

    pub fn from_config() -> anyhow::Result<Self> {
        let config = systemprompt_models::Config::get()?;
        Ok(Self {
            secret: systemprompt_models::SecretsBootstrap::jwt_secret()?.to_string(),
            issuer: config.jwt_issuer.clone(),
            audiences: config.jwt_audiences.clone(),
        })
    }
}

fn convert_claims(claims: OAuthJwtClaims) -> TokenClaims {
    TokenClaims {
        subject: claims.sub,
        username: claims.username,
        email: Some(claims.email),
        audiences: claims.aud.iter().map(ToString::to_string).collect(),
        permissions: claims.scope.iter().map(ToString::to_string).collect(),
        expires_at: claims.exp,
        issued_at: claims.iat,
    }
}

#[async_trait]
impl AuthProvider for JwtAuthProvider {
    async fn validate_token(&self, token: &str) -> AuthResult<TokenClaims> {
        let claims =
            jwt_validation::validate_jwt_token(token, &self.secret, &self.issuer, &self.audiences)
                .map_err(|e| {
                    AuthProviderError::Internal(format!("Token validation failed: {e}"))
                })?;

        Ok(convert_claims(claims))
    }

    async fn refresh_token(&self, _refresh_token: &str) -> AuthResult<TokenPair> {
        Err(AuthProviderError::Internal(
            "Token refresh not yet implemented via trait".to_string(),
        ))
    }

    async fn revoke_token(&self, _token: &str) -> AuthResult<()> {
        Err(AuthProviderError::Internal(
            "Token revocation not yet implemented via trait".to_string(),
        ))
    }
}

#[derive(Debug, Clone)]
pub struct JwtAuthorizationProvider {
    secret: String,
    issuer: String,
    audiences: Vec<JwtAudience>,
}

impl JwtAuthorizationProvider {
    pub const fn new(secret: String, issuer: String, audiences: Vec<JwtAudience>) -> Self {
        Self {
            secret,
            issuer,
            audiences,
        }
    }
}

#[async_trait]
impl AuthorizationProvider for JwtAuthorizationProvider {
    async fn authorize(
        &self,
        _user_id: &str,
        _resource: &str,
        _action: &AuthAction,
    ) -> AuthResult<bool> {
        Ok(true)
    }

    async fn get_permissions(&self, _user_id: &str) -> AuthResult<Vec<AuthPermission>> {
        Ok(vec![])
    }

    async fn has_audience(&self, token: &str, audience: &str) -> AuthResult<bool> {
        let claims =
            jwt_validation::validate_jwt_token(token, &self.secret, &self.issuer, &self.audiences)
                .map_err(|e| {
                    AuthProviderError::Internal(format!("Token validation failed: {e}"))
                })?;

        let has_aud = claims.aud.iter().any(|a| a.to_string() == audience);
        Ok(has_aud)
    }
}

#[derive(Clone)]
pub struct TraitBasedAuthService {
    auth_provider: Arc<dyn AuthProvider>,
    authorization_provider: Arc<dyn AuthorizationProvider>,
}

impl std::fmt::Debug for TraitBasedAuthService {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("TraitBasedAuthService")
            .field("auth_provider", &"AuthProvider")
            .field("authorization_provider", &"AuthorizationProvider")
            .finish()
    }
}

impl TraitBasedAuthService {
    pub fn new(
        auth_provider: Arc<dyn AuthProvider>,
        authorization_provider: Arc<dyn AuthorizationProvider>,
    ) -> Self {
        Self {
            auth_provider,
            authorization_provider,
        }
    }

    pub fn from_config() -> anyhow::Result<Self> {
        let config = systemprompt_models::Config::get()?;
        let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret()?.to_string();
        let auth = Arc::new(JwtAuthProvider::new(
            jwt_secret.clone(),
            config.jwt_issuer.clone(),
            config.jwt_audiences.clone(),
        ));
        let authz = Arc::new(JwtAuthorizationProvider::new(
            jwt_secret,
            config.jwt_issuer.clone(),
            config.jwt_audiences.clone(),
        ));
        Ok(Self::new(auth, authz))
    }

    pub fn auth_provider(&self) -> &Arc<dyn AuthProvider> {
        &self.auth_provider
    }

    pub fn authorization_provider(&self) -> &Arc<dyn AuthorizationProvider> {
        &self.authorization_provider
    }

    pub async fn validate_token(&self, token: &str) -> AuthResult<TokenClaims> {
        self.auth_provider.validate_token(token).await
    }

    pub async fn has_audience(&self, token: &str, audience: &str) -> AuthResult<bool> {
        self.authorization_provider
            .has_audience(token, audience)
            .await
    }
}