Skip to main content

modo/auth/session/jwt/
decoder.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use serde::de::DeserializeOwned;
5
6use crate::encoding::base64url;
7use crate::{Error, Result};
8
9use super::config::JwtSessionsConfig;
10use super::encoder::JwtEncoder;
11use super::error::JwtError;
12use super::signer::{HmacSigner, TokenVerifier};
13use super::validation::ValidationConfig;
14
15/// JWT token decoder. Verifies signatures and validates claims.
16///
17/// All validation is synchronous. Cloning is cheap — state is stored behind `Arc`.
18pub struct JwtDecoder {
19    inner: Arc<JwtDecoderInner>,
20}
21
22struct JwtDecoderInner {
23    verifier: Arc<dyn TokenVerifier>,
24    validation: ValidationConfig,
25}
26
27fn jwt_err(kind: JwtError) -> Error {
28    let status_fn = match kind {
29        JwtError::SigningFailed | JwtError::SerializationFailed => Error::internal,
30        _ => Error::unauthorized,
31    };
32    status_fn("unauthorized").chain(kind).with_code(kind.code())
33}
34
35impl JwtDecoder {
36    /// Creates a `JwtDecoder` with an explicit verifier and validation policy.
37    ///
38    /// Use this constructor when you need full control over the validation
39    /// config, e.g. to set `require_audience` or `leeway`.
40    ///
41    /// # Example
42    ///
43    /// ```rust,ignore
44    /// use std::sync::Arc;
45    /// use modo::auth::session::jwt::{HmacSigner, JwtDecoder, ValidationConfig};
46    ///
47    /// let signer = HmacSigner::new(b"my-secret");
48    /// let validation = ValidationConfig {
49    ///     require_audience: Some("my-app".into()),
50    ///     ..ValidationConfig::default()
51    /// };
52    /// let decoder = JwtDecoder::new(Arc::new(signer), validation);
53    /// ```
54    pub fn new(verifier: Arc<dyn TokenVerifier>, validation: ValidationConfig) -> Self {
55        Self {
56            inner: Arc::new(JwtDecoderInner {
57                verifier,
58                validation,
59            }),
60        }
61    }
62
63    /// Creates a `JwtDecoder` from YAML configuration.
64    ///
65    /// Uses `HmacSigner` (HS256) with the configured secret.
66    pub fn from_config(config: &JwtSessionsConfig) -> Self {
67        let signer = HmacSigner::new(config.signing_secret.as_bytes());
68        Self {
69            inner: Arc::new(JwtDecoderInner {
70                verifier: Arc::new(signer),
71                validation: ValidationConfig {
72                    leeway: Duration::ZERO,
73                    require_issuer: config.issuer.clone(),
74                    require_audience: None,
75                },
76            }),
77        }
78    }
79
80    /// Decodes and validates a JWT token string, returning `T`.
81    ///
82    /// The system auth flow passes [`Claims`](super::claims::Claims) as `T` and
83    /// gets a `Claims` back. Custom auth flows can pass any
84    /// `DeserializeOwned` struct directly.
85    ///
86    /// Validation order:
87    /// 1. Split into 3 parts (`header.payload.signature`)
88    /// 2. Decode header, check algorithm matches the verifier
89    /// 3. Verify HMAC signature
90    /// 4. Decode payload into JSON value
91    /// 5. Enforce `exp` (always required; missing `exp` is treated as expired)
92    /// 6. Check `nbf` (if present)
93    /// 7. Check `iss` (if `require_issuer` is configured)
94    /// 8. Check `aud` (if `require_audience` is configured)
95    /// 9. Deserialize validated JSON value into `T`
96    ///
97    /// Clock skew tolerance (`leeway`) is applied to steps 5 and 6.
98    ///
99    /// # Errors
100    ///
101    /// Returns `Error::unauthorized` with a [`JwtError`](super::JwtError) source for:
102    /// malformed tokens, invalid headers, algorithm mismatch, invalid signatures,
103    /// expired tokens, not-yet-valid tokens, issuer mismatch, or audience mismatch.
104    /// Missing `exp` is treated as expired.
105    pub fn decode<T: DeserializeOwned>(&self, token: &str) -> Result<T> {
106        let parts: Vec<&str> = token.splitn(4, '.').collect();
107        if parts.len() != 3 {
108            return Err(jwt_err(JwtError::MalformedToken));
109        }
110
111        let (header_b64, payload_b64, signature_b64) = (parts[0], parts[1], parts[2]);
112
113        // Decode and verify header
114        let header_bytes =
115            base64url::decode(header_b64).map_err(|_| jwt_err(JwtError::InvalidHeader))?;
116        let header: serde_json::Value =
117            serde_json::from_slice(&header_bytes).map_err(|_| jwt_err(JwtError::InvalidHeader))?;
118
119        let alg = header["alg"]
120            .as_str()
121            .ok_or_else(|| jwt_err(JwtError::InvalidHeader))?;
122        if alg != self.inner.verifier.algorithm_name() {
123            return Err(jwt_err(JwtError::AlgorithmMismatch));
124        }
125
126        // Verify signature
127        let signature =
128            base64url::decode(signature_b64).map_err(|_| jwt_err(JwtError::MalformedToken))?;
129        let header_payload = format!("{header_b64}.{payload_b64}");
130        self.inner
131            .verifier
132            .verify(header_payload.as_bytes(), &signature)?;
133
134        // Decode payload into a JSON value for validation
135        let payload_bytes =
136            base64url::decode(payload_b64).map_err(|_| jwt_err(JwtError::MalformedToken))?;
137        let payload: serde_json::Value = serde_json::from_slice(&payload_bytes)
138            .map_err(|_| jwt_err(JwtError::DeserializationFailed))?;
139
140        // Validate exp (always required)
141        let now = std::time::SystemTime::now()
142            .duration_since(std::time::UNIX_EPOCH)
143            .expect("system clock before UNIX epoch")
144            .as_secs();
145        let leeway = self.inner.validation.leeway.as_secs();
146
147        let exp = payload
148            .get("exp")
149            .and_then(|v| v.as_u64())
150            .ok_or_else(|| jwt_err(JwtError::Expired))?;
151        if now > exp + leeway {
152            return Err(jwt_err(JwtError::Expired));
153        }
154
155        // Validate nbf (if present)
156        if let Some(nbf) = payload.get("nbf").and_then(|v| v.as_u64())
157            && now + leeway < nbf
158        {
159            return Err(jwt_err(JwtError::NotYetValid));
160        }
161
162        // Validate iss (if policy requires it)
163        if let Some(ref required_iss) = self.inner.validation.require_issuer {
164            match payload.get("iss").and_then(|v| v.as_str()) {
165                Some(iss) if iss == required_iss => {}
166                _ => return Err(jwt_err(JwtError::InvalidIssuer)),
167            }
168        }
169
170        // Validate aud (if policy requires it)
171        if let Some(ref required_aud) = self.inner.validation.require_audience {
172            match payload.get("aud").and_then(|v| v.as_str()) {
173                Some(aud) if aud == required_aud => {}
174                _ => return Err(jwt_err(JwtError::InvalidAudience)),
175            }
176        }
177
178        // Deserialize into the requested type
179        serde_json::from_value(payload).map_err(|_| jwt_err(JwtError::DeserializationFailed))
180    }
181}
182
183/// Creates a `JwtDecoder` that shares the signing key and validation config
184/// of an existing `JwtEncoder`. Useful when encoder and decoder are wired
185/// from the same `JwtConfig` value.
186impl From<&JwtEncoder> for JwtDecoder {
187    fn from(encoder: &JwtEncoder) -> Self {
188        Self {
189            inner: Arc::new(JwtDecoderInner {
190                verifier: encoder.verifier(),
191                validation: encoder.validation(),
192            }),
193        }
194    }
195}
196
197impl Clone for JwtDecoder {
198    fn clone(&self) -> Self {
199        Self {
200            inner: self.inner.clone(),
201        }
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use serde::{Deserialize, Serialize};
209
210    use super::super::claims::Claims;
211    use super::super::encoder::JwtEncoder;
212
213    fn test_config() -> JwtSessionsConfig {
214        JwtSessionsConfig {
215            signing_secret: "test-secret-key-at-least-32-bytes-long!".into(),
216            ..JwtSessionsConfig::default()
217        }
218    }
219
220    fn encode_decode_config() -> (JwtEncoder, JwtDecoder) {
221        let config = test_config();
222        let encoder = JwtEncoder::from_config(&config);
223        let decoder = JwtDecoder::from_config(&config);
224        (encoder, decoder)
225    }
226
227    fn now_secs() -> u64 {
228        std::time::SystemTime::now()
229            .duration_since(std::time::UNIX_EPOCH)
230            .unwrap()
231            .as_secs()
232    }
233
234    #[test]
235    fn encode_decode_roundtrip() {
236        let (encoder, decoder) = encode_decode_config();
237        let claims = Claims::new().with_sub("user_1").with_exp(now_secs() + 3600);
238        let token = encoder.encode(&claims).unwrap();
239        let decoded: Claims = decoder.decode(&token).unwrap();
240        assert_eq!(decoded.sub, claims.sub);
241        assert_eq!(decoded.exp, claims.exp);
242    }
243
244    #[test]
245    fn rejects_expired_token() {
246        let (encoder, decoder) = encode_decode_config();
247        let claims = Claims::new().with_exp(now_secs() - 10);
248        let token = encoder.encode(&claims).unwrap();
249        let err = decoder.decode::<Claims>(&token).unwrap_err();
250        assert_eq!(err.error_code(), Some("jwt:expired"));
251    }
252
253    #[test]
254    fn respects_leeway_for_exp() {
255        // leeway is 0 in the new config — a token expired even 1s ago is rejected.
256        let (encoder, decoder) = encode_decode_config();
257        let claims = Claims::new().with_exp(now_secs() - 10);
258        let token = encoder.encode(&claims).unwrap();
259        // With leeway=0, a token expired 10s ago must be rejected.
260        let err = decoder.decode::<Claims>(&token).unwrap_err();
261        assert_eq!(err.error_code(), Some("jwt:expired"));
262    }
263
264    #[test]
265    fn rejects_token_before_nbf() {
266        let (encoder, decoder) = encode_decode_config();
267        let claims = Claims::new()
268            .with_exp(now_secs() + 3600)
269            .with_nbf(now_secs() + 3600);
270        let token = encoder.encode(&claims).unwrap();
271        let err = decoder.decode::<Claims>(&token).unwrap_err();
272        assert_eq!(err.error_code(), Some("jwt:not_yet_valid"));
273    }
274
275    #[test]
276    fn rejects_wrong_issuer() {
277        let mut config = test_config();
278        config.issuer = Some("expected-app".into());
279        let encoder = JwtEncoder::from_config(&config);
280        let decoder = JwtDecoder::from_config(&config);
281        let claims = Claims::new()
282            .with_exp(now_secs() + 3600)
283            .with_iss("wrong-app");
284        let token = encoder.encode(&claims).unwrap();
285        let err = decoder.decode::<Claims>(&token).unwrap_err();
286        assert_eq!(err.error_code(), Some("jwt:invalid_issuer"));
287    }
288
289    #[test]
290    fn rejects_missing_issuer_when_required() {
291        let mut config = test_config();
292        config.issuer = Some("expected-app".into());
293        let encoder = JwtEncoder::from_config(&config);
294        let decoder = JwtDecoder::from_config(&config);
295        let claims = Claims::new().with_exp(now_secs() + 3600);
296        let token = encoder.encode(&claims).unwrap();
297        let err = decoder.decode::<Claims>(&token).unwrap_err();
298        assert_eq!(err.error_code(), Some("jwt:invalid_issuer"));
299    }
300
301    #[test]
302    fn accepts_when_no_issuer_policy() {
303        let (encoder, decoder) = encode_decode_config();
304        let claims = Claims::new()
305            .with_exp(now_secs() + 3600)
306            .with_iss("any-app");
307        let token = encoder.encode(&claims).unwrap();
308        assert!(decoder.decode::<Claims>(&token).is_ok());
309    }
310
311    #[test]
312    fn rejects_wrong_audience() {
313        let config = test_config();
314        let encoder = JwtEncoder::from_config(&config);
315        let signer = HmacSigner::new(config.signing_secret.as_bytes());
316        let validation = super::super::validation::ValidationConfig::default()
317            .with_audience("expected-audience");
318        let decoder = JwtDecoder::new(Arc::new(signer), validation);
319        let claims = Claims::new()
320            .with_exp(now_secs() + 3600)
321            .with_aud("wrong-audience");
322        let token = encoder.encode(&claims).unwrap();
323        let err = decoder.decode::<Claims>(&token).unwrap_err();
324        assert_eq!(err.error_code(), Some("jwt:invalid_audience"));
325    }
326
327    #[test]
328    fn rejects_tampered_signature() {
329        let (encoder, decoder) = encode_decode_config();
330        let claims = Claims::new().with_exp(now_secs() + 3600);
331        let mut token = encoder.encode(&claims).unwrap();
332        // Flip a character well inside the signature (not in base64 padding region)
333        let idx = token.len() - 5;
334        let original = token.as_bytes()[idx];
335        let replacement = if original == b'A' { b'B' } else { b'A' };
336        // SAFETY: replacing one ASCII byte with another ASCII byte
337        unsafe { token.as_bytes_mut()[idx] = replacement };
338        let err = decoder.decode::<Claims>(&token).unwrap_err();
339        assert_eq!(err.error_code(), Some("jwt:invalid_signature"));
340    }
341
342    #[test]
343    fn rejects_malformed_token() {
344        let decoder = JwtDecoder::from_config(&test_config());
345        let err = decoder
346            .decode::<Claims>("not.a.valid.token.at.all")
347            .unwrap_err();
348        assert_eq!(err.error_code(), Some("jwt:malformed_token"));
349    }
350
351    #[test]
352    fn rejects_token_with_wrong_algorithm() {
353        let (encoder, _) = encode_decode_config();
354        let claims = Claims::new().with_exp(now_secs() + 3600);
355        let token = encoder.encode(&claims).unwrap();
356        // Replace HS256 with RS256 in the header
357        let parts: Vec<&str> = token.splitn(3, '.').collect();
358        let header_bytes = base64url::decode(parts[0]).unwrap();
359        let header_str = String::from_utf8(header_bytes).unwrap();
360        let tampered_header = header_str.replace("HS256", "RS256");
361        let tampered_header_b64 = base64url::encode(tampered_header.as_bytes());
362        let tampered_token = format!("{}.{}.{}", tampered_header_b64, parts[1], parts[2]);
363        let decoder = JwtDecoder::from_config(&test_config());
364        let err = decoder.decode::<Claims>(&tampered_token).unwrap_err();
365        assert_eq!(err.error_code(), Some("jwt:algorithm_mismatch"));
366    }
367
368    #[test]
369    fn rejects_missing_exp() {
370        // JwtEncoder::from_config() always injects exp via access_ttl_secs,
371        // so a token without exp can only be produced via the signer directly.
372        // The decoder still rejects tokens without exp — verified here via a
373        // manually crafted token.
374        use super::super::signer::{HmacSigner, TokenSigner};
375        use crate::encoding::base64url;
376
377        let config = test_config();
378        let signer = HmacSigner::new(config.signing_secret.as_bytes());
379        let header = base64url::encode(br#"{"alg":"HS256","typ":"JWT"}"#);
380        let payload = base64url::encode(br#"{"sub":"user_1"}"#); // no exp
381        let header_payload = format!("{header}.{payload}");
382        let sig = signer.sign(header_payload.as_bytes()).unwrap();
383        let sig_b64 = base64url::encode(&sig);
384        let token = format!("{header_payload}.{sig_b64}");
385
386        let decoder = JwtDecoder::from_config(&config);
387        let err = decoder.decode::<Claims>(&token).unwrap_err();
388        assert_eq!(err.error_code(), Some("jwt:expired"));
389    }
390
391    #[test]
392    fn from_encoder_shares_verifier() {
393        let config = test_config();
394        let encoder = JwtEncoder::from_config(&config);
395        let decoder = JwtDecoder::from(&encoder);
396        let claims = Claims::new().with_exp(now_secs() + 3600);
397        let token = encoder.encode(&claims).unwrap();
398        assert!(decoder.decode::<Claims>(&token).is_ok());
399    }
400
401    #[test]
402    fn decode_custom_struct_directly() {
403        #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
404        struct CustomPayload {
405            sub: String,
406            role: String,
407            exp: u64,
408        }
409
410        let (encoder, decoder) = encode_decode_config();
411        let payload = CustomPayload {
412            sub: "user_1".into(),
413            role: "admin".into(),
414            exp: now_secs() + 3600,
415        };
416        let token = encoder.encode(&payload).unwrap();
417        let decoded: CustomPayload = decoder.decode(&token).unwrap();
418        assert_eq!(decoded.sub, "user_1");
419        assert_eq!(decoded.role, "admin");
420    }
421}