Skip to main content

systemprompt_security/auth/
validation.rs

1use axum::http::HeaderMap;
2use systemprompt_identifiers::{Actor, ContextId, SessionId, UserId};
3use systemprompt_models::auth::{
4    JwtAudience, JwtClaims, MAX_ACT_CHAIN_DEPTH, Permission, UserType,
5};
6use systemprompt_models::execution::context::RequestContext;
7
8use crate::error::{AuthError, AuthResult};
9use crate::extraction::HeaderExtractor;
10use crate::keys::authority;
11use crate::session::ValidatedSessionClaims;
12
13const ANONYMOUS_SESSION_ID: &str = "anonymous";
14const BEARER_PREFIX: &str = "Bearer ";
15
16/// Maximum clock-skew tolerance (seconds) for `exp`, `nbf`, and `iat`
17/// validation. Pinned explicitly so deployments see this value in code
18/// review rather than inheriting the `jsonwebtoken` default.
19pub(super) const JWT_LEEWAY_SECONDS: u64 = 30;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum AuthMode {
23    Required,
24    Optional,
25}
26
27#[derive(Debug)]
28pub struct AuthValidationService {
29    issuer: String,
30    audiences: Vec<JwtAudience>,
31}
32
33impl AuthValidationService {
34    #[must_use]
35    pub const fn new(issuer: String, audiences: Vec<JwtAudience>) -> Self {
36        Self { issuer, audiences }
37    }
38
39    pub fn validate_request(
40        &self,
41        headers: &HeaderMap,
42        mode: AuthMode,
43    ) -> AuthResult<RequestContext> {
44        match mode {
45            AuthMode::Required => self.validate_and_fail_fast(headers),
46            AuthMode::Optional => Ok(self.try_validate_or_anonymous(headers)),
47        }
48    }
49
50    fn validate_and_fail_fast(&self, headers: &HeaderMap) -> AuthResult<RequestContext> {
51        let token = Self::extract_token(headers).ok_or(AuthError::MissingAuthorization)?;
52
53        let claims = self.validate_token(token)?;
54        Ok(Self::create_context_from_claims(&claims, token, headers))
55    }
56
57    fn try_validate_or_anonymous(&self, headers: &HeaderMap) -> RequestContext {
58        Self::extract_token(headers).map_or_else(
59            || Self::create_anonymous_context(headers),
60            |token| {
61                self.validate_token(token)
62                    .map_err(|e| {
63                        tracing::debug!(error = %e, "Token validation failed, falling back to anonymous");
64                        e
65                    })
66                    .map_or_else(
67                        |_| Self::create_anonymous_context(headers),
68                        |claims| Self::create_context_from_claims(&claims, token, headers),
69                    )
70            },
71        )
72    }
73
74    fn extract_token(headers: &HeaderMap) -> Option<&str> {
75        headers
76            .get("authorization")
77            .and_then(|h| {
78                h.to_str()
79                    .map_err(|e| {
80                        tracing::debug!(error = %e, "Authorization header contains non-ASCII characters");
81                        e
82                    })
83                    .ok()
84            })
85            .and_then(|s| s.strip_prefix(BEARER_PREFIX))
86    }
87
88    fn validate_token(&self, token: &str) -> AuthResult<ValidatedSessionClaims> {
89        use jsonwebtoken::{Algorithm, Validation, decode, decode_header};
90
91        let header = decode_header(token).map_err(AuthError::InvalidToken)?;
92        if header.alg != Algorithm::RS256 {
93            return Err(AuthError::UnsupportedAlgorithm);
94        }
95        let kid = header.kid.as_deref().ok_or(AuthError::MissingKid)?;
96        let key = authority::decoding_key_for_kid(kid)
97            .map_err(|e| AuthError::KeyLookup(e.to_string()))?
98            .ok_or_else(|| AuthError::UnknownKid(kid.to_owned()))?;
99
100        let mut validation = Validation::new(Algorithm::RS256);
101        validation.leeway = JWT_LEEWAY_SECONDS;
102        validation.validate_nbf = true;
103
104        validation.set_issuer(&[&self.issuer]);
105
106        let audience_strs: Vec<&str> = self.audiences.iter().map(JwtAudience::as_str).collect();
107        validation.set_audience(&audience_strs);
108
109        let token_data =
110            decode::<JwtClaims>(token, key, &validation).map_err(AuthError::InvalidToken)?;
111
112        let claims = token_data.claims;
113
114        if let Some(ref act) = claims.act {
115            let depth = act.depth();
116            if depth > MAX_ACT_CHAIN_DEPTH {
117                return Err(AuthError::ActChainTooDeep {
118                    depth,
119                    max: MAX_ACT_CHAIN_DEPTH,
120                });
121            }
122        }
123
124        let user_type = if claims.scope.contains(&Permission::Admin) {
125            UserType::Admin
126        } else {
127            claims.user_type
128        };
129
130        Ok(ValidatedSessionClaims {
131            user_id: UserId::new(claims.sub),
132            session_id: claims
133                .session_id
134                .map(SessionId::new)
135                .ok_or(AuthError::MissingSessionId)?,
136            user_type,
137            jti: claims.jti,
138            exp: claims.exp,
139        })
140    }
141
142    fn create_context_from_claims(
143        claims: &ValidatedSessionClaims,
144        token: &str,
145        headers: &HeaderMap,
146    ) -> RequestContext {
147        let session_id = claims.session_id.clone();
148        let user_id = claims.user_id.clone();
149
150        RequestContext::new(
151            session_id,
152            HeaderExtractor::extract_trace_id(headers),
153            HeaderExtractor::extract_context_id(headers).unwrap_or_else(ContextId::generate),
154            HeaderExtractor::extract_agent_name(headers),
155        )
156        .with_actor(Actor::user(user_id))
157        .with_auth_token(token)
158        .with_user_type(claims.user_type)
159        .with_jti(claims.jti.clone())
160        .with_token_exp(claims.exp)
161    }
162
163    fn create_anonymous_context(headers: &HeaderMap) -> RequestContext {
164        RequestContext::new(
165            SessionId::new(ANONYMOUS_SESSION_ID.to_owned()),
166            HeaderExtractor::extract_trace_id(headers),
167            HeaderExtractor::extract_context_id(headers).unwrap_or_else(ContextId::generate),
168            HeaderExtractor::extract_agent_name(headers),
169        )
170        .with_actor(Actor::anonymous(
171            systemprompt_identifiers::bootstrap::anonymous(),
172        ))
173        .with_user_type(UserType::Anon)
174    }
175}