Skip to main content

modo/auth/session/jwt/
decoder.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use serde::de::DeserializeOwned;
5
6use crate::encoding::base64url;
7use crate::{Error, Result};
8
9use super::config::JwtSessionsConfig;
10use super::encoder::JwtEncoder;
11use super::error::JwtError;
12use super::signer::{HmacSigner, TokenVerifier};
13use super::validation::ValidationConfig;
14
15/// JWT token decoder. Verifies signatures and validates claims.
16///
17/// All validation is synchronous. Cloning is cheap — state is stored behind `Arc`.
18pub struct JwtDecoder {
19    inner: Arc<JwtDecoderInner>,
20}
21
22struct JwtDecoderInner {
23    verifier: Arc<dyn TokenVerifier>,
24    validation: ValidationConfig,
25}
26
27pub(super) fn jwt_err(kind: JwtError) -> Error {
28    let status_fn = match kind {
29        JwtError::SigningFailed | JwtError::SerializationFailed => Error::internal,
30        _ => Error::unauthorized,
31    };
32    status_fn("unauthorized").chain(kind).with_code(kind.code())
33}
34
35/// Build an `Error::unauthorized` carrying a stable error code string.
36/// Used by middleware/extractors for `auth:*` codes that don't have a typed source.
37pub(super) fn auth_err(code: &'static str) -> Error {
38    Error::unauthorized("unauthorized").with_code(code)
39}
40
41impl JwtDecoder {
42    /// Creates a `JwtDecoder` with an explicit verifier and validation policy.
43    ///
44    /// Use this constructor when you need full control over the validation
45    /// config, e.g. to set `require_audience` or `leeway`.
46    ///
47    /// # Example
48    ///
49    /// ```rust,ignore
50    /// use std::sync::Arc;
51    /// use modo::auth::session::jwt::{HmacSigner, JwtDecoder, ValidationConfig};
52    ///
53    /// let signer = HmacSigner::new(b"my-secret");
54    /// let validation = ValidationConfig {
55    ///     require_audience: Some("my-app".into()),
56    ///     ..ValidationConfig::default()
57    /// };
58    /// let decoder = JwtDecoder::new(Arc::new(signer), validation);
59    /// ```
60    pub fn new(verifier: Arc<dyn TokenVerifier>, validation: ValidationConfig) -> Self {
61        Self {
62            inner: Arc::new(JwtDecoderInner {
63                verifier,
64                validation,
65            }),
66        }
67    }
68
69    /// Creates a `JwtDecoder` from YAML configuration.
70    ///
71    /// Uses `HmacSigner` (HS256) with the configured secret.
72    pub fn from_config(config: &JwtSessionsConfig) -> Self {
73        let signer = HmacSigner::new(config.signing_secret.as_bytes());
74        Self {
75            inner: Arc::new(JwtDecoderInner {
76                verifier: Arc::new(signer),
77                validation: ValidationConfig {
78                    leeway: Duration::ZERO,
79                    require_issuer: config.issuer.clone(),
80                    require_audience: None,
81                },
82            }),
83        }
84    }
85
86    /// Decodes and validates a JWT token string, returning `T`.
87    ///
88    /// The system auth flow passes [`Claims`](super::claims::Claims) as `T` and
89    /// gets a `Claims` back. Custom auth flows can pass any
90    /// `DeserializeOwned` struct directly.
91    ///
92    /// Validation order:
93    /// 1. Split into 3 parts (`header.payload.signature`)
94    /// 2. Decode header, check algorithm matches the verifier
95    /// 3. Verify HMAC signature
96    /// 4. Decode payload into JSON value
97    /// 5. Enforce `exp` (always required; missing `exp` is treated as expired)
98    /// 6. Check `nbf` (if present)
99    /// 7. Check `iss` (if `require_issuer` is configured)
100    /// 8. Check `aud` (if `require_audience` is configured)
101    /// 9. Deserialize validated JSON value into `T`
102    ///
103    /// Clock skew tolerance (`leeway`) is applied to steps 5 and 6.
104    ///
105    /// # Errors
106    ///
107    /// Returns `Error::unauthorized` with a [`JwtError`](super::JwtError) source for:
108    /// malformed tokens, invalid headers, algorithm mismatch, invalid signatures,
109    /// expired tokens, not-yet-valid tokens, issuer mismatch, or audience mismatch.
110    /// Missing `exp` is treated as expired.
111    pub fn decode<T: DeserializeOwned>(&self, token: &str) -> Result<T> {
112        let mut iter = token.splitn(4, '.');
113        let (Some(header_b64), Some(payload_b64), Some(signature_b64), None) =
114            (iter.next(), iter.next(), iter.next(), iter.next())
115        else {
116            return Err(jwt_err(JwtError::MalformedToken));
117        };
118
119        let header_bytes =
120            base64url::decode(header_b64).map_err(|_| jwt_err(JwtError::InvalidHeader))?;
121        let header: serde_json::Value =
122            serde_json::from_slice(&header_bytes).map_err(|_| jwt_err(JwtError::InvalidHeader))?;
123
124        let alg = header["alg"]
125            .as_str()
126            .ok_or_else(|| jwt_err(JwtError::InvalidHeader))?;
127        if alg != self.inner.verifier.algorithm_name() {
128            return Err(jwt_err(JwtError::AlgorithmMismatch));
129        }
130
131        let signature =
132            base64url::decode(signature_b64).map_err(|_| jwt_err(JwtError::MalformedToken))?;
133        let mut header_payload = Vec::with_capacity(header_b64.len() + 1 + payload_b64.len());
134        header_payload.extend_from_slice(header_b64.as_bytes());
135        header_payload.push(b'.');
136        header_payload.extend_from_slice(payload_b64.as_bytes());
137        self.inner.verifier.verify(&header_payload, &signature)?;
138
139        let payload_bytes =
140            base64url::decode(payload_b64).map_err(|_| jwt_err(JwtError::MalformedToken))?;
141        let payload: serde_json::Value = serde_json::from_slice(&payload_bytes)
142            .map_err(|_| jwt_err(JwtError::DeserializationFailed))?;
143
144        let now = std::time::SystemTime::now()
145            .duration_since(std::time::UNIX_EPOCH)
146            .expect("system clock before UNIX epoch")
147            .as_secs();
148        let leeway = self.inner.validation.leeway.as_secs();
149
150        let exp = payload
151            .get("exp")
152            .and_then(|v| v.as_u64())
153            .ok_or_else(|| jwt_err(JwtError::Expired))?;
154        if now > exp + leeway {
155            return Err(jwt_err(JwtError::Expired));
156        }
157
158        if let Some(nbf) = payload.get("nbf").and_then(|v| v.as_u64())
159            && now + leeway < nbf
160        {
161            return Err(jwt_err(JwtError::NotYetValid));
162        }
163
164        if let Some(ref required_iss) = self.inner.validation.require_issuer {
165            match payload.get("iss").and_then(|v| v.as_str()) {
166                Some(iss) if iss == required_iss => {}
167                _ => return Err(jwt_err(JwtError::InvalidIssuer)),
168            }
169        }
170
171        if let Some(ref required_aud) = self.inner.validation.require_audience {
172            match payload.get("aud").and_then(|v| v.as_str()) {
173                Some(aud) if aud == required_aud => {}
174                _ => return Err(jwt_err(JwtError::InvalidAudience)),
175            }
176        }
177
178        serde_json::from_value(payload).map_err(|_| jwt_err(JwtError::DeserializationFailed))
179    }
180}
181
182/// Creates a `JwtDecoder` that shares the signing key and validation config
183/// of an existing `JwtEncoder`. Useful when encoder and decoder are wired
184/// from the same `JwtConfig` value.
185impl From<&JwtEncoder> for JwtDecoder {
186    fn from(encoder: &JwtEncoder) -> Self {
187        Self {
188            inner: Arc::new(JwtDecoderInner {
189                verifier: encoder.verifier(),
190                validation: encoder.validation(),
191            }),
192        }
193    }
194}
195
196impl Clone for JwtDecoder {
197    fn clone(&self) -> Self {
198        Self {
199            inner: self.inner.clone(),
200        }
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207    use serde::{Deserialize, Serialize};
208
209    use super::super::claims::Claims;
210    use super::super::encoder::JwtEncoder;
211
212    fn test_config() -> JwtSessionsConfig {
213        JwtSessionsConfig {
214            signing_secret: "test-secret-key-at-least-32-bytes-long!".into(),
215            ..JwtSessionsConfig::default()
216        }
217    }
218
219    fn encode_decode_config() -> (JwtEncoder, JwtDecoder) {
220        let config = test_config();
221        let encoder = JwtEncoder::from_config(&config);
222        let decoder = JwtDecoder::from_config(&config);
223        (encoder, decoder)
224    }
225
226    fn now_secs() -> u64 {
227        std::time::SystemTime::now()
228            .duration_since(std::time::UNIX_EPOCH)
229            .unwrap()
230            .as_secs()
231    }
232
233    #[test]
234    fn encode_decode_roundtrip() {
235        let (encoder, decoder) = encode_decode_config();
236        let claims = Claims::new().with_sub("user_1").with_exp(now_secs() + 3600);
237        let token = encoder.encode(&claims).unwrap();
238        let decoded: Claims = decoder.decode(&token).unwrap();
239        assert_eq!(decoded.sub, claims.sub);
240        assert_eq!(decoded.exp, claims.exp);
241    }
242
243    #[test]
244    fn rejects_expired_token() {
245        let (encoder, decoder) = encode_decode_config();
246        let claims = Claims::new().with_exp(now_secs() - 10);
247        let token = encoder.encode(&claims).unwrap();
248        let err = decoder.decode::<Claims>(&token).unwrap_err();
249        assert_eq!(err.error_code(), Some("jwt:expired"));
250    }
251
252    #[test]
253    fn respects_leeway_for_exp() {
254        // leeway is 0 in the new config — a token expired even 1s ago is rejected.
255        let (encoder, decoder) = encode_decode_config();
256        let claims = Claims::new().with_exp(now_secs() - 10);
257        let token = encoder.encode(&claims).unwrap();
258        // With leeway=0, a token expired 10s ago must be rejected.
259        let err = decoder.decode::<Claims>(&token).unwrap_err();
260        assert_eq!(err.error_code(), Some("jwt:expired"));
261    }
262
263    #[test]
264    fn rejects_token_before_nbf() {
265        let (encoder, decoder) = encode_decode_config();
266        let claims = Claims::new()
267            .with_exp(now_secs() + 3600)
268            .with_nbf(now_secs() + 3600);
269        let token = encoder.encode(&claims).unwrap();
270        let err = decoder.decode::<Claims>(&token).unwrap_err();
271        assert_eq!(err.error_code(), Some("jwt:not_yet_valid"));
272    }
273
274    #[test]
275    fn rejects_wrong_issuer() {
276        let mut config = test_config();
277        config.issuer = Some("expected-app".into());
278        let encoder = JwtEncoder::from_config(&config);
279        let decoder = JwtDecoder::from_config(&config);
280        let claims = Claims::new()
281            .with_exp(now_secs() + 3600)
282            .with_iss("wrong-app");
283        let token = encoder.encode(&claims).unwrap();
284        let err = decoder.decode::<Claims>(&token).unwrap_err();
285        assert_eq!(err.error_code(), Some("jwt:invalid_issuer"));
286    }
287
288    #[test]
289    fn rejects_missing_issuer_when_required() {
290        let mut config = test_config();
291        config.issuer = Some("expected-app".into());
292        let encoder = JwtEncoder::from_config(&config);
293        let decoder = JwtDecoder::from_config(&config);
294        let claims = Claims::new().with_exp(now_secs() + 3600);
295        let token = encoder.encode(&claims).unwrap();
296        let err = decoder.decode::<Claims>(&token).unwrap_err();
297        assert_eq!(err.error_code(), Some("jwt:invalid_issuer"));
298    }
299
300    #[test]
301    fn accepts_when_no_issuer_policy() {
302        let (encoder, decoder) = encode_decode_config();
303        let claims = Claims::new()
304            .with_exp(now_secs() + 3600)
305            .with_iss("any-app");
306        let token = encoder.encode(&claims).unwrap();
307        assert!(decoder.decode::<Claims>(&token).is_ok());
308    }
309
310    #[test]
311    fn rejects_wrong_audience() {
312        let config = test_config();
313        let encoder = JwtEncoder::from_config(&config);
314        let signer = HmacSigner::new(config.signing_secret.as_bytes());
315        let validation = super::super::validation::ValidationConfig::default()
316            .with_audience("expected-audience");
317        let decoder = JwtDecoder::new(Arc::new(signer), validation);
318        let claims = Claims::new()
319            .with_exp(now_secs() + 3600)
320            .with_aud("wrong-audience");
321        let token = encoder.encode(&claims).unwrap();
322        let err = decoder.decode::<Claims>(&token).unwrap_err();
323        assert_eq!(err.error_code(), Some("jwt:invalid_audience"));
324    }
325
326    #[test]
327    fn rejects_tampered_signature() {
328        let (encoder, decoder) = encode_decode_config();
329        let claims = Claims::new().with_exp(now_secs() + 3600);
330        let mut token = encoder.encode(&claims).unwrap();
331        // Flip a character well inside the signature (not in base64 padding region)
332        let idx = token.len() - 5;
333        let original = token.as_bytes()[idx];
334        let replacement = if original == b'A' { b'B' } else { b'A' };
335        // SAFETY: replacing one ASCII byte with another ASCII byte
336        unsafe { token.as_bytes_mut()[idx] = replacement };
337        let err = decoder.decode::<Claims>(&token).unwrap_err();
338        assert_eq!(err.error_code(), Some("jwt:invalid_signature"));
339    }
340
341    #[test]
342    fn rejects_malformed_token() {
343        let decoder = JwtDecoder::from_config(&test_config());
344        let err = decoder
345            .decode::<Claims>("not.a.valid.token.at.all")
346            .unwrap_err();
347        assert_eq!(err.error_code(), Some("jwt:malformed_token"));
348    }
349
350    #[test]
351    fn rejects_token_with_wrong_algorithm() {
352        let (encoder, _) = encode_decode_config();
353        let claims = Claims::new().with_exp(now_secs() + 3600);
354        let token = encoder.encode(&claims).unwrap();
355        // Replace HS256 with RS256 in the header
356        let parts: Vec<&str> = token.splitn(3, '.').collect();
357        let header_bytes = base64url::decode(parts[0]).unwrap();
358        let header_str = String::from_utf8(header_bytes).unwrap();
359        let tampered_header = header_str.replace("HS256", "RS256");
360        let tampered_header_b64 = base64url::encode(tampered_header.as_bytes());
361        let tampered_token = format!("{}.{}.{}", tampered_header_b64, parts[1], parts[2]);
362        let decoder = JwtDecoder::from_config(&test_config());
363        let err = decoder.decode::<Claims>(&tampered_token).unwrap_err();
364        assert_eq!(err.error_code(), Some("jwt:algorithm_mismatch"));
365    }
366
367    #[test]
368    fn rejects_missing_exp() {
369        // JwtEncoder::from_config() always injects exp via access_ttl_secs,
370        // so a token without exp can only be produced via the signer directly.
371        // The decoder still rejects tokens without exp — verified here via a
372        // manually crafted token.
373        use super::super::signer::{HmacSigner, TokenSigner};
374        use crate::encoding::base64url;
375
376        let config = test_config();
377        let signer = HmacSigner::new(config.signing_secret.as_bytes());
378        let header = base64url::encode(br#"{"alg":"HS256","typ":"JWT"}"#);
379        let payload = base64url::encode(br#"{"sub":"user_1"}"#); // no exp
380        let header_payload = format!("{header}.{payload}");
381        let sig = signer.sign(header_payload.as_bytes()).unwrap();
382        let sig_b64 = base64url::encode(&sig);
383        let token = format!("{header_payload}.{sig_b64}");
384
385        let decoder = JwtDecoder::from_config(&config);
386        let err = decoder.decode::<Claims>(&token).unwrap_err();
387        assert_eq!(err.error_code(), Some("jwt:expired"));
388    }
389
390    #[test]
391    fn from_encoder_shares_verifier() {
392        let config = test_config();
393        let encoder = JwtEncoder::from_config(&config);
394        let decoder = JwtDecoder::from(&encoder);
395        let claims = Claims::new().with_exp(now_secs() + 3600);
396        let token = encoder.encode(&claims).unwrap();
397        assert!(decoder.decode::<Claims>(&token).is_ok());
398    }
399
400    #[test]
401    fn decode_custom_struct_directly() {
402        #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
403        struct CustomPayload {
404            sub: String,
405            role: String,
406            exp: u64,
407        }
408
409        let (encoder, decoder) = encode_decode_config();
410        let payload = CustomPayload {
411            sub: "user_1".into(),
412            role: "admin".into(),
413            exp: now_secs() + 3600,
414        };
415        let token = encoder.encode(&payload).unwrap();
416        let decoded: CustomPayload = decoder.decode(&token).unwrap();
417        assert_eq!(decoded.sub, "user_1");
418        assert_eq!(decoded.role, "admin");
419    }
420}