use crate::{claims_error::ClaimsError, traits::KeyProvider};
use arc_swap::ArcSwap;
use async_trait::async_trait;
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
use jsonwebtoken::{DecodingKey, Header, Validation, decode, decode_header};
use serde::Deserialize;
use serde_json::Value;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tokio::time::Instant;
use tokio_util::sync::CancellationToken;
#[derive(Debug, Clone, Deserialize)]
struct Jwk {
kid: String,
kty: String,
#[serde(rename = "use")]
#[allow(dead_code)]
use_: Option<String>,
n: String,
e: String,
#[allow(dead_code)]
alg: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
struct JwksResponse {
keys: Vec<Jwk>,
}
type HeaderExtrasHandler = dyn Fn(&str, &Value) -> Option<String> + Send + Sync;
const STANDARD_HEADER_FIELDS: &[&str] = &[
"typ", "alg", "cty", "jku", "jwk", "kid", "x5u", "x5c", "x5t", "x5t#S256", "crit", "enc",
"zip", "url", "nonce", "epk", "apu", "apv", "iv", "tag", "p2s", "p2c", "b64",
];
#[must_use]
pub struct JwksKeyProvider {
jwks_uri: String,
keys: Arc<ArcSwap<HashMap<String, DecodingKey>>>,
refresh_state: Arc<RwLock<RefreshState>>,
client: modkit_http::HttpClient,
refresh_interval: Duration,
max_backoff: Duration,
on_demand_refresh_cooldown: Duration,
header_extras_handler: Option<Arc<HeaderExtrasHandler>>,
}
#[derive(Debug, Default)]
struct RefreshState {
last_refresh: Option<Instant>,
last_on_demand_refresh: Option<Instant>,
consecutive_failures: u32,
last_error: Option<String>,
failed_kids: HashSet<String>,
}
impl JwksKeyProvider {
pub fn new(jwks_uri: impl Into<String>) -> Result<Self, modkit_http::HttpError> {
Self::with_http_timeout(jwks_uri, Duration::from_secs(10))
}
pub fn with_http_timeout(
jwks_uri: impl Into<String>,
timeout: Duration,
) -> Result<Self, modkit_http::HttpError> {
let client = modkit_http::HttpClient::builder()
.timeout(timeout)
.retry(None) .build()?;
Ok(Self {
jwks_uri: jwks_uri.into(),
keys: Arc::new(ArcSwap::from_pointee(HashMap::new())),
refresh_state: Arc::new(RwLock::new(RefreshState::default())),
client,
refresh_interval: Duration::from_secs(300), max_backoff: Duration::from_secs(3600), on_demand_refresh_cooldown: Duration::from_secs(60), header_extras_handler: None,
})
}
pub fn try_new(jwks_uri: impl Into<String>) -> Result<Self, modkit_http::HttpError> {
Self::new(jwks_uri)
}
pub fn with_refresh_interval(mut self, interval: Duration) -> Self {
self.refresh_interval = interval;
self
}
pub fn with_max_backoff(mut self, max_backoff: Duration) -> Self {
self.max_backoff = max_backoff;
self
}
pub fn with_on_demand_refresh_cooldown(mut self, cooldown: Duration) -> Self {
self.on_demand_refresh_cooldown = cooldown;
self
}
pub fn with_header_extras_stringified(self) -> Self {
self.with_header_extras_handler(|_, v| Some(v.to_string()))
}
pub fn with_header_extras_handler(
mut self,
handler: impl Fn(&str, &Value) -> Option<String> + Send + Sync + 'static,
) -> Self {
self.header_extras_handler = Some(Arc::new(handler));
self
}
async fn fetch_jwks(&self) -> Result<HashMap<String, DecodingKey>, ClaimsError> {
let jwks: JwksResponse = self
.client
.get(&self.jwks_uri)
.send()
.await
.map_err(|e| map_http_error(&e))?
.json()
.await
.map_err(|e| map_http_error(&e))?;
let mut keys = HashMap::new();
for jwk in jwks.keys {
if jwk.kty == "RSA" {
let key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)
.map_err(|e| ClaimsError::JwksFetchFailed(format!("Invalid RSA key: {e}")))?;
keys.insert(jwk.kid, key);
}
}
if keys.is_empty() {
return Err(ClaimsError::JwksFetchFailed(
"No valid RSA keys found in JWKS".into(),
));
}
Ok(keys)
}
fn calculate_backoff(&self, failures: u32) -> Duration {
let base = Duration::from_secs(60); let exponential = base * 2u32.pow(failures.min(10)); exponential.min(self.max_backoff)
}
async fn should_refresh(&self) -> bool {
let state = self.refresh_state.read().await;
match state.last_refresh {
None => true, Some(last) => {
let elapsed = last.elapsed();
if state.consecutive_failures == 0 {
elapsed >= self.refresh_interval
} else {
elapsed >= self.calculate_backoff(state.consecutive_failures)
}
}
}
}
async fn perform_refresh(&self) -> Result<(), ClaimsError> {
match self.fetch_jwks().await {
Ok(new_keys) => {
self.keys.store(Arc::new(new_keys));
let mut state = self.refresh_state.write().await;
state.last_refresh = Some(Instant::now());
state.consecutive_failures = 0;
state.last_error = None;
Ok(())
}
Err(e) => {
let mut state = self.refresh_state.write().await;
state.last_refresh = Some(Instant::now());
state.consecutive_failures += 1;
state.last_error = Some(e.to_string());
Err(e)
}
}
}
fn key_exists(&self, kid: &str) -> bool {
let keys = self.keys.load();
keys.contains_key(kid)
}
async fn check_refresh_throttle(&self, kid: &str) -> Result<(), ClaimsError> {
let state = self.refresh_state.read().await;
if let Some(last_on_demand) = state.last_on_demand_refresh {
let elapsed = last_on_demand.elapsed();
if elapsed < self.on_demand_refresh_cooldown {
let remaining = self.on_demand_refresh_cooldown.saturating_sub(elapsed);
tracing::debug!(
kid = kid,
remaining_secs = remaining.as_secs(),
"On-demand JWKS refresh throttled (cooldown active)"
);
if state.failed_kids.contains(kid) {
tracing::warn!(
kid = kid,
"Unknown kid repeatedly requested despite recent refresh attempts"
);
}
return Err(ClaimsError::UnknownKeyId(kid.to_owned()));
}
}
Ok(())
}
async fn handle_refresh_success(&self, kid: &str) -> Result<(), ClaimsError> {
let mut state = self.refresh_state.write().await;
state.last_on_demand_refresh = Some(Instant::now());
if self.key_exists(kid) {
state.failed_kids.remove(kid);
} else {
state.failed_kids.insert(kid.to_owned());
tracing::warn!(
kid = kid,
"Kid still not found after on-demand JWKS refresh"
);
}
Ok(())
}
async fn handle_refresh_failure(&self, kid: &str, error: ClaimsError) -> ClaimsError {
let mut state = self.refresh_state.write().await;
state.last_on_demand_refresh = Some(Instant::now());
state.failed_kids.insert(kid.to_owned());
error
}
async fn on_demand_refresh(&self, kid: &str) -> Result<(), ClaimsError> {
if self.key_exists(kid) {
return Ok(());
}
self.check_refresh_throttle(kid).await?;
tracing::info!(
kid = kid,
"Performing on-demand JWKS refresh for unknown kid"
);
match self.perform_refresh().await {
Ok(()) => self.handle_refresh_success(kid).await,
Err(e) => Err(self.handle_refresh_failure(kid, e).await),
}
}
fn get_key(&self, kid: &str) -> Option<DecodingKey> {
let keys = self.keys.load();
keys.get(kid).cloned()
}
fn validate_token(
token: &str,
key: &DecodingKey,
header: &Header,
) -> Result<Value, ClaimsError> {
let mut validation = Validation::new(header.alg);
validation.validate_exp = false;
validation.validate_nbf = false;
validation.validate_aud = false;
let empty_claims: &[&str] = &[];
validation.set_required_spec_claims(empty_claims);
let token_data = decode::<Value>(token, key, &validation)
.map_err(|e| ClaimsError::DecodeFailed(format!("JWT validation failed: {e}")))?;
Ok(token_data.claims)
}
}
#[async_trait]
impl KeyProvider for JwksKeyProvider {
fn name(&self) -> &'static str {
"jwks"
}
async fn validate_and_decode(&self, token: &str) -> Result<(Header, Value), ClaimsError> {
let token = token.trim_start_matches("Bearer ").trim();
let header = match &self.header_extras_handler {
Some(handler) => decode_header_with_handler(token, handler.as_ref()),
None => decode_header(token),
}
.map_err(|e| ClaimsError::DecodeFailed(format!("Invalid JWT header: {e}")))?;
let kid = header
.kid
.as_ref()
.ok_or_else(|| ClaimsError::DecodeFailed("Missing kid in JWT header".into()))?;
let key = if let Some(k) = self.get_key(kid) {
k
} else {
self.on_demand_refresh(kid).await?;
self.get_key(kid)
.ok_or_else(|| ClaimsError::UnknownKeyId(kid.clone()))?
};
let claims = Self::validate_token(token, &key, &header)?;
Ok((header, claims))
}
async fn refresh_keys(&self) -> Result<(), ClaimsError> {
if self.should_refresh().await {
self.perform_refresh().await
} else {
Ok(())
}
}
}
pub async fn run_jwks_refresh_task(
provider: Arc<JwksKeyProvider>,
cancellation_token: CancellationToken,
) {
let mut interval = tokio::time::interval(Duration::from_secs(60));
loop {
tokio::select! {
() = cancellation_token.cancelled() => {
tracing::info!("JWKS refresh task shutting down");
break;
}
_ = interval.tick() => {
if let Err(e) = provider.refresh_keys().await {
tracing::warn!("JWKS refresh failed: {}", e);
}
}
}
}
}
fn decode_header_with_handler(
token: &str,
handler: &dyn Fn(&str, &Value) -> Option<String>,
) -> Result<Header, jsonwebtoken::errors::Error> {
let header_b64 = token
.split('.')
.next()
.ok_or(jsonwebtoken::errors::ErrorKind::InvalidToken)?;
let header_bytes = URL_SAFE_NO_PAD
.decode(header_b64.trim_end_matches('='))
.map_err(jsonwebtoken::errors::ErrorKind::Base64)?;
let mut json: serde_json::Map<String, Value> = serde_json::from_slice(&header_bytes)?;
json.retain(|key, value| {
if STANDARD_HEADER_FIELDS.contains(&key.as_str()) || value.is_string() {
return true;
}
match handler(key, value) {
Some(s) => {
*value = Value::String(s);
true
}
None => false,
}
});
Ok(serde_json::from_value(Value::Object(json))?)
}
fn map_http_error(e: &modkit_http::HttpError) -> ClaimsError {
ClaimsError::JwksFetchFailed(crate::http_error::format_http_error(e, "JWKS"))
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use super::*;
use httpmock::prelude::*;
fn test_provider_with_http(uri: &str) -> JwksKeyProvider {
let client = modkit_http::HttpClient::builder()
.timeout(Duration::from_secs(5))
.retry(None)
.build()
.expect("failed to create test HTTP client");
JwksKeyProvider {
jwks_uri: uri.to_owned(),
keys: Arc::new(ArcSwap::from_pointee(HashMap::new())),
refresh_state: Arc::new(RwLock::new(RefreshState::default())),
client,
refresh_interval: Duration::from_secs(300),
max_backoff: Duration::from_secs(3600),
on_demand_refresh_cooldown: Duration::from_secs(60),
header_extras_handler: None,
}
}
fn test_provider(uri: &str) -> JwksKeyProvider {
JwksKeyProvider::new(uri).expect("failed to create test provider")
}
fn valid_jwks_json() -> &'static str {
r#"{
"keys": [{
"kty": "RSA",
"kid": "test-key-1",
"use": "sig",
"n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw",
"e": "AQAB",
"alg": "RS256"
}]
}"#
}
#[tokio::test]
async fn test_calculate_backoff() {
let provider = test_provider("https://example.com/jwks");
assert_eq!(provider.calculate_backoff(0), Duration::from_secs(60));
assert_eq!(provider.calculate_backoff(1), Duration::from_secs(120));
assert_eq!(provider.calculate_backoff(2), Duration::from_secs(240));
assert_eq!(provider.calculate_backoff(3), Duration::from_secs(480));
assert_eq!(provider.calculate_backoff(100), provider.max_backoff);
}
#[tokio::test]
async fn test_should_refresh_on_first_call() {
let provider = test_provider("https://example.com/jwks");
assert!(provider.should_refresh().await);
}
#[tokio::test]
async fn test_key_storage() {
let provider = test_provider("https://example.com/jwks");
assert!(provider.get_key("test-kid").is_none());
let mut keys = HashMap::new();
keys.insert("test-kid".to_owned(), DecodingKey::from_secret(b"secret"));
provider.keys.store(Arc::new(keys));
assert!(provider.get_key("test-kid").is_some());
}
#[tokio::test]
async fn test_on_demand_refresh_returns_ok_when_key_exists() {
let provider = test_provider("https://example.com/jwks");
let mut keys = HashMap::new();
keys.insert(
"existing-kid".to_owned(),
DecodingKey::from_secret(b"secret"),
);
provider.keys.store(Arc::new(keys));
let result = provider.on_demand_refresh("existing-kid").await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_try_new_returns_result() {
let result = JwksKeyProvider::try_new("https://example.com/jwks");
assert!(result.is_ok());
}
#[tokio::test]
async fn test_fetch_jwks_success_with_valid_json() {
let server = MockServer::start();
let mock = server.mock(|when, then| {
when.method(GET).path("/jwks");
then.status(200)
.header("content-type", "application/json")
.body(valid_jwks_json());
});
let jwks_url = server.url("/jwks");
let provider = test_provider_with_http(&jwks_url);
let result = provider.perform_refresh().await;
assert!(result.is_ok(), "Expected success, got: {result:?}");
assert!(
provider.get_key("test-key-1").is_some(),
"Expected key 'test-key-1' to be stored"
);
mock.assert();
}
#[tokio::test]
async fn test_fetch_jwks_http_404_error_mapping() {
let server = MockServer::start();
let mock = server.mock(|when, then| {
when.method(GET).path("/jwks");
then.status(404).body("Not Found");
});
let jwks_url = server.url("/jwks");
let provider = test_provider_with_http(&jwks_url);
let result = provider.perform_refresh().await;
assert!(result.is_err());
let err = result.unwrap_err();
let err_msg = err.to_string();
assert!(
err_msg.contains("JWKS HTTP 404"),
"Expected error to contain 'JWKS HTTP 404', got: {err_msg}"
);
assert!(
!err_msg.to_lowercase().contains("parse"),
"HTTP status error should not mention 'parse', got: {err_msg}"
);
mock.assert();
}
#[tokio::test]
async fn test_fetch_jwks_http_500_error_mapping() {
let server = MockServer::start();
let mock = server.mock(|when, then| {
when.method(GET).path("/jwks");
then.status(500).body("Internal Server Error");
});
let jwks_url = server.url("/jwks");
let provider = test_provider_with_http(&jwks_url);
let result = provider.perform_refresh().await;
assert!(result.is_err());
let err = result.unwrap_err();
let err_msg = err.to_string();
assert!(
err_msg.contains("JWKS HTTP 500"),
"Expected error to contain 'JWKS HTTP 500', got: {err_msg}"
);
mock.assert();
}
#[tokio::test]
async fn test_fetch_jwks_invalid_json_error_mapping() {
let server = MockServer::start();
let mock = server.mock(|when, then| {
when.method(GET).path("/jwks");
then.status(200)
.header("content-type", "application/json")
.body("this is not valid json");
});
let jwks_url = server.url("/jwks");
let provider = test_provider_with_http(&jwks_url);
let result = provider.perform_refresh().await;
assert!(result.is_err());
let err = result.unwrap_err();
let err_msg = err.to_string();
assert!(
err_msg.contains("JWKS JSON parse failed"),
"Expected error to contain 'JWKS JSON parse failed', got: {err_msg}"
);
mock.assert();
}
#[tokio::test]
async fn test_fetch_jwks_empty_keys_error() {
let server = MockServer::start();
let mock = server.mock(|when, then| {
when.method(GET).path("/jwks");
then.status(200)
.header("content-type", "application/json")
.body(r#"{"keys": []}"#);
});
let jwks_url = server.url("/jwks");
let provider = test_provider_with_http(&jwks_url);
let result = provider.perform_refresh().await;
assert!(result.is_err());
let err = result.unwrap_err();
let err_msg = err.to_string();
assert!(
err_msg.contains("No valid RSA keys"),
"Expected error about no RSA keys, got: {err_msg}"
);
mock.assert();
}
#[tokio::test]
async fn test_on_demand_refresh_respects_cooldown() {
let server = MockServer::start();
let mock = server.mock(|when, then| {
when.method(GET).path("/jwks");
then.status(404).body("Not Found");
});
let jwks_url = server.url("/jwks");
let provider = test_provider_with_http(&jwks_url)
.with_on_demand_refresh_cooldown(Duration::from_secs(60));
let result1 = provider.on_demand_refresh("test-kid").await;
assert!(result1.is_err());
let result2 = provider.on_demand_refresh("test-kid").await;
assert!(result2.is_err());
match result2.unwrap_err() {
ClaimsError::UnknownKeyId(_) => {}
other => panic!("Expected UnknownKeyId during cooldown, got: {other:?}"),
}
mock.assert_calls(1);
}
#[tokio::test]
async fn test_on_demand_refresh_tracks_failed_kids() {
let server = MockServer::start();
server.mock(|when, then| {
when.method(GET).path("/jwks");
then.status(404).body("Not Found");
});
let jwks_url = server.url("/jwks");
let provider = test_provider_with_http(&jwks_url)
.with_on_demand_refresh_cooldown(Duration::from_millis(100));
let result = provider.on_demand_refresh("failed-kid").await;
assert!(result.is_err());
let state = provider.refresh_state.read().await;
assert!(state.failed_kids.contains("failed-kid"));
}
#[tokio::test]
async fn test_perform_refresh_updates_state_on_failure() {
let server = MockServer::start();
server.mock(|when, then| {
when.method(GET).path("/jwks");
then.status(500).body("Server Error");
});
let jwks_url = server.url("/jwks");
let provider = test_provider_with_http(&jwks_url);
{
let mut state = provider.refresh_state.write().await;
state.consecutive_failures = 3;
state.last_error = Some("Previous error".to_owned());
}
_ = provider.perform_refresh().await;
let state = provider.refresh_state.read().await;
assert_eq!(state.consecutive_failures, 4);
assert!(state.last_error.is_some());
}
#[tokio::test]
async fn test_perform_refresh_resets_state_on_success() {
let server = MockServer::start();
server.mock(|when, then| {
when.method(GET).path("/jwks");
then.status(200)
.header("content-type", "application/json")
.body(valid_jwks_json());
});
let jwks_url = server.url("/jwks");
let provider = test_provider_with_http(&jwks_url);
{
let mut state = provider.refresh_state.write().await;
state.consecutive_failures = 5;
state.last_error = Some("Previous error".to_owned());
}
let result = provider.perform_refresh().await;
assert!(result.is_ok());
let state = provider.refresh_state.read().await;
assert_eq!(state.consecutive_failures, 0);
assert!(state.last_error.is_none());
}
#[tokio::test]
async fn test_validate_and_decode_with_missing_kid() {
let server = MockServer::start();
server.mock(|when, then| {
when.method(GET).path("/jwks");
then.status(200)
.header("content-type", "application/json")
.body(valid_jwks_json());
});
let jwks_url = server.url("/jwks");
let provider = test_provider_with_http(&jwks_url)
.with_on_demand_refresh_cooldown(Duration::from_millis(100));
let token = "eyJhbGciOiJSUzI1NiIsImtpZCI6Im5vbmV4aXN0ZW50LWtpZCJ9.\
eyJzdWIiOiIxMjM0NTY3ODkwIn0.invalid";
let result = provider.validate_and_decode(token).await;
assert!(result.is_err());
match result.unwrap_err() {
ClaimsError::UnknownKeyId(kid) => {
assert_eq!(kid, "nonexistent-kid");
}
other => panic!("Expected UnknownKeyId, got: {other:?}"),
}
}
#[test]
fn test_decode_header_with_handler_coerces_non_string_extras() {
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
let header_json = r#"{"alg":"RS256","eap":1,"iri":"some-string-id","irn":["role_a"],"kid":"kid-1","typ":"at+jwt"}"#;
let header_b64 = URL_SAFE_NO_PAD.encode(header_json.as_bytes());
let payload_b64 = URL_SAFE_NO_PAD.encode(b"{}");
let token = format!("{header_b64}.{payload_b64}.fake");
let header = decode_header_with_handler(&token, &|_key, value| Some(value.to_string()))
.expect("should handle non-standard header fields");
assert_eq!(header.alg, jsonwebtoken::Algorithm::RS256);
assert_eq!(header.kid.as_deref(), Some("kid-1"));
assert_eq!(header.typ.as_deref(), Some("at+jwt"));
assert_eq!(header.extras.get("eap").map(String::as_str), Some("1"));
assert_eq!(
header.extras.get("irn").map(String::as_str),
Some(r#"["role_a"]"#)
);
assert_eq!(
header.extras.get("iri").map(String::as_str),
Some("some-string-id")
);
}
#[test]
fn test_decode_header_with_handler_can_drop_fields() {
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
let header_json = r#"{"alg":"RS256","eap":1,"iri":"keep-me","kid":"kid-1","typ":"JWT"}"#;
let header_b64 = URL_SAFE_NO_PAD.encode(header_json.as_bytes());
let token = format!("{header_b64}.e30.fake");
let header = decode_header_with_handler(&token, &|_key, _value| None)
.expect("should succeed when handler drops non-string fields");
assert_eq!(header.alg, jsonwebtoken::Algorithm::RS256);
assert!(!header.extras.contains_key("eap"));
assert_eq!(
header.extras.get("iri").map(String::as_str),
Some("keep-me")
);
}
#[tokio::test]
async fn test_with_header_extras_stringified_coerces_non_string_extras() {
let server = MockServer::start();
server.mock(|when, then| {
when.method(GET).path("/jwks");
then.status(200)
.header("content-type", "application/json")
.body(valid_jwks_json());
});
let jwks_url = server.url("/jwks");
let provider = test_provider_with_http(&jwks_url).with_header_extras_stringified();
let header_json =
r#"{"alg":"RS256","kid":"test-key-1","typ":"JWT","eap":1,"irn":["role_a"]}"#;
let header_b64 = URL_SAFE_NO_PAD.encode(header_json.as_bytes());
let payload_b64 = URL_SAFE_NO_PAD.encode(b"{}");
let token = format!("{header_b64}.{payload_b64}.AAAA");
let result = provider.validate_and_decode(&token).await;
let err = result.expect_err("fake signature should fail validation");
match &err {
ClaimsError::DecodeFailed(msg) => {
assert!(
msg.contains("JWT validation failed"),
"Expected signature-validation error, got: {msg}"
);
}
other => panic!("Expected DecodeFailed, got: {other:?}"),
}
}
#[tokio::test]
async fn test_validate_and_decode_uses_header_extras_handler() {
let server = MockServer::start();
server.mock(|when, then| {
when.method(GET).path("/jwks");
then.status(200)
.header("content-type", "application/json")
.body(valid_jwks_json());
});
let jwks_url = server.url("/jwks");
let provider = test_provider_with_http(&jwks_url)
.with_header_extras_handler(|_key, value| Some(value.to_string()));
let header_json = r#"{"alg":"RS256","kid":"test-key-1","typ":"JWT","eap":1}"#;
let header_b64 = URL_SAFE_NO_PAD.encode(header_json.as_bytes());
let payload_b64 = URL_SAFE_NO_PAD.encode(b"{}");
let token = format!("{header_b64}.{payload_b64}.AAAA");
let result = provider.validate_and_decode(&token).await;
let err = result.expect_err("fake signature should fail validation");
match &err {
ClaimsError::DecodeFailed(msg) => {
assert!(
msg.contains("JWT validation failed"),
"Expected signature-validation error, got: {msg}"
);
}
other => panic!("Expected DecodeFailed, got: {other:?}"),
}
}
#[test]
fn test_decode_header_without_handler_rejects_non_string_extras() {
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
let header_json = r#"{"alg":"RS256","eap":1,"kid":"kid-1","typ":"JWT"}"#;
let header_b64 = URL_SAFE_NO_PAD.encode(header_json.as_bytes());
let token = format!("{header_b64}.e30.fake");
let result = decode_header(&token);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("invalid type: integer"),
"expected type error, got: {err}"
);
}
}