1use std::sync::Arc;
2
3use async_trait::async_trait;
4use dashmap::DashMap;
5use derive_builder::Builder;
6use jsonwebtoken::{jwk::JwkSet, DecodingKey, TokenData, Validation};
7use serde::de::DeserializeOwned;
8use tokio::sync::Notify;
9
10use crate::{Error, JwtDecoder};
11
12const DEFAULT_CACHE_DURATION: std::time::Duration = std::time::Duration::from_secs(60 * 60); const DEFAULT_RETRY_COUNT: usize = 3; const DEFAULT_BACKOFF: std::time::Duration = std::time::Duration::from_secs(1); #[derive(Debug, Clone, Builder)]
17pub struct RemoteJwksDecoderConfig {
18 #[builder(default = "DEFAULT_CACHE_DURATION")]
20 pub cache_duration: std::time::Duration,
21 #[builder(default = "DEFAULT_RETRY_COUNT")]
23 pub retry_count: usize,
24 #[builder(default = "DEFAULT_BACKOFF")]
26 pub backoff: std::time::Duration,
27}
28
29impl Default for RemoteJwksDecoderConfig {
30 fn default() -> Self {
31 Self {
32 cache_duration: DEFAULT_CACHE_DURATION,
33 retry_count: DEFAULT_RETRY_COUNT,
34 backoff: DEFAULT_BACKOFF,
35 }
36 }
37}
38
39impl RemoteJwksDecoderConfig {
40 pub fn builder() -> RemoteJwksDecoderConfigBuilder {
44 RemoteJwksDecoderConfigBuilder::default()
45 }
46}
47
48#[derive(Clone, Builder)]
52pub struct RemoteJwksDecoder {
53 jwks_url: String,
55 #[builder(default = "RemoteJwksDecoderConfig::default()")]
57 config: RemoteJwksDecoderConfig,
58 #[builder(default = "Arc::new(DashMap::new())")]
60 keys_cache: Arc<DashMap<String, DecodingKey>>,
61 validation: Validation,
63 #[builder(default = "reqwest::Client::new()")]
65 client: reqwest::Client,
66 #[builder(default = "Arc::new(Notify::new())")]
68 initialized: Arc<Notify>,
69}
70
71impl RemoteJwksDecoder {
72 pub fn new(jwks_url: String) -> Result<Self, Error> {
74 RemoteJwksDecoderBuilder::default()
75 .jwks_url(jwks_url)
76 .build()
77 .map_err(|e| Error::Configuration(e.to_string()))
78 }
79
80 pub fn builder() -> RemoteJwksDecoderBuilder {
84 RemoteJwksDecoderBuilder::default()
85 }
86
87 async fn refresh_keys(&self) -> Result<(), Error> {
92 let max_attempts = self.config.retry_count;
93 let mut attempt = 0;
94 let mut err = None;
95
96 while attempt < max_attempts {
97 match self.refresh_keys_once().await {
98 Ok(_) => return Ok(()),
99 Err(e) => {
100 err = Some(e);
101 attempt += 1;
102 tokio::time::sleep(self.config.backoff).await;
103 }
104 }
105 }
106
107 Err(Error::JwksRefresh {
108 message: "Failed to refresh JWKS after multiple attempts".to_string(),
109 retry_count: max_attempts,
110 source: err.map(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>),
111 })
112 }
113
114 async fn refresh_keys_once(&self) -> Result<(), Error> {
117 let jwks = self
118 .client
119 .get(&self.jwks_url)
120 .send()
121 .await?
122 .json::<JwkSet>()
123 .await?;
124
125 let mut new_keys = Vec::new();
127 for jwk in jwks.keys.iter() {
128 let key_id = jwk.common.key_id.to_owned();
129 let key = DecodingKey::from_jwk(jwk).map_err(Error::Jwt)?;
130 new_keys.push((key_id.unwrap_or_default(), key));
131 }
132
133 self.keys_cache.clear();
135 for (kid, key) in new_keys {
136 self.keys_cache.insert(kid, key);
137 }
138
139 self.initialized.notify_waiters();
141
142 Ok(())
143 }
144
145 pub async fn refresh_keys_periodically(&self) {
151 loop {
152 tracing::info!("Refreshing JWKS");
153 match self.refresh_keys().await {
154 Ok(_) => {}
155 Err(err) => {
156 tracing::error!(
158 "Failed to refresh JWKS after {} attempts: {:?}",
159 self.config.retry_count,
160 err
161 );
162 }
163 }
164 tokio::time::sleep(self.config.cache_duration).await;
165 }
166 }
167
168 async fn ensure_initialized(&self) {
170 self.initialized.notified().await;
171 }
172}
173
174#[async_trait]
175impl<T> JwtDecoder<T> for RemoteJwksDecoder
176where
177 T: for<'de> DeserializeOwned,
178{
179 async fn decode(&self, token: &str) -> Result<TokenData<T>, Error> {
180 self.ensure_initialized().await;
181 let header = jsonwebtoken::decode_header(token)?;
182 let target_kid = header.kid;
183
184 if let Some(ref kid) = target_kid {
185 if let Some(key) = self.keys_cache.get(kid) {
186 return Ok(jsonwebtoken::decode::<T>(
187 token,
188 key.value(),
189 &self.validation,
190 )?);
191 }
192 return Err(Error::KeyNotFound(Some(kid.clone())));
193 }
194 return Err(Error::KeyNotFound(None));
195 }
196}