systemprompt_security/extraction/
token.rs1use axum::http::HeaderMap;
2use std::error::Error;
3use std::fmt;
4
5const DEFAULT_COOKIE_NAME: &str = "access_token";
6const DEFAULT_MCP_HEADER_NAME: &str = "x-mcp-proxy-auth";
7const BEARER_PREFIX: &str = "Bearer ";
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum ExtractionMethod {
11 AuthorizationHeader,
12 McpProxyHeader,
13 Cookie,
14}
15
16impl fmt::Display for ExtractionMethod {
17 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
18 match self {
19 Self::AuthorizationHeader => write!(f, "Authorization header"),
20 Self::McpProxyHeader => write!(f, "MCP proxy header"),
21 Self::Cookie => write!(f, "Cookie"),
22 }
23 }
24}
25
26#[derive(Debug, Clone)]
27pub struct TokenExtractor {
28 fallback_chain: Vec<ExtractionMethod>,
29 cookie_name: String,
30 mcp_header_name: String,
31}
32
33impl TokenExtractor {
34 pub fn new(fallback_chain: Vec<ExtractionMethod>) -> Self {
35 Self {
36 fallback_chain,
37 cookie_name: DEFAULT_COOKIE_NAME.to_string(),
38 mcp_header_name: DEFAULT_MCP_HEADER_NAME.to_string(),
39 }
40 }
41
42 pub fn with_cookie_name(mut self, name: String) -> Self {
43 self.cookie_name = name;
44 self
45 }
46
47 pub fn with_mcp_header_name(mut self, name: String) -> Self {
48 self.mcp_header_name = name;
49 self
50 }
51
52 pub fn standard() -> Self {
53 Self::new(vec![
54 ExtractionMethod::AuthorizationHeader,
55 ExtractionMethod::McpProxyHeader,
56 ExtractionMethod::Cookie,
57 ])
58 }
59
60 pub fn browser_only() -> Self {
61 Self::new(vec![
62 ExtractionMethod::AuthorizationHeader,
63 ExtractionMethod::Cookie,
64 ])
65 }
66
67 pub fn api_only() -> Self {
68 Self::new(vec![ExtractionMethod::AuthorizationHeader])
69 }
70
71 pub fn chain(&self) -> &[ExtractionMethod] {
72 &self.fallback_chain
73 }
74
75 pub fn extract(&self, headers: &HeaderMap) -> Result<String, TokenExtractionError> {
76 for method in &self.fallback_chain {
77 match method {
78 ExtractionMethod::AuthorizationHeader => {
79 if let Ok(token) = Self::extract_from_authorization(headers) {
80 return Ok(token);
81 }
82 },
83 ExtractionMethod::McpProxyHeader => {
84 if let Ok(token) = self.extract_from_mcp_proxy(headers) {
85 return Ok(token);
86 }
87 },
88 ExtractionMethod::Cookie => {
89 if let Ok(token) = self.extract_from_cookie(headers) {
90 return Ok(token);
91 }
92 },
93 }
94 }
95
96 Err(TokenExtractionError::NoTokenFound)
97 }
98
99 pub fn extract_from_authorization(headers: &HeaderMap) -> Result<String, TokenExtractionError> {
100 let auth_headers = headers.get_all("authorization");
101
102 if auth_headers.iter().count() == 0 {
103 return Err(TokenExtractionError::MissingAuthorizationHeader);
104 }
105
106 for auth_value in &auth_headers {
107 let Ok(auth_header) = auth_value.to_str().map_err(|e| {
108 tracing::debug!(error = %e, "Authorization header contains non-ASCII characters");
109 e
110 }) else {
111 continue;
112 };
113
114 if let Some(token) = auth_header.strip_prefix(BEARER_PREFIX) {
115 let token = token.trim();
116 if !token.is_empty() {
117 return Ok(token.to_string());
118 }
119 }
120 }
121
122 Err(TokenExtractionError::InvalidAuthorizationFormat)
123 }
124
125 pub fn extract_from_mcp_proxy(
126 &self,
127 headers: &HeaderMap,
128 ) -> Result<String, TokenExtractionError> {
129 let header_value = headers
130 .get(&self.mcp_header_name)
131 .ok_or(TokenExtractionError::MissingMcpProxyHeader)?;
132
133 let auth_header = header_value
134 .to_str()
135 .map_err(|_| TokenExtractionError::InvalidMcpProxyFormat)?;
136
137 auth_header
138 .strip_prefix(BEARER_PREFIX)
139 .ok_or(TokenExtractionError::InvalidMcpProxyFormat)
140 .map(ToString::to_string)
141 }
142
143 pub fn extract_from_cookie(&self, headers: &HeaderMap) -> Result<String, TokenExtractionError> {
144 let cookie_header = headers
145 .get("cookie")
146 .ok_or(TokenExtractionError::MissingCookie)?
147 .to_str()
148 .map_err(|_| TokenExtractionError::InvalidCookieFormat)?;
149
150 for cookie in cookie_header.split(';') {
151 let cookie = cookie.trim();
152 let cookie_prefix = format!("{}=", self.cookie_name);
153 if let Some(value) = cookie.strip_prefix(&cookie_prefix) {
154 if !value.is_empty() {
155 return Ok(value.to_string());
156 }
157 }
158 }
159
160 Err(TokenExtractionError::TokenNotFoundInCookie)
161 }
162}
163
164#[derive(Debug, Clone, Copy, PartialEq, Eq)]
165pub enum TokenExtractionError {
166 NoTokenFound,
167 MissingAuthorizationHeader,
168 InvalidAuthorizationFormat,
169 MissingMcpProxyHeader,
170 InvalidMcpProxyFormat,
171 MissingCookie,
172 InvalidCookieFormat,
173 TokenNotFoundInCookie,
174}
175
176impl fmt::Display for TokenExtractionError {
177 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
178 match self {
179 Self::NoTokenFound => write!(f, "No token found in request"),
180 Self::MissingAuthorizationHeader => {
181 write!(f, "Missing Authorization header")
182 },
183 Self::InvalidAuthorizationFormat => {
184 write!(
185 f,
186 "Invalid Authorization header format (expected 'Bearer <token>')"
187 )
188 },
189 Self::MissingMcpProxyHeader => {
190 write!(f, "Missing MCP proxy authorization header")
191 },
192 Self::InvalidMcpProxyFormat => {
193 write!(
194 f,
195 "Invalid MCP proxy header format (expected 'Bearer <token>')"
196 )
197 },
198 Self::MissingCookie => write!(f, "Missing cookie header"),
199 Self::InvalidCookieFormat => write!(f, "Invalid cookie format"),
200 Self::TokenNotFoundInCookie => write!(f, "Token not found in cookies"),
201 }
202 }
203}
204
205impl Error for TokenExtractionError {}