jwt_verify/integration/
extractor.rs

1use std::collections::HashMap;
2use std::fmt;
3use std::sync::Arc;
4
5use crate::common::error::JwtError;
6
7/// A generic representation of HTTP headers
8///
9/// This trait provides a framework-agnostic way to access HTTP headers.
10/// Web framework-specific implementations can be provided to adapt
11/// the framework's header representation to this interface.
12pub trait Headers {
13    /// Get a header value by name
14    ///
15    /// # Parameters
16    ///
17    /// * `name` - The name of the header to get
18    ///
19    /// # Returns
20    ///
21    /// Returns an `Option<&str>` with the header value if found, or `None` if not found.
22    fn get(&self, name: &str) -> Option<&str>;
23}
24
25/// Implementation of Headers for a HashMap
26///
27/// This is a simple implementation that can be used for testing or
28/// for frameworks that don't have a specialized header type.
29impl Headers for HashMap<String, String> {
30    fn get(&self, name: &str) -> Option<&str> {
31        self.get(name).map(|s| s.as_str())
32    }
33}
34
35/// Trait for extracting JWT tokens from HTTP headers
36///
37/// This trait defines the interface for extracting JWT tokens from HTTP headers.
38/// Implementations can extract tokens from different headers, with different prefixes,
39/// or using custom logic.
40pub trait TokenExtractor: Send + Sync {
41    /// Extract a token from HTTP headers
42    ///
43    /// # Parameters
44    ///
45    /// * `headers` - The HTTP headers to extract the token from
46    ///
47    /// # Returns
48    ///
49    /// Returns a `Result` containing an `Option<String>` with the token if found,
50    /// or a `JwtError` if there was an error extracting the token.
51    /// Returns `Ok(None)` if no token was found.
52    fn extract_token(&self, headers: &dyn Headers) -> Result<Option<String>, JwtError>;
53    
54    /// Get a description of this extractor for debugging
55    fn description(&self) -> &str;
56}
57
58/// Extract tokens from HTTP headers
59///
60/// This struct extracts tokens from HTTP headers with configurable header name and prefix.
61/// It can be used to extract tokens from different headers with different prefixes.
62#[derive(Clone)]
63pub struct HeaderTokenExtractor {
64    /// The name of the header to extract the token from
65    header_name: String,
66    /// Optional prefix to remove from the token (e.g., "Bearer ")
67    token_prefix: Option<String>,
68    /// Description for debugging
69    description: String,
70}
71
72impl fmt::Debug for HeaderTokenExtractor {
73    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74        f.debug_struct("HeaderTokenExtractor")
75            .field("header_name", &self.header_name)
76            .field("token_prefix", &self.token_prefix)
77            .finish()
78    }
79}
80
81impl HeaderTokenExtractor {
82    /// Create a new HeaderTokenExtractor with the given header name and token prefix
83    ///
84    /// # Parameters
85    ///
86    /// * `header_name` - The name of the header to extract the token from
87    /// * `token_prefix` - Optional prefix to remove from the token (e.g., "Bearer ")
88    ///
89    /// # Returns
90    ///
91    /// Returns a new `HeaderTokenExtractor` instance.
92    pub fn new(header_name: &str, token_prefix: Option<&str>) -> Self {
93        let prefix_str = token_prefix.unwrap_or("none");
94        let description = format!("HeaderTokenExtractor({}, {})", header_name, prefix_str);
95        
96        Self {
97            header_name: header_name.to_string(),
98            token_prefix: token_prefix.map(|s| s.to_string()),
99            description,
100        }
101    }
102
103    /// Create a new HeaderTokenExtractor for Bearer tokens
104    ///
105    /// This is a convenience constructor that creates a `HeaderTokenExtractor` that
106    /// extracts Bearer tokens from the specified header.
107    ///
108    /// # Parameters
109    ///
110    /// * `header_name` - The name of the header to extract the token from
111    ///
112    /// # Returns
113    ///
114    /// Returns a new `HeaderTokenExtractor` instance configured for Bearer tokens.
115    pub fn bearer(header_name: &str) -> Self {
116        Self::new(header_name, Some("Bearer "))
117    }
118}
119
120impl TokenExtractor for HeaderTokenExtractor {
121    fn extract_token(&self, headers: &dyn Headers) -> Result<Option<String>, JwtError> {
122        // Check if the header exists
123        let header_str = match headers.get(&self.header_name) {
124            Some(value) => value,
125            None => return Ok(None), // Header not found
126        };
127
128        // If a token prefix is specified, check if the header value starts with it
129        if let Some(prefix) = &self.token_prefix {
130            if !header_str.starts_with(prefix) {
131                return Err(JwtError::InvalidToken(format!(
132                    "{} header must start with '{}'",
133                    self.header_name, prefix
134                )));
135            }
136
137            // Remove the prefix and return the token
138            Ok(Some(header_str[prefix.len()..].to_string()))
139        } else {
140            // No prefix, return the entire header value
141            Ok(Some(header_str.to_string()))
142        }
143    }
144    
145    fn description(&self) -> &str {
146        &self.description
147    }
148}
149
150/// A wrapper for TokenExtractor that implements Debug
151#[derive(Clone)]
152pub struct DebugTokenExtractor(pub Arc<dyn TokenExtractor>);
153
154impl fmt::Debug for DebugTokenExtractor {
155    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
156        write!(f, "TokenExtractor({})", self.0.description())
157    }
158}
159
160/// Extract tokens from multiple sources
161///
162/// This struct tries multiple token extractors in sequence until one returns a token.
163/// It can be used to extract tokens from different headers or using different methods.
164#[derive(Clone)]
165pub struct ChainedTokenExtractor {
166    /// The extractors to try in sequence
167    extractors: Vec<DebugTokenExtractor>,
168    /// Description for debugging
169    description: String,
170}
171
172impl fmt::Debug for ChainedTokenExtractor {
173    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
174        f.debug_struct("ChainedTokenExtractor")
175            .field("extractors", &self.extractors)
176            .finish()
177    }
178}
179
180impl ChainedTokenExtractor {
181    /// Create a new ChainedTokenExtractor with the given extractors
182    ///
183    /// # Parameters
184    ///
185    /// * `extractors` - The extractors to try in sequence
186    ///
187    /// # Returns
188    ///
189    /// Returns a new `ChainedTokenExtractor` instance.
190    pub fn new(extractors: Vec<Arc<dyn TokenExtractor>>) -> Self {
191        let debug_extractors = extractors
192            .into_iter()
193            .map(DebugTokenExtractor)
194            .collect::<Vec<_>>();
195        
196        let description = format!("ChainedTokenExtractor({})", debug_extractors.len());
197        
198        Self {
199            extractors: debug_extractors,
200            description,
201        }
202    }
203
204    /// Add an extractor to the chain
205    ///
206    /// # Parameters
207    ///
208    /// * `extractor` - The extractor to add
209    ///
210    /// # Returns
211    ///
212    /// Returns a new `ChainedTokenExtractor` instance with the added extractor.
213    pub fn add_extractor(mut self, extractor: Arc<dyn TokenExtractor>) -> Self {
214        self.extractors.push(DebugTokenExtractor(extractor));
215        self.description = format!("ChainedTokenExtractor({})", self.extractors.len());
216        self
217    }
218}
219
220impl TokenExtractor for ChainedTokenExtractor {
221    fn extract_token(&self, headers: &dyn Headers) -> Result<Option<String>, JwtError> {
222        // Try each extractor in sequence
223        for extractor in &self.extractors {
224            match extractor.0.extract_token(headers)? {
225                Some(token) => return Ok(Some(token)),
226                None => continue,
227            }
228        }
229
230        // No token found
231        Ok(None)
232    }
233    
234    fn description(&self) -> &str {
235        &self.description
236    }
237}
238
239/// Configuration for token extraction
240///
241/// This struct holds the configuration for token extraction, including
242/// extractors for different token types.
243#[derive(Clone)]
244pub struct TokenExtractorConfig {
245    /// Extractor for access tokens
246    pub access_token_extractor: DebugTokenExtractor,
247    /// Extractor for ID tokens
248    pub id_token_extractor: DebugTokenExtractor,
249    /// Whether authentication is required (error if no token found)
250    pub require_auth: bool,
251}
252
253impl fmt::Debug for TokenExtractorConfig {
254    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
255        f.debug_struct("TokenExtractorConfig")
256            .field("access_token_extractor", &self.access_token_extractor)
257            .field("id_token_extractor", &self.id_token_extractor)
258            .field("require_auth", &self.require_auth)
259            .finish()
260    }
261}
262
263impl Default for TokenExtractorConfig {
264    fn default() -> Self {
265        Self {
266            // Default to extracting Bearer tokens from the Authorization header
267            access_token_extractor: DebugTokenExtractor(Arc::new(HeaderTokenExtractor::bearer("Authorization"))),
268            id_token_extractor: DebugTokenExtractor(Arc::new(HeaderTokenExtractor::bearer("Authorization"))),
269            require_auth: true,
270        }
271    }
272}
273
274impl TokenExtractorConfig {
275    /// Create a new TokenExtractorConfig with the given extractors
276    ///
277    /// # Parameters
278    ///
279    /// * `access_token_extractor` - Extractor for access tokens
280    /// * `id_token_extractor` - Extractor for ID tokens
281    /// * `require_auth` - Whether authentication is required
282    ///
283    /// # Returns
284    ///
285    /// Returns a new `TokenExtractorConfig` instance.
286    pub fn new(
287        access_token_extractor: Arc<dyn TokenExtractor>,
288        id_token_extractor: Arc<dyn TokenExtractor>,
289        require_auth: bool,
290    ) -> Self {
291        Self {
292            access_token_extractor: DebugTokenExtractor(access_token_extractor),
293            id_token_extractor: DebugTokenExtractor(id_token_extractor),
294            require_auth,
295        }
296    }
297
298    /// Set the access token extractor
299    ///
300    /// # Parameters
301    ///
302    /// * `extractor` - The extractor to use for access tokens
303    ///
304    /// # Returns
305    ///
306    /// Returns a new `TokenExtractorConfig` instance with the updated extractor.
307    pub fn with_access_token_extractor(
308        mut self,
309        extractor: Arc<dyn TokenExtractor>,
310    ) -> Self {
311        self.access_token_extractor = DebugTokenExtractor(extractor);
312        self
313    }
314
315    /// Set the ID token extractor
316    ///
317    /// # Parameters
318    ///
319    /// * `extractor` - The extractor to use for ID tokens
320    ///
321    /// # Returns
322    ///
323    /// Returns a new `TokenExtractorConfig` instance with the updated extractor.
324    pub fn with_id_token_extractor(
325        mut self,
326        extractor: Arc<dyn TokenExtractor>,
327    ) -> Self {
328        self.id_token_extractor = DebugTokenExtractor(extractor);
329        self
330    }
331
332    /// Set whether authentication is required
333    ///
334    /// # Parameters
335    ///
336    /// * `require_auth` - Whether authentication is required
337    ///
338    /// # Returns
339    ///
340    /// Returns a new `TokenExtractorConfig` instance with the updated setting.
341    pub fn with_require_auth(mut self, require_auth: bool) -> Self {
342        self.require_auth = require_auth;
343        self
344    }
345    
346    /// Get the access token extractor
347    pub fn access_token_extractor(&self) -> &dyn TokenExtractor {
348        &*self.access_token_extractor.0
349    }
350    
351    /// Get the ID token extractor
352    pub fn id_token_extractor(&self) -> &dyn TokenExtractor {
353        &*self.id_token_extractor.0
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360    use std::collections::HashMap;
361
362    #[test]
363    fn test_header_token_extractor_with_prefix() {
364        let extractor = HeaderTokenExtractor::bearer("Authorization");
365        let mut headers = HashMap::new();
366        
367        // Test with valid Bearer token
368        headers.insert("Authorization".to_string(), "Bearer test-token".to_string());
369        
370        let result = extractor.extract_token(&headers).unwrap();
371        assert_eq!(result, Some("test-token".to_string()));
372        
373        // Test with invalid prefix
374        headers.insert("Authorization".to_string(), "Basic test-token".to_string());
375        
376        let result = extractor.extract_token(&headers);
377        assert!(result.is_err());
378        
379        // Test with missing header
380        let empty_headers: HashMap<String, String> = HashMap::new();
381        let result = extractor.extract_token(&empty_headers).unwrap();
382        assert_eq!(result, None);
383    }
384    
385    #[test]
386    fn test_header_token_extractor_without_prefix() {
387        let extractor = HeaderTokenExtractor::new("X-Token", None);
388        let mut headers = HashMap::new();
389        
390        // Test with token
391        headers.insert("X-Token".to_string(), "test-token".to_string());
392        
393        let result = extractor.extract_token(&headers).unwrap();
394        assert_eq!(result, Some("test-token".to_string()));
395        
396        // Test with missing header
397        let empty_headers: HashMap<String, String> = HashMap::new();
398        let result = extractor.extract_token(&empty_headers).unwrap();
399        assert_eq!(result, None);
400    }
401    
402    #[test]
403    fn test_chained_token_extractor() {
404        let extractor1: Arc<dyn TokenExtractor> = 
405            Arc::new(HeaderTokenExtractor::bearer("Authorization"));
406        let extractor2: Arc<dyn TokenExtractor> = 
407            Arc::new(HeaderTokenExtractor::new("X-Token", None));
408        
409        let chained = ChainedTokenExtractor::new(vec![extractor1, extractor2]);
410        
411        let mut headers = HashMap::new();
412        
413        // Test with first extractor matching
414        headers.insert("Authorization".to_string(), "Bearer test-token-1".to_string());
415        
416        let result = chained.extract_token(&headers).unwrap();
417        assert_eq!(result, Some("test-token-1".to_string()));
418        
419        // Test with second extractor matching
420        let mut headers = HashMap::new();
421        headers.insert("X-Token".to_string(), "test-token-2".to_string());
422        
423        let result = chained.extract_token(&headers).unwrap();
424        assert_eq!(result, Some("test-token-2".to_string()));
425        
426        // Test with no extractors matching
427        let empty_headers: HashMap<String, String> = HashMap::new();
428        let result = chained.extract_token(&empty_headers).unwrap();
429        assert_eq!(result, None);
430    }
431}