Skip to main content

systemprompt_security/auth/
validation.rs

1use axum::http::HeaderMap;
2use systemprompt_identifiers::{AgentName, ContextId, SessionId, TraceId, UserId};
3use systemprompt_models::auth::{JwtAudience, JwtClaims, Permission, UserType};
4use systemprompt_models::execution::context::RequestContext;
5
6use crate::error::{AuthError, AuthResult};
7use crate::extraction::HeaderExtractor;
8use crate::session::ValidatedSessionClaims;
9
10const ANONYMOUS_SESSION_ID: &str = "anonymous";
11const TEST_SESSION_ID: &str = "test";
12const TEST_TRACE_ID: &str = "test-trace";
13const TEST_AGENT_NAME: &str = "test-agent";
14const TEST_USER_ID: &str = "test-user";
15const BEARER_PREFIX: &str = "Bearer ";
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum AuthMode {
19    Required,
20    Optional,
21    Disabled,
22}
23
24#[derive(Debug)]
25pub struct AuthValidationService {
26    secret: String,
27    issuer: String,
28    audiences: Vec<JwtAudience>,
29}
30
31impl AuthValidationService {
32    #[must_use]
33    pub const fn new(secret: String, issuer: String, audiences: Vec<JwtAudience>) -> Self {
34        Self {
35            secret,
36            issuer,
37            audiences,
38        }
39    }
40
41    pub fn validate_request(
42        &self,
43        headers: &HeaderMap,
44        mode: AuthMode,
45    ) -> AuthResult<RequestContext> {
46        match mode {
47            AuthMode::Required => self.validate_and_fail_fast(headers),
48            AuthMode::Optional => Ok(self.try_validate_or_anonymous(headers)),
49            AuthMode::Disabled => Ok(Self::create_test_context()),
50        }
51    }
52
53    fn validate_and_fail_fast(&self, headers: &HeaderMap) -> AuthResult<RequestContext> {
54        let token = Self::extract_token(headers).ok_or(AuthError::MissingAuthorization)?;
55
56        let claims = self.validate_token(token)?;
57        Ok(Self::create_context_from_claims(&claims, token, headers))
58    }
59
60    fn try_validate_or_anonymous(&self, headers: &HeaderMap) -> RequestContext {
61        Self::extract_token(headers).map_or_else(
62            || Self::create_anonymous_context(headers),
63            |token| {
64                self.validate_token(token)
65                    .map_err(|e| {
66                        tracing::debug!(error = %e, "Token validation failed, falling back to anonymous");
67                        e
68                    })
69                    .map_or_else(
70                        |_| Self::create_anonymous_context(headers),
71                        |claims| Self::create_context_from_claims(&claims, token, headers),
72                    )
73            },
74        )
75    }
76
77    fn extract_token(headers: &HeaderMap) -> Option<&str> {
78        headers
79            .get("authorization")
80            .and_then(|h| {
81                h.to_str()
82                    .map_err(|e| {
83                        tracing::debug!(error = %e, "Authorization header contains non-ASCII characters");
84                        e
85                    })
86                    .ok()
87            })
88            .and_then(|s| s.strip_prefix(BEARER_PREFIX))
89    }
90
91    fn validate_token(&self, token: &str) -> AuthResult<ValidatedSessionClaims> {
92        use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
93
94        let mut validation = Validation::new(Algorithm::HS256);
95
96        validation.set_issuer(&[&self.issuer]);
97
98        let audience_strs: Vec<&str> = self.audiences.iter().map(JwtAudience::as_str).collect();
99        validation.set_audience(&audience_strs);
100
101        let token_data = decode::<JwtClaims>(
102            token,
103            &DecodingKey::from_secret(self.secret.as_bytes()),
104            &validation,
105        )
106        .map_err(AuthError::InvalidToken)?;
107
108        let claims = token_data.claims;
109
110        let user_type = if claims.scope.contains(&Permission::Admin) {
111            UserType::Admin
112        } else {
113            claims.user_type
114        };
115
116        Ok(ValidatedSessionClaims {
117            user_id: UserId::new(claims.sub),
118            session_id: claims
119                .session_id
120                .map(SessionId::new)
121                .ok_or(AuthError::MissingSessionId)?,
122            user_type,
123        })
124    }
125
126    fn create_context_from_claims(
127        claims: &ValidatedSessionClaims,
128        token: &str,
129        headers: &HeaderMap,
130    ) -> RequestContext {
131        let session_id = claims.session_id.clone();
132        let user_id = claims.user_id.clone();
133
134        RequestContext::new(
135            session_id,
136            HeaderExtractor::extract_trace_id(headers),
137            HeaderExtractor::extract_context_id(headers).unwrap_or_else(ContextId::generate),
138            HeaderExtractor::extract_agent_name(headers),
139        )
140        .with_user_id(user_id)
141        .with_auth_token(token)
142        .with_user_type(claims.user_type)
143    }
144
145    fn create_anonymous_context(headers: &HeaderMap) -> RequestContext {
146        RequestContext::new(
147            SessionId::new(ANONYMOUS_SESSION_ID.to_string()),
148            HeaderExtractor::extract_trace_id(headers),
149            HeaderExtractor::extract_context_id(headers).unwrap_or_else(ContextId::generate),
150            HeaderExtractor::extract_agent_name(headers),
151        )
152        .with_user_id(UserId::anonymous())
153        .with_user_type(UserType::Anon)
154    }
155
156    fn create_test_context() -> RequestContext {
157        RequestContext::new(
158            SessionId::new(TEST_SESSION_ID.to_string()),
159            TraceId::new(TEST_TRACE_ID.to_string()),
160            ContextId::generate(),
161            AgentName::new(TEST_AGENT_NAME.to_string()),
162        )
163        .with_user_id(UserId::new(TEST_USER_ID.to_string()))
164        .with_user_type(UserType::User)
165    }
166}