Skip to main content

assay_auth/
jwt.rs

1//! JWT issuance + verification with key rotation backed by
2//! `auth.jwks_keys`.
3//!
4//! Plan 11 reference: `jsonwebtoken` 10 with kid-based key lookup so old
5//! tokens still verify after a key rotation. We default to EdDSA
6//! (Ed25519) — small keys, fast signatures, no PKCS#1 footguns.
7//!
8//! Lifecycle:
9//! 1. Boot loads keys with [`JwtConfig::load_from_postgres`] /
10//!    [`JwtConfig::load_from_sqlite`]. The row with `rotated_at IS NULL`
11//!    becomes the active signing key; rotated rows become history
12//!    (verify-only).
13//! 2. [`JwtConfig::issue`] signs new tokens with the active key, putting
14//!    its `kid` in the JWT header.
15//! 3. [`JwtConfig::verify`] looks up the signing key by `kid` (active
16//!    first, then history), validates `iss` and `aud`, returns the
17//!    [`jsonwebtoken::TokenData`].
18//! 4. [`JwtConfig::rotate_postgres`] / [`JwtConfig::rotate_sqlite`]
19//!    generate a fresh Ed25519 keypair, persist it, mark the old active
20//!    key rotated, and swap the in-memory state atomically.
21
22use std::sync::Arc;
23
24use jsonwebtoken::{
25    Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation, decode, decode_header,
26    encode,
27};
28use parking_lot::RwLock;
29use serde::Serialize;
30use serde::de::DeserializeOwned;
31
32use crate::error::{Error, Result};
33
34/// Active signing key + its decoding twin. Held by the `Inner` state
35/// behind a `RwLock` so [`JwtConfig::rotate_postgres`] /
36/// [`JwtConfig::rotate_sqlite`] can swap the active key under callers
37/// already verifying inflight tokens.
38pub struct ActiveKey {
39    pub kid: String,
40    pub alg: Algorithm,
41    pub encoding_key: EncodingKey,
42    pub decoding_key: DecodingKey,
43    pub expires_at: Option<f64>,
44}
45
46/// Verify-only entry. Older keys live here so already-issued tokens
47/// validate until they expire on their own.
48pub struct HistoryKey {
49    pub kid: String,
50    pub alg: Algorithm,
51    pub decoding_key: DecodingKey,
52}
53
54struct Inner {
55    active: Option<ActiveKey>,
56    history: Vec<HistoryKey>,
57    issuer: String,
58    audience: Vec<String>,
59}
60
61/// Cheap-to-clone JWT configuration. Wrap with `Arc` internally so all
62/// clones share the same active-key + history view.
63#[derive(Clone)]
64pub struct JwtConfig {
65    inner: Arc<RwLock<Inner>>,
66}
67
68impl JwtConfig {
69    /// Empty configuration — no keys yet. Caller must populate via
70    /// [`JwtConfig::load_from_postgres`] / [`JwtConfig::load_from_sqlite`]
71    /// or [`JwtConfig::set_active`] before issuing tokens.
72    pub fn new(issuer: String, audience: Vec<String>) -> Self {
73        Self {
74            inner: Arc::new(RwLock::new(Inner {
75                active: None,
76                history: Vec::new(),
77                issuer,
78                audience,
79            })),
80        }
81    }
82
83    /// Replace the in-memory active key + history. Useful in tests where
84    /// we want a single ephemeral keypair without round-tripping the DB.
85    pub fn set_active(&self, active: ActiveKey, history: Vec<HistoryKey>) {
86        let mut guard = self.inner.write();
87        guard.active = Some(active);
88        guard.history = history;
89    }
90
91    /// Sign `claims` with the active key. The active key's `kid` is
92    /// written into the JWT header so verify can look it up.
93    pub fn issue<T: Serialize>(&self, claims: &T) -> Result<String> {
94        let guard = self.inner.read();
95        let active = guard
96            .active
97            .as_ref()
98            .ok_or_else(|| Error::Jwt("no active jwt key configured".to_string()))?;
99        let mut header = Header::new(active.alg);
100        header.kid = Some(active.kid.clone());
101        encode(&header, claims, &active.encoding_key).map_err(map_jwt_err)
102    }
103
104    /// Verify `token` and decode its claims. Looks up the decoding key
105    /// by header `kid` (active first, then history), validates `iss` and
106    /// the audience list against the in-memory configuration.
107    pub fn verify<T: DeserializeOwned>(&self, token: &str) -> Result<TokenData<T>> {
108        let header = decode_header(token).map_err(map_jwt_err)?;
109        let kid = header
110            .kid
111            .as_deref()
112            .ok_or_else(|| Error::Jwt("token has no kid header".to_string()))?;
113        let guard = self.inner.read();
114        let (alg, decoding_key) = lookup_decoding_key(&guard, kid)
115            .ok_or_else(|| Error::Jwt(format!("unknown kid {kid}")))?;
116        let mut validation = Validation::new(alg);
117        validation.set_issuer(std::slice::from_ref(&guard.issuer));
118        if !guard.audience.is_empty() {
119            validation.set_audience(&guard.audience);
120        }
121        decode::<T>(token, decoding_key, &validation).map_err(map_jwt_err)
122    }
123
124    /// Borrow the active key's `kid` (cheap clone). Useful in tests and
125    /// for telemetry.
126    pub fn active_kid(&self) -> Option<String> {
127        self.inner.read().active.as_ref().map(|k| k.kid.clone())
128    }
129
130    /// Configured issuer string. Plan-locked: every JWT this config
131    /// signs must carry this `iss` claim. Useful for downstream callers
132    /// (e.g. the BW-compat shim in assay-vault) that mint their own
133    /// claim shapes but still need `verify` to accept the token.
134    pub fn issuer(&self) -> String {
135        self.inner.read().issuer.clone()
136    }
137
138    /// Configured audience list (cheap clone — typical size 1).
139    pub fn audience(&self) -> Vec<String> {
140        self.inner.read().audience.clone()
141    }
142
143    /// Load every key from `auth.jwks_keys` into memory. The row with
144    /// `rotated_at IS NULL` becomes active; the rest become history.
145    /// `private_pem_encrypted` is treated as plaintext PEM for now —
146    /// encryption-at-rest is a later phase.
147    #[cfg(feature = "backend-postgres")]
148    pub async fn load_from_postgres(&self, pool: &sqlx::PgPool) -> Result<()> {
149        use sqlx::Row;
150        let rows = sqlx::query(
151            "SELECT kid, alg, private_pem_encrypted, rotated_at, expires_at
152             FROM auth.jwks_keys
153             ORDER BY created_at",
154        )
155        .fetch_all(pool)
156        .await
157        .map_err(|e| Error::Backend(anyhow::anyhow!("load auth.jwks_keys (pg): {e}")))?;
158
159        let mut active = None;
160        let mut history = Vec::new();
161        for row in rows {
162            let kid: String = row.get("kid");
163            let alg_str: String = row.get("alg");
164            let pem: Option<Vec<u8>> = row.get("private_pem_encrypted");
165            let rotated_at: Option<f64> = row.get("rotated_at");
166            let expires_at: Option<f64> = row.get("expires_at");
167            let alg = parse_alg(&alg_str)?;
168            let pem = pem.ok_or_else(|| {
169                Error::Jwt(format!("auth.jwks_keys row {kid} has no private key"))
170            })?;
171            let (encoding_key, decoding_key) = build_keys(alg, &pem)?;
172            if rotated_at.is_none() && active.is_none() {
173                active = Some(ActiveKey {
174                    kid: kid.clone(),
175                    alg,
176                    encoding_key,
177                    decoding_key,
178                    expires_at,
179                });
180            } else {
181                history.push(HistoryKey {
182                    kid,
183                    alg,
184                    decoding_key,
185                });
186            }
187        }
188        let mut guard = self.inner.write();
189        guard.active = active;
190        guard.history = history;
191        Ok(())
192    }
193
194    /// SQLite mirror of [`JwtConfig::load_from_postgres`].
195    #[cfg(feature = "backend-sqlite")]
196    pub async fn load_from_sqlite(&self, pool: &sqlx::SqlitePool) -> Result<()> {
197        use sqlx::Row;
198        let rows = sqlx::query(
199            "SELECT kid, alg, private_pem_encrypted, rotated_at, expires_at
200             FROM auth.jwks_keys
201             ORDER BY created_at",
202        )
203        .fetch_all(pool)
204        .await
205        .map_err(|e| Error::Backend(anyhow::anyhow!("load auth.jwks_keys (sqlite): {e}")))?;
206
207        let mut active = None;
208        let mut history = Vec::new();
209        for row in rows {
210            let kid: String = row.get("kid");
211            let alg_str: String = row.get("alg");
212            let pem: Option<Vec<u8>> = row.get("private_pem_encrypted");
213            let rotated_at: Option<f64> = row.get("rotated_at");
214            let expires_at: Option<f64> = row.get("expires_at");
215            let alg = parse_alg(&alg_str)?;
216            let pem = pem.ok_or_else(|| {
217                Error::Jwt(format!("auth.jwks_keys row {kid} has no private key"))
218            })?;
219            let (encoding_key, decoding_key) = build_keys(alg, &pem)?;
220            if rotated_at.is_none() && active.is_none() {
221                active = Some(ActiveKey {
222                    kid: kid.clone(),
223                    alg,
224                    encoding_key,
225                    decoding_key,
226                    expires_at,
227                });
228            } else {
229                history.push(HistoryKey {
230                    kid,
231                    alg,
232                    decoding_key,
233                });
234            }
235        }
236        let mut guard = self.inner.write();
237        guard.active = active;
238        guard.history = history;
239        Ok(())
240    }
241
242    /// Generate a fresh Ed25519 keypair, INSERT it into `auth.jwks_keys`
243    /// as the new active row, mark the prior active row rotated, and
244    /// swap the in-memory state. Returns the new `kid`.
245    #[cfg(feature = "backend-postgres")]
246    pub async fn rotate_postgres(&self, pool: &sqlx::PgPool) -> Result<String> {
247        let GeneratedKey {
248            kid,
249            alg,
250            private_pem,
251            public_jwk,
252        } = generate_ed25519_key();
253        let (encoding_key, decoding_key) = build_keys(alg, private_pem.as_bytes())?;
254        let now = now_secs();
255        let mut tx = pool
256            .begin()
257            .await
258            .map_err(|e| Error::Backend(anyhow::anyhow!("begin tx (pg rotate): {e}")))?;
259        sqlx::query(
260            "UPDATE auth.jwks_keys SET rotated_at = $1 WHERE rotated_at IS NULL",
261        )
262        .bind(now)
263        .execute(&mut *tx)
264        .await
265        .map_err(|e| Error::Backend(anyhow::anyhow!("mark old key rotated (pg): {e}")))?;
266        sqlx::query(
267            "INSERT INTO auth.jwks_keys
268                 (kid, alg, public_jwk, private_pem_encrypted, created_at, rotated_at, expires_at)
269             VALUES ($1, $2, $3::jsonb, $4, $5, NULL, NULL)",
270        )
271        .bind(&kid)
272        .bind(alg_str(alg))
273        .bind(public_jwk.to_string())
274        .bind(private_pem.as_bytes())
275        .bind(now)
276        .execute(&mut *tx)
277        .await
278        .map_err(|e| Error::Backend(anyhow::anyhow!("insert new key (pg): {e}")))?;
279        tx.commit()
280            .await
281            .map_err(|e| Error::Backend(anyhow::anyhow!("commit tx (pg rotate): {e}")))?;
282        // Swap in-memory.
283        let mut guard = self.inner.write();
284        if let Some(prev) = guard.active.take() {
285            guard.history.push(HistoryKey {
286                kid: prev.kid,
287                alg: prev.alg,
288                decoding_key: prev.decoding_key,
289            });
290        }
291        guard.active = Some(ActiveKey {
292            kid: kid.clone(),
293            alg,
294            encoding_key,
295            decoding_key,
296            expires_at: None,
297        });
298        Ok(kid)
299    }
300
301    /// SQLite mirror of [`JwtConfig::rotate_postgres`].
302    #[cfg(feature = "backend-sqlite")]
303    pub async fn rotate_sqlite(&self, pool: &sqlx::SqlitePool) -> Result<String> {
304        let GeneratedKey {
305            kid,
306            alg,
307            private_pem,
308            public_jwk,
309        } = generate_ed25519_key();
310        let (encoding_key, decoding_key) = build_keys(alg, private_pem.as_bytes())?;
311        let now = now_secs();
312        let mut tx = pool
313            .begin()
314            .await
315            .map_err(|e| Error::Backend(anyhow::anyhow!("begin tx (sqlite rotate): {e}")))?;
316        sqlx::query(
317            "UPDATE auth.jwks_keys SET rotated_at = ? WHERE rotated_at IS NULL",
318        )
319        .bind(now)
320        .execute(&mut *tx)
321        .await
322        .map_err(|e| Error::Backend(anyhow::anyhow!("mark old key rotated (sqlite): {e}")))?;
323        sqlx::query(
324            "INSERT INTO auth.jwks_keys
325                 (kid, alg, public_jwk, private_pem_encrypted, created_at, rotated_at, expires_at)
326             VALUES (?, ?, ?, ?, ?, NULL, NULL)",
327        )
328        .bind(&kid)
329        .bind(alg_str(alg))
330        .bind(public_jwk.to_string())
331        .bind(private_pem.as_bytes())
332        .bind(now)
333        .execute(&mut *tx)
334        .await
335        .map_err(|e| Error::Backend(anyhow::anyhow!("insert new key (sqlite): {e}")))?;
336        tx.commit()
337            .await
338            .map_err(|e| Error::Backend(anyhow::anyhow!("commit tx (sqlite rotate): {e}")))?;
339        let mut guard = self.inner.write();
340        if let Some(prev) = guard.active.take() {
341            guard.history.push(HistoryKey {
342                kid: prev.kid,
343                alg: prev.alg,
344                decoding_key: prev.decoding_key,
345            });
346        }
347        guard.active = Some(ActiveKey {
348            kid: kid.clone(),
349            alg,
350            encoding_key,
351            decoding_key,
352            expires_at: None,
353        });
354        Ok(kid)
355    }
356}
357
358fn lookup_decoding_key<'a>(
359    inner: &'a Inner,
360    kid: &str,
361) -> Option<(Algorithm, &'a DecodingKey)> {
362    if let Some(active) = &inner.active
363        && active.kid == kid
364    {
365        return Some((active.alg, &active.decoding_key));
366    }
367    inner
368        .history
369        .iter()
370        .find(|h| h.kid == kid)
371        .map(|h| (h.alg, &h.decoding_key))
372}
373
374/// Build encoding+decoding keys from a stored Ed25519 PKCS#8 private
375/// key PEM. `from_ed_pem` for the [`DecodingKey`] expects a *public*
376/// PEM, so we derive the SPKI public PEM from the private key first.
377fn build_keys(alg: Algorithm, pem: &[u8]) -> Result<(EncodingKey, DecodingKey)> {
378    match alg {
379        Algorithm::EdDSA => {
380            let enc = EncodingKey::from_ed_pem(pem).map_err(map_jwt_err)?;
381            let public_pem = ed25519_public_pem_from_private(pem)?;
382            let dec = DecodingKey::from_ed_pem(public_pem.as_bytes()).map_err(map_jwt_err)?;
383            Ok((enc, dec))
384        }
385        // RSA / ECDSA paths land when the operator brings their own key
386        // material; for v0.14.0 phase 4 we ship Ed25519 only.
387        other => Err(Error::Jwt(format!(
388            "unsupported jwt algorithm {other:?} (only EdDSA in phase 4)"
389        ))),
390    }
391}
392
393/// Derive the SPKI (subjectPublicKeyInfo) PEM for an Ed25519 keypair
394/// from the private PKCS#8 PEM. Done by re-parsing the private key with
395/// `ed25519_dalek` and re-encoding only the public half.
396fn ed25519_public_pem_from_private(private_pem: &[u8]) -> Result<String> {
397    use ed25519_dalek::SigningKey;
398    use ed25519_dalek::pkcs8::DecodePrivateKey;
399    use ed25519_dalek::pkcs8::spki::EncodePublicKey;
400
401    let pem_str = std::str::from_utf8(private_pem)
402        .map_err(|e| Error::Jwt(format!("ed25519 private PEM utf8: {e}")))?;
403    let signing = SigningKey::from_pkcs8_pem(pem_str)
404        .map_err(|e| Error::Jwt(format!("parse ed25519 private PEM: {e}")))?;
405    let verifying = signing.verifying_key();
406    verifying
407        .to_public_key_pem(ed25519_dalek::pkcs8::spki::der::pem::LineEnding::LF)
408        .map_err(|e| Error::Jwt(format!("encode ed25519 public PEM: {e}")))
409}
410
411fn parse_alg(name: &str) -> Result<Algorithm> {
412    match name {
413        "EdDSA" => Ok(Algorithm::EdDSA),
414        other => Err(Error::Jwt(format!(
415            "unknown jwt algorithm {other:?} (only EdDSA in phase 4)"
416        ))),
417    }
418}
419
420fn alg_str(alg: Algorithm) -> &'static str {
421    match alg {
422        Algorithm::EdDSA => "EdDSA",
423        // Other variants are unreachable today (build_keys / parse_alg
424        // gate Ed25519 only). Spell them out so future expansion is a
425        // compile-error.
426        _ => "EdDSA",
427    }
428}
429
430fn map_jwt_err(e: jsonwebtoken::errors::Error) -> Error {
431    Error::Jwt(e.to_string())
432}
433
434fn now_secs() -> f64 {
435    std::time::SystemTime::now()
436        .duration_since(std::time::UNIX_EPOCH)
437        .unwrap_or_default()
438        .as_secs_f64()
439}
440
441/// Output of the in-process Ed25519 key generator. Only used by the
442/// rotation helpers; ephemeral test keys go through
443/// [`generate_ephemeral_ed25519`] instead so the PEM never round-trips.
444struct GeneratedKey {
445    kid: String,
446    alg: Algorithm,
447    private_pem: String,
448    public_jwk: serde_json::Value,
449}
450
451fn generate_ed25519_key() -> GeneratedKey {
452    use ed25519_dalek::SigningKey;
453    use ed25519_dalek::pkcs8::EncodePrivateKey;
454
455    let signing = SigningKey::generate(&mut rand_core_06::OsRng);
456    let private_pem = signing
457        .to_pkcs8_pem(ed25519_dalek::pkcs8::spki::der::pem::LineEnding::LF)
458        .expect("ed25519 PKCS#8 PEM encoding")
459        .to_string();
460    let verifying = signing.verifying_key();
461    let pub_bytes = verifying.to_bytes();
462    let kid = format!(
463        "kid_{}",
464        data_encoding::BASE64URL_NOPAD.encode(&pub_bytes[..16])
465    );
466    let public_jwk = serde_json::json!({
467        "kty": "OKP",
468        "crv": "Ed25519",
469        "alg": "EdDSA",
470        "kid": kid,
471        "use": "sig",
472        "x": data_encoding::BASE64URL_NOPAD.encode(&pub_bytes),
473    });
474    GeneratedKey {
475        kid,
476        alg: Algorithm::EdDSA,
477        private_pem,
478        public_jwk,
479    }
480}
481
482/// Generate an ephemeral Ed25519 [`ActiveKey`] without touching any DB.
483/// Used by tests and by short-lived deployments that don't need
484/// rotation persistence.
485pub fn generate_ephemeral_ed25519(kid: impl Into<String>) -> Result<ActiveKey> {
486    let GeneratedKey { private_pem, .. } = generate_ed25519_key();
487    let (encoding_key, decoding_key) = build_keys(Algorithm::EdDSA, private_pem.as_bytes())?;
488    Ok(ActiveKey {
489        kid: kid.into(),
490        alg: Algorithm::EdDSA,
491        encoding_key,
492        decoding_key,
493        expires_at: None,
494    })
495}
496
497#[cfg(test)]
498mod tests {
499    use super::*;
500    use serde::{Deserialize, Serialize};
501
502    #[derive(Debug, Serialize, Deserialize, PartialEq)]
503    struct Claims {
504        sub: String,
505        iss: String,
506        aud: String,
507        exp: usize,
508    }
509
510    fn config_with_active(issuer: &str, audience: &[&str]) -> JwtConfig {
511        let cfg = JwtConfig::new(
512            issuer.to_string(),
513            audience.iter().map(|s| s.to_string()).collect(),
514        );
515        let active = generate_ephemeral_ed25519("kid_test").unwrap();
516        cfg.set_active(active, Vec::new());
517        cfg
518    }
519
520    fn future_exp() -> usize {
521        (now_secs() as usize) + 3600
522    }
523
524    fn past_exp() -> usize {
525        (now_secs() as usize).saturating_sub(3600)
526    }
527
528    #[test]
529    fn issue_and_verify_round_trip() {
530        let cfg = config_with_active("assay", &["assay-engine"]);
531        let claims = Claims {
532            sub: "user_alice".to_string(),
533            iss: "assay".to_string(),
534            aud: "assay-engine".to_string(),
535            exp: future_exp(),
536        };
537        let token = cfg.issue(&claims).unwrap();
538        let data = cfg.verify::<Claims>(&token).unwrap();
539        assert_eq!(data.claims, claims);
540        assert_eq!(data.header.kid.as_deref(), Some("kid_test"));
541    }
542
543    #[test]
544    fn wrong_audience_is_rejected() {
545        let cfg = config_with_active("assay", &["assay-engine"]);
546        let token = cfg
547            .issue(&Claims {
548                sub: "u".to_string(),
549                iss: "assay".to_string(),
550                aud: "someone-else".to_string(),
551                exp: future_exp(),
552            })
553            .unwrap();
554        let result = cfg.verify::<Claims>(&token);
555        assert!(matches!(result, Err(Error::Jwt(_))));
556    }
557
558    #[test]
559    fn expired_token_is_rejected() {
560        let cfg = config_with_active("assay", &["assay-engine"]);
561        let token = cfg
562            .issue(&Claims {
563                sub: "u".to_string(),
564                iss: "assay".to_string(),
565                aud: "assay-engine".to_string(),
566                exp: past_exp(),
567            })
568            .unwrap();
569        let result = cfg.verify::<Claims>(&token);
570        assert!(matches!(result, Err(Error::Jwt(_))));
571    }
572
573    #[test]
574    fn unknown_kid_is_rejected() {
575        let cfg_a = config_with_active("assay", &["assay-engine"]);
576        let token = cfg_a
577            .issue(&Claims {
578                sub: "u".to_string(),
579                iss: "assay".to_string(),
580                aud: "assay-engine".to_string(),
581                exp: future_exp(),
582            })
583            .unwrap();
584        // Build a fresh config with a different active key — verifying
585        // the prior token must fail because the kid isn't known here.
586        let cfg_b = JwtConfig::new("assay".to_string(), vec!["assay-engine".to_string()]);
587        let other = generate_ephemeral_ed25519("kid_b").unwrap();
588        cfg_b.set_active(other, Vec::new());
589        let result = cfg_b.verify::<Claims>(&token);
590        assert!(matches!(result, Err(Error::Jwt(_))));
591    }
592}