use std::sync::{Arc, RwLock};
use axess_clock::{Clock, SystemClock};
use chrono::Duration;
use jsonwebtoken::jwk::{Jwk, JwkSet};
use jsonwebtoken::{Algorithm, Header, encode};
use tokio::sync::RwLock as AsyncRwLock;
use self::primitives::{build_claims_json, enforce_max_ttl_fallible, key_algorithm_to_algorithm};
pub use self::primitives::{IssuanceEvent, IssuanceListener, LocalIdpSigningKey, MintClaims};
#[cfg(any(test, feature = "testing"))]
pub use crate::testing::local_idp::{MockIssuanceListener, RecordedIssuance};
pub mod primitives;
pub mod discovery;
pub use discovery::LocalIdpMetadata;
#[derive(Debug, thiserror::Error)]
pub enum IssuanceError<KE: std::error::Error + Send + Sync + 'static> {
#[error("token lifetime {observed} exceeds configured max-TTL {max}")]
LifetimeExceedsCap {
observed: Duration,
max: Duration,
},
#[error("key store error")]
KeyStore(#[source] KE),
#[error("JWT encoding error: {0}")]
Encoding(String),
}
#[derive(Debug, Clone)]
pub struct LoadedKeys {
pub current: LocalIdpSigningKey,
pub historical: Vec<LocalIdpSigningKey>,
}
pub trait LocalIdpKeyStore: Send + Sync + 'static {
type Error: std::error::Error + Send + Sync + 'static;
fn load_all(&self)
-> impl std::future::Future<Output = Result<LoadedKeys, Self::Error>> + Send;
fn rotate(
&self,
new_current: LocalIdpSigningKey,
) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send;
}
#[derive(Debug, Clone)]
pub struct MemoryLocalIdpKeyStore {
inner: Arc<RwLock<LoadedKeys>>,
}
#[derive(Debug, thiserror::Error)]
pub enum MemoryLocalIdpKeyStoreError {
#[error("memory key store error (unreachable in L1)")]
Infallible,
}
impl MemoryLocalIdpKeyStore {
pub fn with_current(current: LocalIdpSigningKey) -> Self {
Self {
inner: Arc::new(RwLock::new(LoadedKeys {
current,
historical: Vec::new(),
})),
}
}
pub fn with_keys(current: LocalIdpSigningKey, historical: Vec<LocalIdpSigningKey>) -> Self {
Self {
inner: Arc::new(RwLock::new(LoadedKeys {
current,
historical,
})),
}
}
}
impl LocalIdpKeyStore for MemoryLocalIdpKeyStore {
type Error = MemoryLocalIdpKeyStoreError;
fn load_all(
&self,
) -> impl std::future::Future<Output = Result<LoadedKeys, Self::Error>> + Send {
let inner = self.inner.clone();
async move {
let guard = inner
.read()
.expect("MemoryLocalIdpKeyStore lock never poisoned");
Ok(LoadedKeys {
current: guard.current.clone(),
historical: guard.historical.clone(),
})
}
}
fn rotate(
&self,
new_current: LocalIdpSigningKey,
) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send {
let inner = self.inner.clone();
async move {
let mut guard = inner
.write()
.expect("MemoryLocalIdpKeyStore lock never poisoned");
let previous = std::mem::replace(&mut guard.current, new_current);
guard.historical.push(previous);
Ok(())
}
}
}
pub struct LocalIdp<K: LocalIdpKeyStore> {
state: Arc<AsyncRwLock<LocalIdpState>>,
key_store: Arc<K>,
issuer: String,
base_url: Option<String>,
extra_metadata: Vec<(String, serde_json::Value)>,
max_ttl: Option<Duration>,
issuance_listener: Option<Arc<dyn IssuanceListener>>,
clock: Arc<dyn Clock>,
}
impl<K: LocalIdpKeyStore> Clone for LocalIdp<K> {
fn clone(&self) -> Self {
Self {
state: self.state.clone(),
key_store: self.key_store.clone(),
issuer: self.issuer.clone(),
base_url: self.base_url.clone(),
extra_metadata: self.extra_metadata.clone(),
max_ttl: self.max_ttl,
issuance_listener: self.issuance_listener.clone(),
clock: self.clock.clone(),
}
}
}
impl<K: LocalIdpKeyStore> std::fmt::Debug for LocalIdp<K> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LocalIdp")
.field("issuer", &self.issuer)
.field("base_url", &self.base_url)
.field("extra_metadata_fields", &self.extra_metadata.len())
.field("max_ttl", &self.max_ttl)
.field(
"issuance_listener",
&self.issuance_listener.as_ref().map(|_| "<set>"),
)
.finish()
}
}
struct LocalIdpState {
signing_key: LocalIdpSigningKey,
historical_keys: Vec<LocalIdpSigningKey>,
extra_public_jwks: Vec<Jwk>,
jwks: JwkSet,
}
impl<K: LocalIdpKeyStore> LocalIdp<K> {
pub async fn from_key_store(
issuer: impl Into<String>,
key_store: K,
) -> Result<Self, IssuanceError<K::Error>> {
let loaded = key_store
.load_all()
.await
.map_err(IssuanceError::KeyStore)?;
let jwks = rebuild_jwks(&loaded.current, &loaded.historical, &[]);
Ok(Self {
state: Arc::new(AsyncRwLock::new(LocalIdpState {
signing_key: loaded.current,
historical_keys: loaded.historical,
extra_public_jwks: Vec::new(),
jwks,
})),
key_store: Arc::new(key_store),
issuer: issuer.into(),
base_url: None,
extra_metadata: Vec::new(),
max_ttl: None,
issuance_listener: None,
clock: Arc::new(SystemClock),
})
}
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = Some(base_url.into());
self
}
pub fn with_metadata_field(
mut self,
name: impl Into<String>,
value: serde_json::Value,
) -> Self {
self.extra_metadata.push((name.into(), value));
self
}
pub fn with_max_ttl(mut self, ttl: Duration) -> Self {
self.max_ttl = Some(ttl);
self
}
pub fn with_issuance_listener(mut self, listener: Arc<dyn IssuanceListener>) -> Self {
self.issuance_listener = Some(listener);
self
}
pub fn with_clock(mut self, clock: Arc<dyn Clock>) -> Self {
self.clock = clock;
self
}
pub async fn mint(&self, claims: &MintClaims) -> Result<String, IssuanceError<K::Error>> {
self.mint_with_header(claims, Header::default()).await
}
pub async fn mint_with_header(
&self,
claims: &MintClaims,
mut header: Header,
) -> Result<String, IssuanceError<K::Error>> {
let state = self.state.read().await;
if let Some(max_ttl) = self.max_ttl {
let now = self.clock.now();
enforce_max_ttl_fallible(claims, max_ttl, now)
.map_err(|(observed, max)| IssuanceError::LifetimeExceedsCap { observed, max })?;
}
header.kid = Some(state.signing_key.key_id().to_string());
header.alg = state.signing_key.algorithm();
let claims_json = build_claims_json(&self.issuer, claims);
let key = state.signing_key.encoding_key();
let token = encode(&header, &claims_json, &key)
.map_err(|e| IssuanceError::Encoding(e.to_string()))?;
if let Some(listener) = &self.issuance_listener {
let event = IssuanceEvent {
issuer: &self.issuer,
key_id: state.signing_key.key_id(),
algorithm: state.signing_key.algorithm(),
claims,
};
listener.on_mint(&event);
}
Ok(token)
}
pub async fn rotate_signing_key(
&self,
new_current: LocalIdpSigningKey,
) -> Result<(), IssuanceError<K::Error>> {
self.key_store
.rotate(new_current.clone())
.await
.map_err(IssuanceError::KeyStore)?;
let mut state = self.state.write().await;
let previous = std::mem::replace(&mut state.signing_key, new_current);
state.historical_keys.push(previous);
state.jwks = rebuild_jwks(
&state.signing_key,
&state.historical_keys,
&state.extra_public_jwks,
);
Ok(())
}
pub fn issuer(&self) -> &str {
&self.issuer
}
pub fn max_ttl(&self) -> Option<Duration> {
self.max_ttl
}
pub async fn algorithm(&self) -> Algorithm {
self.state.read().await.signing_key.algorithm()
}
pub async fn verifier_algorithms(&self) -> Vec<Algorithm> {
let state = self.state.read().await;
let mut out = Vec::new();
let push_unique = |a: Algorithm, out: &mut Vec<Algorithm>| {
if !out.contains(&a) {
out.push(a);
}
};
push_unique(state.signing_key.algorithm(), &mut out);
for hk in &state.historical_keys {
push_unique(hk.algorithm(), &mut out);
}
for jwk in &state.extra_public_jwks {
if let Some(ka) = jwk.common.key_algorithm
&& let Some(alg) = key_algorithm_to_algorithm(ka)
{
push_unique(alg, &mut out);
}
}
out
}
pub async fn jwks(&self) -> JwkSet {
self.state.read().await.jwks.clone()
}
pub async fn jwks_json(&self) -> String {
serde_json::to_string(&self.state.read().await.jwks)
.expect("JwkSet serialisation always succeeds")
}
pub async fn jwks_handle(&self) -> Arc<RwLock<JwkSet>> {
Arc::new(RwLock::new(self.state.read().await.jwks.clone()))
}
pub fn base_url(&self) -> &str {
self.base_url.as_deref().unwrap_or(&self.issuer)
}
pub fn discovery_url(&self) -> String {
format!(
"{}/.well-known/openid-configuration",
self.base_url().trim_end_matches('/')
)
}
pub fn jwks_url(&self) -> String {
format!("{}/jwks.json", self.base_url().trim_end_matches('/'))
}
pub async fn metadata(&self) -> discovery::LocalIdpMetadata {
let algs = self
.verifier_algorithms()
.await
.into_iter()
.map(|a| {
serde_json::to_value(a)
.ok()
.and_then(|v| v.as_str().map(str::to_owned))
.expect("jsonwebtoken::Algorithm serialises as a JSON string")
})
.collect();
discovery::LocalIdpMetadata {
issuer: self.issuer.clone(),
jwks_uri: self.jwks_url(),
id_token_signing_alg_values_supported: algs,
extra: self.extra_metadata.clone(),
}
}
pub fn router(&self) -> axum::Router<()>
where
K: 'static,
{
axum::Router::new()
.route(
"/.well-known/openid-configuration",
axum::routing::get(discovery::handlers::openid_configuration::<K>),
)
.route(
"/jwks.json",
axum::routing::get(discovery::handlers::jwks::<K>),
)
.with_state(self.clone())
}
}
fn rebuild_jwks(
current: &LocalIdpSigningKey,
historical: &[LocalIdpSigningKey],
extra: &[Jwk],
) -> JwkSet {
let mut keys = Vec::with_capacity(1 + historical.len() + extra.len());
keys.push(current.jwk().clone());
keys.extend(historical.iter().map(|k| k.jwk().clone()));
keys.extend(extra.iter().cloned());
JwkSet { keys }
}
#[cfg(test)]
mod tests;