use std::sync::{Arc, Mutex, RwLock};
use chrono::{DateTime, Duration, Utc};
use jsonwebtoken::jwk::{Jwk, JwkSet};
use jsonwebtoken::{Algorithm, Header, encode};
#[cfg(test)]
use crate::local_idp::primitives::LocalIdpKeyError;
use crate::local_idp::primitives::{
IssuanceEvent, IssuanceListener, LocalIdpSigningKey, MintClaims, build_claims_json,
enforce_max_ttl, key_algorithm_to_algorithm, rebuild_jwks,
};
#[derive(Debug, Clone)]
pub struct RecordedIssuance {
pub issuer: String,
pub key_id: String,
pub algorithm: Algorithm,
pub subject: String,
pub audience: Vec<String>,
pub expires_at: DateTime<Utc>,
pub not_before: Option<DateTime<Utc>>,
pub issued_at: Option<DateTime<Utc>>,
pub jwt_id: Option<String>,
}
#[derive(Default)]
pub struct MockIssuanceListener {
events: Mutex<Vec<RecordedIssuance>>,
}
impl MockIssuanceListener {
pub fn new() -> Self {
Self::default()
}
pub fn events(&self) -> Vec<RecordedIssuance> {
self.events
.lock()
.expect("MockIssuanceListener mutex never poisoned in tests")
.clone()
}
pub fn count(&self) -> usize {
self.events
.lock()
.expect("MockIssuanceListener mutex never poisoned in tests")
.len()
}
pub fn clear(&self) {
self.events
.lock()
.expect("MockIssuanceListener mutex never poisoned in tests")
.clear();
}
}
impl std::fmt::Debug for MockIssuanceListener {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MockIssuanceListener")
.field("count", &self.count())
.finish()
}
}
impl IssuanceListener for MockIssuanceListener {
fn on_mint(&self, event: &IssuanceEvent<'_>) {
self.events
.lock()
.expect("MockIssuanceListener mutex never poisoned in tests")
.push(RecordedIssuance {
issuer: event.issuer.to_string(),
key_id: event.key_id.to_string(),
algorithm: event.algorithm,
subject: event.claims.subject.clone(),
audience: event.claims.audience.clone(),
expires_at: event.claims.expires_at,
not_before: event.claims.not_before,
issued_at: event.claims.issued_at,
jwt_id: event.claims.jwt_id.clone(),
});
}
}
#[derive(Clone)]
pub struct LocalIdpFixture {
inner: Arc<LocalIdpInner>,
}
struct LocalIdpInner {
signing_key: LocalIdpSigningKey,
historical_keys: Vec<LocalIdpSigningKey>,
extra_public_jwks: Vec<Jwk>,
issuer: String,
max_ttl: Option<Duration>,
issuance_listener: Option<Arc<dyn IssuanceListener>>,
jwks: JwkSet,
}
impl LocalIdpFixture {
pub fn new(issuer: impl Into<String>) -> Self {
Self::with_signing_key(issuer, LocalIdpSigningKey::generate_rsa())
}
pub fn with_algorithm(issuer: impl Into<String>, algorithm: Algorithm) -> Self {
let signing_key = match algorithm {
Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => {
LocalIdpSigningKey::generate_rsa_with_algorithm(algorithm)
}
Algorithm::ES256 => LocalIdpSigningKey::generate_es256(),
other => panic!(
"LocalIdpFixture::with_algorithm supports RS256/RS384/RS512 and ES256; got {other:?}"
),
};
Self::with_signing_key(issuer, signing_key)
}
pub fn with_signing_key(issuer: impl Into<String>, signing_key: LocalIdpSigningKey) -> Self {
let jwks = rebuild_jwks(&signing_key, &[], &[]);
Self {
inner: Arc::new(LocalIdpInner {
signing_key,
historical_keys: Vec::new(),
extra_public_jwks: Vec::new(),
issuer: issuer.into(),
max_ttl: None,
issuance_listener: None,
jwks,
}),
}
}
pub fn with_historical_signing_key(self, key: LocalIdpSigningKey) -> Self {
let mut historical = self.inner.historical_keys.clone();
historical.push(key);
let jwks = rebuild_jwks(
&self.inner.signing_key,
&historical,
&self.inner.extra_public_jwks,
);
Self {
inner: Arc::new(LocalIdpInner {
signing_key: self.inner.signing_key.clone(),
historical_keys: historical,
extra_public_jwks: self.inner.extra_public_jwks.clone(),
issuer: self.inner.issuer.clone(),
max_ttl: self.inner.max_ttl,
issuance_listener: self.inner.issuance_listener.clone(),
jwks,
}),
}
}
pub fn with_extra_public_jwk(self, jwk: Jwk) -> Self {
let mut extra = self.inner.extra_public_jwks.clone();
extra.push(jwk);
let jwks = rebuild_jwks(&self.inner.signing_key, &self.inner.historical_keys, &extra);
Self {
inner: Arc::new(LocalIdpInner {
signing_key: self.inner.signing_key.clone(),
historical_keys: self.inner.historical_keys.clone(),
extra_public_jwks: extra,
issuer: self.inner.issuer.clone(),
max_ttl: self.inner.max_ttl,
issuance_listener: self.inner.issuance_listener.clone(),
jwks,
}),
}
}
pub fn rotate_signing_key(self, new_key: LocalIdpSigningKey) -> Self {
let mut historical = self.inner.historical_keys.clone();
historical.push(self.inner.signing_key.clone());
let jwks = rebuild_jwks(&new_key, &historical, &self.inner.extra_public_jwks);
Self {
inner: Arc::new(LocalIdpInner {
signing_key: new_key,
historical_keys: historical,
extra_public_jwks: self.inner.extra_public_jwks.clone(),
issuer: self.inner.issuer.clone(),
max_ttl: self.inner.max_ttl,
issuance_listener: self.inner.issuance_listener.clone(),
jwks,
}),
}
}
pub fn with_max_ttl(self, ttl: Duration) -> Self {
Self {
inner: Arc::new(LocalIdpInner {
signing_key: self.inner.signing_key.clone(),
historical_keys: self.inner.historical_keys.clone(),
extra_public_jwks: self.inner.extra_public_jwks.clone(),
issuer: self.inner.issuer.clone(),
max_ttl: Some(ttl),
issuance_listener: self.inner.issuance_listener.clone(),
jwks: self.inner.jwks.clone(),
}),
}
}
pub fn max_ttl(&self) -> Option<Duration> {
self.inner.max_ttl
}
pub fn with_issuance_listener(self, listener: Arc<dyn IssuanceListener>) -> Self {
Self {
inner: Arc::new(LocalIdpInner {
signing_key: self.inner.signing_key.clone(),
historical_keys: self.inner.historical_keys.clone(),
extra_public_jwks: self.inner.extra_public_jwks.clone(),
issuer: self.inner.issuer.clone(),
max_ttl: self.inner.max_ttl,
issuance_listener: Some(listener),
jwks: self.inner.jwks.clone(),
}),
}
}
pub fn issuance_listener(&self) -> Option<&Arc<dyn IssuanceListener>> {
self.inner.issuance_listener.as_ref()
}
pub fn with_key_id(self, key_id: impl Into<String>) -> Self {
let new_signing_key = self.inner.signing_key.clone().with_key_id(key_id);
let jwks = rebuild_jwks(
&new_signing_key,
&self.inner.historical_keys,
&self.inner.extra_public_jwks,
);
Self {
inner: Arc::new(LocalIdpInner {
signing_key: new_signing_key,
historical_keys: self.inner.historical_keys.clone(),
extra_public_jwks: self.inner.extra_public_jwks.clone(),
issuer: self.inner.issuer.clone(),
max_ttl: self.inner.max_ttl,
issuance_listener: self.inner.issuance_listener.clone(),
jwks,
}),
}
}
pub fn issuer(&self) -> &str {
&self.inner.issuer
}
pub fn key_id(&self) -> &str {
self.inner.signing_key.key_id()
}
pub fn algorithm(&self) -> Algorithm {
self.inner.signing_key.algorithm()
}
pub fn verifier_algorithms(&self) -> Vec<Algorithm> {
let mut out = Vec::new();
let push_unique = |a: Algorithm, out: &mut Vec<Algorithm>| {
if !out.contains(&a) {
out.push(a);
}
};
push_unique(self.inner.signing_key.algorithm(), &mut out);
for hk in &self.inner.historical_keys {
push_unique(hk.algorithm(), &mut out);
}
for jwk in &self.inner.extra_public_jwks {
if let Some(ka) = jwk.common.key_algorithm
&& let Some(alg) = key_algorithm_to_algorithm(ka)
{
push_unique(alg, &mut out);
}
}
out
}
pub fn signing_key(&self) -> &LocalIdpSigningKey {
&self.inner.signing_key
}
pub fn jwks(&self) -> &JwkSet {
&self.inner.jwks
}
pub fn jwks_json(&self) -> String {
serde_json::to_string(&self.inner.jwks).expect("JwkSet serialisation always succeeds")
}
pub fn jwks_handle(&self) -> Arc<RwLock<JwkSet>> {
Arc::new(RwLock::new(self.inner.jwks.clone()))
}
pub fn mint(&self, claims: &MintClaims) -> String {
self.mint_with_header(claims, &mut Header::new(self.inner.signing_key.algorithm()))
}
pub fn mint_with_header(&self, claims: &MintClaims, header: &mut Header) -> String {
if let Some(max_ttl) = self.inner.max_ttl {
enforce_max_ttl(claims, max_ttl);
}
header.kid = Some(self.inner.signing_key.key_id().to_string());
header.alg = self.inner.signing_key.algorithm();
let claims_json = build_claims_json(&self.inner.issuer, claims);
let key = self.inner.signing_key.encoding_key();
let token =
encode(header, &claims_json, &key).expect("JWT encode never fails on valid inputs");
if let Some(listener) = &self.inner.issuance_listener {
let event = IssuanceEvent {
issuer: &self.inner.issuer,
key_id: self.inner.signing_key.key_id(),
algorithm: self.inner.signing_key.algorithm(),
claims,
};
listener.on_mint(&event);
}
token
}
pub fn mint_jwt_svid(
&self,
trust_domain: &str,
service: &str,
tenant: &str,
audience: impl Into<String>,
ttl: Duration,
) -> String {
let now = Utc::now();
let exp = now + ttl;
let subject = format!("spiffe://{trust_domain}/{service}/{tenant}");
let claims = MintClaims::new(subject, exp)
.with_audience(audience)
.with_issued_at(now);
self.mint(&claims)
}
}
impl std::fmt::Debug for LocalIdpFixture {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LocalIdpFixture")
.field("issuer", &self.inner.issuer)
.field("signing_key", &self.inner.signing_key)
.finish()
}
}
#[cfg(test)]
mod audit_hook;
#[cfg(test)]
mod es256;
#[cfg(test)]
mod max_ttl;
#[cfg(test)]
mod mint_verify;
#[cfg(test)]
mod multi_key;
#[cfg(test)]
mod signing_key;