jwt_verify/cognito/
verifier.rs

1use anyhow::Result;
2use serde::de::DeserializeOwned;
3use std::collections::HashMap;
4
5use std::sync::Arc;
6
7use crate::claims::{
8    ClaimsValidator, CognitoAccessTokenClaims, CognitoIdTokenClaims, CognitoJwtClaims,
9};
10use crate::cognito::config::{self, TokenUse, VerifierConfig};
11use crate::cognito::token::CognitoTokenParser;
12use crate::common::error::{ErrorLogger, JwtError};
13use crate::common::token::TokenParser;
14use crate::jwk::provider::JwkProvider;
15use crate::jwk::registry::{JwkProviderRegistry, RegistryError};
16use crate::verifier::{AccessTokenClaims, IdTokenClaims, JwtVerifier};
17
18/// Cognito JWT verifier
19///
20/// This is the main entry point for the JWT verification library. It manages multiple
21/// user pools and provides methods for verifying different types of tokens.
22///
23/// The verifier uses issuer-based provider selection to determine which user pool
24/// should verify a given token. This means that tokens must contain a valid issuer
25/// claim that matches one of the registered user pools.
26///
27/// # Examples
28///
29/// ```
30/// use jwt_verify::{CognitoJwtVerifier, VerifierConfig};
31/// use std::time::Duration;
32///
33/// // Create configurations for multiple user pools
34/// let config1 = VerifierConfig::new(
35///     "us-east-1",
36///     "us-east-1_example1",
37///     &["client1".to_string()],
38///     None,
39/// ).unwrap();
40///
41/// let config2 = VerifierConfig::new(
42///     "us-west-2",
43///     "us-west-2_example2",
44///     &["client2".to_string()],
45///     None,
46/// ).unwrap();
47///
48/// // Create a verifier with multiple user pools
49/// let verifier = CognitoJwtVerifier::new(vec![config1, config2]).unwrap();
50///
51/// // Verify a token (the issuer in the token will determine which user pool to use)
52/// let token = "..."; // JWT token string
53/// let claims = verifier.verify_id_token(token).await.unwrap();
54/// ```
55#[derive(Debug)]
56pub struct CognitoJwtVerifier {
57    /// JWK provider registry
58    jwk_registry: JwkProviderRegistry,
59    /// Configurations for different user pools
60    configs: HashMap<String, VerifierConfig>,
61    /// Error logger
62    error_logger: ErrorLogger,
63}
64
65impl CognitoJwtVerifier {
66    /// Create a new verifier with multiple user pool configurations
67    ///
68    /// This constructor takes a vector of `VerifierConfig` objects, each representing
69    /// a different Cognito user pool. The verifier will register all these user pools
70    /// and use them for token verification based on the issuer claim in the tokens.
71    ///
72    /// # Parameters
73    ///
74    /// * `configs` - Vector of configurations for different user pools
75    ///
76    /// # Returns
77    ///
78    /// Returns a `Result` containing the new `CognitoJwtVerifier` if successful, or a `JwtError`
79    /// if any of the configurations are invalid or if there's an error registering the user pools.
80    ///
81    /// # Examples
82    ///
83    /// ```
84    /// use jwt_verify::{CognitoJwtVerifier, VerifierConfig};
85    ///
86    /// // Create configurations for multiple user pools
87    /// let config1 = VerifierConfig::new(
88    ///     "us-east-1",
89    ///     "us-east-1_example1",
90    ///     &["client1".to_string()],
91    ///     None,
92    /// ).unwrap();
93    ///
94    /// let config2 = VerifierConfig::new(
95    ///     "us-west-2",
96    ///     "us-west-2_example2",
97    ///     &["client2".to_string()],
98    ///     None,
99    /// ).unwrap();
100    ///
101    /// // Create a verifier with multiple user pools
102    /// let verifier = CognitoJwtVerifier::new(vec![config1, config2]).unwrap();
103    /// ```
104    pub fn new(configs: Vec<VerifierConfig>) -> Result<Self, JwtError> {
105        let mut verifier = Self {
106            jwk_registry: JwkProviderRegistry::new(),
107            configs: HashMap::new(),
108            error_logger: ErrorLogger::new(crate::common::error::ErrorVerbosity::Standard),
109        };
110
111        for config in configs {
112            let id = format!("{}_{}", config.region, config.user_pool_id);
113            verifier.add_user_pool(&id, config)?;
114        }
115
116        Ok(verifier)
117    }
118
119    /// Create a new verifier with a single user pool configuration
120    ///
121    /// This is a convenience constructor that creates a verifier with a single user pool.
122    ///
123    /// # Parameters
124    ///
125    /// * `region` - AWS region where the Cognito user pool is located (e.g., "us-east-1")
126    /// * `user_pool_id` - Cognito user pool ID in the format "region_poolid"
127    /// * `client_ids` - List of allowed client IDs for this user pool
128    ///
129    /// # Returns
130    ///
131    /// Returns a `Result` containing the new `CognitoJwtVerifier` if successful, or a `JwtError`
132    /// if the configuration is invalid or if there's an error registering the user pool.
133    ///
134    /// # Examples
135    ///
136    /// ```
137    /// use jwt_verify::CognitoJwtVerifier;
138    ///
139    /// // Create a verifier with a single user pool
140    /// let verifier = CognitoJwtVerifier::new_single_pool(
141    ///     "us-east-1",
142    ///     "us-east-1_example",
143    ///     &["client1".to_string()],
144    /// ).unwrap();
145    /// ```
146    pub fn new_single_pool(
147        region: &str,
148        user_pool_id: &str,
149        client_ids: &[String],
150    ) -> Result<Self, JwtError> {
151        let config = VerifierConfig::new(region, user_pool_id, client_ids, None)?;
152        Self::new(vec![config])
153    }
154
155    /// Add a user pool with configuration
156    ///
157    /// This method adds a new user pool to the verifier. The user pool is identified by
158    /// the provided ID, which should be unique among all registered user pools.
159    ///
160    /// # Parameters
161    ///
162    /// * `id` - Unique identifier for the user pool
163    /// * `config` - Configuration for the user pool
164    ///
165    /// # Returns
166    ///
167    /// Returns `Ok(())` if the user pool was successfully added, or a `JwtError` if there
168    /// was an error registering the user pool.
169    ///
170    /// # Examples
171    ///
172    /// ```
173    /// use jwt_verify::{CognitoJwtVerifier, VerifierConfig};
174    ///
175    /// // Create a verifier
176    /// let mut verifier = CognitoJwtVerifier::new(vec![]).unwrap();
177    ///
178    /// // Create a configuration for a user pool
179    /// let config = VerifierConfig::new(
180    ///     "us-east-1",
181    ///     "us-east-1_example",
182    ///     &["client1".to_string()],
183    ///     None,
184    /// ).unwrap();
185    ///
186    /// // Add the user pool to the verifier
187    /// verifier.add_user_pool("my_pool", config).unwrap();
188    /// ```
189    pub fn add_user_pool(&mut self, id: &str, config: VerifierConfig) -> Result<(), JwtError> {
190        // Create JWK provider
191        let jwk_provider = JwkProvider::new(
192            &config.region,
193            &config.user_pool_id,
194            config.jwk_cache_duration,
195        )?;
196
197        // Add to registry and configs
198        self.jwk_registry
199            .register(id, jwk_provider)
200            .map_err(|e| match e {
201                RegistryError::JwtError(err) => err,
202                other => JwtError::ConfigurationError {
203                    parameter: Some("provider_registration".to_string()),
204                    error: other.to_string(),
205                },
206            })?;
207
208        self.configs.insert(id.to_string(), config);
209
210        Ok(())
211    }
212
213    /// Add a user pool with region, user pool ID, and client IDs
214    ///
215    /// This is a convenience method that creates a configuration for a user pool
216    /// and adds it to the verifier.
217    ///
218    /// # Parameters
219    ///
220    /// * `id` - Unique identifier for the user pool
221    /// * `region` - AWS region where the Cognito user pool is located (e.g., "us-east-1")
222    /// * `user_pool_id` - Cognito user pool ID in the format "region_poolid"
223    /// * `client_ids` - List of allowed client IDs for this user pool
224    ///
225    /// # Returns
226    ///
227    /// Returns `Ok(())` if the user pool was successfully added, or a `JwtError` if there
228    /// was an error creating the configuration or registering the user pool.
229    ///
230    /// # Examples
231    ///
232    /// ```
233    /// use jwt_verify::CognitoJwtVerifier;
234    ///
235    /// // Create a verifier
236    /// let mut verifier = CognitoJwtVerifier::new(vec![]).unwrap();
237    ///
238    /// // Add a user pool
239    /// verifier.add_user_pool_with_params(
240    ///     "my_pool",
241    ///     "us-east-1",
242    ///     "us-east-1_example",
243    ///     &["client1".to_string()],
244    /// ).unwrap();
245    /// ```
246    pub fn add_user_pool_with_params(
247        &mut self,
248        id: &str,
249        region: &str,
250        user_pool_id: &str,
251        client_ids: &[String],
252    ) -> Result<(), JwtError> {
253        let config = VerifierConfig::new(region, user_pool_id, client_ids, None)?;
254        self.add_user_pool(id, config)
255    }
256
257    /// Get the list of registered user pool IDs
258    ///
259    /// This method returns a list of all user pool IDs that have been registered
260    /// with the verifier.
261    ///
262    /// # Returns
263    ///
264    /// Returns a vector of user pool IDs.
265    ///
266    /// # Examples
267    ///
268    /// ```
269    /// use jwt_verify::{CognitoJwtVerifier, VerifierConfig};
270    ///
271    /// // Create a verifier with multiple user pools
272    /// let config1 = VerifierConfig::new(
273    ///     "us-east-1",
274    ///     "us-east-1_example1",
275    ///     &["client1".to_string()],
276    ///     None,
277    /// ).unwrap();
278    ///
279    /// let config2 = VerifierConfig::new(
280    ///     "us-west-2",
281    ///     "us-west-2_example2",
282    ///     &["client2".to_string()],
283    ///     None,
284    /// ).unwrap();
285    ///
286    /// let verifier = CognitoJwtVerifier::new(vec![config1, config2]).unwrap();
287    ///
288    /// // Get the list of user pool IDs
289    /// let pool_ids = verifier.get_user_pool_ids();
290    /// assert_eq!(pool_ids.len(), 2);
291    /// ```
292    pub fn get_user_pool_ids(&self) -> Vec<String> {
293        self.jwk_registry.list_ids()
294    }
295
296    /// Remove a user pool
297    ///
298    /// This method removes a user pool from the verifier. The user pool is identified
299    /// by the provided ID.
300    ///
301    /// # Parameters
302    ///
303    /// * `id` - Unique identifier for the user pool to remove
304    ///
305    /// # Returns
306    ///
307    /// Returns `Ok(())` if the user pool was successfully removed, or a `JwtError` if
308    /// the user pool was not found.
309    ///
310    /// # Examples
311    ///
312    /// ```
313    /// use jwt_verify::{CognitoJwtVerifier, VerifierConfig};
314    ///
315    /// // Create a verifier with a user pool
316    /// let config = VerifierConfig::new(
317    ///     "us-east-1",
318    ///     "us-east-1_example",
319    ///     &["client1".to_string()],
320    ///     None,
321    /// ).unwrap();
322    ///
323    /// let mut verifier = CognitoJwtVerifier::new(vec![config]).unwrap();
324    ///
325    /// // Remove the user pool
326    /// verifier.remove_user_pool("us-east-1_us-east-1_example").unwrap();
327    /// ```
328    pub fn remove_user_pool(&mut self, id: &str) -> Result<(), JwtError> {
329        // Remove from registry
330        self.jwk_registry.remove(id).map_err(|e| match e {
331            RegistryError::ProviderNotFound(_) => JwtError::ConfigurationError {
332                parameter: Some("pool_id".to_string()),
333                error: format!("User pool '{}' not found", id),
334            },
335            RegistryError::JwtError(err) => err,
336            other => JwtError::ConfigurationError {
337                parameter: Some("pool_id".to_string()),
338                error: other.to_string(),
339            },
340        })?;
341
342        // Remove from configs
343        self.configs.remove(id);
344
345        Ok(())
346    }
347
348    /// Set the error verbosity level
349    ///
350    /// This method sets the verbosity level for error logging and reporting.
351    ///
352    /// # Parameters
353    ///
354    /// * `verbosity` - The error verbosity level
355    ///
356    /// # Examples
357    ///
358    /// ```
359    /// use jwt_verify::{CognitoJwtVerifier, VerifierConfig, ErrorVerbosity};
360    ///
361    /// // Create a verifier
362    /// let mut verifier = CognitoJwtVerifier::new(vec![]).unwrap();
363    ///
364    /// // Set the error verbosity level
365    /// verifier.set_error_verbosity(ErrorVerbosity::Detailed);
366    /// ```
367    pub fn set_error_verbosity(&mut self, verbosity: crate::common::error::ErrorVerbosity) {
368        self.error_logger = ErrorLogger::new(verbosity);
369    }
370
371    /// Prefetch JWKs for all user pools
372    ///
373    /// This method prefetches JWKs for all registered user pools. This can be useful
374    /// to warm up the cache before handling requests.
375    ///
376    /// # Returns
377    ///
378    /// Returns a vector of tuples containing the user pool ID and the result of the
379    /// prefetch operation.
380    ///
381    /// # Examples
382    ///
383    /// ```
384    /// use jwt_verify::{CognitoJwtVerifier, VerifierConfig};
385    ///
386    /// // Create a verifier with multiple user pools
387    /// let config1 = VerifierConfig::new(
388    ///     "us-east-1",
389    ///     "us-east-1_example1",
390    ///     &["client1".to_string()],
391    ///     None,
392    /// ).unwrap();
393    ///
394    /// let config2 = VerifierConfig::new(
395    ///     "us-west-2",
396    ///     "us-west-2_example2",
397    ///     &["client2".to_string()],
398    ///     None,
399    /// ).unwrap();
400    ///
401    /// let verifier = CognitoJwtVerifier::new(vec![config1, config2]).unwrap();
402    ///
403    /// // Prefetch JWKs for all user pools
404    /// let results = verifier.hydrate().await;
405    /// ```
406    pub async fn hydrate(&self) -> Vec<(String, Result<(), JwtError>)> {
407        self.jwk_registry.hydrate().await
408    }
409
410    /// Verify a token with generic type support
411    ///
412    /// This method verifies a JWT token and returns the claims as the specified type.
413    /// It automatically selects the appropriate user pool based on the issuer claim in the token.
414    ///
415    /// # Parameters
416    ///
417    /// * `token` - The JWT token to verify
418    ///
419    /// # Returns
420    ///
421    /// Returns a `Result` containing the verified claims if successful, or a `JwtError`
422    /// if verification fails.
423    ///
424    /// # Examples
425    ///
426    /// ```
427    /// use jwt_verify::{CognitoJwtVerifier, CognitoIdTokenClaims};
428    ///
429    /// // Create a verifier
430    /// let verifier = CognitoJwtVerifier::new_single_pool(
431    ///     "us-east-1",
432    ///     "us-east-1_example",
433    ///     &["client1".to_string()],
434    /// ).unwrap();
435    ///
436    /// // Verify a token
437    /// let token = "..."; // JWT token string
438    /// let claims = verifier.verify::<CognitoIdTokenClaims>(token).await.unwrap();
439    /// ```
440    pub async fn verify<T>(&self, token: &str) -> Result<T, JwtError>
441    where
442        T: DeserializeOwned + TryFrom<CognitoJwtClaims, Error = JwtError>,
443    {
444        // Parse header to get the token type
445        let header = TokenParser::parse_token_header(token)?;
446
447        // Extract issuer from token without full validation yet
448        let issuer = TokenParser::extract_issuer(token)?;
449
450        // Find the user pool ID based on the issuer
451        let pool_id = self.find_pool_id_by_issuer(&issuer)?;
452
453        // Verify with the specific pool
454        self.verify_generic_with_pool::<T>(token, &pool_id).await
455    }
456
457    /// Find a user pool ID by issuer
458    ///
459    /// This method finds the user pool ID that matches the given issuer.
460    ///
461    /// # Parameters
462    ///
463    /// * `issuer` - The issuer to match
464    ///
465    /// # Returns
466    ///
467    /// Returns a `Result` containing the user pool ID if found, or a `JwtError`
468    /// if no matching user pool is found.
469    fn find_pool_id_by_issuer(&self, issuer: &str) -> Result<String, JwtError> {
470        for (id, config) in &self.configs {
471            let expected_issuer = format!(
472                "https://cognito-idp.{}.amazonaws.com/{}",
473                config.region, config.user_pool_id
474            );
475            if expected_issuer == issuer {
476                return Ok(id.clone());
477            }
478        }
479
480        Err(JwtError::InvalidIssuer {
481            expected: "a registered Cognito user pool".to_string(),
482            actual: issuer.to_string(),
483        })
484    }
485
486    /// Verify a token with a specific user pool with generic type support
487    async fn verify_generic_with_pool<T>(&self, token: &str, pool_id: &str) -> Result<T, JwtError>
488    where
489        T: DeserializeOwned + TryFrom<CognitoJwtClaims, Error = JwtError>,
490    {
491        // Get provider and config
492        let jwk_provider = self.jwk_registry.get(pool_id).map_err(|e| match e {
493            RegistryError::ProviderNotFound(_) => JwtError::ConfigurationError {
494                parameter: Some("pool_id".to_string()),
495                error: format!("User pool '{}' not found", pool_id),
496            },
497            RegistryError::JwtError(err) => err,
498            other => JwtError::ConfigurationError {
499                parameter: Some("pool_id".to_string()),
500                error: other.to_string(),
501            },
502        })?;
503
504        let config = self
505            .configs
506            .get(pool_id)
507            .ok_or_else(|| JwtError::ConfigurationError {
508                parameter: Some("pool_id".to_string()),
509                error: format!("Configuration for user pool '{}' not found", pool_id),
510            })?;
511
512        // Parse header
513        let header = TokenParser::parse_token_header(token)?;
514
515        // Get key
516        let key = jwk_provider.get_key(&header.kid).await?;
517
518        // Create validation using the registry
519        let validation = self
520            .jwk_registry
521            .create_validation_for_issuer(
522                jwk_provider.get_issuer(),
523                config.clock_skew,
524                &config.client_ids,
525            )
526            .map_err(|e| match e {
527                RegistryError::JwtError(err) => err,
528                other => JwtError::ConfigurationError {
529                    parameter: Some("validation".to_string()),
530                    error: other.to_string(),
531                },
532            })?;
533
534        // Parse claims as base type first
535        let claims: CognitoJwtClaims = TokenParser::parse_token_claims(token, &key, &validation)?;
536
537        // Validate base claims
538        self.validate_claims(&claims, config)?;
539
540        // Convert to the requested type
541        T::try_from(claims)
542    }
543
544    /// Validate claims
545    fn validate_claims(
546        &self,
547        claims: &CognitoJwtClaims,
548        config: &VerifierConfig,
549    ) -> Result<(), JwtError> {
550        // Check if token_use is empty
551        if claims.token_use.is_empty() {
552            return Err(JwtError::InvalidClaim {
553                claim: "token_use".to_string(),
554                reason: "Token use is empty".to_string(),
555                value: None,
556            });
557        }
558
559        // Validate token use against allowed token uses
560        let token_use = match TokenUse::from_str(&claims.token_use) {
561            Some(tu) => tu,
562            None => {
563                // If it's "refresh", explicitly reject
564                if claims.token_use == "refresh" {
565                    return Err(JwtError::UnsupportedTokenType {
566                        token_type: "refresh".to_string(),
567                    });
568                } else {
569                    // Otherwise, it's an invalid token use
570                    return Err(JwtError::InvalidTokenUse {
571                        expected: "id or access".to_string(),
572                        actual: claims.token_use.clone(),
573                    });
574                }
575            }
576        };
577
578        // Check if the token use is allowed for this provider
579        if !config.allowed_token_uses.contains(&token_use) {
580            let expected = config
581                .allowed_token_uses
582                .iter()
583                .map(|t| t.as_str())
584                .collect::<Vec<_>>()
585                .join(" or ");
586
587            return Err(JwtError::InvalidTokenUse {
588                expected,
589                actual: claims.token_use.clone(),
590            });
591        }
592
593        // Create a claims validator and validate the claims
594        let claims_validator = ClaimsValidator::new(Arc::new(config.clone()));
595        claims_validator.validate_claims(claims)
596    }
597}
598
599// Implement the JwtVerifier trait for CognitoJwtVerifier
600impl JwtVerifier for CognitoJwtVerifier {
601    // We've removed the generic verify method to make the trait object-safe
602    // Instead, we'll use the specific methods for ID and Access tokens
603
604    async fn verify_id_token(&self, token: &str) -> Result<Box<dyn IdTokenClaims>, JwtError> {
605        // Use the generic verify method with CognitoIdTokenClaims as the target type
606        let claims = self.verify::<CognitoIdTokenClaims>(token).await?;
607        Ok(Box::new(claims))
608    }
609
610    async fn verify_access_token(
611        &self,
612        token: &str,
613    ) -> Result<Box<dyn AccessTokenClaims>, JwtError> {
614        // Use the generic verify method with CognitoAccessTokenClaims as the target type
615        let claims = self.verify::<CognitoAccessTokenClaims>(token).await?;
616        Ok(Box::new(claims))
617    }
618}
619
620// Implement IdTokenClaims for CognitoIdTokenClaims
621impl IdTokenClaims for CognitoIdTokenClaims {
622    fn get_sub(&self) -> &str {
623        &self.base.sub
624    }
625
626    fn get_iss(&self) -> &str {
627        &self.base.iss
628    }
629
630    fn get_aud(&self) -> &str {
631        &self.base.client_id
632    }
633
634    fn get_exp(&self) -> u64 {
635        self.base.exp
636    }
637
638    fn get_iat(&self) -> u64 {
639        self.base.iat
640    }
641
642    fn get_email(&self) -> Option<&str> {
643        self.email.as_deref()
644    }
645
646    fn is_email_verified(&self) -> bool {
647        self.email_verified.unwrap_or(false)
648    }
649
650    fn get_name(&self) -> Option<&str> {
651        self.name.as_deref()
652    }
653}
654
655// Implement AccessTokenClaims for CognitoAccessTokenClaims
656impl AccessTokenClaims for CognitoAccessTokenClaims {
657    fn get_sub(&self) -> &str {
658        &self.base.sub
659    }
660
661    fn get_iss(&self) -> &str {
662        &self.base.iss
663    }
664
665    fn get_aud(&self) -> &str {
666        &self.base.client_id
667    }
668
669    fn get_exp(&self) -> u64 {
670        self.base.exp
671    }
672
673    fn get_iat(&self) -> u64 {
674        self.base.iat
675    }
676
677    fn get_scopes(&self) -> Vec<String> {
678        match &self.scope {
679            Some(scope) => scope.split_whitespace().map(|s| s.to_string()).collect(),
680            None => Vec::new(),
681        }
682    }
683
684    fn has_scope(&self, scope: &str) -> bool {
685        self.get_scopes().contains(&scope.to_string())
686    }
687
688    fn get_client_id(&self) -> Option<&str> {
689        Some(&self.base.client_id)
690    }
691}