systemprompt_security/auth/
validation.rs1use anyhow::{Result, anyhow};
2use axum::http::HeaderMap;
3use systemprompt_identifiers::{AgentName, ContextId, SessionId, TraceId, UserId};
4use systemprompt_models::auth::{JwtAudience, JwtClaims, Permission, UserType};
5use systemprompt_models::execution::context::RequestContext;
6
7use crate::extraction::HeaderExtractor;
8use crate::session::ValidatedSessionClaims;
9
10const ANONYMOUS_SESSION_ID: &str = "anonymous";
11const TEST_SESSION_ID: &str = "test";
12const TEST_TRACE_ID: &str = "test-trace";
13const TEST_CONTEXT_ID: &str = "test-context";
14const TEST_AGENT_NAME: &str = "test-agent";
15const TEST_USER_ID: &str = "test-user";
16const BEARER_PREFIX: &str = "Bearer ";
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum AuthMode {
20 Required,
21 Optional,
22 Disabled,
23}
24
25#[derive(Debug)]
26pub struct AuthValidationService {
27 secret: String,
28 issuer: String,
29 audiences: Vec<JwtAudience>,
30}
31
32impl AuthValidationService {
33 pub const fn new(secret: String, issuer: String, audiences: Vec<JwtAudience>) -> Self {
34 Self {
35 secret,
36 issuer,
37 audiences,
38 }
39 }
40
41 pub fn validate_request(&self, headers: &HeaderMap, mode: AuthMode) -> Result<RequestContext> {
42 match mode {
43 AuthMode::Required => self.validate_and_fail_fast(headers),
44 AuthMode::Optional => Ok(self.try_validate_or_anonymous(headers)),
45 AuthMode::Disabled => Ok(Self::create_test_context()),
46 }
47 }
48
49 fn validate_and_fail_fast(&self, headers: &HeaderMap) -> Result<RequestContext> {
50 let token =
51 Self::extract_token(headers).ok_or_else(|| anyhow!("Missing authorization header"))?;
52
53 let claims = self.validate_token(token)?;
54 Ok(Self::create_context_from_claims(&claims, token, headers))
55 }
56
57 fn try_validate_or_anonymous(&self, headers: &HeaderMap) -> RequestContext {
58 Self::extract_token(headers).map_or_else(
59 || Self::create_anonymous_context(headers),
60 |token| {
61 self.validate_token(token)
62 .map_err(|e| {
63 tracing::debug!(error = %e, "Token validation failed, falling back to anonymous");
64 e
65 })
66 .map_or_else(
67 |_| Self::create_anonymous_context(headers),
68 |claims| Self::create_context_from_claims(&claims, token, headers),
69 )
70 },
71 )
72 }
73
74 fn extract_token(headers: &HeaderMap) -> Option<&str> {
75 headers
76 .get("authorization")
77 .and_then(|h| {
78 h.to_str()
79 .map_err(|e| {
80 tracing::debug!(error = %e, "Authorization header contains non-ASCII characters");
81 e
82 })
83 .ok()
84 })
85 .and_then(|s| s.strip_prefix(BEARER_PREFIX))
86 }
87
88 fn validate_token(&self, token: &str) -> Result<ValidatedSessionClaims> {
89 use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
90
91 let mut validation = Validation::new(Algorithm::HS256);
92
93 validation.set_issuer(&[&self.issuer]);
94
95 let audience_strs: Vec<&str> = self.audiences.iter().map(JwtAudience::as_str).collect();
96 validation.set_audience(&audience_strs);
97
98 let token_data = decode::<JwtClaims>(
99 token,
100 &DecodingKey::from_secret(self.secret.as_bytes()),
101 &validation,
102 )
103 .map_err(|e| anyhow!("Invalid JWT token: {e}"))?;
104
105 let claims = token_data.claims;
106
107 let user_type = if claims.scope.contains(&Permission::Admin) {
108 UserType::Admin
109 } else {
110 claims.user_type
111 };
112
113 Ok(ValidatedSessionClaims {
114 user_id: UserId::new(claims.sub),
115 session_id: claims
116 .session_id
117 .map(SessionId::new)
118 .ok_or_else(|| anyhow!("Missing session_id in token"))?,
119 user_type,
120 })
121 }
122
123 fn create_context_from_claims(
124 claims: &ValidatedSessionClaims,
125 token: &str,
126 headers: &HeaderMap,
127 ) -> RequestContext {
128 let session_id = claims.session_id.clone();
129 let user_id = claims.user_id.clone();
130
131 RequestContext::new(
132 session_id,
133 HeaderExtractor::extract_trace_id(headers),
134 HeaderExtractor::extract_context_id(headers),
135 HeaderExtractor::extract_agent_name(headers),
136 )
137 .with_user_id(user_id)
138 .with_auth_token(token)
139 .with_user_type(claims.user_type)
140 }
141
142 fn create_anonymous_context(headers: &HeaderMap) -> RequestContext {
143 RequestContext::new(
144 SessionId::new(ANONYMOUS_SESSION_ID.to_string()),
145 HeaderExtractor::extract_trace_id(headers),
146 HeaderExtractor::extract_context_id(headers),
147 HeaderExtractor::extract_agent_name(headers),
148 )
149 .with_user_id(UserId::anonymous())
150 .with_user_type(UserType::Anon)
151 }
152
153 fn create_test_context() -> RequestContext {
154 RequestContext::new(
155 SessionId::new(TEST_SESSION_ID.to_string()),
156 TraceId::new(TEST_TRACE_ID.to_string()),
157 ContextId::new(TEST_CONTEXT_ID.to_string()),
158 AgentName::new(TEST_AGENT_NAME.to_string()),
159 )
160 .with_user_id(UserId::new(TEST_USER_ID.to_string()))
161 .with_user_type(UserType::User)
162 }
163}