use base64::Engine;
use jsonwebtoken::{decode_header, Algorithm, DecodingKey};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::fmt;
use crate::common::error::JwtError;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TokenType {
IdToken,
AccessToken,
Unknown,
None,
}
impl fmt::Display for TokenType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TokenType::IdToken => write!(f, "ID token"),
TokenType::AccessToken => write!(f, "Access token"),
TokenType::Unknown => write!(f, "Unknown token"),
TokenType::None => write!(f, "No token"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JwtHeader {
pub alg: String,
pub kid: String,
#[serde(rename = "typ")]
pub token_type: Option<String>,
#[serde(flatten)]
pub additional_headers: HashMap<String, Value>,
}
pub struct TokenParser;
impl TokenParser {
pub fn parse_token_header(token: &str) -> Result<JwtHeader, JwtError> {
if token.is_empty() {
return Err(JwtError::MissingToken);
}
if !token.contains('.') || token.matches('.').count() != 2 {
return Err(JwtError::ParseError {
part: Some("token".to_string()),
error: "Invalid token format: expected 3 parts separated by dots".to_string(),
});
}
let header = decode_header(token).map_err(|e| JwtError::ParseError {
part: Some("header".to_string()),
error: format!("Failed to decode header: {}", e),
})?;
let kid = header.kid.ok_or_else(|| JwtError::ParseError {
part: Some("header".to_string()),
error: "Missing 'kid' in token header".to_string(),
})?;
if header.alg != Algorithm::RS256 {
return Err(JwtError::InvalidClaim {
claim: "alg".to_string(),
reason: format!("Unsupported algorithm: {:?}, expected RS256", header.alg),
value: Some(format!("{:?}", header.alg)),
});
}
let alg = match header.alg {
Algorithm::RS256 => "RS256".to_string(),
Algorithm::RS384 => "RS384".to_string(),
Algorithm::RS512 => "RS512".to_string(),
Algorithm::HS256 => "HS256".to_string(),
Algorithm::HS384 => "HS384".to_string(),
Algorithm::HS512 => "HS512".to_string(),
Algorithm::ES256 => "ES256".to_string(),
Algorithm::ES384 => "ES384".to_string(),
_ => {
return Err(JwtError::ParseError {
part: Some("header".to_string()),
error: "Unsupported algorithm".to_string(),
})
}
};
let token_type = header.typ.clone();
let mut additional_headers = HashMap::new();
if let Some(cty) = &header.cty {
additional_headers.insert("cty".to_string(), Value::String(cty.clone()));
}
let jwt_header = JwtHeader {
alg,
kid,
token_type,
additional_headers,
};
Ok(jwt_header)
}
pub fn parse_token_claims<T: for<'de> Deserialize<'de>>(
token: &str,
key: &DecodingKey,
validation: &jsonwebtoken::Validation,
) -> Result<T, JwtError> {
if token.is_empty() {
return Err(JwtError::MissingToken);
}
if !token.contains('.') || token.matches('.').count() != 2 {
return Err(JwtError::ParseError {
part: Some("token".to_string()),
error: "Invalid token format: expected 3 parts separated by dots".to_string(),
});
}
let token_data = jsonwebtoken::decode::<T>(token, key, validation).map_err(|e| {
use jsonwebtoken::errors::ErrorKind;
match e.kind() {
ErrorKind::ExpiredSignature => {
if let Ok(exp) = Self::extract_claim_from_token::<u64>(token, "exp") {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
JwtError::ExpiredToken {
exp: Some(exp),
current_time: Some(now),
}
} else {
JwtError::ExpiredToken {
exp: None,
current_time: None,
}
}
}
ErrorKind::InvalidSignature => JwtError::InvalidSignature,
ErrorKind::InvalidIssuer => {
if let Ok(iss) = Self::extract_claim_from_token::<String>(token, "iss") {
JwtError::InvalidIssuer {
expected: validation
.iss
.as_ref()
.and_then(|iss_set| iss_set.iter().next())
.cloned()
.unwrap_or_default(),
actual: iss,
}
} else {
JwtError::InvalidClaim {
claim: "iss".to_string(),
reason: "Invalid issuer".to_string(),
value: None,
}
}
}
ErrorKind::InvalidAudience => JwtError::InvalidClaim {
claim: "aud".to_string(),
reason: "Invalid audience".to_string(),
value: None,
},
ErrorKind::InvalidSubject => JwtError::InvalidClaim {
claim: "sub".to_string(),
reason: "Invalid subject".to_string(),
value: None,
},
ErrorKind::ImmatureSignature => {
if let Ok(nbf) = Self::extract_claim_from_token::<u64>(token, "nbf") {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
JwtError::TokenNotYetValid {
nbf: Some(nbf),
current_time: Some(now),
}
} else {
JwtError::TokenNotYetValid {
nbf: None,
current_time: None,
}
}
}
ErrorKind::InvalidAlgorithm => JwtError::InvalidClaim {
claim: "alg".to_string(),
reason: "Invalid algorithm".to_string(),
value: None,
},
_ => JwtError::ParseError {
part: Some("claims".to_string()),
error: format!("Failed to decode token: {}", e),
},
}
})?;
Ok(token_data.claims)
}
pub fn extract_claim_from_token<T: for<'de> Deserialize<'de>>(
token: &str,
claim_name: &str,
) -> Result<T, JwtError> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err(JwtError::ParseError {
part: Some("token".to_string()),
error: "Invalid token format: expected 3 parts separated by dots".to_string(),
});
}
let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(parts[1])
.map_err(|e| JwtError::ParseError {
part: Some("payload".to_string()),
error: format!("Invalid base64 in payload: {}", e),
})?;
let payload: serde_json::Value =
serde_json::from_slice(&payload).map_err(|e| JwtError::ParseError {
part: Some("payload".to_string()),
error: format!("Invalid JSON in payload: {}", e),
})?;
let claim = payload
.get(claim_name)
.ok_or_else(|| JwtError::ParseError {
part: Some("payload".to_string()),
error: format!("Claim '{}' not found in payload", claim_name),
})?;
serde_json::from_value(claim.clone()).map_err(|e| JwtError::ParseError {
part: Some("payload".to_string()),
error: format!("Failed to deserialize claim '{}': {}", claim_name, e),
})
}
pub fn extract_issuer(token: &str) -> Result<String, JwtError> {
Self::extract_claim_from_token(token, "iss")
}
}