jwt-verify 0.1.0

JWT verification library for AWS Cognito tokens and any OIDC-compatible IDP
Documentation
use crate::common::error::JwtError;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::fmt;

/// Base JWT claims for all Cognito tokens
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct CognitoJwtClaims {
    /// Subject (user identifier)
    pub sub: String,
    /// Issuer
    pub iss: String,
    /// Client ID
    #[serde(rename = "client_id", alias = "aud")]
    pub client_id: String,
    /// Origin JTI (JWT ID)
    pub origin_jti: Option<String>,
    /// Event ID
    pub event_id: Option<String>,
    /// Token use (id, access)
    #[serde(rename = "token_use")]
    pub token_use: String,
    /// Scope
    pub scope: Option<String>,
    /// Authentication time
    pub auth_time: u64,
    /// Expiration time
    pub exp: u64,
    /// Issued at time
    pub iat: u64,
    /// JWT ID
    pub jti: String,
    /// Username
    pub username: Option<String>,
    /// Custom claims
    #[serde(flatten)]
    pub custom_claims: HashMap<String, Value>,
}

/// ID token specific claims
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CognitoIdTokenClaims {
    /// Base claims
    #[serde(flatten)]
    pub base: CognitoJwtClaims,
    /// Email
    pub email: Option<String>,
    /// Email verified
    pub email_verified: Option<bool>,
    /// Phone number
    pub phone_number: Option<String>,
    /// Phone number verified
    pub phone_number_verified: Option<bool>,
    /// Cognito groups
    #[serde(rename = "cognito:groups")]
    pub cognito_groups: Option<Vec<String>>,
    /// Name
    pub name: Option<String>,
    /// Audience
    pub aud: String,
    /// Access token hash
    pub at_hash: Option<String>,
    /// Cognito username
    #[serde(rename = "cognito:username")]
    pub cognito_username: Option<String>,
    /// Cognito roles
    #[serde(rename = "cognito:roles")]
    pub cognito_roles: Option<Vec<String>>,
    /// Cognito preferred role
    #[serde(rename = "cognito:preferred_role")]
    pub cognito_preferred_role: Option<String>,
}

/// Access token specific claims
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CognitoAccessTokenClaims {
    /// Base claims
    #[serde(flatten)]
    pub base: CognitoJwtClaims,
    /// Scope (optional in some Cognito access tokens)
    pub scope: Option<String>,
    /// Version (not a standard field, but included for compatibility)
    pub version: Option<u32>,
}

/// Identity struct for ID tokens
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Identity {
    /// User ID
    pub user_id: String,
    /// Provider name
    pub provider_name: String,
    /// Provider type
    pub provider_type: String,
    /// Primary
    pub primary: String,
    /// Date created
    pub date_created: String,
}

impl CognitoJwtClaims {
    // Removed redundant validation methods that duplicate jsonwebtoken functionality:
    // - validate_expiration
    // - validate_not_before
    // - validate_issued_at
    // These validations are handled by jsonwebtoken's Validation struct

    /// Validate that the token has the expected issuer
    pub fn validate_issuer(&self, expected_issuer: &str) -> bool {
        self.iss == expected_issuer
    }

    /// Validate that the token has one of the expected client IDs
    pub fn validate_client_id(&self, expected_client_ids: &[String]) -> bool {
        expected_client_ids.is_empty() || expected_client_ids.contains(&self.client_id)
    }

    /// Validate that the token has the expected token use
    pub fn validate_token_use(&self, expected_token_use: &str) -> bool {
        self.token_use == expected_token_use
    }

    /// Get a custom claim as a string
    pub fn get_custom_claim_string(&self, claim_name: &str) -> Option<String> {
        self.custom_claims
            .get(claim_name)
            .and_then(|v| v.as_str().map(|s| s.to_string()))
    }

    /// Get a custom claim as a number
    pub fn get_custom_claim_number(&self, claim_name: &str) -> Option<f64> {
        self.custom_claims.get(claim_name).and_then(|v| v.as_f64())
    }

    /// Get a custom claim as a boolean
    pub fn get_custom_claim_bool(&self, claim_name: &str) -> Option<bool> {
        self.custom_claims.get(claim_name).and_then(|v| v.as_bool())
    }

    /// Get a custom claim as an array of strings
    pub fn get_custom_claim_string_array(&self, claim_name: &str) -> Option<Vec<String>> {
        self.custom_claims.get(claim_name).and_then(|v| {
            v.as_array().map(|arr| {
                arr.iter()
                    .filter_map(|item| item.as_str().map(|s| s.to_string()))
                    .collect()
            })
        })
    }
}

impl fmt::Display for CognitoJwtClaims {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(
            f,
            "JWT Claims (sub={}, iss={}, exp={})",
            self.sub, self.iss, self.exp
        )
    }
}

impl CognitoIdTokenClaims {
    /// Get the user's email if available
    pub fn get_email(&self) -> Option<&str> {
        self.email.as_deref()
    }

    /// Check if the user's email is verified
    pub fn is_email_verified(&self) -> bool {
        self.email_verified.unwrap_or(false)
    }

    /// Get the user's phone number if available
    pub fn get_phone_number(&self) -> Option<&str> {
        self.phone_number.as_deref()
    }

    /// Check if the user's phone number is verified
    pub fn is_phone_verified(&self) -> bool {
        self.phone_number_verified.unwrap_or(false)
    }

    /// Get the user's Cognito groups if available
    pub fn get_groups(&self) -> Vec<String> {
        self.cognito_groups.clone().unwrap_or_default()
    }

    /// Check if the user is in a specific group
    pub fn is_in_group(&self, group: &str) -> bool {
        self.cognito_groups
            .as_ref()
            .map(|groups| groups.contains(&group.to_string()))
            .unwrap_or(false)
    }

    /// Get the user's Cognito roles if available
    pub fn get_roles(&self) -> Vec<String> {
        self.cognito_roles.clone().unwrap_or_default()
    }

    /// Check if the user has a specific role
    pub fn has_role(&self, role: &str) -> bool {
        self.cognito_roles
            .as_ref()
            .map(|roles| roles.contains(&role.to_string()))
            .unwrap_or(false)
    }

    /// Get the user's preferred role if available
    pub fn get_preferred_role(&self) -> Option<&str> {
        self.cognito_preferred_role.as_deref()
    }
}

impl fmt::Display for CognitoIdTokenClaims {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(
            f,
            "ID Token Claims (sub={}, iss={}, exp={})",
            self.base.sub, self.base.iss, self.base.exp
        )
    }
}

impl TryFrom<CognitoJwtClaims> for CognitoIdTokenClaims {
    type Error = JwtError;

    fn try_from(claims: CognitoJwtClaims) -> Result<Self, Self::Error> {
        if claims.token_use != "id" {
            return Err(JwtError::InvalidTokenUse {
                expected: "id".to_string(),
                actual: claims.token_use.clone(),
            });
        }

        // For a proper implementation, we would deserialize the additional ID token fields
        // from the custom_claims map. For now, we'll create a minimal implementation.
        let email = claims.get_custom_claim_string("email");
        let email_verified = claims.get_custom_claim_bool("email_verified");
        let phone_number = claims.get_custom_claim_string("phone_number");
        let phone_number_verified = claims.get_custom_claim_bool("phone_number_verified");
        let cognito_groups = claims.get_custom_claim_string_array("cognito:groups");
        let name = claims.get_custom_claim_string("name");
        let aud = claims.get_custom_claim_string("aud").unwrap_or_default();
        let at_hash = claims.get_custom_claim_string("at_hash");
        let cognito_username = claims.get_custom_claim_string("cognito:username");
        let cognito_roles = claims.get_custom_claim_string_array("cognito:roles");
        let cognito_preferred_role = claims.get_custom_claim_string("cognito:preferred_role");

        Ok(Self {
            base: claims,
            email,
            email_verified,
            phone_number,
            phone_number_verified,
            cognito_groups,
            name,
            aud,
            at_hash,
            cognito_username,
            cognito_roles,
            cognito_preferred_role,
        })
    }
}

impl CognitoAccessTokenClaims {
    /// Get the token's scope as a list of individual scopes
    pub fn get_scopes(&self) -> Vec<String> {
        match &self.scope {
            Some(scope) => scope.split_whitespace().map(|s| s.to_string()).collect(),
            None => Vec::new(),
        }
    }

    /// Check if the token has a specific scope
    pub fn has_scope(&self, scope: &str) -> bool {
        self.get_scopes().contains(&scope.to_string())
    }

    /// Get the token version
    pub fn get_version(&self) -> u32 {
        self.version.unwrap_or(1) // Default to version 1 if not specified
    }
}

impl fmt::Display for CognitoAccessTokenClaims {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        let scope_display = match &self.scope {
            Some(s) => s,
            None => "none",
        };

        write!(
            f,
            "Access Token Claims (sub={}, iss={}, exp={}, scope={})",
            self.base.sub, self.base.iss, self.base.exp, scope_display
        )
    }
}

impl TryFrom<CognitoJwtClaims> for CognitoAccessTokenClaims {
    type Error = JwtError;

    fn try_from(claims: CognitoJwtClaims) -> Result<Self, Self::Error> {
        if claims.token_use != "access" {
            return Err(JwtError::InvalidTokenUse {
                expected: "access".to_string(),
                actual: claims.token_use.clone(),
            });
        }

        // Extract the scope from the claims (now optional)
        let scope = claims.scope.clone();

        // Extract the version from the claims (now optional)
        let version = claims.get_custom_claim_number("version").map(|v| v as u32);

        Ok(Self {
            base: claims,
            scope,
            version,
        })
    }
}