systemprompt-security 0.1.22

Security module for systemprompt.io - authentication, authorization, JWT, and token extraction
Documentation
use anyhow::{Result, anyhow};
use axum::http::HeaderMap;
use systemprompt_identifiers::{AgentName, ContextId, SessionId, TraceId, UserId};
use systemprompt_models::auth::{JwtAudience, JwtClaims, Permission, UserType};
use systemprompt_models::execution::context::RequestContext;

use crate::extraction::HeaderExtractor;
use crate::session::ValidatedSessionClaims;

const ANONYMOUS_SESSION_ID: &str = "anonymous";
const TEST_SESSION_ID: &str = "test";
const TEST_TRACE_ID: &str = "test-trace";
const TEST_CONTEXT_ID: &str = "test-context";
const TEST_AGENT_NAME: &str = "test-agent";
const TEST_USER_ID: &str = "test-user";
const BEARER_PREFIX: &str = "Bearer ";

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AuthMode {
    Required,
    Optional,
    Disabled,
}

#[derive(Debug)]
pub struct AuthValidationService {
    secret: String,
    issuer: String,
    audiences: Vec<JwtAudience>,
}

impl AuthValidationService {
    pub const fn new(secret: String, issuer: String, audiences: Vec<JwtAudience>) -> Self {
        Self {
            secret,
            issuer,
            audiences,
        }
    }

    pub fn validate_request(&self, headers: &HeaderMap, mode: AuthMode) -> Result<RequestContext> {
        match mode {
            AuthMode::Required => self.validate_and_fail_fast(headers),
            AuthMode::Optional => Ok(self.try_validate_or_anonymous(headers)),
            AuthMode::Disabled => Ok(Self::create_test_context()),
        }
    }

    fn validate_and_fail_fast(&self, headers: &HeaderMap) -> Result<RequestContext> {
        let token =
            Self::extract_token(headers).ok_or_else(|| anyhow!("Missing authorization header"))?;

        let claims = self.validate_token(token)?;
        Ok(Self::create_context_from_claims(&claims, token, headers))
    }

    fn try_validate_or_anonymous(&self, headers: &HeaderMap) -> RequestContext {
        Self::extract_token(headers).map_or_else(
            || Self::create_anonymous_context(headers),
            |token| {
                self.validate_token(token)
                    .map_err(|e| {
                        tracing::debug!(error = %e, "Token validation failed, falling back to anonymous");
                        e
                    })
                    .map_or_else(
                        |_| Self::create_anonymous_context(headers),
                        |claims| Self::create_context_from_claims(&claims, token, headers),
                    )
            },
        )
    }

    fn extract_token(headers: &HeaderMap) -> Option<&str> {
        headers
            .get("authorization")
            .and_then(|h| {
                h.to_str()
                    .map_err(|e| {
                        tracing::debug!(error = %e, "Authorization header contains non-ASCII characters");
                        e
                    })
                    .ok()
            })
            .and_then(|s| s.strip_prefix(BEARER_PREFIX))
    }

    fn validate_token(&self, token: &str) -> Result<ValidatedSessionClaims> {
        use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};

        let mut validation = Validation::new(Algorithm::HS256);

        validation.set_issuer(&[&self.issuer]);

        let audience_strs: Vec<&str> = self.audiences.iter().map(JwtAudience::as_str).collect();
        validation.set_audience(&audience_strs);

        let token_data = decode::<JwtClaims>(
            token,
            &DecodingKey::from_secret(self.secret.as_bytes()),
            &validation,
        )
        .map_err(|e| anyhow!("Invalid JWT token: {e}"))?;

        let claims = token_data.claims;

        let user_type = if claims.scope.contains(&Permission::Admin) {
            UserType::Admin
        } else {
            claims.user_type
        };

        Ok(ValidatedSessionClaims {
            user_id: claims.sub,
            session_id: claims
                .session_id
                .ok_or_else(|| anyhow!("Missing session_id in token"))?,
            user_type,
        })
    }

    fn create_context_from_claims(
        claims: &ValidatedSessionClaims,
        token: &str,
        headers: &HeaderMap,
    ) -> RequestContext {
        let session_id = SessionId::new(claims.session_id.clone());
        let user_id = UserId::new(claims.user_id.clone());

        RequestContext::new(
            session_id,
            HeaderExtractor::extract_trace_id(headers),
            HeaderExtractor::extract_context_id(headers),
            HeaderExtractor::extract_agent_name(headers),
        )
        .with_user_id(user_id)
        .with_auth_token(token)
        .with_user_type(claims.user_type)
    }

    fn create_anonymous_context(headers: &HeaderMap) -> RequestContext {
        RequestContext::new(
            SessionId::new(ANONYMOUS_SESSION_ID.to_string()),
            HeaderExtractor::extract_trace_id(headers),
            HeaderExtractor::extract_context_id(headers),
            HeaderExtractor::extract_agent_name(headers),
        )
        .with_user_id(UserId::anonymous())
        .with_user_type(UserType::Anon)
    }

    fn create_test_context() -> RequestContext {
        RequestContext::new(
            SessionId::new(TEST_SESSION_ID.to_string()),
            TraceId::new(TEST_TRACE_ID.to_string()),
            ContextId::new(TEST_CONTEXT_ID.to_string()),
            AgentName::new(TEST_AGENT_NAME.to_string()),
        )
        .with_user_id(UserId::new(TEST_USER_ID.to_string()))
        .with_user_type(UserType::User)
    }
}