jwt-verify 0.1.0

JWT verification library for AWS Cognito tokens and any OIDC-compatible IDP
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;

use crate::common::error::JwtError;

/// A generic representation of HTTP headers
///
/// This trait provides a framework-agnostic way to access HTTP headers.
/// Web framework-specific implementations can be provided to adapt
/// the framework's header representation to this interface.
pub trait Headers {
    /// Get a header value by name
    ///
    /// # Parameters
    ///
    /// * `name` - The name of the header to get
    ///
    /// # Returns
    ///
    /// Returns an `Option<&str>` with the header value if found, or `None` if not found.
    fn get(&self, name: &str) -> Option<&str>;
}

/// Implementation of Headers for a HashMap
///
/// This is a simple implementation that can be used for testing or
/// for frameworks that don't have a specialized header type.
impl Headers for HashMap<String, String> {
    fn get(&self, name: &str) -> Option<&str> {
        self.get(name).map(|s| s.as_str())
    }
}

/// Trait for extracting JWT tokens from HTTP headers
///
/// This trait defines the interface for extracting JWT tokens from HTTP headers.
/// Implementations can extract tokens from different headers, with different prefixes,
/// or using custom logic.
pub trait TokenExtractor: Send + Sync {
    /// Extract a token from HTTP headers
    ///
    /// # Parameters
    ///
    /// * `headers` - The HTTP headers to extract the token from
    ///
    /// # Returns
    ///
    /// Returns a `Result` containing an `Option<String>` with the token if found,
    /// or a `JwtError` if there was an error extracting the token.
    /// Returns `Ok(None)` if no token was found.
    fn extract_token(&self, headers: &dyn Headers) -> Result<Option<String>, JwtError>;
    
    /// Get a description of this extractor for debugging
    fn description(&self) -> &str;
}

/// Extract tokens from HTTP headers
///
/// This struct extracts tokens from HTTP headers with configurable header name and prefix.
/// It can be used to extract tokens from different headers with different prefixes.
#[derive(Clone)]
pub struct HeaderTokenExtractor {
    /// The name of the header to extract the token from
    header_name: String,
    /// Optional prefix to remove from the token (e.g., "Bearer ")
    token_prefix: Option<String>,
    /// Description for debugging
    description: String,
}

impl fmt::Debug for HeaderTokenExtractor {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("HeaderTokenExtractor")
            .field("header_name", &self.header_name)
            .field("token_prefix", &self.token_prefix)
            .finish()
    }
}

impl HeaderTokenExtractor {
    /// Create a new HeaderTokenExtractor with the given header name and token prefix
    ///
    /// # Parameters
    ///
    /// * `header_name` - The name of the header to extract the token from
    /// * `token_prefix` - Optional prefix to remove from the token (e.g., "Bearer ")
    ///
    /// # Returns
    ///
    /// Returns a new `HeaderTokenExtractor` instance.
    pub fn new(header_name: &str, token_prefix: Option<&str>) -> Self {
        let prefix_str = token_prefix.unwrap_or("none");
        let description = format!("HeaderTokenExtractor({}, {})", header_name, prefix_str);
        
        Self {
            header_name: header_name.to_string(),
            token_prefix: token_prefix.map(|s| s.to_string()),
            description,
        }
    }

    /// Create a new HeaderTokenExtractor for Bearer tokens
    ///
    /// This is a convenience constructor that creates a `HeaderTokenExtractor` that
    /// extracts Bearer tokens from the specified header.
    ///
    /// # Parameters
    ///
    /// * `header_name` - The name of the header to extract the token from
    ///
    /// # Returns
    ///
    /// Returns a new `HeaderTokenExtractor` instance configured for Bearer tokens.
    pub fn bearer(header_name: &str) -> Self {
        Self::new(header_name, Some("Bearer "))
    }
}

impl TokenExtractor for HeaderTokenExtractor {
    fn extract_token(&self, headers: &dyn Headers) -> Result<Option<String>, JwtError> {
        // Check if the header exists
        let header_str = match headers.get(&self.header_name) {
            Some(value) => value,
            None => return Ok(None), // Header not found
        };

        // If a token prefix is specified, check if the header value starts with it
        if let Some(prefix) = &self.token_prefix {
            if !header_str.starts_with(prefix) {
                return Err(JwtError::InvalidToken(format!(
                    "{} header must start with '{}'",
                    self.header_name, prefix
                )));
            }

            // Remove the prefix and return the token
            Ok(Some(header_str[prefix.len()..].to_string()))
        } else {
            // No prefix, return the entire header value
            Ok(Some(header_str.to_string()))
        }
    }
    
    fn description(&self) -> &str {
        &self.description
    }
}

/// A wrapper for TokenExtractor that implements Debug
#[derive(Clone)]
pub struct DebugTokenExtractor(pub Arc<dyn TokenExtractor>);

impl fmt::Debug for DebugTokenExtractor {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "TokenExtractor({})", self.0.description())
    }
}

/// Extract tokens from multiple sources
///
/// This struct tries multiple token extractors in sequence until one returns a token.
/// It can be used to extract tokens from different headers or using different methods.
#[derive(Clone)]
pub struct ChainedTokenExtractor {
    /// The extractors to try in sequence
    extractors: Vec<DebugTokenExtractor>,
    /// Description for debugging
    description: String,
}

impl fmt::Debug for ChainedTokenExtractor {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("ChainedTokenExtractor")
            .field("extractors", &self.extractors)
            .finish()
    }
}

impl ChainedTokenExtractor {
    /// Create a new ChainedTokenExtractor with the given extractors
    ///
    /// # Parameters
    ///
    /// * `extractors` - The extractors to try in sequence
    ///
    /// # Returns
    ///
    /// Returns a new `ChainedTokenExtractor` instance.
    pub fn new(extractors: Vec<Arc<dyn TokenExtractor>>) -> Self {
        let debug_extractors = extractors
            .into_iter()
            .map(DebugTokenExtractor)
            .collect::<Vec<_>>();
        
        let description = format!("ChainedTokenExtractor({})", debug_extractors.len());
        
        Self {
            extractors: debug_extractors,
            description,
        }
    }

    /// Add an extractor to the chain
    ///
    /// # Parameters
    ///
    /// * `extractor` - The extractor to add
    ///
    /// # Returns
    ///
    /// Returns a new `ChainedTokenExtractor` instance with the added extractor.
    pub fn add_extractor(mut self, extractor: Arc<dyn TokenExtractor>) -> Self {
        self.extractors.push(DebugTokenExtractor(extractor));
        self.description = format!("ChainedTokenExtractor({})", self.extractors.len());
        self
    }
}

impl TokenExtractor for ChainedTokenExtractor {
    fn extract_token(&self, headers: &dyn Headers) -> Result<Option<String>, JwtError> {
        // Try each extractor in sequence
        for extractor in &self.extractors {
            match extractor.0.extract_token(headers)? {
                Some(token) => return Ok(Some(token)),
                None => continue,
            }
        }

        // No token found
        Ok(None)
    }
    
    fn description(&self) -> &str {
        &self.description
    }
}

/// Configuration for token extraction
///
/// This struct holds the configuration for token extraction, including
/// extractors for different token types.
#[derive(Clone)]
pub struct TokenExtractorConfig {
    /// Extractor for access tokens
    pub access_token_extractor: DebugTokenExtractor,
    /// Extractor for ID tokens
    pub id_token_extractor: DebugTokenExtractor,
    /// Whether authentication is required (error if no token found)
    pub require_auth: bool,
}

impl fmt::Debug for TokenExtractorConfig {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("TokenExtractorConfig")
            .field("access_token_extractor", &self.access_token_extractor)
            .field("id_token_extractor", &self.id_token_extractor)
            .field("require_auth", &self.require_auth)
            .finish()
    }
}

impl Default for TokenExtractorConfig {
    fn default() -> Self {
        Self {
            // Default to extracting Bearer tokens from the Authorization header
            access_token_extractor: DebugTokenExtractor(Arc::new(HeaderTokenExtractor::bearer("Authorization"))),
            id_token_extractor: DebugTokenExtractor(Arc::new(HeaderTokenExtractor::bearer("Authorization"))),
            require_auth: true,
        }
    }
}

impl TokenExtractorConfig {
    /// Create a new TokenExtractorConfig with the given extractors
    ///
    /// # Parameters
    ///
    /// * `access_token_extractor` - Extractor for access tokens
    /// * `id_token_extractor` - Extractor for ID tokens
    /// * `require_auth` - Whether authentication is required
    ///
    /// # Returns
    ///
    /// Returns a new `TokenExtractorConfig` instance.
    pub fn new(
        access_token_extractor: Arc<dyn TokenExtractor>,
        id_token_extractor: Arc<dyn TokenExtractor>,
        require_auth: bool,
    ) -> Self {
        Self {
            access_token_extractor: DebugTokenExtractor(access_token_extractor),
            id_token_extractor: DebugTokenExtractor(id_token_extractor),
            require_auth,
        }
    }

    /// Set the access token extractor
    ///
    /// # Parameters
    ///
    /// * `extractor` - The extractor to use for access tokens
    ///
    /// # Returns
    ///
    /// Returns a new `TokenExtractorConfig` instance with the updated extractor.
    pub fn with_access_token_extractor(
        mut self,
        extractor: Arc<dyn TokenExtractor>,
    ) -> Self {
        self.access_token_extractor = DebugTokenExtractor(extractor);
        self
    }

    /// Set the ID token extractor
    ///
    /// # Parameters
    ///
    /// * `extractor` - The extractor to use for ID tokens
    ///
    /// # Returns
    ///
    /// Returns a new `TokenExtractorConfig` instance with the updated extractor.
    pub fn with_id_token_extractor(
        mut self,
        extractor: Arc<dyn TokenExtractor>,
    ) -> Self {
        self.id_token_extractor = DebugTokenExtractor(extractor);
        self
    }

    /// Set whether authentication is required
    ///
    /// # Parameters
    ///
    /// * `require_auth` - Whether authentication is required
    ///
    /// # Returns
    ///
    /// Returns a new `TokenExtractorConfig` instance with the updated setting.
    pub fn with_require_auth(mut self, require_auth: bool) -> Self {
        self.require_auth = require_auth;
        self
    }
    
    /// Get the access token extractor
    pub fn access_token_extractor(&self) -> &dyn TokenExtractor {
        &*self.access_token_extractor.0
    }
    
    /// Get the ID token extractor
    pub fn id_token_extractor(&self) -> &dyn TokenExtractor {
        &*self.id_token_extractor.0
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::collections::HashMap;

    #[test]
    fn test_header_token_extractor_with_prefix() {
        let extractor = HeaderTokenExtractor::bearer("Authorization");
        let mut headers = HashMap::new();
        
        // Test with valid Bearer token
        headers.insert("Authorization".to_string(), "Bearer test-token".to_string());
        
        let result = extractor.extract_token(&headers).unwrap();
        assert_eq!(result, Some("test-token".to_string()));
        
        // Test with invalid prefix
        headers.insert("Authorization".to_string(), "Basic test-token".to_string());
        
        let result = extractor.extract_token(&headers);
        assert!(result.is_err());
        
        // Test with missing header
        let empty_headers: HashMap<String, String> = HashMap::new();
        let result = extractor.extract_token(&empty_headers).unwrap();
        assert_eq!(result, None);
    }
    
    #[test]
    fn test_header_token_extractor_without_prefix() {
        let extractor = HeaderTokenExtractor::new("X-Token", None);
        let mut headers = HashMap::new();
        
        // Test with token
        headers.insert("X-Token".to_string(), "test-token".to_string());
        
        let result = extractor.extract_token(&headers).unwrap();
        assert_eq!(result, Some("test-token".to_string()));
        
        // Test with missing header
        let empty_headers: HashMap<String, String> = HashMap::new();
        let result = extractor.extract_token(&empty_headers).unwrap();
        assert_eq!(result, None);
    }
    
    #[test]
    fn test_chained_token_extractor() {
        let extractor1: Arc<dyn TokenExtractor> = 
            Arc::new(HeaderTokenExtractor::bearer("Authorization"));
        let extractor2: Arc<dyn TokenExtractor> = 
            Arc::new(HeaderTokenExtractor::new("X-Token", None));
        
        let chained = ChainedTokenExtractor::new(vec![extractor1, extractor2]);
        
        let mut headers = HashMap::new();
        
        // Test with first extractor matching
        headers.insert("Authorization".to_string(), "Bearer test-token-1".to_string());
        
        let result = chained.extract_token(&headers).unwrap();
        assert_eq!(result, Some("test-token-1".to_string()));
        
        // Test with second extractor matching
        let mut headers = HashMap::new();
        headers.insert("X-Token".to_string(), "test-token-2".to_string());
        
        let result = chained.extract_token(&headers).unwrap();
        assert_eq!(result, Some("test-token-2".to_string()));
        
        // Test with no extractors matching
        let empty_headers: HashMap<String, String> = HashMap::new();
        let result = chained.extract_token(&empty_headers).unwrap();
        assert_eq!(result, None);
    }
}