use crate::{
apis::jwks_api::{Jwks, JwksKey},
clerk::Clerk,
validators::authorizer::ClerkError,
};
use arc_swap::{ArcSwap, Guard};
use async_trait::async_trait;
use std::{
collections::HashMap,
sync::Arc,
time::{Duration, SystemTime},
};
#[async_trait]
pub trait JwksProvider {
type Error: Into<ClerkError>;
async fn get_key(&self, kid: &str) -> Result<JwksKey, Self::Error>;
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum JwksProviderError {
UnknownKey,
JwksApi,
}
impl From<JwksProviderError> for ClerkError {
fn from(e: JwksProviderError) -> Self {
match e {
JwksProviderError::UnknownKey => ClerkError::Unauthorized("Error: Invalid JWT!".into()),
JwksProviderError::JwksApi => ClerkError::InternalServerError(String::from("Error: Could not fetch JWKS!")),
}
}
}
pub struct JwksProviderNoCache {
clerk_client: Clerk,
}
impl JwksProviderNoCache {
pub fn new(clerk_client: Clerk) -> Self {
Self { clerk_client }
}
}
#[async_trait]
impl JwksProvider for JwksProviderNoCache {
type Error = JwksProviderError;
async fn get_key(&self, kid: &str) -> Result<JwksKey, JwksProviderError> {
let jwks = Jwks::get_jwks(&self.clerk_client).await.map_err(|_| JwksProviderError::JwksApi)?;
jwks.keys.into_iter().find(|k| k.kid == kid).ok_or(JwksProviderError::UnknownKey)
}
}
pub enum RefreshOnUnknown {
Never,
Ratelimit(Duration),
Always,
}
pub struct MemoryCacheJwksProviderOptions {
pub expire_after: Option<Duration>,
pub refresh_on_unknown: RefreshOnUnknown,
}
impl Default for MemoryCacheJwksProviderOptions {
fn default() -> Self {
Self {
expire_after: Some(Duration::from_secs(60 * 60)),
refresh_on_unknown: RefreshOnUnknown::Ratelimit(Duration::from_secs(60 * 5)),
}
}
}
struct MemoryCacheJwksProviderState {
keys: HashMap<String, JwksKey>,
last_updated: SystemTime,
}
impl MemoryCacheJwksProviderState {
fn is_uninitialized(&self) -> bool {
self.last_updated == SystemTime::UNIX_EPOCH
}
fn is_expired(&self, expire_after: Option<Duration>) -> bool {
let Some(expire_after) = expire_after else { return false };
let Ok(elapsed) = self.last_updated.elapsed() else {
return true;
};
elapsed >= expire_after
}
}
pub struct MemoryCacheJwksProvider {
clerk_client: Clerk,
options: MemoryCacheJwksProviderOptions,
state: ArcSwap<MemoryCacheJwksProviderState>,
}
impl MemoryCacheJwksProvider {
pub fn new(clerk_client: Clerk) -> Self {
Self::new_with_options(clerk_client, MemoryCacheJwksProviderOptions::default())
}
pub fn new_with_options(clerk_client: Clerk, options: MemoryCacheJwksProviderOptions) -> Self {
let initial_state = MemoryCacheJwksProviderState {
keys: HashMap::new(),
last_updated: SystemTime::UNIX_EPOCH, };
Self {
clerk_client,
options,
state: ArcSwap::new(Arc::new(initial_state)),
}
}
async fn refresh(&self) -> Result<Arc<MemoryCacheJwksProviderState>, JwksProviderError> {
let jwks_model = Jwks::get_jwks(&self.clerk_client).await.map_err(|_| JwksProviderError::JwksApi)?;
let keys = jwks_model.keys.into_iter().map(|k| (k.kid.clone(), k)).collect();
let state = MemoryCacheJwksProviderState {
keys,
last_updated: SystemTime::now(),
};
Ok(Arc::new(state))
}
}
#[async_trait]
impl JwksProvider for MemoryCacheJwksProvider {
type Error = JwksProviderError;
async fn get_key(&self, kid: &str) -> Result<JwksKey, Self::Error> {
let state = self.state.load();
let mut refreshed = false;
let state = if state.is_uninitialized() || state.is_expired(self.options.expire_after) {
let new_state = self.refresh().await?;
self.state.swap(new_state.clone());
refreshed = true;
Guard::from_inner(new_state)
} else {
state
};
let maybe_key = state.keys.get(kid).cloned();
if let Some(key) = maybe_key {
Ok(key)
} else {
if refreshed {
return Err(JwksProviderError::UnknownKey);
}
if let RefreshOnUnknown::Never = self.options.refresh_on_unknown {
return Err(JwksProviderError::UnknownKey);
}
if let RefreshOnUnknown::Ratelimit(min_age) = self.options.refresh_on_unknown {
if !state.is_expired(Some(min_age)) {
return Err(JwksProviderError::UnknownKey);
}
}
let new_state = self.refresh().await?;
self.state.swap(new_state.clone());
new_state.keys.get(kid).cloned().ok_or(JwksProviderError::UnknownKey)
}
}
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use crate::{apis::jwks_api::JwksKey, ClerkConfiguration};
use std::collections::HashMap;
pub struct StaticJwksProvider {
keys: HashMap<String, JwksKey>,
}
impl StaticJwksProvider {
pub fn from_key(key: JwksKey) -> Self {
let mut keys = HashMap::new();
keys.insert(key.kid.clone(), key);
Self { keys }
}
}
#[async_trait]
impl JwksProvider for StaticJwksProvider {
type Error = JwksProviderError;
async fn get_key(&self, kid: &str) -> Result<JwksKey, Self::Error> {
self.keys.get(kid).cloned().ok_or(JwksProviderError::UnknownKey)
}
}
const MOCK_JWKS_BODY: &str = r#"{
"keys": [{
"use": "sig",
"kty": "RSA",
"kid": "bc63c2e9-5d1c-4e32-9b62-178f60409abd",
"alg": "RS256",
"n": "tgY-zUiCj6p4gDLZos28PJXyBimvDCnvlCxpG8jktCJdbw1VrsAR1tqmz3XrKnpXgKuWaBnoAh9SpslSN-lQhHT_KVHgVUQShMETybKrNx9DoeRChwen26n35BOLCtZE7yUamUGrQpcL4DsL8ZmWcllOLCWjHRenuXJohoQO7jKN9tao2mUpsRor-2O1xKZ_YesCDkDHw7ood4lfNKvDONB8gYIENlOJgAbAKPxTdmnkEraUgZVGeaS7FeB59A_ibj9VnXyqpHmhabSf5xskuA9EJiQn6c3781uGqcF2CS0E4I576oJGsKeKo5AgF2duuDnPd67bVRNvmLH5kDF_Ow",
"e": "AQAB"
}]
}"#;
const MOCK_KID: &str = "bc63c2e9-5d1c-4e32-9b62-178f60409abd";
#[tokio::test]
async fn test_simple_jwks_provider_success() {
let mut server = mockito::Server::new_async().await;
let mock = server.mock("GET", "/v1/jwks").expect(1).with_body(MOCK_JWKS_BODY).create_async().await;
let config = ClerkConfiguration {
base_path: format!("{}/v1", server.url()),
..Default::default()
};
let clerk = Clerk::new(config);
let jwks = JwksProviderNoCache::new(clerk);
let res = jwks.get_key(MOCK_KID).await.expect("should retrieve key");
assert_eq!(res.kid, MOCK_KID);
mock.assert_async().await;
}
#[tokio::test]
async fn test_simple_jwks_provider_repeat() {
let mut server = mockito::Server::new_async().await;
let mock = server.mock("GET", "/v1/jwks").expect(3).with_body(MOCK_JWKS_BODY).create_async().await;
let config = ClerkConfiguration {
base_path: format!("{}/v1", server.url()),
..Default::default()
};
let clerk = Clerk::new(config);
let jwks = JwksProviderNoCache::new(clerk);
jwks.get_key(MOCK_KID).await.expect("should retrieve key");
jwks.get_key(MOCK_KID).await.expect("should retrieve key");
jwks.get_key(MOCK_KID).await.expect("should retrieve key");
mock.assert_async().await;
}
#[tokio::test]
async fn test_simple_jwks_provider_unknown_key() {
let mut server = mockito::Server::new_async().await;
server.mock("GET", "/v1/jwks").with_body(MOCK_JWKS_BODY).create_async().await;
let config = ClerkConfiguration {
base_path: format!("{}/v1", server.url()),
..Default::default()
};
let clerk = Clerk::new(config);
let jwks = JwksProviderNoCache::new(clerk);
let res = jwks.get_key("unknown key").await.expect_err("should fail");
assert_eq!(res, JwksProviderError::UnknownKey)
}
#[tokio::test]
async fn test_memory_cache_jwks_provider_success() {
let mut server = mockito::Server::new_async().await;
let mock = server.mock("GET", "/v1/jwks").expect(1).with_body(MOCK_JWKS_BODY).create_async().await;
let config = ClerkConfiguration {
base_path: format!("{}/v1", server.url()),
..Default::default()
};
let clerk = Clerk::new(config);
let jwks = MemoryCacheJwksProvider::new(clerk);
let res = jwks.get_key(MOCK_KID).await.expect("should retrieve key");
assert_eq!(res.kid, MOCK_KID);
mock.assert_async().await;
}
#[tokio::test]
async fn test_memory_cache_jwks_provider_caching() {
let mut server = mockito::Server::new_async().await;
let mock = server.mock("GET", "/v1/jwks").expect(1).with_body(MOCK_JWKS_BODY).create_async().await;
let config = ClerkConfiguration {
base_path: format!("{}/v1", server.url()),
..Default::default()
};
let clerk = Clerk::new(config);
let jwks = MemoryCacheJwksProvider::new(clerk);
jwks.get_key(MOCK_KID).await.expect("should retrieve key");
jwks.get_key(MOCK_KID).await.expect("should retrieve key");
jwks.get_key(MOCK_KID).await.expect("should retrieve key");
mock.assert_async().await;
}
#[tokio::test]
async fn test_memory_cache_jwks_provider_unknown_never() {
let mut server = mockito::Server::new_async().await;
let mock = server.mock("GET", "/v1/jwks").expect(1).with_body(MOCK_JWKS_BODY).create_async().await;
let config = ClerkConfiguration {
base_path: format!("{}/v1", server.url()),
..Default::default()
};
let clerk = Clerk::new(config);
let jwks = MemoryCacheJwksProvider::new_with_options(
clerk,
MemoryCacheJwksProviderOptions {
refresh_on_unknown: RefreshOnUnknown::Never,
..Default::default()
},
);
jwks.get_key(MOCK_KID).await.expect("should retrieve key");
jwks.get_key("unknown key").await.expect_err("should fail");
mock.assert_async().await;
}
#[tokio::test]
async fn test_memory_cache_jwks_provider_unknown_refresh() {
let mut server = mockito::Server::new_async().await;
let mock = server.mock("GET", "/v1/jwks").expect(3).with_body(MOCK_JWKS_BODY).create_async().await;
let config = ClerkConfiguration {
base_path: format!("{}/v1", server.url()),
..Default::default()
};
let clerk = Clerk::new(config);
let jwks = MemoryCacheJwksProvider::new_with_options(
clerk,
MemoryCacheJwksProviderOptions {
refresh_on_unknown: RefreshOnUnknown::Always,
..Default::default()
},
);
jwks.get_key(MOCK_KID).await.expect("should retrieve key");
jwks.get_key("unknown key 1").await.expect_err("should fail");
jwks.get_key("unknown key 2").await.expect_err("should fail");
mock.assert_async().await;
}
#[tokio::test]
async fn test_memory_cache_jwks_provider_unknown_ratelimit() {
let mut server = mockito::Server::new_async().await;
let mock = server.mock("GET", "/v1/jwks").expect(3).with_body(MOCK_JWKS_BODY).create_async().await;
let config = ClerkConfiguration {
base_path: format!("{}/v1", server.url()),
..Default::default()
};
let clerk = Clerk::new(config);
let jwks = MemoryCacheJwksProvider::new_with_options(
clerk,
MemoryCacheJwksProviderOptions {
refresh_on_unknown: RefreshOnUnknown::Ratelimit(Duration::from_secs(1)),
..Default::default()
},
);
jwks.get_key(MOCK_KID).await.expect("should retrieve key");
tokio::time::sleep(Duration::from_secs(2)).await;
jwks.get_key("unknown key 1").await.expect_err("should fail");
jwks.get_key("unknown key 2").await.expect_err("should fail");
jwks.get_key("unknown key 3").await.expect_err("should fail");
tokio::time::sleep(Duration::from_secs(2)).await;
jwks.get_key("unknown key 4").await.expect_err("should fail");
jwks.get_key("unknown key 5").await.expect_err("should fail");
jwks.get_key("unknown key 6").await.expect_err("should fail");
mock.assert_async().await;
}
#[tokio::test]
async fn test_memory_cache_jwks_provider_expires() {
let mut server = mockito::Server::new_async().await;
let mock = server.mock("GET", "/v1/jwks").expect(2).with_body(MOCK_JWKS_BODY).create_async().await;
let config = ClerkConfiguration {
base_path: format!("{}/v1", server.url()),
..Default::default()
};
let clerk = Clerk::new(config);
let jwks = MemoryCacheJwksProvider::new_with_options(
clerk,
MemoryCacheJwksProviderOptions {
expire_after: Some(Duration::from_secs(1)),
..Default::default()
},
);
jwks.get_key(MOCK_KID).await.expect("should retrieve key");
jwks.get_key(MOCK_KID).await.expect("should retrieve key");
jwks.get_key(MOCK_KID).await.expect("should retrieve key");
tokio::time::sleep(Duration::from_secs(2)).await;
jwks.get_key(MOCK_KID).await.expect("should retrieve key");
jwks.get_key(MOCK_KID).await.expect("should retrieve key");
jwks.get_key(MOCK_KID).await.expect("should retrieve key");
mock.assert_async().await;
}
}