use std::sync::{Arc, Mutex};
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine;
use rustauth_core::context::create_auth_context_with_adapter;
use rustauth_core::db::MemoryAdapter;
use rustauth_plugins::jwt::{
jwt, sign_jwt, verify_jwt, verify_jwt_with_options, JwtClaims, JwtJwksOptions, JwtOptions,
JwtSignHandler, JwtSigningOptions,
};
use serde_json::{json, Value};
use super::helpers::*;
#[tokio::test]
async fn custom_signer_receives_defaulted_claims() -> Result<(), Box<dyn std::error::Error>> {
let captured = Arc::new(Mutex::new(None::<JwtClaims>));
let signer: JwtSignHandler = Arc::new({
let captured = Arc::clone(&captured);
move |claims| {
let captured = Arc::clone(&captured);
Box::pin(async move {
*captured.lock().map_err(|error| {
rustauth_core::error::RustAuthError::Api(error.to_string())
})? = Some(claims);
Ok("remote.jwt.signature".to_owned())
})
}
});
let options = JwtOptions {
jwks: JwtJwksOptions {
remote_url: Some("https://example.com/jwks?tenant=one".to_owned()),
..JwtJwksOptions::default()
},
jwt: JwtSigningOptions {
sign: Some(signer),
..JwtSigningOptions::default()
},
..JwtOptions::default()
};
let context = create_auth_context_with_adapter(
options_with_plugin(jwt(options.clone())?),
Arc::new(MemoryAdapter::new()),
)?;
let mut claims = JwtClaims::new();
claims.insert("sub".to_owned(), json!("user_1"));
let token = sign_jwt(&context, claims, Some(options)).await?;
let claims = captured
.lock()
.map_err(|error| error.to_string())?
.clone()
.ok_or("missing captured claims")?;
assert_eq!(token, "remote.jwt.signature");
assert_eq!(claims["sub"], "user_1");
assert_eq!(claims["iss"], TEST_BASE_URL);
assert_eq!(claims["aud"], TEST_BASE_URL);
assert!(claims.get("iat").and_then(Value::as_i64).is_some());
assert!(claims.get("exp").and_then(Value::as_i64).is_some());
Ok(())
}
#[tokio::test]
async fn verify_returns_none_for_invalid_claim_or_key_cases(
) -> Result<(), Box<dyn std::error::Error>> {
let adapter = Arc::new(MemoryAdapter::new());
let context = create_auth_context_with_adapter(
options_with_plugin(jwt(JwtOptions::default())?),
adapter,
)?;
let mut valid_claims = JwtClaims::new();
valid_claims.insert("sub".to_owned(), json!("user_1"));
let valid = sign_jwt(&context, valid_claims, None).await?;
assert!(verify_jwt(&context, &valid, None).await?.is_some());
let mut no_sub = JwtClaims::new();
no_sub.insert("custom".to_owned(), json!("value"));
let no_sub_token = sign_jwt(&context, no_sub, None).await?;
assert!(verify_jwt(&context, &no_sub_token, None).await?.is_none());
let mut wrong_aud = JwtClaims::new();
wrong_aud.insert("sub".to_owned(), json!("user_1"));
wrong_aud.insert("aud".to_owned(), json!("https://wrong.example"));
let wrong_aud_token = sign_jwt(&context, wrong_aud, None).await?;
assert!(verify_jwt(&context, &wrong_aud_token, None)
.await?
.is_none());
let mut expired = JwtClaims::new();
expired.insert("sub".to_owned(), json!("user_1"));
expired.insert("exp".to_owned(), json!(1));
let expired_token = sign_jwt(&context, expired, None).await?;
assert!(verify_jwt(&context, &expired_token, None).await?.is_none());
assert!(verify_jwt(&context, "malformed", None).await?.is_none());
assert!(verify_jwt(&context, &remove_kid(&valid)?, None)
.await?
.is_none());
assert!(verify_jwt(&context, &replace_kid(&valid, "unknown")?, None)
.await?
.is_none());
assert!(verify_jwt(&context, &replace_alg(&valid, "none")?, None)
.await?
.is_none());
Ok(())
}
#[tokio::test]
async fn verify_with_options_accepts_custom_audience() -> Result<(), Box<dyn std::error::Error>> {
let adapter = Arc::new(MemoryAdapter::new());
let context = create_auth_context_with_adapter(
options_with_plugin(jwt(JwtOptions::default())?),
adapter,
)?;
let options = JwtOptions {
jwt: JwtSigningOptions {
audience: Some(vec!["https://api.example".to_owned()]),
..JwtSigningOptions::default()
},
..JwtOptions::default()
};
let mut claims = JwtClaims::new();
claims.insert("sub".to_owned(), json!("user_1"));
let token = sign_jwt(&context, claims, Some(options.clone())).await?;
assert!(verify_jwt(&context, &token, None).await?.is_none());
assert!(verify_jwt_with_options(&context, &token, &options, None)
.await?
.is_some());
Ok(())
}
#[tokio::test]
async fn direct_sign_jwt_override_options_are_preserved() -> Result<(), Box<dyn std::error::Error>>
{
let adapter = Arc::new(MemoryAdapter::new());
let context = create_auth_context_with_adapter(
options_with_plugin(jwt(JwtOptions::default())?),
adapter,
)?;
let options = JwtOptions {
jwt: JwtSigningOptions {
issuer: Some("https://issuer.example".to_owned()),
audience: Some(vec!["https://api.example".to_owned()]),
expiration_time: Some(rustauth_plugins::jwt::TimeInput::Duration("30m".to_owned())),
..JwtSigningOptions::default()
},
..JwtOptions::default()
};
let mut claims = JwtClaims::new();
claims.insert("sub".to_owned(), json!("user_1"));
let token = sign_jwt(&context, claims, Some(options)).await?;
assert!(verify_jwt(&context, &token, None).await?.is_none());
let claims = verify_jwt_with_options(
&context,
&token,
&JwtOptions {
jwt: JwtSigningOptions {
audience: Some(vec!["https://api.example".to_owned()]),
..JwtSigningOptions::default()
},
..JwtOptions::default()
},
Some("https://issuer.example"),
)
.await?
.ok_or("missing verified claims")?;
assert_eq!(claims["aud"], "https://api.example");
Ok(())
}
fn replace_kid(token: &str, kid: &str) -> Result<String, Box<dyn std::error::Error>> {
let parts = token.split('.').collect::<Vec<_>>();
let mut header: Value = serde_json::from_slice(&URL_SAFE_NO_PAD.decode(parts[0])?)?;
header["kid"] = json!(kid);
Ok(format!(
"{}.{}.{}",
URL_SAFE_NO_PAD.encode(serde_json::to_vec(&header)?),
parts.get(1).ok_or("missing payload")?,
parts.get(2).ok_or("missing signature")?
))
}
fn remove_kid(token: &str) -> Result<String, Box<dyn std::error::Error>> {
let parts = token.split('.').collect::<Vec<_>>();
let mut header: Value = serde_json::from_slice(&URL_SAFE_NO_PAD.decode(parts[0])?)?;
header
.as_object_mut()
.ok_or("header must be object")?
.remove("kid");
Ok(format!(
"{}.{}.{}",
URL_SAFE_NO_PAD.encode(serde_json::to_vec(&header)?),
parts.get(1).ok_or("missing payload")?,
parts.get(2).ok_or("missing signature")?
))
}
fn replace_alg(token: &str, alg: &str) -> Result<String, Box<dyn std::error::Error>> {
let parts = token.split('.').collect::<Vec<_>>();
let mut header: Value = serde_json::from_slice(&URL_SAFE_NO_PAD.decode(parts[0])?)?;
header["alg"] = json!(alg);
Ok(format!(
"{}.{}.{}",
URL_SAFE_NO_PAD.encode(serde_json::to_vec(&header)?),
parts.get(1).ok_or("missing payload")?,
parts.get(2).ok_or("missing signature")?
))
}