systemprompt-api 0.1.18

HTTP API server and gateway for systemprompt.io OS
Documentation
use anyhow::{Result, anyhow};
use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};

use systemprompt_identifiers::{ClientId, SessionId, UserId};
use systemprompt_models::auth::UserType;
use systemprompt_oauth::models::JwtClaims;

#[derive(Debug, Clone)]
pub struct JwtUserContext {
    pub user_id: UserId,
    pub session_id: SessionId,
    pub role: systemprompt_models::auth::Permission,
    pub user_type: UserType,
    pub client_id: Option<ClientId>,
}

pub struct JwtExtractor {
    decoding_key: DecodingKey,
    validation: Validation,
}

impl std::fmt::Debug for JwtExtractor {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("JwtExtractor")
            .field("decoding_key", &"<DecodingKey>")
            .field("validation", &self.validation)
            .finish()
    }
}

impl JwtExtractor {
    pub fn new(jwt_secret: &str) -> Self {
        let mut validation = Validation::new(Algorithm::HS256);
        validation.validate_exp = true;
        validation.validate_aud = false;

        Self {
            decoding_key: DecodingKey::from_secret(jwt_secret.as_bytes()),
            validation,
        }
    }

    pub fn validate_token(&self, token: &str) -> Result<(), String> {
        match decode::<JwtClaims>(token, &self.decoding_key, &self.validation) {
            Ok(_) => Ok(()),
            Err(err) => {
                let reason = err.to_string();
                if reason.contains("InvalidSignature") || reason.contains("invalid signature") {
                    Err("Invalid signature".to_string())
                } else if reason.contains("ExpiredSignature") || reason.contains("token expired") {
                    Err("Token expired".to_string())
                } else if reason.contains("MissingRequiredClaim") || reason.contains("missing") {
                    Err("Missing required claim".to_string())
                } else {
                    Err("Invalid token".to_string())
                }
            },
        }
    }

    pub fn extract_user_context(&self, token: &str) -> Result<JwtUserContext> {
        let token_data = decode::<JwtClaims>(token, &self.decoding_key, &self.validation)?;

        let session_id_str = token_data
            .claims
            .session_id
            .ok_or_else(|| anyhow!("JWT must contain session_id claim"))?;

        let role = *token_data
            .claims
            .scope
            .first()
            .ok_or_else(|| anyhow!("JWT must contain valid scope claim"))?;

        let client_id = token_data.claims.client_id.map(ClientId::new);

        Ok(JwtUserContext {
            user_id: UserId::new(token_data.claims.sub),
            session_id: SessionId::new(session_id_str),
            role,
            user_type: token_data.claims.user_type,
            client_id,
        })
    }
}