1use std::sync::Arc;
23
24use jsonwebtoken::{
25 Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation, decode, decode_header,
26 encode,
27};
28use parking_lot::RwLock;
29use serde::Serialize;
30use serde::de::DeserializeOwned;
31
32use crate::error::{Error, Result};
33
34pub struct ActiveKey {
39 pub kid: String,
40 pub alg: Algorithm,
41 pub encoding_key: EncodingKey,
42 pub decoding_key: DecodingKey,
43 pub expires_at: Option<f64>,
44}
45
46pub struct HistoryKey {
49 pub kid: String,
50 pub alg: Algorithm,
51 pub decoding_key: DecodingKey,
52}
53
54struct Inner {
55 active: Option<ActiveKey>,
56 history: Vec<HistoryKey>,
57 issuer: String,
58 audience: Vec<String>,
59}
60
61#[derive(Clone)]
64pub struct JwtConfig {
65 inner: Arc<RwLock<Inner>>,
66}
67
68impl JwtConfig {
69 pub fn new(issuer: String, audience: Vec<String>) -> Self {
73 Self {
74 inner: Arc::new(RwLock::new(Inner {
75 active: None,
76 history: Vec::new(),
77 issuer,
78 audience,
79 })),
80 }
81 }
82
83 pub fn set_active(&self, active: ActiveKey, history: Vec<HistoryKey>) {
86 let mut guard = self.inner.write();
87 guard.active = Some(active);
88 guard.history = history;
89 }
90
91 pub fn issue<T: Serialize>(&self, claims: &T) -> Result<String> {
94 let guard = self.inner.read();
95 let active = guard
96 .active
97 .as_ref()
98 .ok_or_else(|| Error::Jwt("no active jwt key configured".to_string()))?;
99 let mut header = Header::new(active.alg);
100 header.kid = Some(active.kid.clone());
101 encode(&header, claims, &active.encoding_key).map_err(map_jwt_err)
102 }
103
104 pub fn verify<T: DeserializeOwned>(&self, token: &str) -> Result<TokenData<T>> {
108 let header = decode_header(token).map_err(map_jwt_err)?;
109 let kid = header
110 .kid
111 .as_deref()
112 .ok_or_else(|| Error::Jwt("token has no kid header".to_string()))?;
113 let guard = self.inner.read();
114 let (alg, decoding_key) = lookup_decoding_key(&guard, kid)
115 .ok_or_else(|| Error::Jwt(format!("unknown kid {kid}")))?;
116 let mut validation = Validation::new(alg);
117 validation.set_issuer(std::slice::from_ref(&guard.issuer));
118 if !guard.audience.is_empty() {
119 validation.set_audience(&guard.audience);
120 }
121 decode::<T>(token, decoding_key, &validation).map_err(map_jwt_err)
122 }
123
124 pub fn active_kid(&self) -> Option<String> {
127 self.inner.read().active.as_ref().map(|k| k.kid.clone())
128 }
129
130 pub fn issuer(&self) -> String {
135 self.inner.read().issuer.clone()
136 }
137
138 pub fn audience(&self) -> Vec<String> {
140 self.inner.read().audience.clone()
141 }
142
143 #[cfg(feature = "backend-postgres")]
148 pub async fn load_from_postgres(&self, pool: &sqlx::PgPool) -> Result<()> {
149 use sqlx::Row;
150 let rows = sqlx::query(
151 "SELECT kid, alg, private_pem_encrypted, rotated_at, expires_at
152 FROM auth.jwks_keys
153 ORDER BY created_at",
154 )
155 .fetch_all(pool)
156 .await
157 .map_err(|e| Error::Backend(anyhow::anyhow!("load auth.jwks_keys (pg): {e}")))?;
158
159 let mut active = None;
160 let mut history = Vec::new();
161 for row in rows {
162 let kid: String = row.get("kid");
163 let alg_str: String = row.get("alg");
164 let pem: Option<Vec<u8>> = row.get("private_pem_encrypted");
165 let rotated_at: Option<f64> = row.get("rotated_at");
166 let expires_at: Option<f64> = row.get("expires_at");
167 let alg = parse_alg(&alg_str)?;
168 let pem = pem.ok_or_else(|| {
169 Error::Jwt(format!("auth.jwks_keys row {kid} has no private key"))
170 })?;
171 let (encoding_key, decoding_key) = build_keys(alg, &pem)?;
172 if rotated_at.is_none() && active.is_none() {
173 active = Some(ActiveKey {
174 kid: kid.clone(),
175 alg,
176 encoding_key,
177 decoding_key,
178 expires_at,
179 });
180 } else {
181 history.push(HistoryKey {
182 kid,
183 alg,
184 decoding_key,
185 });
186 }
187 }
188 let mut guard = self.inner.write();
189 guard.active = active;
190 guard.history = history;
191 Ok(())
192 }
193
194 #[cfg(feature = "backend-sqlite")]
196 pub async fn load_from_sqlite(&self, pool: &sqlx::SqlitePool) -> Result<()> {
197 use sqlx::Row;
198 let rows = sqlx::query(
199 "SELECT kid, alg, private_pem_encrypted, rotated_at, expires_at
200 FROM auth.jwks_keys
201 ORDER BY created_at",
202 )
203 .fetch_all(pool)
204 .await
205 .map_err(|e| Error::Backend(anyhow::anyhow!("load auth.jwks_keys (sqlite): {e}")))?;
206
207 let mut active = None;
208 let mut history = Vec::new();
209 for row in rows {
210 let kid: String = row.get("kid");
211 let alg_str: String = row.get("alg");
212 let pem: Option<Vec<u8>> = row.get("private_pem_encrypted");
213 let rotated_at: Option<f64> = row.get("rotated_at");
214 let expires_at: Option<f64> = row.get("expires_at");
215 let alg = parse_alg(&alg_str)?;
216 let pem = pem.ok_or_else(|| {
217 Error::Jwt(format!("auth.jwks_keys row {kid} has no private key"))
218 })?;
219 let (encoding_key, decoding_key) = build_keys(alg, &pem)?;
220 if rotated_at.is_none() && active.is_none() {
221 active = Some(ActiveKey {
222 kid: kid.clone(),
223 alg,
224 encoding_key,
225 decoding_key,
226 expires_at,
227 });
228 } else {
229 history.push(HistoryKey {
230 kid,
231 alg,
232 decoding_key,
233 });
234 }
235 }
236 let mut guard = self.inner.write();
237 guard.active = active;
238 guard.history = history;
239 Ok(())
240 }
241
242 #[cfg(feature = "backend-postgres")]
246 pub async fn rotate_postgres(&self, pool: &sqlx::PgPool) -> Result<String> {
247 let GeneratedKey {
248 kid,
249 alg,
250 private_pem,
251 public_jwk,
252 } = generate_ed25519_key();
253 let (encoding_key, decoding_key) = build_keys(alg, private_pem.as_bytes())?;
254 let now = now_secs();
255 let mut tx = pool
256 .begin()
257 .await
258 .map_err(|e| Error::Backend(anyhow::anyhow!("begin tx (pg rotate): {e}")))?;
259 sqlx::query(
260 "UPDATE auth.jwks_keys SET rotated_at = $1 WHERE rotated_at IS NULL",
261 )
262 .bind(now)
263 .execute(&mut *tx)
264 .await
265 .map_err(|e| Error::Backend(anyhow::anyhow!("mark old key rotated (pg): {e}")))?;
266 sqlx::query(
267 "INSERT INTO auth.jwks_keys
268 (kid, alg, public_jwk, private_pem_encrypted, created_at, rotated_at, expires_at)
269 VALUES ($1, $2, $3::jsonb, $4, $5, NULL, NULL)",
270 )
271 .bind(&kid)
272 .bind(alg_str(alg))
273 .bind(public_jwk.to_string())
274 .bind(private_pem.as_bytes())
275 .bind(now)
276 .execute(&mut *tx)
277 .await
278 .map_err(|e| Error::Backend(anyhow::anyhow!("insert new key (pg): {e}")))?;
279 tx.commit()
280 .await
281 .map_err(|e| Error::Backend(anyhow::anyhow!("commit tx (pg rotate): {e}")))?;
282 let mut guard = self.inner.write();
284 if let Some(prev) = guard.active.take() {
285 guard.history.push(HistoryKey {
286 kid: prev.kid,
287 alg: prev.alg,
288 decoding_key: prev.decoding_key,
289 });
290 }
291 guard.active = Some(ActiveKey {
292 kid: kid.clone(),
293 alg,
294 encoding_key,
295 decoding_key,
296 expires_at: None,
297 });
298 Ok(kid)
299 }
300
301 #[cfg(feature = "backend-sqlite")]
303 pub async fn rotate_sqlite(&self, pool: &sqlx::SqlitePool) -> Result<String> {
304 let GeneratedKey {
305 kid,
306 alg,
307 private_pem,
308 public_jwk,
309 } = generate_ed25519_key();
310 let (encoding_key, decoding_key) = build_keys(alg, private_pem.as_bytes())?;
311 let now = now_secs();
312 let mut tx = pool
313 .begin()
314 .await
315 .map_err(|e| Error::Backend(anyhow::anyhow!("begin tx (sqlite rotate): {e}")))?;
316 sqlx::query(
317 "UPDATE auth.jwks_keys SET rotated_at = ? WHERE rotated_at IS NULL",
318 )
319 .bind(now)
320 .execute(&mut *tx)
321 .await
322 .map_err(|e| Error::Backend(anyhow::anyhow!("mark old key rotated (sqlite): {e}")))?;
323 sqlx::query(
324 "INSERT INTO auth.jwks_keys
325 (kid, alg, public_jwk, private_pem_encrypted, created_at, rotated_at, expires_at)
326 VALUES (?, ?, ?, ?, ?, NULL, NULL)",
327 )
328 .bind(&kid)
329 .bind(alg_str(alg))
330 .bind(public_jwk.to_string())
331 .bind(private_pem.as_bytes())
332 .bind(now)
333 .execute(&mut *tx)
334 .await
335 .map_err(|e| Error::Backend(anyhow::anyhow!("insert new key (sqlite): {e}")))?;
336 tx.commit()
337 .await
338 .map_err(|e| Error::Backend(anyhow::anyhow!("commit tx (sqlite rotate): {e}")))?;
339 let mut guard = self.inner.write();
340 if let Some(prev) = guard.active.take() {
341 guard.history.push(HistoryKey {
342 kid: prev.kid,
343 alg: prev.alg,
344 decoding_key: prev.decoding_key,
345 });
346 }
347 guard.active = Some(ActiveKey {
348 kid: kid.clone(),
349 alg,
350 encoding_key,
351 decoding_key,
352 expires_at: None,
353 });
354 Ok(kid)
355 }
356}
357
358fn lookup_decoding_key<'a>(
359 inner: &'a Inner,
360 kid: &str,
361) -> Option<(Algorithm, &'a DecodingKey)> {
362 if let Some(active) = &inner.active
363 && active.kid == kid
364 {
365 return Some((active.alg, &active.decoding_key));
366 }
367 inner
368 .history
369 .iter()
370 .find(|h| h.kid == kid)
371 .map(|h| (h.alg, &h.decoding_key))
372}
373
374fn build_keys(alg: Algorithm, pem: &[u8]) -> Result<(EncodingKey, DecodingKey)> {
378 match alg {
379 Algorithm::EdDSA => {
380 let enc = EncodingKey::from_ed_pem(pem).map_err(map_jwt_err)?;
381 let public_pem = ed25519_public_pem_from_private(pem)?;
382 let dec = DecodingKey::from_ed_pem(public_pem.as_bytes()).map_err(map_jwt_err)?;
383 Ok((enc, dec))
384 }
385 other => Err(Error::Jwt(format!(
388 "unsupported jwt algorithm {other:?} (only EdDSA in phase 4)"
389 ))),
390 }
391}
392
393fn ed25519_public_pem_from_private(private_pem: &[u8]) -> Result<String> {
397 use ed25519_dalek::SigningKey;
398 use ed25519_dalek::pkcs8::DecodePrivateKey;
399 use ed25519_dalek::pkcs8::spki::EncodePublicKey;
400
401 let pem_str = std::str::from_utf8(private_pem)
402 .map_err(|e| Error::Jwt(format!("ed25519 private PEM utf8: {e}")))?;
403 let signing = SigningKey::from_pkcs8_pem(pem_str)
404 .map_err(|e| Error::Jwt(format!("parse ed25519 private PEM: {e}")))?;
405 let verifying = signing.verifying_key();
406 verifying
407 .to_public_key_pem(ed25519_dalek::pkcs8::spki::der::pem::LineEnding::LF)
408 .map_err(|e| Error::Jwt(format!("encode ed25519 public PEM: {e}")))
409}
410
411fn parse_alg(name: &str) -> Result<Algorithm> {
412 match name {
413 "EdDSA" => Ok(Algorithm::EdDSA),
414 other => Err(Error::Jwt(format!(
415 "unknown jwt algorithm {other:?} (only EdDSA in phase 4)"
416 ))),
417 }
418}
419
420fn alg_str(alg: Algorithm) -> &'static str {
421 match alg {
422 Algorithm::EdDSA => "EdDSA",
423 _ => "EdDSA",
427 }
428}
429
430fn map_jwt_err(e: jsonwebtoken::errors::Error) -> Error {
431 Error::Jwt(e.to_string())
432}
433
434fn now_secs() -> f64 {
435 std::time::SystemTime::now()
436 .duration_since(std::time::UNIX_EPOCH)
437 .unwrap_or_default()
438 .as_secs_f64()
439}
440
441struct GeneratedKey {
445 kid: String,
446 alg: Algorithm,
447 private_pem: String,
448 public_jwk: serde_json::Value,
449}
450
451fn generate_ed25519_key() -> GeneratedKey {
452 use ed25519_dalek::SigningKey;
453 use ed25519_dalek::pkcs8::EncodePrivateKey;
454
455 let signing = SigningKey::generate(&mut rand_core_06::OsRng);
456 let private_pem = signing
457 .to_pkcs8_pem(ed25519_dalek::pkcs8::spki::der::pem::LineEnding::LF)
458 .expect("ed25519 PKCS#8 PEM encoding")
459 .to_string();
460 let verifying = signing.verifying_key();
461 let pub_bytes = verifying.to_bytes();
462 let kid = format!(
463 "kid_{}",
464 data_encoding::BASE64URL_NOPAD.encode(&pub_bytes[..16])
465 );
466 let public_jwk = serde_json::json!({
467 "kty": "OKP",
468 "crv": "Ed25519",
469 "alg": "EdDSA",
470 "kid": kid,
471 "use": "sig",
472 "x": data_encoding::BASE64URL_NOPAD.encode(&pub_bytes),
473 });
474 GeneratedKey {
475 kid,
476 alg: Algorithm::EdDSA,
477 private_pem,
478 public_jwk,
479 }
480}
481
482pub fn generate_ephemeral_ed25519(kid: impl Into<String>) -> Result<ActiveKey> {
486 let GeneratedKey { private_pem, .. } = generate_ed25519_key();
487 let (encoding_key, decoding_key) = build_keys(Algorithm::EdDSA, private_pem.as_bytes())?;
488 Ok(ActiveKey {
489 kid: kid.into(),
490 alg: Algorithm::EdDSA,
491 encoding_key,
492 decoding_key,
493 expires_at: None,
494 })
495}
496
497#[cfg(test)]
498mod tests {
499 use super::*;
500 use serde::{Deserialize, Serialize};
501
502 #[derive(Debug, Serialize, Deserialize, PartialEq)]
503 struct Claims {
504 sub: String,
505 iss: String,
506 aud: String,
507 exp: usize,
508 }
509
510 fn config_with_active(issuer: &str, audience: &[&str]) -> JwtConfig {
511 let cfg = JwtConfig::new(
512 issuer.to_string(),
513 audience.iter().map(|s| s.to_string()).collect(),
514 );
515 let active = generate_ephemeral_ed25519("kid_test").unwrap();
516 cfg.set_active(active, Vec::new());
517 cfg
518 }
519
520 fn future_exp() -> usize {
521 (now_secs() as usize) + 3600
522 }
523
524 fn past_exp() -> usize {
525 (now_secs() as usize).saturating_sub(3600)
526 }
527
528 #[test]
529 fn issue_and_verify_round_trip() {
530 let cfg = config_with_active("assay", &["assay-engine"]);
531 let claims = Claims {
532 sub: "user_alice".to_string(),
533 iss: "assay".to_string(),
534 aud: "assay-engine".to_string(),
535 exp: future_exp(),
536 };
537 let token = cfg.issue(&claims).unwrap();
538 let data = cfg.verify::<Claims>(&token).unwrap();
539 assert_eq!(data.claims, claims);
540 assert_eq!(data.header.kid.as_deref(), Some("kid_test"));
541 }
542
543 #[test]
544 fn wrong_audience_is_rejected() {
545 let cfg = config_with_active("assay", &["assay-engine"]);
546 let token = cfg
547 .issue(&Claims {
548 sub: "u".to_string(),
549 iss: "assay".to_string(),
550 aud: "someone-else".to_string(),
551 exp: future_exp(),
552 })
553 .unwrap();
554 let result = cfg.verify::<Claims>(&token);
555 assert!(matches!(result, Err(Error::Jwt(_))));
556 }
557
558 #[test]
559 fn expired_token_is_rejected() {
560 let cfg = config_with_active("assay", &["assay-engine"]);
561 let token = cfg
562 .issue(&Claims {
563 sub: "u".to_string(),
564 iss: "assay".to_string(),
565 aud: "assay-engine".to_string(),
566 exp: past_exp(),
567 })
568 .unwrap();
569 let result = cfg.verify::<Claims>(&token);
570 assert!(matches!(result, Err(Error::Jwt(_))));
571 }
572
573 #[test]
574 fn unknown_kid_is_rejected() {
575 let cfg_a = config_with_active("assay", &["assay-engine"]);
576 let token = cfg_a
577 .issue(&Claims {
578 sub: "u".to_string(),
579 iss: "assay".to_string(),
580 aud: "assay-engine".to_string(),
581 exp: future_exp(),
582 })
583 .unwrap();
584 let cfg_b = JwtConfig::new("assay".to_string(), vec!["assay-engine".to_string()]);
587 let other = generate_ephemeral_ed25519("kid_b").unwrap();
588 cfg_b.set_active(other, Vec::new());
589 let result = cfg_b.verify::<Claims>(&token);
590 assert!(matches!(result, Err(Error::Jwt(_))));
591 }
592}