Skip to main content

systemprompt_security/auth/
validation.rs

1use axum::http::HeaderMap;
2use systemprompt_identifiers::{Actor, ContextId, SessionId, UserId};
3use systemprompt_models::auth::{JwtAudience, MAX_ACT_CHAIN_DEPTH, Permission, UserType};
4use systemprompt_models::execution::context::RequestContext;
5
6use crate::error::{AuthError, AuthResult};
7use crate::extraction::{HeaderExtractor, TokenExtractor};
8use crate::jwt::{ValidationPolicy, decode_rs256_claims};
9use crate::session::ValidatedSessionClaims;
10
11#[derive(Debug)]
12pub struct AuthValidationService {
13    issuer: String,
14    audiences: Vec<JwtAudience>,
15}
16
17impl AuthValidationService {
18    #[must_use]
19    pub const fn new(issuer: String, audiences: Vec<JwtAudience>) -> Self {
20        Self { issuer, audiences }
21    }
22
23    pub fn validate_request(&self, headers: &HeaderMap) -> AuthResult<RequestContext> {
24        let token = TokenExtractor::extract_from_authorization(headers)
25            .map_err(|_e| AuthError::MissingAuthorization)?;
26        let claims = self.validate_token(&token)?;
27        Ok(Self::create_context_from_claims(&claims, &token, headers))
28    }
29
30    fn validate_token(&self, token: &str) -> AuthResult<ValidatedSessionClaims> {
31        let policy = ValidationPolicy::issuer_scoped(&self.issuer, &self.audiences);
32        let claims = decode_rs256_claims(token, &policy)?;
33
34        if let Some(ref act) = claims.act {
35            let depth = act.depth();
36            if depth > MAX_ACT_CHAIN_DEPTH {
37                return Err(AuthError::ActChainTooDeep {
38                    depth,
39                    max: MAX_ACT_CHAIN_DEPTH,
40                });
41            }
42        }
43
44        let user_type = if claims.scope.contains(&Permission::Admin) {
45            UserType::Admin
46        } else {
47            claims.user_type
48        };
49
50        Ok(ValidatedSessionClaims {
51            user_id: UserId::new(claims.sub),
52            session_id: claims
53                .session_id
54                .map(SessionId::new)
55                .ok_or(AuthError::MissingSessionId)?,
56            user_type,
57            jti: claims.jti,
58            exp: claims.exp,
59        })
60    }
61
62    fn create_context_from_claims(
63        claims: &ValidatedSessionClaims,
64        token: &str,
65        headers: &HeaderMap,
66    ) -> RequestContext {
67        let session_id = claims.session_id.clone();
68        let user_id = claims.user_id.clone();
69
70        RequestContext::new(
71            session_id,
72            HeaderExtractor::extract_trace_id(headers),
73            HeaderExtractor::extract_context_id(headers).unwrap_or_else(ContextId::generate),
74            HeaderExtractor::extract_agent_name(headers),
75        )
76        .with_actor(Actor::user(user_id))
77        .with_auth_token(token)
78        .with_user_type(claims.user_type)
79        .with_jti(claims.jti.clone())
80        .with_token_exp(claims.exp)
81    }
82}