use anyhow::{anyhow, Result};
use base64::Engine;
use jsonwebtoken::errors::ErrorKind;
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation};
use serde_json::Value;
use std::collections::HashMap;
use std::error::Error;
use std::fmt;
use crate::utils::compression;
#[derive(Debug, Clone)]
pub enum JwtError {
InvalidSignature,
ExpiredSignature,
ImmatureSignature,
InvalidAlgorithm,
Other(String),
}
impl fmt::Display for JwtError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
JwtError::InvalidSignature => write!(f, "Invalid signature"),
JwtError::ExpiredSignature => write!(f, "Expired signature"),
JwtError::ImmatureSignature => write!(f, "Immature signature"),
JwtError::InvalidAlgorithm => write!(f, "Invalid algorithm"),
JwtError::Other(msg) => write!(f, "JWT error: {msg}"),
}
}
}
impl Error for JwtError {}
impl From<ErrorKind> for JwtError {
fn from(kind: ErrorKind) -> Self {
match kind {
ErrorKind::InvalidSignature => JwtError::InvalidSignature,
ErrorKind::ExpiredSignature => JwtError::ExpiredSignature,
ErrorKind::ImmatureSignature => JwtError::ImmatureSignature,
ErrorKind::InvalidAlgorithm => JwtError::InvalidAlgorithm,
_ => JwtError::Other(format!("{kind:?}")),
}
}
}
#[derive(Debug, Clone)]
pub struct DecodedToken {
pub header: HashMap<String, Value>,
pub claims: Value,
pub algorithm: Algorithm,
}
#[derive(Debug, Clone)]
pub struct DecodedJweToken {
pub header: HashMap<String, Value>,
pub encrypted_key: String,
pub iv: String,
pub ciphertext: String,
pub tag: String,
pub algorithm: String,
pub encryption: String,
}
#[derive(Debug, Clone, PartialEq)]
pub enum TokenType {
Jwt,
Jwe,
Unknown,
}
pub enum KeyData<'a> {
Secret(&'a str),
PrivateKeyPem(&'a str),
#[allow(dead_code)]
PrivateKeyDer(&'a [u8]),
None,
}
pub struct EncodeOptions<'a> {
pub algorithm: &'a str,
pub key_data: KeyData<'a>,
pub header_params: Option<HashMap<&'a str, &'a str>>,
pub compress_payload: bool,
}
impl<'a> Default for EncodeOptions<'a> {
fn default() -> Self {
Self {
algorithm: "HS256",
key_data: KeyData::Secret(""),
header_params: None,
compress_payload: false,
}
}
}
#[allow(dead_code)]
pub fn encode(claims: &Value, secret: &str, alg_str: &str) -> Result<String> {
let options = EncodeOptions {
algorithm: alg_str,
key_data: KeyData::Secret(secret),
header_params: None,
compress_payload: false,
};
encode_with_options(claims, &options)
}
pub fn encode_with_options(claims: &Value, options: &EncodeOptions) -> Result<String> {
use std::collections::BTreeMap;
let algorithm = match options.algorithm.to_uppercase().as_str() {
"HS256" => Algorithm::HS256,
"HS384" => Algorithm::HS384,
"HS512" => Algorithm::HS512,
"RS256" => Algorithm::RS256,
"RS384" => Algorithm::RS384,
"RS512" => Algorithm::RS512,
"ES256" => Algorithm::ES256,
"ES384" => Algorithm::ES384,
"PS256" => Algorithm::PS256,
"PS384" => Algorithm::PS384,
"PS512" => Algorithm::PS512,
"EDDSA" => Algorithm::EdDSA,
"NONE" => Algorithm::HS256, _ => {
return Err(anyhow!(
"Unsupported algorithm '{}'. Supported algorithms: HS256, HS384, HS512, RS256, RS384, RS512, ES256, ES384, PS256, PS384, PS512, EdDSA, none",
options.algorithm
))
}
};
if options.compress_payload {
return encode_compressed_jwt(claims, options, algorithm);
}
let mut header = Header::new(algorithm);
if let Some(params) = &options.header_params {
for (key, value) in params {
match *key {
"typ" => header.typ = Some(value.to_string()),
"cty" => header.cty = Some(value.to_string()),
_ => { }
}
}
}
if options.algorithm.to_uppercase() == "NONE" {
let mut header_map = BTreeMap::new();
header_map.insert("alg".to_string(), Value::String("none".to_string()));
header_map.insert("typ".to_string(), Value::String("JWT".to_string()));
if let Some(params) = &options.header_params {
for (key, value) in params {
header_map.insert(key.to_string(), Value::String(value.to_string()));
}
}
let header_json = serde_json::to_string(&header_map)?;
let claims_json = serde_json::to_string(claims)?;
let encoded_header =
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(header_json.as_bytes());
let encoded_claims =
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(claims_json.as_bytes());
return Ok(format!("{encoded_header}.{encoded_claims}.''"));
}
let encoding_key = match &options.key_data {
KeyData::Secret(secret) => match algorithm {
Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => {
EncodingKey::from_secret(secret.as_bytes())
}
_ => {
return Err(anyhow!(
"Secret key provided but algorithm {:?} is not an HMAC algorithm. Use HS256, HS384, or HS512 for secret keys",
algorithm
))
}
},
KeyData::PrivateKeyPem(pem) => match algorithm {
Algorithm::RS256
| Algorithm::RS384
| Algorithm::RS512
| Algorithm::PS256
| Algorithm::PS384
| Algorithm::PS512 => EncodingKey::from_rsa_pem(pem.as_bytes())?,
Algorithm::ES256 | Algorithm::ES384 => EncodingKey::from_ec_pem(pem.as_bytes())?,
Algorithm::EdDSA => EncodingKey::from_ed_pem(pem.as_bytes())?,
_ => {
return Err(anyhow!(
"Algorithm {:?} not compatible with PEM key",
algorithm
))
}
},
KeyData::PrivateKeyDer(der) => match algorithm {
Algorithm::RS256
| Algorithm::RS384
| Algorithm::RS512
| Algorithm::PS256
| Algorithm::PS384
| Algorithm::PS512 => EncodingKey::from_rsa_der(der),
Algorithm::ES256 | Algorithm::ES384 => EncodingKey::from_ec_der(der),
Algorithm::EdDSA => EncodingKey::from_ed_der(der),
_ => {
return Err(anyhow!(
"Algorithm {:?} not compatible with DER key",
algorithm
))
}
},
KeyData::None => {
return Err(anyhow!(
"No key or secret provided for algorithm {:?}. Please provide a secret (for HMAC) or private key (for RSA/ECDSA/EdDSA)",
algorithm
))
}
};
let token = jsonwebtoken::encode(&header, claims, &encoding_key)?;
Ok(token)
}
fn encode_compressed_jwt(
claims: &Value,
options: &EncodeOptions,
algorithm: Algorithm,
) -> Result<String> {
use std::collections::BTreeMap;
let mut header_map = BTreeMap::new();
header_map.insert(
"alg".to_string(),
Value::String(options.algorithm.to_string()),
);
header_map.insert("typ".to_string(), Value::String("JWT".to_string()));
header_map.insert("zip".to_string(), Value::String("DEF".to_string()));
if let Some(params) = &options.header_params {
for (key, value) in params {
if *key != "zip" {
header_map.insert(key.to_string(), Value::String(value.to_string()));
}
}
}
let claims_json = serde_json::to_string(claims)?;
let compressed_payload = compression::compress_deflate(claims_json.as_bytes())?;
let header_json = serde_json::to_string(&header_map)?;
let encoded_header =
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(header_json.as_bytes());
let encoded_payload =
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(&compressed_payload);
if options.algorithm.to_uppercase() == "NONE" {
return Ok(format!("{encoded_header}.{encoded_payload}.''"));
}
let message = format!("{encoded_header}.{encoded_payload}");
let signature = match &options.key_data {
KeyData::Secret(secret) => match algorithm {
Algorithm::HS256 => {
hmac_sha256::HMAC::mac(message.as_bytes(), secret.as_bytes()).to_vec()
}
Algorithm::HS384 | Algorithm::HS512 => {
let temp_header = jsonwebtoken::Header::new(algorithm);
let temp_claims = serde_json::json!({"temp": "data"});
let encoding_key = EncodingKey::from_secret(secret.as_bytes());
let temp_token = jsonwebtoken::encode(&temp_header, &temp_claims, &encoding_key)?;
let temp_parts: Vec<&str> = temp_token.split('.').collect();
if temp_parts.len() != 3 {
return Err(anyhow!("Failed to create temporary token for signing"));
}
return Err(anyhow!("HS384/HS512 with compression not yet supported"));
}
_ => return Err(anyhow!("HMAC algorithms require a secret key")),
},
_ => {
return Err(anyhow!(
"Only HMAC-SHA256 is currently supported for compressed JWTs"
))
}
};
let encoded_signature = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(&signature);
Ok(format!(
"{encoded_header}.{encoded_payload}.{encoded_signature}"
))
}
pub fn detect_token_type(token: &str) -> TokenType {
let parts: Vec<&str> = token.split('.').collect();
match parts.len() {
3 => TokenType::Jwt,
5 => TokenType::Jwe,
_ => TokenType::Unknown,
}
}
pub fn decode_jwe(token: &str) -> Result<DecodedJweToken> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 5 {
return Err(anyhow!(
"Invalid JWE token format: expected 5 parts, got {}",
parts.len()
));
}
let header_b64 = parts[0];
let header_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(header_b64)
.map_err(|_| anyhow!("Invalid JWE header encoding"))?;
let header_str = String::from_utf8(header_bytes)?;
let header: HashMap<String, Value> = serde_json::from_str(&header_str)?;
let alg_value = header
.get("alg")
.ok_or_else(|| anyhow!("Missing 'alg' in JWE header"))?;
let algorithm = alg_value
.as_str()
.ok_or_else(|| anyhow!("'alg' is not a string"))?
.to_string();
let enc_value = header
.get("enc")
.ok_or_else(|| anyhow!("Missing 'enc' in JWE header"))?;
let encryption = enc_value
.as_str()
.ok_or_else(|| anyhow!("'enc' is not a string"))?
.to_string();
Ok(DecodedJweToken {
header,
encrypted_key: parts[1].to_string(),
iv: parts[2].to_string(),
ciphertext: parts[3].to_string(),
tag: parts[4].to_string(),
algorithm,
encryption,
})
}
pub fn decode(token: &str) -> Result<DecodedToken> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() < 2 {
return Err(anyhow!(
"Invalid JWT token format: expected at least 2 parts (header.payload), found {} part(s)",
parts.len()
));
}
let header_b64 = parts[0];
let header_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(header_b64)
.map_err(|_| anyhow!("Invalid header encoding"))?;
let header_str = String::from_utf8(header_bytes)?;
let header: HashMap<String, Value> = serde_json::from_str(&header_str)?;
let alg_value = header
.get("alg")
.ok_or_else(|| anyhow!("Missing 'alg' in header"))?;
let alg_str = alg_value
.as_str()
.ok_or_else(|| anyhow!("'alg' is not a string"))?;
let algorithm = match alg_str.to_uppercase().as_str() {
"HS256" => Algorithm::HS256,
"HS384" => Algorithm::HS384,
"HS512" => Algorithm::HS512,
"RS256" => Algorithm::RS256,
"RS384" => Algorithm::RS384,
"RS512" => Algorithm::RS512,
"ES256" => Algorithm::ES256,
"ES384" => Algorithm::ES384,
"PS256" => Algorithm::PS256,
"PS384" => Algorithm::PS384,
"PS512" => Algorithm::PS512,
"EDDSA" => Algorithm::EdDSA,
"NONE" => Algorithm::HS256, _ => return Err(anyhow!("Unsupported algorithm: {}", alg_str)),
};
let payload_b64 = parts[1];
let payload_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(payload_b64)
.map_err(|_| anyhow!("Invalid payload encoding"))?;
let is_compressed = header
.get("zip")
.and_then(|v| v.as_str())
.map(|s| s.to_uppercase() == "DEF")
.unwrap_or(false);
let payload_str = if is_compressed {
let decompressed_bytes = compression::decompress_deflate(&payload_bytes)
.map_err(|e| anyhow!("Failed to decompress payload: {}", e))?;
String::from_utf8(decompressed_bytes)?
} else {
String::from_utf8(payload_bytes)?
};
let claims: Value = serde_json::from_str(&payload_str)?;
Ok(DecodedToken {
header,
claims,
algorithm,
})
}
pub enum VerifyKeyData<'a> {
Secret(&'a str),
#[allow(dead_code)]
PublicKeyPem(&'a str),
#[allow(dead_code)]
PublicKeyDer(&'a [u8]),
}
pub struct VerifyOptions<'a> {
pub key_data: VerifyKeyData<'a>,
pub validate_exp: bool,
pub validate_nbf: bool,
pub leeway: u64,
}
impl<'a> Default for VerifyOptions<'a> {
fn default() -> Self {
Self {
key_data: VerifyKeyData::Secret(""),
validate_exp: false,
validate_nbf: false,
leeway: 0,
}
}
}
pub fn verify(token: &str, secret: &str) -> Result<bool> {
let options = VerifyOptions {
key_data: VerifyKeyData::Secret(secret),
..Default::default()
};
verify_with_options(token, &options)
}
fn create_validation(algorithm: Algorithm, options: &VerifyOptions) -> Validation {
let mut validation = Validation::new(algorithm);
validation.validate_exp = options.validate_exp;
validation.validate_nbf = options.validate_nbf;
validation.leeway = options.leeway;
validation
}
fn handle_verification_result(
result: std::result::Result<jsonwebtoken::TokenData<Value>, jsonwebtoken::errors::Error>,
) -> Result<bool> {
match result {
Ok(_) => Ok(true),
Err(e) => {
let jwt_error = JwtError::from(e.kind().clone());
if matches!(
jwt_error,
JwtError::ExpiredSignature | JwtError::ImmatureSignature
) {
Err(anyhow::anyhow!(jwt_error))
} else {
Ok(false)
}
}
}
}
pub fn verify_with_options(token: &str, options: &VerifyOptions) -> Result<bool> {
let decoded_token = decode(token)?;
if let Some(alg) = decoded_token.header.get("alg") {
if let Some(alg_str) = alg.as_str() {
if alg_str.to_uppercase() == "NONE" {
return Ok(true);
}
}
}
let parts: Vec<&str> = token.split('.').collect();
if parts.len() < 3 {
return Err(anyhow!("Invalid token format for verification"));
}
let message = format!("{}.{}", parts[0], parts[1]);
let signature_b64 = parts[2];
if signature_b64.is_empty() {
return Ok(false);
}
let signature = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(signature_b64)
.map_err(|_| anyhow!("Invalid signature encoding"))?;
match &options.key_data {
VerifyKeyData::Secret(secret) => {
match decoded_token.algorithm {
Algorithm::HS256 => {
let calculated_sig =
hmac_sha256::HMAC::mac(message.as_bytes(), secret.as_bytes());
if signature != calculated_sig.as_slice() {
return Ok(false); }
if options.validate_exp || options.validate_nbf {
let validation = create_validation(Algorithm::HS256, options);
let decoding_key = DecodingKey::from_secret(secret.as_bytes());
match jsonwebtoken::decode::<Value>(token, &decoding_key, &validation) {
Ok(_) => Ok(true), Err(e) => Err(anyhow::anyhow!(JwtError::from(e.kind().clone()))), }
} else {
Ok(true) }
}
Algorithm::HS384 | Algorithm::HS512 => {
let decoding_key = DecodingKey::from_secret(secret.as_bytes());
let validation = create_validation(decoded_token.algorithm, options);
let result = jsonwebtoken::decode::<Value>(token, &decoding_key, &validation);
Ok(result.is_ok())
}
_ => Err(anyhow!(
"Secret key provided but token uses algorithm {:?}. Secret keys can only verify HMAC algorithms (HS256, HS384, HS512)",
decoded_token.algorithm
)),
}
}
VerifyKeyData::PublicKeyPem(pem) => {
verify_with_public_key_pem(token, pem, decoded_token.algorithm, options)
}
VerifyKeyData::PublicKeyDer(der) => {
verify_with_public_key_der(token, der, decoded_token.algorithm, options)
}
}
}
fn verify_with_public_key_pem(
token: &str,
pem: &str,
algorithm: Algorithm,
options: &VerifyOptions,
) -> Result<bool> {
let decoding_key = match algorithm {
Algorithm::RS256
| Algorithm::RS384
| Algorithm::RS512
| Algorithm::PS256
| Algorithm::PS384
| Algorithm::PS512 => DecodingKey::from_rsa_pem(pem.as_bytes())?,
Algorithm::ES256 | Algorithm::ES384 => DecodingKey::from_ec_pem(pem.as_bytes())?,
Algorithm::EdDSA => DecodingKey::from_ed_pem(pem.as_bytes())?,
_ => {
return Err(anyhow!(
"Public key provided but algorithm is {:?}",
algorithm
))
}
};
let validation = create_validation(algorithm, options);
let result = jsonwebtoken::decode::<Value>(token, &decoding_key, &validation);
handle_verification_result(result)
}
fn verify_with_public_key_der(
token: &str,
der: &[u8],
algorithm: Algorithm,
options: &VerifyOptions,
) -> Result<bool> {
let decoding_key = match algorithm {
Algorithm::RS256
| Algorithm::RS384
| Algorithm::RS512
| Algorithm::PS256
| Algorithm::PS384
| Algorithm::PS512 => DecodingKey::from_rsa_der(der),
Algorithm::ES256 | Algorithm::ES384 => DecodingKey::from_ec_der(der),
Algorithm::EdDSA => DecodingKey::from_ed_der(der),
_ => {
return Err(anyhow!(
"Public key provided but algorithm is {:?}",
algorithm
))
}
};
let validation = create_validation(algorithm, options);
let result = jsonwebtoken::decode::<Value>(token, &decoding_key, &validation);
handle_verification_result(result)
}
pub fn encode_jwe_demo(payload: &str, _recipient_key: &str) -> Result<String> {
let header_json = serde_json::json!({
"alg": "dir",
"enc": "A256GCM"
});
let header_str = serde_json::to_string(&header_json)?;
let encoded_header =
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(header_str.as_bytes());
let encrypted_key = ""; let iv = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(b"dummy_iv_123456");
let ciphertext = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(payload.as_bytes());
let tag = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(b"dummy_tag");
Ok(format!(
"{}.{}.{}.{}.{}",
encoded_header, encrypted_key, iv, ciphertext, tag
))
}
#[cfg(test)]
mod tests {
use super::*;
use base64::Engine; use chrono::{Duration, Utc};
use serde_json::json;
use std::collections::HashMap;
use std::fs;
const RSA_PRIVATE_KEY_PEM_PATH: &str = "src/jwt/test_rsa_private.pem";
const RSA_PUBLIC_KEY_PEM_PATH: &str = "src/jwt/test_rsa_public.pem"; const EC_PRIVATE_KEY_PEM_PATH: &str = "src/jwt/test_ec_private.pem";
const ED25519_PRIVATE_KEY_PEM_PATH: &str = "src/jwt/test_ed25519_private.pem";
#[test]
fn test_encode_hs256() {
let claims = json!({"user": "test"});
let options = EncodeOptions {
algorithm: "HS256",
key_data: KeyData::Secret("test_secret"),
header_params: None,
compress_payload: false,
};
let result = encode_with_options(&claims, &options);
assert!(result.is_ok());
let token_str = result.unwrap();
let decoded_result = decode(&token_str);
assert!(decoded_result.is_ok());
let decoded_token = decoded_result.unwrap();
assert_eq!(
decoded_token.header.get("alg").unwrap().as_str().unwrap(),
"HS256"
);
assert_eq!(decoded_token.claims, claims);
}
#[test]
fn test_encode_rs256() {
let rsa_private_key = fs::read_to_string(RSA_PRIVATE_KEY_PEM_PATH)
.expect("Should have been able to read the RSA private key file");
let claims = json!({"user": "test_rs256"});
let options = EncodeOptions {
algorithm: "RS256",
key_data: KeyData::PrivateKeyPem(&rsa_private_key),
header_params: None,
compress_payload: false,
};
let result = encode_with_options(&claims, &options);
assert!(result.is_err());
}
#[test]
fn test_encode_es256() {
let ec_private_key = fs::read_to_string(EC_PRIVATE_KEY_PEM_PATH)
.expect("Should have been able to read the EC private key file");
let claims = json!({"user": "test_es256"});
let options = EncodeOptions {
algorithm: "ES256",
key_data: KeyData::PrivateKeyPem(&ec_private_key),
header_params: None,
compress_payload: false,
};
let result = encode_with_options(&claims, &options);
assert!(result.is_err());
}
#[test]
fn test_encode_eddsa() {
let ed25519_private_key = fs::read_to_string(ED25519_PRIVATE_KEY_PEM_PATH)
.expect("Should have been able to read the Ed25519 private key file");
let claims = json!({"user": "test_eddsa"});
let options = EncodeOptions {
algorithm: "EdDSA",
key_data: KeyData::PrivateKeyPem(&ed25519_private_key),
header_params: None,
compress_payload: false,
};
let result = encode_with_options(&claims, &options);
assert!(result.is_err());
}
#[test]
fn test_encode_none_algorithm() {
let claims = json!({"user": "test_none"});
let options = EncodeOptions {
algorithm: "none",
key_data: KeyData::None, header_params: None,
compress_payload: false,
};
let result = encode_with_options(&claims, &options);
assert!(result.is_ok());
let token_str = result.unwrap();
let parts: Vec<&str> = token_str.split('.').collect();
assert_eq!(parts.len(), 3, "Token should have three parts");
assert_eq!(
parts[2], "''",
"Signature part should be empty for 'none' algorithm"
);
let header_b64 = parts[0];
let header_bytes_result =
base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(header_b64);
assert!(
header_bytes_result.is_ok(),
"Header should be valid Base64Url"
);
let header_bytes = header_bytes_result.unwrap();
let header_str_result = String::from_utf8(header_bytes);
assert!(header_str_result.is_ok(), "Header should be valid UTF-8");
let header_str = header_str_result.unwrap();
let header_json_result: Result<Value, _> = serde_json::from_str(&header_str);
assert!(header_json_result.is_ok(), "Header should be valid JSON");
let header_json = header_json_result.unwrap();
assert_eq!(header_json.get("alg").unwrap().as_str().unwrap(), "none");
}
#[test]
fn test_encode_with_header_params() {
let claims = json!({"user": "test_header_params"});
let mut header_params = HashMap::new();
header_params.insert("kid", "test_key_id");
header_params.insert("custom_param", "custom_value");
let options = EncodeOptions {
algorithm: "HS256",
key_data: KeyData::Secret("test_secret_for_header_params"),
header_params: Some(header_params),
compress_payload: false,
};
let result = encode_with_options(&claims, &options);
assert!(result.is_ok());
let token_str = result.unwrap();
let decoded_result = decode(&token_str);
assert!(decoded_result.is_ok());
let decoded_token = decoded_result.unwrap();
assert_eq!(
decoded_token.header.get("alg").unwrap().as_str().unwrap(),
"HS256"
);
let mut header_params_for_cty = HashMap::new();
header_params_for_cty.insert("cty", "test_content_type");
let options_cty = EncodeOptions {
algorithm: "HS256",
key_data: KeyData::Secret("test_secret_for_cty"),
header_params: Some(header_params_for_cty),
compress_payload: false,
};
let result_cty = encode_with_options(&claims, &options_cty);
assert!(result_cty.is_ok(), "Encoding with cty should succeed");
let token_cty_str = result_cty.unwrap();
let decoded_cty_result = decode(&token_cty_str);
assert!(
decoded_cty_result.is_ok(),
"Decoding cty token should succeed"
);
let decoded_cty_token = decoded_cty_result.unwrap();
assert_eq!(
decoded_cty_token
.header
.get("cty")
.unwrap()
.as_str()
.unwrap(),
"test_content_type"
);
let mut header_params_for_none = HashMap::new();
header_params_for_none.insert("kid", "test_key_id_for_none");
header_params_for_none.insert("custom_field", "custom_value_for_none");
let options_none_custom = EncodeOptions {
algorithm: "none",
key_data: KeyData::None,
header_params: Some(header_params_for_none),
compress_payload: false,
};
let result_none_custom = encode_with_options(&claims, &options_none_custom);
assert!(
result_none_custom.is_ok(),
"Encoding with none and custom params should succeed"
);
let token_none_custom_str = result_none_custom.unwrap();
let parts_none_custom: Vec<&str> = token_none_custom_str.split('.').collect();
assert_eq!(parts_none_custom.len(), 3);
let header_none_custom_b64 = parts_none_custom[0];
let header_none_custom_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(header_none_custom_b64)
.unwrap();
let header_none_custom_str = String::from_utf8(header_none_custom_bytes).unwrap();
let header_none_custom_json: Value = serde_json::from_str(&header_none_custom_str).unwrap();
assert_eq!(
header_none_custom_json
.get("alg")
.unwrap()
.as_str()
.unwrap(),
"none"
);
assert_eq!(
header_none_custom_json
.get("kid")
.unwrap()
.as_str()
.unwrap(),
"test_key_id_for_none"
);
assert_eq!(
header_none_custom_json
.get("custom_field")
.unwrap()
.as_str()
.unwrap(),
"custom_value_for_none"
);
}
#[test]
fn test_decode_valid_hs256_token() {
let claims = json!({"user": "test_decode_valid"});
let options = EncodeOptions {
algorithm: "HS256",
key_data: KeyData::Secret("test_secret_for_decode"),
header_params: None,
compress_payload: false,
};
let encode_result = encode_with_options(&claims, &options);
assert!(
encode_result.is_ok(),
"Token encoding failed for decode test"
);
let token_str = encode_result.unwrap();
let decode_result = decode(&token_str);
assert!(
decode_result.is_ok(),
"Decoding valid token failed. Error: {:?}",
decode_result.err()
);
let decoded_token = decode_result.unwrap();
assert_eq!(
decoded_token.header.get("alg").unwrap().as_str().unwrap(),
"HS256"
);
assert_eq!(decoded_token.claims, claims);
assert_eq!(decoded_token.algorithm, Algorithm::HS256);
}
#[test]
fn test_decode_token_invalid_header_base64() {
let token_str = "!!!!.eyJ1c2VyIjoidGVzdCJ9."; let decode_result = decode(token_str);
assert!(decode_result.is_err());
let err = decode_result.err().unwrap();
assert!(
err.to_string().contains("Invalid header encoding"),
"Unexpected error message: {err}"
);
}
#[test]
fn test_decode_token_invalid_payload_base64() {
let header = json!({"alg": "HS256", "typ": "JWT"});
let encoded_header =
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(header.to_string().as_bytes());
let token_str = format!("{encoded_header}.!!!!.");
let decode_result = decode(&token_str);
assert!(decode_result.is_err());
let err = decode_result.err().unwrap();
assert!(
err.to_string().contains("Invalid payload encoding"),
"Unexpected error message: {err}"
);
}
#[test]
fn test_decode_token_missing_alg_in_header() {
let header_no_alg = json!({"typ": "JWT"});
let encoded_header_no_alg = base64::engine::general_purpose::URL_SAFE_NO_PAD
.encode(header_no_alg.to_string().as_bytes());
let payload = json!({"user": "test"});
let encoded_payload =
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(payload.to_string().as_bytes());
let token_str = format!("{encoded_header_no_alg}.{encoded_payload}.");
let decode_result = decode(&token_str);
assert!(decode_result.is_err());
let err = decode_result.err().unwrap();
assert!(
err.to_string().contains("Missing 'alg' in header"),
"Unexpected error message: {err}"
);
}
#[test]
fn test_decode_token_alg_not_a_string() {
let header_alg_not_string = json!({"alg": 123, "typ": "JWT"});
let encoded_header_alg_not_string = base64::engine::general_purpose::URL_SAFE_NO_PAD
.encode(header_alg_not_string.to_string().as_bytes());
let payload = json!({"user": "test"});
let encoded_payload =
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(payload.to_string().as_bytes());
let token_str = format!("{encoded_header_alg_not_string}.{encoded_payload}.");
let decode_result = decode(&token_str);
assert!(decode_result.is_err());
let err = decode_result.err().unwrap();
assert!(
err.to_string().contains("'alg' is not a string"),
"Unexpected error message: {err}"
);
}
#[test]
fn test_decode_invalid_token_format_not_enough_parts() {
let token_str = "invalidtoken";
let decode_result = decode(token_str);
assert!(decode_result.is_err());
let err = decode_result.err().unwrap();
assert!(
err.to_string().contains("Invalid JWT token format")
&& err.to_string().contains("expected at least 2 parts"),
"Unexpected error message: {err}"
);
let token_str_one_dot = "only.onepart";
let decode_result_one_dot = decode(token_str_one_dot);
assert!(decode_result_one_dot.is_err(), "Expected error for 'only.onepart' due to invalid header content (not base64url or not utf8 after decode)");
if let Some(err) = decode_result_one_dot.err() {
assert!(
err.to_string().contains("Invalid header encoding")
|| err.to_string().contains("invalid utf-8"),
"Unexpected error message for 'only.onepart': {err}"
);
}
}
#[test]
fn test_verify_hs256_token_correct_secret() {
let claims = json!({"user": "test_verify_correct"});
let options_encode = EncodeOptions {
algorithm: "HS256",
key_data: KeyData::Secret("correct_secret"),
header_params: None,
compress_payload: false,
};
let token_str = encode_with_options(&claims, &options_encode)
.expect("Token encoding failed for verify test");
let options_verify = VerifyOptions {
key_data: VerifyKeyData::Secret("correct_secret"),
..Default::default()
};
let result = verify_with_options(&token_str, &options_verify);
assert!(
result.is_ok(),
"Verification failed for correct secret: {:?}",
result.err()
);
assert!(
result.unwrap(),
"Verification returned false for correct secret"
);
}
#[test]
fn test_verify_hs256_token_incorrect_secret() {
let claims = json!({"user": "test_verify_incorrect"});
let options_encode = EncodeOptions {
algorithm: "HS256",
key_data: KeyData::Secret("correct_secret"),
header_params: None,
compress_payload: false,
};
let token_str = encode_with_options(&claims, &options_encode)
.expect("Token encoding failed for verify incorrect secret test");
let options_verify = VerifyOptions {
key_data: VerifyKeyData::Secret("incorrect_secret"),
..Default::default()
};
let result = verify_with_options(&token_str, &options_verify);
assert!(result.is_ok(), "Verification with incorrect secret should not error initially unless key format is wrong, but expect Ok(false). Error: {:?}", result.err());
assert!(
!result.unwrap(),
"Verification returned true for incorrect secret"
);
}
#[test]
fn test_verify_rs256_token_correct_key() {
let header = json!({"alg": "RS256", "typ": "JWT"});
let claims = json!({"user": "test_rs256_verify"});
let encoded_header =
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(header.to_string());
let encoded_claims =
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(claims.to_string());
let fake_signature =
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode("fake_signature");
let token_str = format!("{encoded_header}.{encoded_claims}.{fake_signature}");
let public_key_pem_string = fs::read_to_string(RSA_PUBLIC_KEY_PEM_PATH)
.unwrap_or_else(|_| String::from("-----BEGIN PUBLIC KEY-----\nTHIS IS A SHORT PLACEHOLDER PUBLIC KEY.\nWILL BE REPLACED LATER IF NEEDED.\n-----END PUBLIC KEY-----"));
let options_verify = VerifyOptions {
key_data: VerifyKeyData::PublicKeyPem(&public_key_pem_string),
validate_exp: false,
validate_nbf: false,
leeway: 0,
};
let result = verify_with_options(&token_str, &options_verify);
assert!(
result.is_err(),
"Verification should fail with placeholder RSA public key. Result was: {result:?}"
);
}
#[test]
fn test_verify_none_algorithm_token() {
let claims = json!({"user": "test_none_verify"});
let options_encode = EncodeOptions {
algorithm: "none",
key_data: KeyData::None,
header_params: None,
compress_payload: false,
};
let token_str = encode_with_options(&claims, &options_encode)
.expect("Encoding 'none' algorithm token failed");
let options_verify = VerifyOptions {
key_data: VerifyKeyData::Secret("any_secret_is_ignored_for_none"),
..Default::default()
};
let result = verify_with_options(&token_str, &options_verify);
assert!(
result.is_ok(),
"Verification of 'none' token erred: {:?}",
result.err()
);
assert!(
result.unwrap(),
"Verification of 'none' token returned false"
);
}
#[test]
fn test_verify_token_with_exp_validation_valid() {
let current_time = Utc::now();
let claims = json!({
"user": "test_exp_valid",
"exp": (current_time + Duration::seconds(3600)).timestamp()
});
let options_encode = EncodeOptions {
algorithm: "HS256",
key_data: KeyData::Secret("secret_exp_valid"),
header_params: None,
compress_payload: false,
};
let token_str = encode_with_options(&claims, &options_encode)
.expect("Token encoding for exp valid test failed");
let options_verify = VerifyOptions {
key_data: VerifyKeyData::Secret("secret_exp_valid"),
validate_exp: true,
..Default::default()
};
let result = verify_with_options(&token_str, &options_verify);
assert!(
result.is_ok(),
"Verification of valid exp token erred: {:?}",
result.err()
);
assert!(
result.unwrap(),
"Verification of valid exp token returned false"
);
}
#[test]
fn test_verify_token_with_exp_validation_expired() {
let current_time = Utc::now();
let claims = json!({
"user": "test_exp_expired",
"exp": (current_time - Duration::seconds(3600)).timestamp()
});
let options_encode = EncodeOptions {
algorithm: "HS256",
key_data: KeyData::Secret("secret_exp_expired"),
header_params: None,
compress_payload: false,
};
let token_str = encode_with_options(&claims, &options_encode)
.expect("Token encoding for exp expired test failed");
let options_verify = VerifyOptions {
key_data: VerifyKeyData::Secret("secret_exp_expired"),
validate_exp: true,
..Default::default()
};
let result = verify_with_options(&token_str, &options_verify);
assert!(
result.is_err(),
"Verification of expired token should return an error. Result: {result:?}"
);
}
#[test]
fn test_verify_token_with_nbf_validation_valid() {
let current_time = Utc::now();
let claims = json!({
"user": "test_nbf_valid",
"nbf": (current_time - Duration::seconds(3600)).timestamp(),
"exp": (current_time + Duration::seconds(3600)).timestamp() });
let options_encode = EncodeOptions {
algorithm: "HS256",
key_data: KeyData::Secret("secret_nbf_valid"),
header_params: None,
compress_payload: false,
};
let token_str = encode_with_options(&claims, &options_encode)
.expect("Token encoding for nbf valid test failed");
let options_verify = VerifyOptions {
key_data: VerifyKeyData::Secret("secret_nbf_valid"),
validate_nbf: true,
validate_exp: false, ..Default::default()
};
let result = verify_with_options(&token_str, &options_verify);
assert!(
result.is_ok(),
"Verification of valid nbf token erred: {:?}",
result.err()
);
assert!(
result.unwrap(),
"Verification of valid nbf token returned false"
);
}
#[test]
fn test_verify_token_with_nbf_validation_not_yet_valid() {
let current_time = Utc::now();
let claims = json!({
"user": "test_nbf_not_yet_valid",
"nbf": (current_time + Duration::seconds(3600)).timestamp(),
"exp": (current_time + Duration::seconds(7200)).timestamp() });
let options_encode = EncodeOptions {
algorithm: "HS256",
key_data: KeyData::Secret("secret_nbf_not_yet_valid"),
header_params: None,
compress_payload: false,
};
let token_str = encode_with_options(&claims, &options_encode)
.expect("Token encoding for nbf not yet valid test failed");
let options_verify = VerifyOptions {
key_data: VerifyKeyData::Secret("secret_nbf_not_yet_valid"),
validate_nbf: true,
validate_exp: false, ..Default::default()
};
let result = verify_with_options(&token_str, &options_verify);
assert!(
result.is_err(),
"Verification of not-yet-valid nbf token should return an error. Result: {result:?}"
);
}
#[test]
fn test_verify_es256_token_pathway() {
let header = json!({"alg": "ES256", "typ": "JWT"});
let claims = json!({"user": "test_es256_verify"});
let encoded_header =
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(header.to_string());
let encoded_claims =
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(claims.to_string());
let fake_signature =
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode("fake_es256_signature");
let token_str = format!("{encoded_header}.{encoded_claims}.{fake_signature}");
let public_key_pem_string = fs::read_to_string("src/jwt/test_ec_public.pem")
.expect("Should have created/found test_ec_public.pem");
let options_verify = VerifyOptions {
key_data: VerifyKeyData::PublicKeyPem(&public_key_pem_string),
validate_exp: false,
validate_nbf: false,
leeway: 0,
};
let result = verify_with_options(&token_str, &options_verify);
assert!(
result.is_err(),
"Verification should return Err with a placeholder EC public key. Result was: {result:?}"
);
}
#[test]
fn test_verify_es256_token_invalid_key_format() {
let token_str = "eyJhbGciOiJFUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoidGVzdCJ9.c2lnbmF0dXJl"; let invalid_pem = "this is not a valid pem";
let options_verify = VerifyOptions {
key_data: VerifyKeyData::PublicKeyPem(invalid_pem),
..Default::default()
};
let result = verify_with_options(token_str, &options_verify);
assert!(
result.is_err(),
"Verification should fail with an invalid EC key format"
);
}
#[test]
fn test_encode_with_compression() {
let claims = json!({"sub": "test", "name": "Test User", "description": "This is a test payload for compression testing"});
let options = EncodeOptions {
algorithm: "none",
key_data: KeyData::None,
header_params: None,
compress_payload: true,
};
let result = encode_with_options(&claims, &options);
assert!(result.is_ok(), "Encoding with compression should succeed");
let token = result.unwrap();
let decoded = decode(&token).expect("Decoding compressed token should succeed");
assert_eq!(decoded.header.get("zip").unwrap().as_str().unwrap(), "DEF");
assert_eq!(decoded.claims, claims);
}
#[test]
fn test_encode_with_compression_hs256() {
let claims =
json!({"sub": "test", "name": "Test User", "data": "Some test data for compression"});
let options = EncodeOptions {
algorithm: "HS256",
key_data: KeyData::Secret("test_secret"),
header_params: None,
compress_payload: true,
};
let result = encode_with_options(&claims, &options);
assert!(
result.is_ok(),
"Encoding HS256 with compression should succeed"
);
let token = result.unwrap();
let decoded = decode(&token).expect("Decoding compressed HS256 token should succeed");
assert_eq!(decoded.header.get("zip").unwrap().as_str().unwrap(), "DEF");
assert_eq!(
decoded.header.get("alg").unwrap().as_str().unwrap(),
"HS256"
);
assert_eq!(decoded.claims, claims);
}
#[test]
fn test_decode_compressed_token() {
let claims = json!({"user": "testuser", "role": "admin", "permissions": ["read", "write", "delete"]});
let options = EncodeOptions {
algorithm: "none",
key_data: KeyData::None,
header_params: None,
compress_payload: true,
};
let token =
encode_with_options(&claims, &options).expect("Failed to create compressed token");
let decoded = decode(&token).expect("Failed to decode compressed token");
assert_eq!(decoded.claims, claims);
assert_eq!(decoded.header.get("zip").unwrap().as_str().unwrap(), "DEF");
}
#[test]
fn test_compression_preserves_signature_verification() {
let claims = json!({"sub": "test", "exp": (chrono::Utc::now() + chrono::Duration::hours(1)).timestamp()});
let secret = "test_secret_for_compression";
let options = EncodeOptions {
algorithm: "HS256",
key_data: KeyData::Secret(secret),
header_params: None,
compress_payload: true,
};
let token =
encode_with_options(&claims, &options).expect("Failed to create compressed token");
let verify_options = VerifyOptions {
key_data: VerifyKeyData::Secret(secret),
validate_exp: false,
validate_nbf: false,
leeway: 0,
};
let verification_result = verify_with_options(&token, &verify_options);
assert!(
verification_result.is_ok(),
"Verification should succeed for compressed token"
);
assert!(
verification_result.unwrap(),
"Compressed token should be valid"
);
}
#[test]
fn test_detect_token_type_jwt() {
let jwt_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0In0.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ";
assert_eq!(detect_token_type(jwt_token), TokenType::Jwt);
}
#[test]
fn test_detect_token_type_jwe() {
let jwe_token = "eyJhbGciOiJkaXIiLCJlbmMiOiJBMjU2R0NNIn0..ZHVtbXlfaXZfMTIzNDU2.eyJzdWIiOiJ0ZXN0In0.ZHVtbXlfdGFn";
assert_eq!(detect_token_type(jwe_token), TokenType::Jwe);
}
#[test]
fn test_detect_token_type_unknown() {
let invalid_token = "invalid.token.format.with.too.many.parts.here";
assert_eq!(detect_token_type(invalid_token), TokenType::Unknown);
}
#[test]
fn test_decode_jwe_basic() {
let jwe_token = "eyJhbGciOiJkaXIiLCJlbmMiOiJBMjU2R0NNIn0..ZHVtbXlfaXZfMTIzNDU2.eyJzdWIiOiJ0ZXN0In0.ZHVtbXlfdGFn";
let result = decode_jwe(jwe_token);
assert!(result.is_ok(), "JWE decoding should succeed");
let decoded = result.unwrap();
assert_eq!(decoded.algorithm, "dir");
assert_eq!(decoded.encryption, "A256GCM");
assert!(decoded.encrypted_key.is_empty());
assert!(!decoded.iv.is_empty());
assert!(!decoded.ciphertext.is_empty());
assert!(!decoded.tag.is_empty());
}
#[test]
fn test_decode_jwe_invalid_format() {
let invalid_jwe = "invalid.jwt.token"; let result = decode_jwe(invalid_jwe);
assert!(
result.is_err(),
"JWE decoding should fail for invalid format"
);
}
#[test]
fn test_encode_jwe_demo() {
let payload = r#"{"sub":"test","name":"JWE User"}"#;
let result = encode_jwe_demo(payload, "test_key");
assert!(result.is_ok(), "JWE encoding demo should succeed");
let jwe_token = result.unwrap();
let parts: Vec<&str> = jwe_token.split('.').collect();
assert_eq!(parts.len(), 5, "JWE token should have 5 parts");
let decode_result = decode_jwe(&jwe_token);
assert!(
decode_result.is_ok(),
"Generated JWE token should be decodable"
);
}
#[test]
fn test_jwt_error_display() {
assert_eq!(JwtError::InvalidSignature.to_string(), "Invalid signature");
assert_eq!(JwtError::ExpiredSignature.to_string(), "Expired signature");
assert_eq!(
JwtError::ImmatureSignature.to_string(),
"Immature signature"
);
assert_eq!(JwtError::InvalidAlgorithm.to_string(), "Invalid algorithm");
assert_eq!(
JwtError::Other("test error".to_string()).to_string(),
"JWT error: test error"
);
}
}