Skip to main content

axon/
jwt_verifier.rs

1//! JWT signature verification + JWKS client.
2//!
3//! Closes the §Fase 10.e gap where `tenant.rs` previously extracted
4//! `tenant_id` from the payload without checking the signature.
5//!
6//! Wire contract (must match Python `axon_enterprise.jwt_issuer`):
7//!
8//! - `alg` ∈ { RS256, RS384, RS512 }. HS* / `none` / ES* rejected.
9//! - `iss` must equal the configured issuer.
10//! - `aud` must contain the configured audience.
11//! - `exp` / `nbf` / `iat` validated with configurable clock-skew leeway.
12//! - `tenant_id` claim is required — absence means the token is
13//!   structurally valid but not usable for tenant extraction, so we
14//!   treat it as rejection.
15//!
16//! JWKS is fetched lazily and cached for `jwks_ttl` seconds. On a `kid`
17//! miss we force-refresh once — matches Python's behaviour so IdP
18//! rotation (new kid published minutes before first use) works
19//! transparently.
20
21use std::sync::Arc;
22use std::time::{Duration, Instant};
23
24use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
25use serde::Deserialize;
26use serde_json::Value;
27use tokio::sync::Mutex;
28
29/// Errors the verifier can surface. Mapped to HTTP 401 by the middleware.
30///
31/// Hand-rolled (no `thiserror` dep) to keep the Rust crate's dependency
32/// surface minimal.
33#[derive(Debug)]
34pub enum JwtVerifyError {
35    UnsupportedAlg(String),
36    MissingKid,
37    UnknownKid { kid: String },
38    JwksFetchFailed(String),
39    Invalid(String),
40    ClaimMissing(&'static str),
41}
42
43impl std::fmt::Display for JwtVerifyError {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        match self {
46            Self::UnsupportedAlg(a) => write!(f, "unsupported algorithm: {a}"),
47            Self::MissingKid => write!(f, "missing kid in JWT header"),
48            Self::UnknownKid { kid } => write!(f, "kid {kid:?} not in JWKS after refresh"),
49            Self::JwksFetchFailed(m) => write!(f, "JWKS fetch failed: {m}"),
50            Self::Invalid(m) => write!(f, "signature or claim validation failed: {m}"),
51            Self::ClaimMissing(n) => write!(f, "required claim missing: {n}"),
52        }
53    }
54}
55
56impl std::error::Error for JwtVerifyError {}
57
58/// Subset of JWT claims the tenant extractor cares about. Additional
59/// claims present in the token (roles, plan, jti, …) are available
60/// via [`VerifiedToken::claims`].
61#[derive(Debug, Clone, Deserialize)]
62pub struct MinimalClaims {
63    #[serde(rename = "tenant_id")]
64    pub tenant_id: String,
65}
66
67#[derive(Debug, Clone)]
68pub struct VerifiedToken {
69    pub tenant_id: String,
70    pub plan: Option<String>,
71    pub roles: Vec<String>,
72    pub jti: Option<String>,
73    pub sub: Option<String>,
74    pub claims: Value,
75}
76
77/// Configuration resolved at startup from env vars.
78#[derive(Debug, Clone)]
79pub struct JwtVerifierConfig {
80    /// Expected `iss` claim.
81    pub issuer: String,
82    /// Expected `aud` claim.
83    pub audience: String,
84    /// Absolute URL of the JWKS document (typically
85    /// `https://auth.<host>/.well-known/jwks.json`).
86    pub jwks_url: String,
87    /// Duration a cached JWKS document is trusted without a refresh.
88    pub jwks_ttl: Duration,
89    /// Clock-skew leeway in seconds applied to exp / nbf / iat.
90    pub leeway_secs: u64,
91    /// When true, missing JWTs cause a 401; when false (default for
92    /// pre-10.e deployments still rolling out), the middleware falls
93    /// back to header-based + unverified-payload extraction with a
94    /// warning log. Production deployments flip this to true in 10.j.
95    pub enforce: bool,
96}
97
98impl JwtVerifierConfig {
99    /// Build the config from env vars, returning `None` when the JWKS
100    /// URL is unset — the middleware treats `None` as "no verifier
101    /// configured" and keeps the legacy behaviour for tests / OSS users.
102    pub fn from_env() -> Option<Self> {
103        let jwks_url = std::env::var("AXON_JWT_JWKS_URL").ok().filter(|s| !s.is_empty())?;
104        let issuer = std::env::var("AXON_JWT_ISSUER")
105            .unwrap_or_else(|_| "https://auth.bemarking.com".into());
106        let audience =
107            std::env::var("AXON_JWT_AUDIENCE").unwrap_or_else(|_| "axon-api".into());
108        let jwks_ttl_secs: u64 = std::env::var("AXON_JWT_JWKS_TTL_SECONDS")
109            .ok()
110            .and_then(|v| v.parse().ok())
111            .unwrap_or(600);
112        let leeway_secs: u64 = std::env::var("AXON_JWT_LEEWAY_SECONDS")
113            .ok()
114            .and_then(|v| v.parse().ok())
115            .unwrap_or(60);
116        let enforce = std::env::var("AXON_ENFORCE_JWT_VERIFICATION")
117            .ok()
118            .map(|v| matches!(v.as_str(), "1" | "true" | "TRUE" | "yes"))
119            .unwrap_or(true);
120        Some(Self {
121            issuer,
122            audience,
123            jwks_url,
124            jwks_ttl: Duration::from_secs(jwks_ttl_secs),
125            leeway_secs,
126            enforce,
127        })
128    }
129}
130
131// ── JWKS cache ──────────────────────────────────────────────────────────────
132
133#[derive(Debug, Clone, Deserialize)]
134struct JwksEntry {
135    kid: String,
136    kty: String,
137    alg: Option<String>,
138    n: Option<String>,
139    e: Option<String>,
140}
141
142#[derive(Debug, Clone, Deserialize)]
143struct JwksDocument {
144    keys: Vec<JwksEntry>,
145}
146
147struct CacheSlot {
148    loaded_at: Instant,
149    keys: Vec<JwksEntry>,
150}
151
152/// Thread-safe JWKS fetcher with TTL + rotation-on-miss.
153pub struct JwksClient {
154    url: String,
155    ttl: Duration,
156    http: reqwest::Client,
157    slot: Mutex<Option<CacheSlot>>,
158}
159
160impl JwksClient {
161    pub fn new(url: String, ttl: Duration) -> Self {
162        Self {
163            url,
164            ttl,
165            http: reqwest::Client::builder()
166                .timeout(Duration::from_secs(5))
167                .build()
168                .expect("reqwest client"),
169            slot: Mutex::new(None),
170        }
171    }
172
173    async fn resolve_key(&self, kid: &str) -> Result<JwksEntry, JwtVerifyError> {
174        {
175            let slot = self.slot.lock().await;
176            if let Some(c) = slot.as_ref() {
177                if c.loaded_at.elapsed() < self.ttl {
178                    if let Some(k) = c.keys.iter().find(|k| k.kid == kid) {
179                        return Ok(k.clone());
180                    }
181                }
182            }
183        }
184        self.refresh().await?;
185        let slot = self.slot.lock().await;
186        let cache = slot.as_ref().ok_or_else(|| {
187            JwtVerifyError::JwksFetchFailed("empty cache after refresh".into())
188        })?;
189        cache
190            .keys
191            .iter()
192            .find(|k| k.kid == kid)
193            .cloned()
194            .ok_or_else(|| JwtVerifyError::UnknownKid { kid: kid.to_string() })
195    }
196
197    async fn refresh(&self) -> Result<(), JwtVerifyError> {
198        let resp = self
199            .http
200            .get(&self.url)
201            .header("Accept", "application/json")
202            .send()
203            .await
204            .map_err(|e| JwtVerifyError::JwksFetchFailed(e.to_string()))?;
205        if !resp.status().is_success() {
206            return Err(JwtVerifyError::JwksFetchFailed(format!(
207                "HTTP {}",
208                resp.status()
209            )));
210        }
211        let doc: JwksDocument = resp
212            .json()
213            .await
214            .map_err(|e| JwtVerifyError::JwksFetchFailed(e.to_string()))?;
215        let mut slot = self.slot.lock().await;
216        *slot = Some(CacheSlot {
217            loaded_at: Instant::now(),
218            keys: doc.keys,
219        });
220        Ok(())
221    }
222}
223
224// ── Verifier ────────────────────────────────────────────────────────────────
225
226pub struct JwtVerifier {
227    cfg: JwtVerifierConfig,
228    jwks: Arc<JwksClient>,
229}
230
231impl JwtVerifier {
232    pub fn new(cfg: JwtVerifierConfig) -> Self {
233        let jwks = Arc::new(JwksClient::new(cfg.jwks_url.clone(), cfg.jwks_ttl));
234        Self { cfg, jwks }
235    }
236
237    pub fn config(&self) -> &JwtVerifierConfig {
238        &self.cfg
239    }
240
241    pub async fn verify(&self, token: &str) -> Result<VerifiedToken, JwtVerifyError> {
242        let header =
243            decode_header(token).map_err(|e| JwtVerifyError::Invalid(e.to_string()))?;
244        let alg = match header.alg {
245            Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => header.alg,
246            other => {
247                return Err(JwtVerifyError::UnsupportedAlg(format!("{other:?}")));
248            }
249        };
250        let kid = header.kid.ok_or(JwtVerifyError::MissingKid)?;
251        let entry = self.jwks.resolve_key(&kid).await?;
252
253        if entry.kty != "RSA" {
254            return Err(JwtVerifyError::UnsupportedAlg(format!(
255                "non-RSA JWK kty={}",
256                entry.kty
257            )));
258        }
259
260        let n = entry.n.ok_or_else(|| {
261            JwtVerifyError::Invalid("JWK missing modulus".into())
262        })?;
263        let e = entry.e.ok_or_else(|| {
264            JwtVerifyError::Invalid("JWK missing exponent".into())
265        })?;
266        let key = DecodingKey::from_rsa_components(&n, &e)
267            .map_err(|err| JwtVerifyError::Invalid(err.to_string()))?;
268
269        let mut validation = Validation::new(alg);
270        validation.set_issuer(&[self.cfg.issuer.clone()]);
271        validation.set_audience(&[self.cfg.audience.clone()]);
272        validation.leeway = self.cfg.leeway_secs;
273        validation.validate_exp = true;
274        validation.validate_nbf = true;
275        validation.required_spec_claims =
276            ["iss", "aud", "exp", "iat", "sub"].iter().map(|s| s.to_string()).collect();
277
278        let data = decode::<Value>(token, &key, &validation)
279            .map_err(|err| JwtVerifyError::Invalid(err.to_string()))?;
280        let claims = data.claims;
281
282        let tenant_id = claims
283            .get("tenant_id")
284            .and_then(|v| v.as_str())
285            .ok_or(JwtVerifyError::ClaimMissing("tenant_id"))?
286            .to_string();
287        let plan = claims.get("plan").and_then(|v| v.as_str()).map(String::from);
288        let roles = claims
289            .get("roles")
290            .and_then(|v| v.as_array())
291            .map(|arr| {
292                arr.iter()
293                    .filter_map(|v| v.as_str().map(String::from))
294                    .collect()
295            })
296            .unwrap_or_default();
297        let jti = claims.get("jti").and_then(|v| v.as_str()).map(String::from);
298        let sub = claims.get("sub").and_then(|v| v.as_str()).map(String::from);
299
300        Ok(VerifiedToken {
301            tenant_id,
302            plan,
303            roles,
304            jti,
305            sub,
306            claims,
307        })
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314
315    #[test]
316    fn config_from_env_requires_jwks_url() {
317        // Safety: isolate from other tests that may set the env var.
318        let prev = std::env::var("AXON_JWT_JWKS_URL").ok();
319        std::env::remove_var("AXON_JWT_JWKS_URL");
320        assert!(JwtVerifierConfig::from_env().is_none());
321        if let Some(v) = prev {
322            std::env::set_var("AXON_JWT_JWKS_URL", v);
323        }
324    }
325
326    #[test]
327    fn config_from_env_reads_values() {
328        std::env::set_var("AXON_JWT_JWKS_URL", "https://x/jwks.json");
329        std::env::set_var("AXON_JWT_ISSUER", "https://x");
330        std::env::set_var("AXON_JWT_AUDIENCE", "x-api");
331        let cfg = JwtVerifierConfig::from_env().unwrap();
332        assert_eq!(cfg.issuer, "https://x");
333        assert_eq!(cfg.audience, "x-api");
334        std::env::remove_var("AXON_JWT_JWKS_URL");
335        std::env::remove_var("AXON_JWT_ISSUER");
336        std::env::remove_var("AXON_JWT_AUDIENCE");
337    }
338}