systemprompt_security/auth/
validation.rs1use axum::http::HeaderMap;
2use systemprompt_identifiers::{Actor, ContextId, SessionId, UserId};
3use systemprompt_models::auth::{
4 JwtAudience, JwtClaims, MAX_ACT_CHAIN_DEPTH, Permission, UserType,
5};
6use systemprompt_models::execution::context::RequestContext;
7
8use crate::error::{AuthError, AuthResult};
9use crate::extraction::HeaderExtractor;
10use crate::keys::authority;
11use crate::session::ValidatedSessionClaims;
12
13const ANONYMOUS_SESSION_ID: &str = "anonymous";
14const BEARER_PREFIX: &str = "Bearer ";
15
16pub(super) const JWT_LEEWAY_SECONDS: u64 = 30;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum AuthMode {
23 Required,
24 Optional,
25}
26
27#[derive(Debug)]
28pub struct AuthValidationService {
29 issuer: String,
30 audiences: Vec<JwtAudience>,
31}
32
33impl AuthValidationService {
34 #[must_use]
35 pub const fn new(issuer: String, audiences: Vec<JwtAudience>) -> Self {
36 Self { issuer, audiences }
37 }
38
39 pub fn validate_request(
40 &self,
41 headers: &HeaderMap,
42 mode: AuthMode,
43 ) -> AuthResult<RequestContext> {
44 match mode {
45 AuthMode::Required => self.validate_and_fail_fast(headers),
46 AuthMode::Optional => Ok(self.try_validate_or_anonymous(headers)),
47 }
48 }
49
50 fn validate_and_fail_fast(&self, headers: &HeaderMap) -> AuthResult<RequestContext> {
51 let token = Self::extract_token(headers).ok_or(AuthError::MissingAuthorization)?;
52
53 let claims = self.validate_token(token)?;
54 Ok(Self::create_context_from_claims(&claims, token, headers))
55 }
56
57 fn try_validate_or_anonymous(&self, headers: &HeaderMap) -> RequestContext {
58 Self::extract_token(headers).map_or_else(
59 || Self::create_anonymous_context(headers),
60 |token| {
61 self.validate_token(token)
62 .map_err(|e| {
63 tracing::debug!(error = %e, "Token validation failed, falling back to anonymous");
64 e
65 })
66 .map_or_else(
67 |_| Self::create_anonymous_context(headers),
68 |claims| Self::create_context_from_claims(&claims, token, headers),
69 )
70 },
71 )
72 }
73
74 fn extract_token(headers: &HeaderMap) -> Option<&str> {
75 headers
76 .get("authorization")
77 .and_then(|h| {
78 h.to_str()
79 .map_err(|e| {
80 tracing::debug!(error = %e, "Authorization header contains non-ASCII characters");
81 e
82 })
83 .ok()
84 })
85 .and_then(|s| s.strip_prefix(BEARER_PREFIX))
86 }
87
88 fn validate_token(&self, token: &str) -> AuthResult<ValidatedSessionClaims> {
89 use jsonwebtoken::{Algorithm, Validation, decode, decode_header};
90
91 let header = decode_header(token).map_err(AuthError::InvalidToken)?;
92 if header.alg != Algorithm::RS256 {
93 return Err(AuthError::UnsupportedAlgorithm);
94 }
95 let kid = header.kid.as_deref().ok_or(AuthError::MissingKid)?;
96 let key = authority::decoding_key_for_kid(kid)
97 .map_err(|e| AuthError::KeyLookup(e.to_string()))?
98 .ok_or_else(|| AuthError::UnknownKid(kid.to_string()))?;
99
100 let mut validation = Validation::new(Algorithm::RS256);
101 validation.leeway = JWT_LEEWAY_SECONDS;
102 validation.validate_nbf = true;
103
104 validation.set_issuer(&[&self.issuer]);
105
106 let audience_strs: Vec<&str> = self.audiences.iter().map(JwtAudience::as_str).collect();
107 validation.set_audience(&audience_strs);
108
109 let token_data =
110 decode::<JwtClaims>(token, key, &validation).map_err(AuthError::InvalidToken)?;
111
112 let claims = token_data.claims;
113
114 if let Some(ref act) = claims.act {
115 let depth = act.depth();
116 if depth > MAX_ACT_CHAIN_DEPTH {
117 return Err(AuthError::ActChainTooDeep {
118 depth,
119 max: MAX_ACT_CHAIN_DEPTH,
120 });
121 }
122 }
123
124 let user_type = if claims.scope.contains(&Permission::Admin) {
125 UserType::Admin
126 } else {
127 claims.user_type
128 };
129
130 Ok(ValidatedSessionClaims {
131 user_id: UserId::new(claims.sub),
132 session_id: claims
133 .session_id
134 .map(SessionId::new)
135 .ok_or(AuthError::MissingSessionId)?,
136 user_type,
137 jti: claims.jti,
138 exp: claims.exp,
139 })
140 }
141
142 fn create_context_from_claims(
143 claims: &ValidatedSessionClaims,
144 token: &str,
145 headers: &HeaderMap,
146 ) -> RequestContext {
147 let session_id = claims.session_id.clone();
148 let user_id = claims.user_id.clone();
149
150 RequestContext::new(
151 session_id,
152 HeaderExtractor::extract_trace_id(headers),
153 HeaderExtractor::extract_context_id(headers).unwrap_or_else(ContextId::generate),
154 HeaderExtractor::extract_agent_name(headers),
155 )
156 .with_actor(Actor::user(user_id))
157 .with_auth_token(token)
158 .with_user_type(claims.user_type)
159 .with_jti(claims.jti.clone())
160 .with_token_exp(claims.exp)
161 }
162
163 fn create_anonymous_context(headers: &HeaderMap) -> RequestContext {
164 RequestContext::new(
165 SessionId::new(ANONYMOUS_SESSION_ID.to_string()),
166 HeaderExtractor::extract_trace_id(headers),
167 HeaderExtractor::extract_context_id(headers).unwrap_or_else(ContextId::generate),
168 HeaderExtractor::extract_agent_name(headers),
169 )
170 .with_actor(Actor::anonymous(
171 systemprompt_identifiers::bootstrap::anonymous(),
172 ))
173 .with_user_type(UserType::Anon)
174 }
175}