systemprompt_security/extraction/
token.rs1use axum::http::HeaderMap;
2use std::error::Error;
3use std::fmt;
4
5use super::cookie::{CookieExtractionError, CookieExtractor};
6
7const DEFAULT_COOKIE_NAME: &str = "access_token";
8const DEFAULT_MCP_HEADER_NAME: &str = "x-mcp-proxy-auth";
9const BEARER_PREFIX: &str = "Bearer ";
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum ExtractionMethod {
13 AuthorizationHeader,
14 McpProxyHeader,
15 Cookie,
16}
17
18impl fmt::Display for ExtractionMethod {
19 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20 match self {
21 Self::AuthorizationHeader => write!(f, "Authorization header"),
22 Self::McpProxyHeader => write!(f, "MCP proxy header"),
23 Self::Cookie => write!(f, "Cookie"),
24 }
25 }
26}
27
28#[derive(Debug, Clone)]
29pub struct TokenExtractor {
30 fallback_chain: Vec<ExtractionMethod>,
31 cookie_name: String,
32 mcp_header_name: String,
33}
34
35impl TokenExtractor {
36 #[must_use]
37 pub fn new(fallback_chain: Vec<ExtractionMethod>) -> Self {
38 Self {
39 fallback_chain,
40 cookie_name: DEFAULT_COOKIE_NAME.to_string(),
41 mcp_header_name: DEFAULT_MCP_HEADER_NAME.to_string(),
42 }
43 }
44
45 #[must_use]
46 pub fn with_cookie_name(mut self, name: String) -> Self {
47 self.cookie_name = name;
48 self
49 }
50
51 #[must_use]
52 pub fn with_mcp_header_name(mut self, name: String) -> Self {
53 self.mcp_header_name = name;
54 self
55 }
56
57 #[must_use]
58 pub fn standard() -> Self {
59 Self::new(vec![
60 ExtractionMethod::AuthorizationHeader,
61 ExtractionMethod::McpProxyHeader,
62 ExtractionMethod::Cookie,
63 ])
64 }
65
66 #[must_use]
67 pub fn browser_only() -> Self {
68 Self::new(vec![
69 ExtractionMethod::AuthorizationHeader,
70 ExtractionMethod::Cookie,
71 ])
72 }
73
74 #[must_use]
75 pub fn api_only() -> Self {
76 Self::new(vec![ExtractionMethod::AuthorizationHeader])
77 }
78
79 #[must_use]
80 pub fn chain(&self) -> &[ExtractionMethod] {
81 &self.fallback_chain
82 }
83
84 pub fn extract(&self, headers: &HeaderMap) -> Result<String, TokenExtractionError> {
85 for method in &self.fallback_chain {
86 match method {
87 ExtractionMethod::AuthorizationHeader => {
88 if let Ok(token) = Self::extract_from_authorization(headers) {
89 return Ok(token);
90 }
91 },
92 ExtractionMethod::McpProxyHeader => {
93 if let Ok(token) = self.extract_from_mcp_proxy(headers) {
94 return Ok(token);
95 }
96 },
97 ExtractionMethod::Cookie => {
98 if let Ok(token) = self.extract_from_cookie(headers) {
99 return Ok(token);
100 }
101 },
102 }
103 }
104
105 Err(TokenExtractionError::NoTokenFound)
106 }
107
108 pub fn extract_from_authorization(headers: &HeaderMap) -> Result<String, TokenExtractionError> {
109 let auth_headers = headers.get_all("authorization");
110
111 if auth_headers.iter().count() == 0 {
112 return Err(TokenExtractionError::MissingAuthorizationHeader);
113 }
114
115 for auth_value in &auth_headers {
116 let Ok(auth_header) = auth_value.to_str().map_err(|e| {
117 tracing::debug!(error = %e, "Authorization header contains non-ASCII characters");
118 e
119 }) else {
120 continue;
121 };
122
123 if let Some(token) = auth_header.strip_prefix(BEARER_PREFIX) {
124 let token = token.trim();
125 if !token.is_empty() {
126 return Ok(token.to_string());
127 }
128 }
129 }
130
131 Err(TokenExtractionError::InvalidAuthorizationFormat)
132 }
133
134 pub fn extract_from_mcp_proxy(
135 &self,
136 headers: &HeaderMap,
137 ) -> Result<String, TokenExtractionError> {
138 let header_value = headers
139 .get(&self.mcp_header_name)
140 .ok_or(TokenExtractionError::MissingMcpProxyHeader)?;
141
142 let auth_header = header_value
143 .to_str()
144 .map_err(|_| TokenExtractionError::InvalidMcpProxyFormat)?;
145
146 auth_header
147 .strip_prefix(BEARER_PREFIX)
148 .ok_or(TokenExtractionError::InvalidMcpProxyFormat)
149 .map(ToString::to_string)
150 }
151
152 pub fn extract_from_cookie(&self, headers: &HeaderMap) -> Result<String, TokenExtractionError> {
153 CookieExtractor::new(self.cookie_name.clone())
154 .extract(headers)
155 .map_err(|err| match err {
156 CookieExtractionError::MissingCookie => TokenExtractionError::MissingCookie,
157 CookieExtractionError::InvalidCookieFormat => {
158 TokenExtractionError::InvalidCookieFormat
159 },
160 CookieExtractionError::TokenNotFoundInCookie => {
161 TokenExtractionError::TokenNotFoundInCookie
162 },
163 })
164 }
165}
166
167#[derive(Debug, Clone, Copy, PartialEq, Eq)]
168pub enum TokenExtractionError {
169 NoTokenFound,
170 MissingAuthorizationHeader,
171 InvalidAuthorizationFormat,
172 MissingMcpProxyHeader,
173 InvalidMcpProxyFormat,
174 MissingCookie,
175 InvalidCookieFormat,
176 TokenNotFoundInCookie,
177}
178
179impl fmt::Display for TokenExtractionError {
180 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
181 match self {
182 Self::NoTokenFound => write!(f, "No token found in request"),
183 Self::MissingAuthorizationHeader => {
184 write!(f, "Missing Authorization header")
185 },
186 Self::InvalidAuthorizationFormat => {
187 write!(
188 f,
189 "Invalid Authorization header format (expected 'Bearer <token>')"
190 )
191 },
192 Self::MissingMcpProxyHeader => {
193 write!(f, "Missing MCP proxy authorization header")
194 },
195 Self::InvalidMcpProxyFormat => {
196 write!(
197 f,
198 "Invalid MCP proxy header format (expected 'Bearer <token>')"
199 )
200 },
201 Self::MissingCookie => write!(f, "Missing cookie header"),
202 Self::InvalidCookieFormat => write!(f, "Invalid cookie format"),
203 Self::TokenNotFoundInCookie => write!(f, "Token not found in cookies"),
204 }
205 }
206}
207
208impl Error for TokenExtractionError {}