Skip to main content

modo/auth/jwt/
decoder.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use serde::de::DeserializeOwned;
5
6use crate::encoding::base64url;
7use crate::{Error, Result};
8
9use super::claims::Claims;
10use super::config::JwtConfig;
11use super::encoder::JwtEncoder;
12use super::error::JwtError;
13use super::signer::{HmacSigner, TokenVerifier};
14use super::validation::ValidationConfig;
15
16/// JWT token decoder. Verifies signatures and validates claims.
17///
18/// All validation is synchronous — revocation checks happen in [`JwtLayer`](super::middleware::JwtLayer).
19/// Cloning is cheap — state is stored behind `Arc`.
20pub struct JwtDecoder {
21    inner: Arc<JwtDecoderInner>,
22}
23
24struct JwtDecoderInner {
25    verifier: Arc<dyn TokenVerifier>,
26    validation: ValidationConfig,
27}
28
29fn jwt_err(kind: JwtError) -> Error {
30    let status_fn = match kind {
31        JwtError::SigningFailed | JwtError::SerializationFailed => Error::internal,
32        _ => Error::unauthorized,
33    };
34    status_fn("unauthorized").chain(kind).with_code(kind.code())
35}
36
37impl JwtDecoder {
38    /// Creates a `JwtDecoder` from YAML configuration.
39    ///
40    /// Uses `HmacSigner` (HS256) with the configured secret.
41    pub fn from_config(config: &JwtConfig) -> Self {
42        let signer = HmacSigner::new(config.secret.as_bytes());
43        Self {
44            inner: Arc::new(JwtDecoderInner {
45                verifier: Arc::new(signer),
46                validation: ValidationConfig {
47                    leeway: Duration::from_secs(config.leeway),
48                    require_issuer: config.issuer.clone(),
49                    require_audience: config.audience.clone(),
50                },
51            }),
52        }
53    }
54
55    /// Decodes and validates a JWT token string, returning typed `Claims<T>`.
56    ///
57    /// Validation order:
58    /// 1. Split into 3 parts (`header.payload.signature`)
59    /// 2. Decode header, check algorithm matches the verifier
60    /// 3. Verify HMAC signature
61    /// 4. Decode and deserialize payload into `Claims<T>`
62    /// 5. Enforce `exp` (always required; missing `exp` is treated as expired)
63    /// 6. Check `nbf` (if present)
64    /// 7. Check `iss` (if `require_issuer` is configured)
65    /// 8. Check `aud` (if `require_audience` is configured)
66    ///
67    /// Clock skew tolerance (`leeway`) is applied to steps 5 and 6.
68    ///
69    /// # Errors
70    ///
71    /// Returns `Error::unauthorized` with a [`JwtError`](super::JwtError) source for:
72    /// malformed tokens, invalid headers, algorithm mismatch, invalid signatures,
73    /// expired tokens, not-yet-valid tokens, issuer mismatch, or audience mismatch.
74    /// Missing `exp` is treated as expired.
75    pub fn decode<T: DeserializeOwned>(&self, token: &str) -> Result<Claims<T>> {
76        let parts: Vec<&str> = token.splitn(4, '.').collect();
77        if parts.len() != 3 {
78            return Err(jwt_err(JwtError::MalformedToken));
79        }
80
81        let (header_b64, payload_b64, signature_b64) = (parts[0], parts[1], parts[2]);
82
83        // Decode and verify header
84        let header_bytes =
85            base64url::decode(header_b64).map_err(|_| jwt_err(JwtError::InvalidHeader))?;
86        let header: serde_json::Value =
87            serde_json::from_slice(&header_bytes).map_err(|_| jwt_err(JwtError::InvalidHeader))?;
88
89        let alg = header["alg"]
90            .as_str()
91            .ok_or_else(|| jwt_err(JwtError::InvalidHeader))?;
92        if alg != self.inner.verifier.algorithm_name() {
93            return Err(jwt_err(JwtError::AlgorithmMismatch));
94        }
95
96        // Verify signature
97        let signature =
98            base64url::decode(signature_b64).map_err(|_| jwt_err(JwtError::MalformedToken))?;
99        let header_payload = format!("{header_b64}.{payload_b64}");
100        self.inner
101            .verifier
102            .verify(header_payload.as_bytes(), &signature)?;
103
104        // Decode payload
105        let payload_bytes =
106            base64url::decode(payload_b64).map_err(|_| jwt_err(JwtError::MalformedToken))?;
107        let claims: Claims<T> = serde_json::from_slice(&payload_bytes)
108            .map_err(|_| jwt_err(JwtError::DeserializationFailed))?;
109
110        // Validate exp (always required)
111        let now = std::time::SystemTime::now()
112            .duration_since(std::time::UNIX_EPOCH)
113            .expect("system clock before UNIX epoch")
114            .as_secs();
115        let leeway = self.inner.validation.leeway.as_secs();
116
117        let exp = claims.exp.ok_or_else(|| jwt_err(JwtError::Expired))?;
118        if now > exp + leeway {
119            return Err(jwt_err(JwtError::Expired));
120        }
121
122        // Validate nbf (if present)
123        if let Some(nbf) = claims.nbf
124            && now + leeway < nbf
125        {
126            return Err(jwt_err(JwtError::NotYetValid));
127        }
128
129        // Validate iss (if policy requires it)
130        if let Some(ref required_iss) = self.inner.validation.require_issuer {
131            match claims.iss.as_deref() {
132                Some(iss) if iss == required_iss => {}
133                _ => return Err(jwt_err(JwtError::InvalidIssuer)),
134            }
135        }
136
137        // Validate aud (if policy requires it)
138        if let Some(ref required_aud) = self.inner.validation.require_audience {
139            match claims.aud.as_deref() {
140                Some(aud) if aud == required_aud => {}
141                _ => return Err(jwt_err(JwtError::InvalidAudience)),
142            }
143        }
144
145        Ok(claims)
146    }
147}
148
149/// Creates a `JwtDecoder` that shares the signing key and validation config
150/// of an existing `JwtEncoder`. Useful when encoder and decoder are wired
151/// from the same `JwtConfig` value.
152impl From<&JwtEncoder> for JwtDecoder {
153    fn from(encoder: &JwtEncoder) -> Self {
154        Self {
155            inner: Arc::new(JwtDecoderInner {
156                verifier: encoder.verifier(),
157                validation: encoder.validation(),
158            }),
159        }
160    }
161}
162
163impl Clone for JwtDecoder {
164    fn clone(&self) -> Self {
165        Self {
166            inner: self.inner.clone(),
167        }
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174    use serde::{Deserialize, Serialize};
175
176    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
177    struct TestClaims {
178        role: String,
179    }
180
181    fn test_config() -> JwtConfig {
182        JwtConfig {
183            secret: "test-secret-key-at-least-32-bytes-long!".into(),
184            default_expiry: None,
185            leeway: 0,
186            issuer: None,
187            audience: None,
188        }
189    }
190
191    fn encode_decode_config() -> (JwtEncoder, JwtDecoder) {
192        let config = test_config();
193        let encoder = JwtEncoder::from_config(&config);
194        let decoder = JwtDecoder::from_config(&config);
195        (encoder, decoder)
196    }
197
198    fn make_token(encoder: &JwtEncoder, claims: &Claims<TestClaims>) -> String {
199        encoder.encode(claims).unwrap()
200    }
201
202    fn now_secs() -> u64 {
203        std::time::SystemTime::now()
204            .duration_since(std::time::UNIX_EPOCH)
205            .unwrap()
206            .as_secs()
207    }
208
209    #[test]
210    fn encode_decode_roundtrip() {
211        let (encoder, decoder) = encode_decode_config();
212        let claims = Claims::new(TestClaims {
213            role: "admin".into(),
214        })
215        .with_sub("user_1")
216        .with_exp(now_secs() + 3600);
217        let token = make_token(&encoder, &claims);
218        let decoded: Claims<TestClaims> = decoder.decode(&token).unwrap();
219        assert_eq!(decoded.sub, claims.sub);
220        assert_eq!(decoded.custom.role, "admin");
221    }
222
223    #[test]
224    fn rejects_expired_token() {
225        let (encoder, decoder) = encode_decode_config();
226        let claims = Claims::new(TestClaims {
227            role: "admin".into(),
228        })
229        .with_exp(now_secs() - 10);
230        let token = make_token(&encoder, &claims);
231        let err = decoder.decode::<TestClaims>(&token).unwrap_err();
232        assert_eq!(err.error_code(), Some("jwt:expired"));
233    }
234
235    #[test]
236    fn respects_leeway_for_exp() {
237        let mut config = test_config();
238        config.leeway = 30;
239        let encoder = JwtEncoder::from_config(&config);
240        let decoder = JwtDecoder::from_config(&config);
241        // Token expired 10s ago, but leeway is 30s — should be accepted
242        let claims = Claims::new(TestClaims {
243            role: "admin".into(),
244        })
245        .with_exp(now_secs() - 10);
246        let token = encoder.encode(&claims).unwrap();
247        assert!(decoder.decode::<TestClaims>(&token).is_ok());
248    }
249
250    #[test]
251    fn rejects_token_before_nbf() {
252        let (encoder, decoder) = encode_decode_config();
253        let claims = Claims::new(TestClaims {
254            role: "admin".into(),
255        })
256        .with_exp(now_secs() + 3600)
257        .with_nbf(now_secs() + 3600);
258        let token = make_token(&encoder, &claims);
259        let err = decoder.decode::<TestClaims>(&token).unwrap_err();
260        assert_eq!(err.error_code(), Some("jwt:not_yet_valid"));
261    }
262
263    #[test]
264    fn rejects_wrong_issuer() {
265        let mut config = test_config();
266        config.issuer = Some("expected-app".into());
267        let encoder = JwtEncoder::from_config(&config);
268        let decoder = JwtDecoder::from_config(&config);
269        let claims = Claims::new(TestClaims {
270            role: "admin".into(),
271        })
272        .with_exp(now_secs() + 3600)
273        .with_iss("wrong-app");
274        let token = encoder.encode(&claims).unwrap();
275        let err = decoder.decode::<TestClaims>(&token).unwrap_err();
276        assert_eq!(err.error_code(), Some("jwt:invalid_issuer"));
277    }
278
279    #[test]
280    fn rejects_missing_issuer_when_required() {
281        let mut config = test_config();
282        config.issuer = Some("expected-app".into());
283        let encoder = JwtEncoder::from_config(&config);
284        let decoder = JwtDecoder::from_config(&config);
285        let claims = Claims::new(TestClaims {
286            role: "admin".into(),
287        })
288        .with_exp(now_secs() + 3600);
289        let token = encoder.encode(&claims).unwrap();
290        let err = decoder.decode::<TestClaims>(&token).unwrap_err();
291        assert_eq!(err.error_code(), Some("jwt:invalid_issuer"));
292    }
293
294    #[test]
295    fn accepts_when_no_issuer_policy() {
296        let (encoder, decoder) = encode_decode_config();
297        let claims = Claims::new(TestClaims {
298            role: "admin".into(),
299        })
300        .with_exp(now_secs() + 3600)
301        .with_iss("any-app");
302        let token = make_token(&encoder, &claims);
303        assert!(decoder.decode::<TestClaims>(&token).is_ok());
304    }
305
306    #[test]
307    fn rejects_wrong_audience() {
308        let mut config = test_config();
309        config.audience = Some("expected-aud".into());
310        let encoder = JwtEncoder::from_config(&config);
311        let decoder = JwtDecoder::from_config(&config);
312        let claims = Claims::new(TestClaims {
313            role: "admin".into(),
314        })
315        .with_exp(now_secs() + 3600)
316        .with_aud("wrong-aud");
317        let token = encoder.encode(&claims).unwrap();
318        let err = decoder.decode::<TestClaims>(&token).unwrap_err();
319        assert_eq!(err.error_code(), Some("jwt:invalid_audience"));
320    }
321
322    #[test]
323    fn rejects_tampered_signature() {
324        let (encoder, decoder) = encode_decode_config();
325        let claims = Claims::new(TestClaims {
326            role: "admin".into(),
327        })
328        .with_exp(now_secs() + 3600);
329        let mut token = make_token(&encoder, &claims);
330        // Flip a character well inside the signature (not in base64 padding region)
331        let idx = token.len() - 5;
332        let original = token.as_bytes()[idx];
333        let replacement = if original == b'A' { b'B' } else { b'A' };
334        // SAFETY: replacing one ASCII byte with another ASCII byte
335        unsafe { token.as_bytes_mut()[idx] = replacement };
336        let err = decoder.decode::<TestClaims>(&token).unwrap_err();
337        assert_eq!(err.error_code(), Some("jwt:invalid_signature"));
338    }
339
340    #[test]
341    fn rejects_malformed_token() {
342        let decoder = JwtDecoder::from_config(&test_config());
343        let err = decoder
344            .decode::<TestClaims>("not.a.valid.token.at.all")
345            .unwrap_err();
346        assert_eq!(err.error_code(), Some("jwt:malformed_token"));
347    }
348
349    #[test]
350    fn rejects_token_with_wrong_algorithm() {
351        let (encoder, _) = encode_decode_config();
352        let claims = Claims::new(TestClaims {
353            role: "admin".into(),
354        })
355        .with_exp(now_secs() + 3600);
356        let token = encoder.encode(&claims).unwrap();
357        // Replace HS256 with RS256 in the header
358        let parts: Vec<&str> = token.splitn(3, '.').collect();
359        let header_bytes = base64url::decode(parts[0]).unwrap();
360        let header_str = String::from_utf8(header_bytes).unwrap();
361        let tampered_header = header_str.replace("HS256", "RS256");
362        let tampered_header_b64 = base64url::encode(tampered_header.as_bytes());
363        let tampered_token = format!("{}.{}.{}", tampered_header_b64, parts[1], parts[2]);
364        let decoder = JwtDecoder::from_config(&test_config());
365        let err = decoder.decode::<TestClaims>(&tampered_token).unwrap_err();
366        assert_eq!(err.error_code(), Some("jwt:algorithm_mismatch"));
367    }
368
369    #[test]
370    fn rejects_missing_exp() {
371        let (encoder, decoder) = encode_decode_config();
372        let claims = Claims::new(TestClaims {
373            role: "admin".into(),
374        });
375        let token = encoder.encode(&claims).unwrap();
376        let err = decoder.decode::<TestClaims>(&token).unwrap_err();
377        assert_eq!(err.error_code(), Some("jwt:expired"));
378    }
379
380    #[test]
381    fn from_encoder_shares_verifier() {
382        let config = test_config();
383        let encoder = JwtEncoder::from_config(&config);
384        let decoder = JwtDecoder::from(&encoder);
385        let claims = Claims::new(TestClaims {
386            role: "admin".into(),
387        })
388        .with_exp(now_secs() + 3600);
389        let token = encoder.encode(&claims).unwrap();
390        assert!(decoder.decode::<TestClaims>(&token).is_ok());
391    }
392}