1use std::sync::Arc;
22use std::time::{Duration, Instant};
23
24use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
25use serde::Deserialize;
26use serde_json::Value;
27use tokio::sync::Mutex;
28
29#[derive(Debug)]
34pub enum JwtVerifyError {
35 UnsupportedAlg(String),
36 MissingKid,
37 UnknownKid { kid: String },
38 JwksFetchFailed(String),
39 Invalid(String),
40 ClaimMissing(&'static str),
41}
42
43impl std::fmt::Display for JwtVerifyError {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 match self {
46 Self::UnsupportedAlg(a) => write!(f, "unsupported algorithm: {a}"),
47 Self::MissingKid => write!(f, "missing kid in JWT header"),
48 Self::UnknownKid { kid } => write!(f, "kid {kid:?} not in JWKS after refresh"),
49 Self::JwksFetchFailed(m) => write!(f, "JWKS fetch failed: {m}"),
50 Self::Invalid(m) => write!(f, "signature or claim validation failed: {m}"),
51 Self::ClaimMissing(n) => write!(f, "required claim missing: {n}"),
52 }
53 }
54}
55
56impl std::error::Error for JwtVerifyError {}
57
58#[derive(Debug, Clone, Deserialize)]
62pub struct MinimalClaims {
63 #[serde(rename = "tenant_id")]
64 pub tenant_id: String,
65}
66
67#[derive(Debug, Clone)]
68pub struct VerifiedToken {
69 pub tenant_id: String,
70 pub plan: Option<String>,
71 pub roles: Vec<String>,
72 pub jti: Option<String>,
73 pub sub: Option<String>,
74 pub claims: Value,
75}
76
77#[derive(Debug, Clone)]
79pub struct JwtVerifierConfig {
80 pub issuer: String,
82 pub audience: String,
84 pub jwks_url: String,
87 pub jwks_ttl: Duration,
89 pub leeway_secs: u64,
91 pub enforce: bool,
96}
97
98impl JwtVerifierConfig {
99 pub fn from_env() -> Option<Self> {
103 let jwks_url = std::env::var("AXON_JWT_JWKS_URL").ok().filter(|s| !s.is_empty())?;
104 let issuer = std::env::var("AXON_JWT_ISSUER")
105 .unwrap_or_else(|_| "https://auth.bemarking.com".into());
106 let audience =
107 std::env::var("AXON_JWT_AUDIENCE").unwrap_or_else(|_| "axon-api".into());
108 let jwks_ttl_secs: u64 = std::env::var("AXON_JWT_JWKS_TTL_SECONDS")
109 .ok()
110 .and_then(|v| v.parse().ok())
111 .unwrap_or(600);
112 let leeway_secs: u64 = std::env::var("AXON_JWT_LEEWAY_SECONDS")
113 .ok()
114 .and_then(|v| v.parse().ok())
115 .unwrap_or(60);
116 let enforce = std::env::var("AXON_ENFORCE_JWT_VERIFICATION")
117 .ok()
118 .map(|v| matches!(v.as_str(), "1" | "true" | "TRUE" | "yes"))
119 .unwrap_or(true);
120 Some(Self {
121 issuer,
122 audience,
123 jwks_url,
124 jwks_ttl: Duration::from_secs(jwks_ttl_secs),
125 leeway_secs,
126 enforce,
127 })
128 }
129}
130
131#[derive(Debug, Clone, Deserialize)]
134struct JwksEntry {
135 kid: String,
136 kty: String,
137 alg: Option<String>,
138 n: Option<String>,
139 e: Option<String>,
140}
141
142#[derive(Debug, Clone, Deserialize)]
143struct JwksDocument {
144 keys: Vec<JwksEntry>,
145}
146
147struct CacheSlot {
148 loaded_at: Instant,
149 keys: Vec<JwksEntry>,
150}
151
152pub struct JwksClient {
154 url: String,
155 ttl: Duration,
156 http: reqwest::Client,
157 slot: Mutex<Option<CacheSlot>>,
158}
159
160impl JwksClient {
161 pub fn new(url: String, ttl: Duration) -> Self {
162 Self {
163 url,
164 ttl,
165 http: reqwest::Client::builder()
166 .timeout(Duration::from_secs(5))
167 .build()
168 .expect("reqwest client"),
169 slot: Mutex::new(None),
170 }
171 }
172
173 async fn resolve_key(&self, kid: &str) -> Result<JwksEntry, JwtVerifyError> {
174 {
175 let slot = self.slot.lock().await;
176 if let Some(c) = slot.as_ref() {
177 if c.loaded_at.elapsed() < self.ttl {
178 if let Some(k) = c.keys.iter().find(|k| k.kid == kid) {
179 return Ok(k.clone());
180 }
181 }
182 }
183 }
184 self.refresh().await?;
185 let slot = self.slot.lock().await;
186 let cache = slot.as_ref().ok_or_else(|| {
187 JwtVerifyError::JwksFetchFailed("empty cache after refresh".into())
188 })?;
189 cache
190 .keys
191 .iter()
192 .find(|k| k.kid == kid)
193 .cloned()
194 .ok_or_else(|| JwtVerifyError::UnknownKid { kid: kid.to_string() })
195 }
196
197 async fn refresh(&self) -> Result<(), JwtVerifyError> {
198 let resp = self
199 .http
200 .get(&self.url)
201 .header("Accept", "application/json")
202 .send()
203 .await
204 .map_err(|e| JwtVerifyError::JwksFetchFailed(e.to_string()))?;
205 if !resp.status().is_success() {
206 return Err(JwtVerifyError::JwksFetchFailed(format!(
207 "HTTP {}",
208 resp.status()
209 )));
210 }
211 let doc: JwksDocument = resp
212 .json()
213 .await
214 .map_err(|e| JwtVerifyError::JwksFetchFailed(e.to_string()))?;
215 let mut slot = self.slot.lock().await;
216 *slot = Some(CacheSlot {
217 loaded_at: Instant::now(),
218 keys: doc.keys,
219 });
220 Ok(())
221 }
222}
223
224pub struct JwtVerifier {
227 cfg: JwtVerifierConfig,
228 jwks: Arc<JwksClient>,
229}
230
231impl JwtVerifier {
232 pub fn new(cfg: JwtVerifierConfig) -> Self {
233 let jwks = Arc::new(JwksClient::new(cfg.jwks_url.clone(), cfg.jwks_ttl));
234 Self { cfg, jwks }
235 }
236
237 pub fn config(&self) -> &JwtVerifierConfig {
238 &self.cfg
239 }
240
241 pub async fn verify(&self, token: &str) -> Result<VerifiedToken, JwtVerifyError> {
242 let header =
243 decode_header(token).map_err(|e| JwtVerifyError::Invalid(e.to_string()))?;
244 let alg = match header.alg {
245 Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => header.alg,
246 other => {
247 return Err(JwtVerifyError::UnsupportedAlg(format!("{other:?}")));
248 }
249 };
250 let kid = header.kid.ok_or(JwtVerifyError::MissingKid)?;
251 let entry = self.jwks.resolve_key(&kid).await?;
252
253 if entry.kty != "RSA" {
254 return Err(JwtVerifyError::UnsupportedAlg(format!(
255 "non-RSA JWK kty={}",
256 entry.kty
257 )));
258 }
259
260 let n = entry.n.ok_or_else(|| {
261 JwtVerifyError::Invalid("JWK missing modulus".into())
262 })?;
263 let e = entry.e.ok_or_else(|| {
264 JwtVerifyError::Invalid("JWK missing exponent".into())
265 })?;
266 let key = DecodingKey::from_rsa_components(&n, &e)
267 .map_err(|err| JwtVerifyError::Invalid(err.to_string()))?;
268
269 let mut validation = Validation::new(alg);
270 validation.set_issuer(&[self.cfg.issuer.clone()]);
271 validation.set_audience(&[self.cfg.audience.clone()]);
272 validation.leeway = self.cfg.leeway_secs;
273 validation.validate_exp = true;
274 validation.validate_nbf = true;
275 validation.required_spec_claims =
276 ["iss", "aud", "exp", "iat", "sub"].iter().map(|s| s.to_string()).collect();
277
278 let data = decode::<Value>(token, &key, &validation)
279 .map_err(|err| JwtVerifyError::Invalid(err.to_string()))?;
280 let claims = data.claims;
281
282 let tenant_id = claims
283 .get("tenant_id")
284 .and_then(|v| v.as_str())
285 .ok_or(JwtVerifyError::ClaimMissing("tenant_id"))?
286 .to_string();
287 let plan = claims.get("plan").and_then(|v| v.as_str()).map(String::from);
288 let roles = claims
289 .get("roles")
290 .and_then(|v| v.as_array())
291 .map(|arr| {
292 arr.iter()
293 .filter_map(|v| v.as_str().map(String::from))
294 .collect()
295 })
296 .unwrap_or_default();
297 let jti = claims.get("jti").and_then(|v| v.as_str()).map(String::from);
298 let sub = claims.get("sub").and_then(|v| v.as_str()).map(String::from);
299
300 Ok(VerifiedToken {
301 tenant_id,
302 plan,
303 roles,
304 jti,
305 sub,
306 claims,
307 })
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314
315 #[test]
316 fn config_from_env_requires_jwks_url() {
317 let prev = std::env::var("AXON_JWT_JWKS_URL").ok();
319 std::env::remove_var("AXON_JWT_JWKS_URL");
320 assert!(JwtVerifierConfig::from_env().is_none());
321 if let Some(v) = prev {
322 std::env::set_var("AXON_JWT_JWKS_URL", v);
323 }
324 }
325
326 #[test]
327 fn config_from_env_reads_values() {
328 std::env::set_var("AXON_JWT_JWKS_URL", "https://x/jwks.json");
329 std::env::set_var("AXON_JWT_ISSUER", "https://x");
330 std::env::set_var("AXON_JWT_AUDIENCE", "x-api");
331 let cfg = JwtVerifierConfig::from_env().unwrap();
332 assert_eq!(cfg.issuer, "https://x");
333 assert_eq!(cfg.audience, "x-api");
334 std::env::remove_var("AXON_JWT_JWKS_URL");
335 std::env::remove_var("AXON_JWT_ISSUER");
336 std::env::remove_var("AXON_JWT_AUDIENCE");
337 }
338}