jwt-verify 0.1.3

JWT verification library for AWS Cognito tokens and any OIDC-compatible IDP
Documentation
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};

use crate::claims::CognitoJwtClaims;
use crate::cognito::config::{TokenUse, VerifierConfig};
use crate::common::error::{ErrorVerbosity, JwtError};

/// Claims validator for JWT tokens
pub struct ClaimsValidator {
    /// Configuration
    config: Arc<VerifierConfig>,
}

impl ClaimsValidator {
    /// Create a new claims validator
    pub fn new(config: Arc<VerifierConfig>) -> Self {
        Self { config }
    }

    /// Validate claims
    pub fn validate_claims(&self, claims: &CognitoJwtClaims) -> Result<(), JwtError> {
        // Validate required claims
        for claim in &self.config.required_claims {
            match claim.as_str() {
                "sub" if claims.sub.is_empty() => {
                    return Err(JwtError::InvalidClaim {
                        claim: "sub".to_string(),
                        reason: "Subject is empty".to_string(),
                        value: None,
                    });
                }
                "iss" if claims.iss.is_empty() => {
                    return Err(JwtError::InvalidClaim {
                        claim: "iss".to_string(),
                        reason: "Issuer is empty".to_string(),
                        value: None,
                    });
                }
                "client_id" if claims.client_id.is_empty() => {
                    return Err(JwtError::InvalidClaim {
                        claim: "client_id".to_string(),
                        reason: "Client ID is empty".to_string(),
                        value: None,
                    });
                }
                _ => {}
            }
        }

        self.validate_token_use(claims.token_use.as_str())?;

        // Validate expiration
        self.validate_expiration(claims.exp)?;

        // Validate issued at
        self.validate_issued_at(claims.iat)?;

        // Validate issuer
        self.validate_issuer(&claims.iss)?;

        // Validate client ID
        self.validate_client_id(&claims.client_id)?;

        // Validate token type-specific requirements
        // self.validate_token_type_requirements(claims, actual_token_use)?;

        // Run custom validators
        for validator in &self.config.custom_validators {
            if let Err(reason) = validator.validate(claims) {
                return Err(JwtError::InvalidClaim {
                    claim: "custom".to_string(),
                    reason,
                    value: None,
                });
            }
        }

        Ok(())
    }

    /// Validate expiration time with clock skew
    pub fn validate_expiration(&self, exp: u64) -> Result<(), JwtError> {
        let now = SystemTime::now()
            .duration_since(UNIX_EPOCH)
            .unwrap_or_default()
            .as_secs();

        // Allow for clock skew
        if exp <= now - self.config.clock_skew.as_secs() {
            return Err(JwtError::ExpiredToken {
                exp: Some(exp),
                current_time: Some(now),
            });
        }

        Ok(())
    }

    /// Validate not before time with clock skew
    pub fn validate_not_before(&self, nbf: Option<u64>) -> Result<(), JwtError> {
        if let Some(nbf) = nbf {
            let now = SystemTime::now()
                .duration_since(UNIX_EPOCH)
                .unwrap_or_default()
                .as_secs();

            // Allow for clock skew
            if nbf > now + self.config.clock_skew.as_secs() {
                return Err(JwtError::TokenNotYetValid {
                    nbf: Some(nbf),
                    current_time: Some(now),
                });
            }
        }

        Ok(())
    }

    /// Validate issued at time with clock skew
    pub fn validate_issued_at(&self, iat: u64) -> Result<(), JwtError> {
        let now = SystemTime::now()
            .duration_since(UNIX_EPOCH)
            .unwrap_or_default()
            .as_secs();

        // Allow for clock skew
        if iat > now + self.config.clock_skew.as_secs() {
            return Err(JwtError::InvalidClaim {
                claim: "iat".to_string(),
                reason: "Token issued in the future".to_string(),
                value: Some(iat.to_string()),
            });
        }

        Ok(())
    }

    /// Validate issuer
    pub fn validate_issuer(&self, issuer: &str) -> Result<(), JwtError> {
        let expected_issuer = self.config.get_issuer_url();

        if issuer != expected_issuer {
            return Err(JwtError::InvalidIssuer {
                expected: expected_issuer,
                actual: issuer.to_string(),
            });
        }

        Ok(())
    }

    /// Validate client ID
    pub fn validate_client_id(&self, client_id: &str) -> Result<(), JwtError> {
        // If no client IDs are configured, skip validation
        if self.config.client_ids.is_empty() {
            return Ok(());
        }

        // Check if the client ID is in the list of allowed client IDs
        if !self.config.client_ids.contains(&client_id.to_string()) {
            return Err(JwtError::InvalidClientId {
                expected: self.config.client_ids.clone(),
                actual: client_id.to_string(),
            });
        }

        Ok(())
    }

    /// Validate token use
    pub fn validate_token_use(&self, claims_token_use: &str) -> Result<TokenUse, JwtError> {
        if claims_token_use.is_empty() {
            return Err(JwtError::InvalidClaim {
                claim: "token_use".to_string(),
                reason: "Token use is empty".to_string(),
                value: None,
            });
        }
        let actual_token_use = match TokenUse::from_str(&claims_token_use) {
            Some(tu) => tu,
            None => {
                return Err(JwtError::InvalidTokenUse {
                    expected: "id or access".to_string(),
                    actual: claims_token_use.to_string(),
                });
            }
        };

        // Validate token use against allowed token uses
        let allowed_token_uses = &self.config.allowed_token_uses;
        if !allowed_token_uses.contains(&actual_token_use) {
            let expected = allowed_token_uses
                .iter()
                .map(|t| t.as_str())
                .collect::<Vec<_>>()
                .join(" or ");

            return Err(JwtError::InvalidTokenUse {
                expected,
                actual: claims_token_use.to_string(),
            });
        }

        Ok(actual_token_use)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::collections::HashSet;
    use std::time::Duration;

    #[test]
    fn test_validate_expiration() {
        // Create a config with 60 seconds of clock skew
        let config = Arc::new(VerifierConfig {
            region: "us-east-1".to_string(),
            user_pool_id: "us-east-1_example".to_string(),
            client_ids: vec!["client1".to_string()],
            allowed_token_uses: vec![TokenUse::Id, TokenUse::Access],
            clock_skew: Duration::from_secs(60),
            jwk_cache_duration: Duration::from_secs(3600),
            required_claims: HashSet::new(),
            custom_validators: Vec::new(),
            error_verbosity: ErrorVerbosity::Standard,
        });

        let validator = ClaimsValidator::new(config);

        // Get current time
        let now = SystemTime::now()
            .duration_since(UNIX_EPOCH)
            .unwrap()
            .as_secs();

        // Test with a token that expires in the future
        assert!(validator.validate_expiration(now + 3600).is_ok());

        // Test with a token that expired just now (should be valid due to clock skew)
        assert!(validator.validate_expiration(now - 30).is_ok());

        // Test with a token that expired more than the clock skew ago
        assert!(validator.validate_expiration(now - 120).is_err());
    }

    #[test]
    fn test_validate_issuer() {
        // Create a config
        let config = Arc::new(VerifierConfig {
            region: "us-east-1".to_string(),
            user_pool_id: "us-east-1_example".to_string(),
            client_ids: vec!["client1".to_string()],
            allowed_token_uses: vec![TokenUse::Id, TokenUse::Access],
            clock_skew: Duration::from_secs(60),
            jwk_cache_duration: Duration::from_secs(3600),
            required_claims: HashSet::new(),
            custom_validators: Vec::new(),
            error_verbosity: ErrorVerbosity::Standard,
        });

        let validator = ClaimsValidator::new(config);

        // Test with the correct issuer
        assert!(validator
            .validate_issuer("https://cognito-idp.us-east-1.amazonaws.com/us-east-1_example")
            .is_ok());

        // Test with an incorrect issuer
        assert!(validator
            .validate_issuer("https://cognito-idp.us-west-2.amazonaws.com/us-west-2_example")
            .is_err());
    }

    #[test]
    fn test_validate_client_id() {
        // Create a config with multiple client IDs
        let config = Arc::new(VerifierConfig {
            region: "us-east-1".to_string(),
            user_pool_id: "us-east-1_example".to_string(),
            client_ids: vec!["client1".to_string(), "client2".to_string()],
            allowed_token_uses: vec![TokenUse::Id, TokenUse::Access],
            clock_skew: Duration::from_secs(60),
            jwk_cache_duration: Duration::from_secs(3600),
            required_claims: HashSet::new(),
            custom_validators: Vec::new(),
            error_verbosity: ErrorVerbosity::Standard,
        });

        let validator = ClaimsValidator::new(config);

        // Test with a valid client ID
        assert!(validator.validate_client_id("client1").is_ok());
        assert!(validator.validate_client_id("client2").is_ok());

        // Test with an invalid client ID
        assert!(validator.validate_client_id("client3").is_err());

        // Create a config with no client IDs
        let config = Arc::new(VerifierConfig {
            region: "us-east-1".to_string(),
            user_pool_id: "us-east-1_example".to_string(),
            client_ids: vec![],
            allowed_token_uses: vec![TokenUse::Id, TokenUse::Access],
            clock_skew: Duration::from_secs(60),
            jwk_cache_duration: Duration::from_secs(3600),
            required_claims: HashSet::new(),
            custom_validators: Vec::new(),
            error_verbosity: ErrorVerbosity::Standard,
        });

        let validator = ClaimsValidator::new(config);

        // Test with any client ID (should pass since no client IDs are configured)
        assert!(validator.validate_client_id("any_client").is_ok());
    }
}