jwt-verify 0.1.3

JWT verification library for AWS Cognito tokens and any OIDC-compatible IDP
Documentation
use base64::Engine;
use jsonwebtoken::{decode_header, Algorithm, DecodingKey};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::fmt;

use crate::common::error::JwtError;

/// JWT token types
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TokenType {
    /// ID token
    IdToken,
    /// Access token
    AccessToken,
    /// Unknown token type
    Unknown,
    /// No token provided
    None,
}

impl fmt::Display for TokenType {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            TokenType::IdToken => write!(f, "ID token"),
            TokenType::AccessToken => write!(f, "Access token"),
            TokenType::Unknown => write!(f, "Unknown token"),
            TokenType::None => write!(f, "No token"),
        }
    }
}

/// JWT header
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JwtHeader {
    /// Algorithm
    pub alg: String,
    /// Key ID
    pub kid: String,
    /// Token type
    #[serde(rename = "typ")]
    pub token_type: Option<String>,
    /// Additional headers
    #[serde(flatten)]
    pub additional_headers: HashMap<String, Value>,
}

/// Token parser
pub struct TokenParser;

// TokenPayload moved to cognito/token.rs

impl TokenParser {
    /// Parse JWT header
    pub fn parse_token_header(token: &str) -> Result<JwtHeader, JwtError> {
        // Check if token is empty
        if token.is_empty() {
            return Err(JwtError::MissingToken);
        }

        // Validate token format (should have 3 parts separated by dots)
        if !token.contains('.') || token.matches('.').count() != 2 {
            return Err(JwtError::ParseError {
                part: Some("token".to_string()),
                error: "Invalid token format: expected 3 parts separated by dots".to_string(),
            });
        }

        // Decode header
        let header = decode_header(token).map_err(|e| JwtError::ParseError {
            part: Some("header".to_string()),
            error: format!("Failed to decode header: {}", e),
        })?;

        // Extract kid
        let kid = header.kid.ok_or_else(|| JwtError::ParseError {
            part: Some("header".to_string()),
            error: "Missing 'kid' in token header".to_string(),
        })?;

        // Validate algorithm (Cognito uses RS256)
        if header.alg != Algorithm::RS256 {
            return Err(JwtError::InvalidClaim {
                claim: "alg".to_string(),
                reason: format!("Unsupported algorithm: {:?}, expected RS256", header.alg),
                value: Some(format!("{:?}", header.alg)),
            });
        }

        // Extract algorithm
        let alg = match header.alg {
            Algorithm::RS256 => "RS256".to_string(),
            Algorithm::RS384 => "RS384".to_string(),
            Algorithm::RS512 => "RS512".to_string(),
            Algorithm::HS256 => "HS256".to_string(),
            Algorithm::HS384 => "HS384".to_string(),
            Algorithm::HS512 => "HS512".to_string(),
            Algorithm::ES256 => "ES256".to_string(),
            Algorithm::ES384 => "ES384".to_string(),
            _ => {
                return Err(JwtError::ParseError {
                    part: Some("header".to_string()),
                    error: "Unsupported algorithm".to_string(),
                })
            }
        };

        // Extract token type from header if available
        let token_type = header.typ.clone();

        // Create header with additional headers
        let mut additional_headers = HashMap::new();

        // Add any other fields from the header that we're not explicitly handling
        if let Some(cty) = &header.cty {
            additional_headers.insert("cty".to_string(), Value::String(cty.clone()));
        }

        // Create header
        let jwt_header = JwtHeader {
            alg,
            kid,
            token_type,
            additional_headers,
        };

        Ok(jwt_header)
    }

    /// Parse token claims
    pub fn parse_token_claims<T: for<'de> Deserialize<'de>>(
        token: &str,
        key: &DecodingKey,
        validation: &jsonwebtoken::Validation,
    ) -> Result<T, JwtError> {
        // Check if token is empty
        if token.is_empty() {
            return Err(JwtError::MissingToken);
        }

        // Validate token format (should have 3 parts separated by dots)
        if !token.contains('.') || token.matches('.').count() != 2 {
            return Err(JwtError::ParseError {
                part: Some("token".to_string()),
                error: "Invalid token format: expected 3 parts separated by dots".to_string(),
            });
        }

        // Decode and validate the token
        let token_data = jsonwebtoken::decode::<T>(token, key, validation).map_err(|e| {
            // Convert jsonwebtoken errors to our custom error types with more context
            use jsonwebtoken::errors::ErrorKind;
            match e.kind() {
                ErrorKind::ExpiredSignature => {
                    // Try to extract the expiration time from the token
                    if let Ok(exp) = Self::extract_claim_from_token::<u64>(token, "exp") {
                        let now = std::time::SystemTime::now()
                            .duration_since(std::time::UNIX_EPOCH)
                            .unwrap_or_default()
                            .as_secs();

                        JwtError::ExpiredToken {
                            exp: Some(exp),
                            current_time: Some(now),
                        }
                    } else {
                        JwtError::ExpiredToken {
                            exp: None,
                            current_time: None,
                        }
                    }
                }
                ErrorKind::InvalidSignature => JwtError::InvalidSignature,
                ErrorKind::InvalidIssuer => {
                    // Try to extract the issuer from the token
                    if let Ok(iss) = Self::extract_claim_from_token::<String>(token, "iss") {
                        JwtError::InvalidIssuer {
                            expected: validation
                                .iss
                                .as_ref()
                                .and_then(|iss_set| iss_set.iter().next())
                                .cloned()
                                .unwrap_or_default(),
                            actual: iss,
                        }
                    } else {
                        JwtError::InvalidClaim {
                            claim: "iss".to_string(),
                            reason: "Invalid issuer".to_string(),
                            value: None,
                        }
                    }
                }
                ErrorKind::InvalidAudience => JwtError::InvalidClaim {
                    claim: "aud".to_string(),
                    reason: "Invalid audience".to_string(),
                    value: None,
                },
                ErrorKind::InvalidSubject => JwtError::InvalidClaim {
                    claim: "sub".to_string(),
                    reason: "Invalid subject".to_string(),
                    value: None,
                },
                ErrorKind::ImmatureSignature => {
                    // Try to extract the not before time from the token
                    if let Ok(nbf) = Self::extract_claim_from_token::<u64>(token, "nbf") {
                        let now = std::time::SystemTime::now()
                            .duration_since(std::time::UNIX_EPOCH)
                            .unwrap_or_default()
                            .as_secs();

                        JwtError::TokenNotYetValid {
                            nbf: Some(nbf),
                            current_time: Some(now),
                        }
                    } else {
                        JwtError::TokenNotYetValid {
                            nbf: None,
                            current_time: None,
                        }
                    }
                }
                ErrorKind::InvalidAlgorithm => JwtError::InvalidClaim {
                    claim: "alg".to_string(),
                    reason: "Invalid algorithm".to_string(),
                    value: None,
                },
                _ => JwtError::ParseError {
                    part: Some("claims".to_string()),
                    error: format!("Failed to decode token: {}", e),
                },
            }
        })?;

        Ok(token_data.claims)
    }

    // parse_token_payload moved to cognito/token.rs

    /// Extract a specific claim from a token without validating the signature
    pub fn extract_claim_from_token<T: for<'de> Deserialize<'de>>(
        token: &str,
        claim_name: &str,
    ) -> Result<T, JwtError> {
        // Split the token
        let parts: Vec<&str> = token.split('.').collect();
        if parts.len() != 3 {
            return Err(JwtError::ParseError {
                part: Some("token".to_string()),
                error: "Invalid token format: expected 3 parts separated by dots".to_string(),
            });
        }

        // Decode the payload (second part)
        let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
            .decode(parts[1])
            .map_err(|e| JwtError::ParseError {
                part: Some("payload".to_string()),
                error: format!("Invalid base64 in payload: {}", e),
            })?;

        // Parse the payload
        let payload: serde_json::Value =
            serde_json::from_slice(&payload).map_err(|e| JwtError::ParseError {
                part: Some("payload".to_string()),
                error: format!("Invalid JSON in payload: {}", e),
            })?;

        // Extract the claim
        let claim = payload
            .get(claim_name)
            .ok_or_else(|| JwtError::ParseError {
                part: Some("payload".to_string()),
                error: format!("Claim '{}' not found in payload", claim_name),
            })?;

        // Deserialize the claim
        serde_json::from_value(claim.clone()).map_err(|e| JwtError::ParseError {
            part: Some("payload".to_string()),
            error: format!("Failed to deserialize claim '{}': {}", claim_name, e),
        })
    }

    /// Extract the issuer from a token without validating the signature
    ///
    /// This is a convenience method that extracts the issuer claim from a token.
    ///
    /// # Parameters
    ///
    /// * `token` - The JWT token
    ///
    /// # Returns
    ///
    /// Returns a `Result` containing the issuer if successful, or a `JwtError`
    /// if the issuer could not be extracted.
    pub fn extract_issuer(token: &str) -> Result<String, JwtError> {
        Self::extract_claim_from_token(token, "iss")
    }
}