Skip to main content

meerkat_mobkit/auth/
mod.rs

1//! JWT validation, JWKS caching, and OIDC discovery for API authentication.
2
3pub mod peer_keys;
4
5pub use peer_keys::{
6    GatewayPeerKeyError, GatewayPeerKeys, KEY_FILE_NAME as GATEWAY_PEER_KEY_FILE,
7    PubkeyDecodeError, decode_pubkey_b64,
8};
9
10use std::fmt::{Display, Formatter};
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13
14use base64::Engine;
15use base64::engine::general_purpose::URL_SAFE_NO_PAD;
16use hmac::{Hmac, Mac};
17use ring::signature::{self, RsaPublicKeyComponents, UnparsedPublicKey};
18use serde::{Deserialize, Serialize};
19use serde_json::Value;
20use sha2::Sha256;
21use tokio::sync::RwLock;
22
23type HmacSha256 = Hmac<Sha256>;
24
25#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
26pub struct JwtValidationConfig {
27    pub shared_secret: String,
28    pub issuer: Option<String>,
29    pub audience: Option<String>,
30    pub now_epoch_seconds: u64,
31    pub leeway_seconds: u64,
32}
33
34#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
35pub struct JwtClaimsValidationConfig {
36    pub issuer: Option<String>,
37    pub audience: Option<String>,
38    pub now_epoch_seconds: u64,
39    pub leeway_seconds: u64,
40}
41
42#[derive(Debug, Clone, PartialEq, Eq)]
43pub enum JwtVerificationKey {
44    Hs256(Vec<u8>),
45    Rs256 { modulus: Vec<u8>, exponent: Vec<u8> },
46    Es256P256 { public_key: Vec<u8> },
47}
48
49#[derive(Debug, Clone, PartialEq, Eq)]
50pub enum JwtValidationError {
51    InvalidFormat,
52    InvalidBase64,
53    InvalidJson,
54    UnsupportedAlgorithm(String),
55    InvalidSignature,
56    Expired,
57    NotYetValid,
58    IssuerMismatch,
59    AudienceMismatch,
60}
61
62#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
63pub struct ValidatedJwt {
64    pub subject: Option<String>,
65    pub email: Option<String>,
66    pub provider: Option<String>,
67    pub actor_type: Option<String>,
68    pub issuer: Option<String>,
69    pub audience: Vec<String>,
70    pub expires_at_epoch_seconds: Option<u64>,
71    pub not_before_epoch_seconds: Option<u64>,
72}
73
74#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
75struct JwtHeader {
76    alg: String,
77    #[serde(default)]
78    kid: Option<String>,
79}
80
81#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
82pub struct OidcDiscoveryDocument {
83    pub issuer: String,
84    pub jwks_uri: String,
85}
86
87#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
88pub struct JwksDocument {
89    pub keys: Vec<Jwk>,
90}
91
92#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
93pub struct Jwk {
94    pub kid: Option<String>,
95    pub kty: String,
96    #[serde(default)]
97    pub alg: Option<String>,
98    #[serde(default)]
99    pub k: Option<String>,
100    #[serde(default)]
101    pub n: Option<String>,
102    #[serde(default)]
103    pub e: Option<String>,
104    #[serde(default)]
105    pub crv: Option<String>,
106    #[serde(default)]
107    pub x: Option<String>,
108    #[serde(default)]
109    pub y: Option<String>,
110}
111
112#[derive(Debug, Clone, PartialEq, Eq)]
113pub enum OidcContractError {
114    InvalidJson,
115    MissingIssuer,
116    MissingJwksUri,
117    MissingKeys,
118    NoMatchingKey,
119    UnsupportedKeyType(String),
120    UnsupportedJwtAlgorithm(String),
121    MissingSymmetricKeyMaterial,
122    MissingRsaKeyMaterial,
123    MissingEcKeyMaterial,
124    InvalidKeyEncoding,
125    InvalidSymmetricKeyMaterial,
126    InvalidRsaKeyMaterial,
127    InvalidEcKeyMaterial,
128    UnsupportedEllipticCurve(String),
129}
130
131#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
132pub struct JwtHeaderView {
133    pub alg: String,
134    pub kid: Option<String>,
135}
136
137pub fn validate_jwt_locally(
138    token: &str,
139    config: &JwtValidationConfig,
140) -> Result<ValidatedJwt, JwtValidationError> {
141    let claims_config = JwtClaimsValidationConfig {
142        issuer: config.issuer.clone(),
143        audience: config.audience.clone(),
144        now_epoch_seconds: config.now_epoch_seconds,
145        leeway_seconds: config.leeway_seconds,
146    };
147    let key = JwtVerificationKey::Hs256(config.shared_secret.as_bytes().to_vec());
148    validate_jwt_with_verification_key(token, &key, &claims_config)
149}
150
151pub fn validate_jwt_with_verification_key(
152    token: &str,
153    verification_key: &JwtVerificationKey,
154    config: &JwtClaimsValidationConfig,
155) -> Result<ValidatedJwt, JwtValidationError> {
156    let parsed = parse_jwt(token)?;
157    if !key_supports_algorithm(verification_key, parsed.header.alg.as_str()) {
158        return Err(JwtValidationError::UnsupportedAlgorithm(parsed.header.alg));
159    }
160    verify_signature(
161        verification_key,
162        parsed.header.alg.as_str(),
163        parsed.signing_input.as_bytes(),
164        &parsed.signature,
165    )?;
166    validate_claims(&parsed.claims, config)
167}
168
169pub fn build_jwt_verification_key(
170    key: &Jwk,
171    alg: &str,
172) -> Result<JwtVerificationKey, OidcContractError> {
173    match alg {
174        "HS256" => build_hs256_key(key),
175        "RS256" => build_rs256_key(key),
176        "ES256" => build_es256_key(key),
177        unsupported => Err(OidcContractError::UnsupportedJwtAlgorithm(
178            unsupported.to_string(),
179        )),
180    }
181}
182
183fn build_hs256_key(key: &Jwk) -> Result<JwtVerificationKey, OidcContractError> {
184    if key.kty != "oct" {
185        return Err(OidcContractError::UnsupportedKeyType(key.kty.clone()));
186    }
187    let encoded = key
188        .k
189        .as_deref()
190        .ok_or(OidcContractError::MissingSymmetricKeyMaterial)?;
191    let bytes = URL_SAFE_NO_PAD
192        .decode(encoded)
193        .map_err(|_| OidcContractError::InvalidKeyEncoding)?;
194    Ok(JwtVerificationKey::Hs256(bytes))
195}
196
197fn build_rs256_key(key: &Jwk) -> Result<JwtVerificationKey, OidcContractError> {
198    if key.kty != "RSA" {
199        return Err(OidcContractError::UnsupportedKeyType(key.kty.clone()));
200    }
201    let modulus = key
202        .n
203        .as_deref()
204        .ok_or(OidcContractError::MissingRsaKeyMaterial)
205        .and_then(decode_key_component)?;
206    let exponent = key
207        .e
208        .as_deref()
209        .ok_or(OidcContractError::MissingRsaKeyMaterial)
210        .and_then(decode_key_component)?;
211    if modulus.is_empty() || exponent.is_empty() {
212        return Err(OidcContractError::InvalidRsaKeyMaterial);
213    }
214    Ok(JwtVerificationKey::Rs256 { modulus, exponent })
215}
216
217fn build_es256_key(key: &Jwk) -> Result<JwtVerificationKey, OidcContractError> {
218    if key.kty != "EC" {
219        return Err(OidcContractError::UnsupportedKeyType(key.kty.clone()));
220    }
221    let curve = key
222        .crv
223        .as_deref()
224        .ok_or(OidcContractError::MissingEcKeyMaterial)?;
225    if curve != "P-256" {
226        return Err(OidcContractError::UnsupportedEllipticCurve(
227            curve.to_string(),
228        ));
229    }
230
231    let x = key
232        .x
233        .as_deref()
234        .ok_or(OidcContractError::MissingEcKeyMaterial)
235        .and_then(decode_key_component)?;
236    let y = key
237        .y
238        .as_deref()
239        .ok_or(OidcContractError::MissingEcKeyMaterial)
240        .and_then(decode_key_component)?;
241    if x.len() != 32 || y.len() != 32 {
242        return Err(OidcContractError::InvalidEcKeyMaterial);
243    }
244
245    let mut public_key = Vec::with_capacity(65);
246    public_key.push(0x04);
247    public_key.extend_from_slice(&x);
248    public_key.extend_from_slice(&y);
249    Ok(JwtVerificationKey::Es256P256 { public_key })
250}
251
252fn decode_key_component(encoded: &str) -> Result<Vec<u8>, OidcContractError> {
253    URL_SAFE_NO_PAD
254        .decode(encoded)
255        .map_err(|_| OidcContractError::InvalidKeyEncoding)
256}
257
258struct ParsedJwt {
259    header: JwtHeader,
260    claims: Value,
261    signing_input: String,
262    signature: Vec<u8>,
263}
264
265fn parse_jwt(token: &str) -> Result<ParsedJwt, JwtValidationError> {
266    let parts: Vec<&str> = token.split('.').collect();
267    if parts.len() != 3 {
268        return Err(JwtValidationError::InvalidFormat);
269    }
270
271    let header_bytes = URL_SAFE_NO_PAD
272        .decode(parts[0])
273        .map_err(|_| JwtValidationError::InvalidBase64)?;
274    let payload_bytes = URL_SAFE_NO_PAD
275        .decode(parts[1])
276        .map_err(|_| JwtValidationError::InvalidBase64)?;
277    let signature = URL_SAFE_NO_PAD
278        .decode(parts[2])
279        .map_err(|_| JwtValidationError::InvalidBase64)?;
280
281    let header: JwtHeader =
282        serde_json::from_slice(&header_bytes).map_err(|_| JwtValidationError::InvalidJson)?;
283    let claims: Value =
284        serde_json::from_slice(&payload_bytes).map_err(|_| JwtValidationError::InvalidJson)?;
285    let signing_input = format!("{}.{}", parts[0], parts[1]);
286
287    Ok(ParsedJwt {
288        header,
289        claims,
290        signing_input,
291        signature,
292    })
293}
294
295fn key_supports_algorithm(key: &JwtVerificationKey, alg: &str) -> bool {
296    matches!(
297        (key, alg),
298        (JwtVerificationKey::Hs256(_), "HS256")
299            | (JwtVerificationKey::Rs256 { .. }, "RS256")
300            | (JwtVerificationKey::Es256P256 { .. }, "ES256")
301    )
302}
303
304fn verify_signature(
305    verification_key: &JwtVerificationKey,
306    alg: &str,
307    signing_input: &[u8],
308    signature: &[u8],
309) -> Result<(), JwtValidationError> {
310    match (verification_key, alg) {
311        (JwtVerificationKey::Hs256(secret), "HS256") => {
312            let mut mac = HmacSha256::new_from_slice(secret)
313                .map_err(|_| JwtValidationError::InvalidSignature)?;
314            mac.update(signing_input);
315            // Pre-fix: byte-wise loop short-circuited on first mismatch
316            // — a remote timing oracle for HMAC tag forgery one byte at
317            // a time. `Mac::verify_slice` performs the comparison in
318            // constant time via the underlying HMAC implementation.
319            mac.verify_slice(signature)
320                .map_err(|_| JwtValidationError::InvalidSignature)?;
321            Ok(())
322        }
323        (JwtVerificationKey::Rs256 { modulus, exponent }, "RS256") => {
324            let components = RsaPublicKeyComponents {
325                n: modulus.as_slice(),
326                e: exponent.as_slice(),
327            };
328            components
329                .verify(
330                    &signature::RSA_PKCS1_2048_8192_SHA256,
331                    signing_input,
332                    signature,
333                )
334                .map_err(|_| JwtValidationError::InvalidSignature)
335        }
336        (JwtVerificationKey::Es256P256 { public_key }, "ES256") => {
337            UnparsedPublicKey::new(&signature::ECDSA_P256_SHA256_FIXED, public_key.as_slice())
338                .verify(signing_input, signature)
339                .map_err(|_| JwtValidationError::InvalidSignature)
340        }
341        (_, unsupported) => Err(JwtValidationError::UnsupportedAlgorithm(
342            unsupported.to_string(),
343        )),
344    }
345}
346
347fn validate_claims(
348    claims: &Value,
349    config: &JwtClaimsValidationConfig,
350) -> Result<ValidatedJwt, JwtValidationError> {
351    let exp = claims.get("exp").and_then(Value::as_u64);
352    let nbf = claims.get("nbf").and_then(Value::as_u64);
353    let iss = claims
354        .get("iss")
355        .and_then(Value::as_str)
356        .map(ToString::to_string);
357    let aud = extract_audiences(claims);
358
359    if let Some(exp) = exp {
360        let threshold = config
361            .now_epoch_seconds
362            .saturating_sub(config.leeway_seconds);
363        if exp < threshold {
364            return Err(JwtValidationError::Expired);
365        }
366    }
367    if let Some(nbf) = nbf {
368        let threshold = config
369            .now_epoch_seconds
370            .saturating_add(config.leeway_seconds);
371        if nbf > threshold {
372            return Err(JwtValidationError::NotYetValid);
373        }
374    }
375
376    if let Some(expected_issuer) = &config.issuer
377        && iss.as_deref() != Some(expected_issuer.as_str())
378    {
379        return Err(JwtValidationError::IssuerMismatch);
380    }
381    if let Some(expected_audience) = &config.audience
382        && !aud.iter().any(|entry| entry == expected_audience)
383    {
384        return Err(JwtValidationError::AudienceMismatch);
385    }
386
387    Ok(ValidatedJwt {
388        subject: claims
389            .get("sub")
390            .and_then(Value::as_str)
391            .map(ToString::to_string),
392        email: claims
393            .get("email")
394            .and_then(Value::as_str)
395            .map(ToString::to_string),
396        provider: claims
397            .get("provider")
398            .and_then(Value::as_str)
399            .map(ToString::to_string),
400        actor_type: claims
401            .get("actor_type")
402            .and_then(Value::as_str)
403            .map(ToString::to_string),
404        issuer: iss,
405        audience: aud,
406        expires_at_epoch_seconds: exp,
407        not_before_epoch_seconds: nbf,
408    })
409}
410
411pub fn inspect_jwt_header(token: &str) -> Result<JwtHeaderView, JwtValidationError> {
412    let parts: Vec<&str> = token.split('.').collect();
413    if parts.len() != 3 {
414        return Err(JwtValidationError::InvalidFormat);
415    }
416    let header_bytes = URL_SAFE_NO_PAD
417        .decode(parts[0])
418        .map_err(|_| JwtValidationError::InvalidBase64)?;
419    let header: JwtHeader =
420        serde_json::from_slice(&header_bytes).map_err(|_| JwtValidationError::InvalidJson)?;
421    Ok(JwtHeaderView {
422        alg: header.alg,
423        kid: header.kid,
424    })
425}
426
427pub fn parse_oidc_discovery_json(
428    json_text: &str,
429) -> Result<OidcDiscoveryDocument, OidcContractError> {
430    let doc: OidcDiscoveryDocument =
431        serde_json::from_str(json_text).map_err(|_| OidcContractError::InvalidJson)?;
432    if doc.issuer.trim().is_empty() {
433        return Err(OidcContractError::MissingIssuer);
434    }
435    if doc.jwks_uri.trim().is_empty() {
436        return Err(OidcContractError::MissingJwksUri);
437    }
438    Ok(doc)
439}
440
441pub fn parse_jwks_json(json_text: &str) -> Result<JwksDocument, OidcContractError> {
442    let doc: JwksDocument =
443        serde_json::from_str(json_text).map_err(|_| OidcContractError::InvalidJson)?;
444    if doc.keys.is_empty() {
445        return Err(OidcContractError::MissingKeys);
446    }
447    Ok(doc)
448}
449
450pub fn select_jwk_for_token<'a>(
451    jwks: &'a JwksDocument,
452    kid: Option<&str>,
453    alg: &str,
454) -> Result<&'a Jwk, OidcContractError> {
455    if let Some(kid) = kid {
456        return jwks
457            .keys
458            .iter()
459            .find(|key| {
460                key.kid.as_deref() == Some(kid)
461                    && key.alg.as_deref().is_none_or(|key_alg| key_alg == alg)
462            })
463            .ok_or(OidcContractError::NoMatchingKey);
464    }
465
466    jwks.keys
467        .iter()
468        .find(|key| key.alg.as_deref().is_none_or(|key_alg| key_alg == alg))
469        .ok_or(OidcContractError::NoMatchingKey)
470}
471
472pub fn extract_hs256_shared_secret(key: &Jwk) -> Result<String, OidcContractError> {
473    let bytes = match build_hs256_key(key)? {
474        JwtVerificationKey::Hs256(bytes) => bytes,
475        _ => return Err(OidcContractError::InvalidSymmetricKeyMaterial),
476    };
477    String::from_utf8(bytes).map_err(|_| OidcContractError::InvalidSymmetricKeyMaterial)
478}
479
480fn extract_audiences(claims: &Value) -> Vec<String> {
481    match claims.get("aud") {
482        Some(Value::String(aud)) => vec![aud.clone()],
483        Some(Value::Array(values)) => values
484            .iter()
485            .filter_map(Value::as_str)
486            .map(ToString::to_string)
487            .collect(),
488        _ => Vec::new(),
489    }
490}
491
492// ---------------------------------------------------------------------------
493// JwksCache — runtime JWKS cache with discovery, rotation, kid-miss refresh
494// ---------------------------------------------------------------------------
495
496#[derive(Debug)]
497pub enum JwksCacheError {
498    Discovery(OidcContractError),
499    Http(String),
500    Validation(JwtValidationError),
501    NoMatchingKey,
502    NotInitialized,
503}
504
505impl Display for JwksCacheError {
506    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
507        match self {
508            Self::Discovery(err) => write!(f, "OIDC discovery error: {err:?}"),
509            Self::Http(msg) => write!(f, "HTTP fetch error: {msg}"),
510            Self::Validation(err) => write!(f, "JWT validation error: {err:?}"),
511            Self::NoMatchingKey => write!(f, "no matching JWK for token"),
512            Self::NotInitialized => write!(f, "JWKS cache not initialized"),
513        }
514    }
515}
516
517impl std::error::Error for JwksCacheError {}
518
519#[derive(Debug, Clone)]
520pub struct JwksCacheConfig {
521    pub discovery_url: String,
522    pub refresh_interval: Duration,
523    pub http_timeout: Duration,
524    pub issuer: Option<String>,
525    pub audience: Option<String>,
526    pub leeway_seconds: u64,
527}
528
529impl JwksCacheConfig {
530    pub fn new(discovery_url: String) -> Self {
531        Self {
532            discovery_url,
533            refresh_interval: Duration::from_hours(1),
534            http_timeout: Duration::from_secs(10),
535            issuer: None,
536            audience: None,
537            leeway_seconds: 60,
538        }
539    }
540}
541
542struct JwksCacheInner {
543    jwks: Option<JwksDocument>,
544    last_refresh: Option<Instant>,
545}
546
547#[derive(Clone)]
548pub struct JwksCache {
549    inner: Arc<RwLock<JwksCacheInner>>,
550    config: Arc<JwksCacheConfig>,
551    http_client: reqwest::Client,
552}
553
554impl JwksCache {
555    pub fn new(config: JwksCacheConfig) -> Self {
556        let http_client = reqwest::Client::builder()
557            .timeout(config.http_timeout)
558            .build()
559            .unwrap_or_else(|_| reqwest::Client::new());
560        Self {
561            inner: Arc::new(RwLock::new(JwksCacheInner {
562                jwks: None,
563                last_refresh: None,
564            })),
565            config: Arc::new(config),
566            http_client,
567        }
568    }
569
570    /// Fetch OIDC discovery document and JWKS keys. Called on first use or periodic refresh.
571    pub async fn refresh(&self) -> Result<(), JwksCacheError> {
572        let discovery_json = fetch_json(&self.http_client, &self.config.discovery_url).await?;
573        let discovery =
574            parse_oidc_discovery_json(&discovery_json).map_err(JwksCacheError::Discovery)?;
575
576        let jwks_json = fetch_json(&self.http_client, &discovery.jwks_uri).await?;
577        let jwks = parse_jwks_json(&jwks_json).map_err(JwksCacheError::Discovery)?;
578
579        let mut inner = self.inner.write().await;
580        inner.jwks = Some(jwks);
581        inner.last_refresh = Some(Instant::now());
582        Ok(())
583    }
584
585    /// Validate a JWT token using cached JWKS. Refreshes on kid miss.
586    pub async fn validate_token(&self, token: &str) -> Result<ValidatedJwt, JwksCacheError> {
587        self.maybe_refresh().await?;
588
589        let header = inspect_jwt_header(token).map_err(JwksCacheError::Validation)?;
590
591        // First attempt: try to find key in current cache.
592        match self.try_validate(token, &header).await {
593            Ok(jwt) => return Ok(jwt),
594            Err(JwksCacheError::NoMatchingKey) => {
595                // Kid miss — force refresh and retry once.
596            }
597            Err(err) => return Err(err),
598        }
599
600        self.refresh().await?;
601        self.try_validate(token, &header).await
602    }
603
604    async fn try_validate(
605        &self,
606        token: &str,
607        header: &JwtHeaderView,
608    ) -> Result<ValidatedJwt, JwksCacheError> {
609        let inner = self.inner.read().await;
610        let jwks = inner.jwks.as_ref().ok_or(JwksCacheError::NotInitialized)?;
611
612        let jwk = select_jwk_for_token(jwks, header.kid.as_deref(), &header.alg)
613            .map_err(|_| JwksCacheError::NoMatchingKey)?;
614
615        let verification_key =
616            build_jwt_verification_key(jwk, &header.alg).map_err(JwksCacheError::Discovery)?;
617
618        let now = std::time::SystemTime::now()
619            .duration_since(std::time::UNIX_EPOCH)
620            .unwrap_or_default()
621            .as_secs();
622
623        let claims_config = JwtClaimsValidationConfig {
624            issuer: self.config.issuer.clone(),
625            audience: self.config.audience.clone(),
626            now_epoch_seconds: now,
627            leeway_seconds: self.config.leeway_seconds,
628        };
629
630        validate_jwt_with_verification_key(token, &verification_key, &claims_config)
631            .map_err(JwksCacheError::Validation)
632    }
633
634    /// Check if cache needs periodic refresh and refresh if needed.
635    async fn maybe_refresh(&self) -> Result<(), JwksCacheError> {
636        let needs_refresh = {
637            let inner = self.inner.read().await;
638            match inner.last_refresh {
639                Some(last) => last.elapsed() >= self.config.refresh_interval,
640                None => true,
641            }
642        };
643        if needs_refresh {
644            self.refresh().await?;
645        }
646        Ok(())
647    }
648}
649
650async fn fetch_json(client: &reqwest::Client, url: &str) -> Result<String, JwksCacheError> {
651    let response = client
652        .get(url)
653        .send()
654        .await
655        .map_err(|err| JwksCacheError::Http(format!("{err}")))?;
656    let status = response.status();
657    if !status.is_success() {
658        return Err(JwksCacheError::Http(format!("HTTP {status} from {url}")));
659    }
660    response
661        .text()
662        .await
663        .map_err(|err| JwksCacheError::Http(format!("{err}")))
664}