jwt_verify/integration/
extractor.rs1use std::collections::HashMap;
2use std::fmt;
3use std::sync::Arc;
4
5use crate::common::error::JwtError;
6
7pub trait Headers {
13 fn get(&self, name: &str) -> Option<&str>;
23}
24
25impl Headers for HashMap<String, String> {
30 fn get(&self, name: &str) -> Option<&str> {
31 self.get(name).map(|s| s.as_str())
32 }
33}
34
35pub trait TokenExtractor: Send + Sync {
41 fn extract_token(&self, headers: &dyn Headers) -> Result<Option<String>, JwtError>;
53
54 fn description(&self) -> &str;
56}
57
58#[derive(Clone)]
63pub struct HeaderTokenExtractor {
64 header_name: String,
66 token_prefix: Option<String>,
68 description: String,
70}
71
72impl fmt::Debug for HeaderTokenExtractor {
73 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74 f.debug_struct("HeaderTokenExtractor")
75 .field("header_name", &self.header_name)
76 .field("token_prefix", &self.token_prefix)
77 .finish()
78 }
79}
80
81impl HeaderTokenExtractor {
82 pub fn new(header_name: &str, token_prefix: Option<&str>) -> Self {
93 let prefix_str = token_prefix.unwrap_or("none");
94 let description = format!("HeaderTokenExtractor({}, {})", header_name, prefix_str);
95
96 Self {
97 header_name: header_name.to_string(),
98 token_prefix: token_prefix.map(|s| s.to_string()),
99 description,
100 }
101 }
102
103 pub fn bearer(header_name: &str) -> Self {
116 Self::new(header_name, Some("Bearer "))
117 }
118}
119
120impl TokenExtractor for HeaderTokenExtractor {
121 fn extract_token(&self, headers: &dyn Headers) -> Result<Option<String>, JwtError> {
122 let header_str = match headers.get(&self.header_name) {
124 Some(value) => value,
125 None => return Ok(None), };
127
128 if let Some(prefix) = &self.token_prefix {
130 if !header_str.starts_with(prefix) {
131 return Err(JwtError::InvalidToken(format!(
132 "{} header must start with '{}'",
133 self.header_name, prefix
134 )));
135 }
136
137 Ok(Some(header_str[prefix.len()..].to_string()))
139 } else {
140 Ok(Some(header_str.to_string()))
142 }
143 }
144
145 fn description(&self) -> &str {
146 &self.description
147 }
148}
149
150#[derive(Clone)]
152pub struct DebugTokenExtractor(pub Arc<dyn TokenExtractor>);
153
154impl fmt::Debug for DebugTokenExtractor {
155 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
156 write!(f, "TokenExtractor({})", self.0.description())
157 }
158}
159
160#[derive(Clone)]
165pub struct ChainedTokenExtractor {
166 extractors: Vec<DebugTokenExtractor>,
168 description: String,
170}
171
172impl fmt::Debug for ChainedTokenExtractor {
173 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
174 f.debug_struct("ChainedTokenExtractor")
175 .field("extractors", &self.extractors)
176 .finish()
177 }
178}
179
180impl ChainedTokenExtractor {
181 pub fn new(extractors: Vec<Arc<dyn TokenExtractor>>) -> Self {
191 let debug_extractors = extractors
192 .into_iter()
193 .map(DebugTokenExtractor)
194 .collect::<Vec<_>>();
195
196 let description = format!("ChainedTokenExtractor({})", debug_extractors.len());
197
198 Self {
199 extractors: debug_extractors,
200 description,
201 }
202 }
203
204 pub fn add_extractor(mut self, extractor: Arc<dyn TokenExtractor>) -> Self {
214 self.extractors.push(DebugTokenExtractor(extractor));
215 self.description = format!("ChainedTokenExtractor({})", self.extractors.len());
216 self
217 }
218}
219
220impl TokenExtractor for ChainedTokenExtractor {
221 fn extract_token(&self, headers: &dyn Headers) -> Result<Option<String>, JwtError> {
222 for extractor in &self.extractors {
224 match extractor.0.extract_token(headers)? {
225 Some(token) => return Ok(Some(token)),
226 None => continue,
227 }
228 }
229
230 Ok(None)
232 }
233
234 fn description(&self) -> &str {
235 &self.description
236 }
237}
238
239#[derive(Clone)]
244pub struct TokenExtractorConfig {
245 pub access_token_extractor: DebugTokenExtractor,
247 pub id_token_extractor: DebugTokenExtractor,
249 pub require_auth: bool,
251}
252
253impl fmt::Debug for TokenExtractorConfig {
254 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
255 f.debug_struct("TokenExtractorConfig")
256 .field("access_token_extractor", &self.access_token_extractor)
257 .field("id_token_extractor", &self.id_token_extractor)
258 .field("require_auth", &self.require_auth)
259 .finish()
260 }
261}
262
263impl Default for TokenExtractorConfig {
264 fn default() -> Self {
265 Self {
266 access_token_extractor: DebugTokenExtractor(Arc::new(HeaderTokenExtractor::bearer("Authorization"))),
268 id_token_extractor: DebugTokenExtractor(Arc::new(HeaderTokenExtractor::bearer("Authorization"))),
269 require_auth: true,
270 }
271 }
272}
273
274impl TokenExtractorConfig {
275 pub fn new(
287 access_token_extractor: Arc<dyn TokenExtractor>,
288 id_token_extractor: Arc<dyn TokenExtractor>,
289 require_auth: bool,
290 ) -> Self {
291 Self {
292 access_token_extractor: DebugTokenExtractor(access_token_extractor),
293 id_token_extractor: DebugTokenExtractor(id_token_extractor),
294 require_auth,
295 }
296 }
297
298 pub fn with_access_token_extractor(
308 mut self,
309 extractor: Arc<dyn TokenExtractor>,
310 ) -> Self {
311 self.access_token_extractor = DebugTokenExtractor(extractor);
312 self
313 }
314
315 pub fn with_id_token_extractor(
325 mut self,
326 extractor: Arc<dyn TokenExtractor>,
327 ) -> Self {
328 self.id_token_extractor = DebugTokenExtractor(extractor);
329 self
330 }
331
332 pub fn with_require_auth(mut self, require_auth: bool) -> Self {
342 self.require_auth = require_auth;
343 self
344 }
345
346 pub fn access_token_extractor(&self) -> &dyn TokenExtractor {
348 &*self.access_token_extractor.0
349 }
350
351 pub fn id_token_extractor(&self) -> &dyn TokenExtractor {
353 &*self.id_token_extractor.0
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360 use std::collections::HashMap;
361
362 #[test]
363 fn test_header_token_extractor_with_prefix() {
364 let extractor = HeaderTokenExtractor::bearer("Authorization");
365 let mut headers = HashMap::new();
366
367 headers.insert("Authorization".to_string(), "Bearer test-token".to_string());
369
370 let result = extractor.extract_token(&headers).unwrap();
371 assert_eq!(result, Some("test-token".to_string()));
372
373 headers.insert("Authorization".to_string(), "Basic test-token".to_string());
375
376 let result = extractor.extract_token(&headers);
377 assert!(result.is_err());
378
379 let empty_headers: HashMap<String, String> = HashMap::new();
381 let result = extractor.extract_token(&empty_headers).unwrap();
382 assert_eq!(result, None);
383 }
384
385 #[test]
386 fn test_header_token_extractor_without_prefix() {
387 let extractor = HeaderTokenExtractor::new("X-Token", None);
388 let mut headers = HashMap::new();
389
390 headers.insert("X-Token".to_string(), "test-token".to_string());
392
393 let result = extractor.extract_token(&headers).unwrap();
394 assert_eq!(result, Some("test-token".to_string()));
395
396 let empty_headers: HashMap<String, String> = HashMap::new();
398 let result = extractor.extract_token(&empty_headers).unwrap();
399 assert_eq!(result, None);
400 }
401
402 #[test]
403 fn test_chained_token_extractor() {
404 let extractor1: Arc<dyn TokenExtractor> =
405 Arc::new(HeaderTokenExtractor::bearer("Authorization"));
406 let extractor2: Arc<dyn TokenExtractor> =
407 Arc::new(HeaderTokenExtractor::new("X-Token", None));
408
409 let chained = ChainedTokenExtractor::new(vec![extractor1, extractor2]);
410
411 let mut headers = HashMap::new();
412
413 headers.insert("Authorization".to_string(), "Bearer test-token-1".to_string());
415
416 let result = chained.extract_token(&headers).unwrap();
417 assert_eq!(result, Some("test-token-1".to_string()));
418
419 let mut headers = HashMap::new();
421 headers.insert("X-Token".to_string(), "test-token-2".to_string());
422
423 let result = chained.extract_token(&headers).unwrap();
424 assert_eq!(result, Some("test-token-2".to_string()));
425
426 let empty_headers: HashMap<String, String> = HashMap::new();
428 let result = chained.extract_token(&empty_headers).unwrap();
429 assert_eq!(result, None);
430 }
431}