modo/auth/session/jwt/
decoder.rs1use std::sync::Arc;
2use std::time::Duration;
3
4use serde::de::DeserializeOwned;
5
6use crate::encoding::base64url;
7use crate::{Error, Result};
8
9use super::config::JwtSessionsConfig;
10use super::encoder::JwtEncoder;
11use super::error::JwtError;
12use super::signer::{HmacSigner, TokenVerifier};
13use super::validation::ValidationConfig;
14
15pub struct JwtDecoder {
19 inner: Arc<JwtDecoderInner>,
20}
21
22struct JwtDecoderInner {
23 verifier: Arc<dyn TokenVerifier>,
24 validation: ValidationConfig,
25}
26
27fn jwt_err(kind: JwtError) -> Error {
28 let status_fn = match kind {
29 JwtError::SigningFailed | JwtError::SerializationFailed => Error::internal,
30 _ => Error::unauthorized,
31 };
32 status_fn("unauthorized").chain(kind).with_code(kind.code())
33}
34
35impl JwtDecoder {
36 pub fn new(verifier: Arc<dyn TokenVerifier>, validation: ValidationConfig) -> Self {
55 Self {
56 inner: Arc::new(JwtDecoderInner {
57 verifier,
58 validation,
59 }),
60 }
61 }
62
63 pub fn from_config(config: &JwtSessionsConfig) -> Self {
67 let signer = HmacSigner::new(config.signing_secret.as_bytes());
68 Self {
69 inner: Arc::new(JwtDecoderInner {
70 verifier: Arc::new(signer),
71 validation: ValidationConfig {
72 leeway: Duration::ZERO,
73 require_issuer: config.issuer.clone(),
74 require_audience: None,
75 },
76 }),
77 }
78 }
79
80 pub fn decode<T: DeserializeOwned>(&self, token: &str) -> Result<T> {
106 let parts: Vec<&str> = token.splitn(4, '.').collect();
107 if parts.len() != 3 {
108 return Err(jwt_err(JwtError::MalformedToken));
109 }
110
111 let (header_b64, payload_b64, signature_b64) = (parts[0], parts[1], parts[2]);
112
113 let header_bytes =
115 base64url::decode(header_b64).map_err(|_| jwt_err(JwtError::InvalidHeader))?;
116 let header: serde_json::Value =
117 serde_json::from_slice(&header_bytes).map_err(|_| jwt_err(JwtError::InvalidHeader))?;
118
119 let alg = header["alg"]
120 .as_str()
121 .ok_or_else(|| jwt_err(JwtError::InvalidHeader))?;
122 if alg != self.inner.verifier.algorithm_name() {
123 return Err(jwt_err(JwtError::AlgorithmMismatch));
124 }
125
126 let signature =
128 base64url::decode(signature_b64).map_err(|_| jwt_err(JwtError::MalformedToken))?;
129 let header_payload = format!("{header_b64}.{payload_b64}");
130 self.inner
131 .verifier
132 .verify(header_payload.as_bytes(), &signature)?;
133
134 let payload_bytes =
136 base64url::decode(payload_b64).map_err(|_| jwt_err(JwtError::MalformedToken))?;
137 let payload: serde_json::Value = serde_json::from_slice(&payload_bytes)
138 .map_err(|_| jwt_err(JwtError::DeserializationFailed))?;
139
140 let now = std::time::SystemTime::now()
142 .duration_since(std::time::UNIX_EPOCH)
143 .expect("system clock before UNIX epoch")
144 .as_secs();
145 let leeway = self.inner.validation.leeway.as_secs();
146
147 let exp = payload
148 .get("exp")
149 .and_then(|v| v.as_u64())
150 .ok_or_else(|| jwt_err(JwtError::Expired))?;
151 if now > exp + leeway {
152 return Err(jwt_err(JwtError::Expired));
153 }
154
155 if let Some(nbf) = payload.get("nbf").and_then(|v| v.as_u64())
157 && now + leeway < nbf
158 {
159 return Err(jwt_err(JwtError::NotYetValid));
160 }
161
162 if let Some(ref required_iss) = self.inner.validation.require_issuer {
164 match payload.get("iss").and_then(|v| v.as_str()) {
165 Some(iss) if iss == required_iss => {}
166 _ => return Err(jwt_err(JwtError::InvalidIssuer)),
167 }
168 }
169
170 if let Some(ref required_aud) = self.inner.validation.require_audience {
172 match payload.get("aud").and_then(|v| v.as_str()) {
173 Some(aud) if aud == required_aud => {}
174 _ => return Err(jwt_err(JwtError::InvalidAudience)),
175 }
176 }
177
178 serde_json::from_value(payload).map_err(|_| jwt_err(JwtError::DeserializationFailed))
180 }
181}
182
183impl From<&JwtEncoder> for JwtDecoder {
187 fn from(encoder: &JwtEncoder) -> Self {
188 Self {
189 inner: Arc::new(JwtDecoderInner {
190 verifier: encoder.verifier(),
191 validation: encoder.validation(),
192 }),
193 }
194 }
195}
196
197impl Clone for JwtDecoder {
198 fn clone(&self) -> Self {
199 Self {
200 inner: self.inner.clone(),
201 }
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use serde::{Deserialize, Serialize};
209
210 use super::super::claims::Claims;
211 use super::super::encoder::JwtEncoder;
212
213 fn test_config() -> JwtSessionsConfig {
214 JwtSessionsConfig {
215 signing_secret: "test-secret-key-at-least-32-bytes-long!".into(),
216 ..JwtSessionsConfig::default()
217 }
218 }
219
220 fn encode_decode_config() -> (JwtEncoder, JwtDecoder) {
221 let config = test_config();
222 let encoder = JwtEncoder::from_config(&config);
223 let decoder = JwtDecoder::from_config(&config);
224 (encoder, decoder)
225 }
226
227 fn now_secs() -> u64 {
228 std::time::SystemTime::now()
229 .duration_since(std::time::UNIX_EPOCH)
230 .unwrap()
231 .as_secs()
232 }
233
234 #[test]
235 fn encode_decode_roundtrip() {
236 let (encoder, decoder) = encode_decode_config();
237 let claims = Claims::new().with_sub("user_1").with_exp(now_secs() + 3600);
238 let token = encoder.encode(&claims).unwrap();
239 let decoded: Claims = decoder.decode(&token).unwrap();
240 assert_eq!(decoded.sub, claims.sub);
241 assert_eq!(decoded.exp, claims.exp);
242 }
243
244 #[test]
245 fn rejects_expired_token() {
246 let (encoder, decoder) = encode_decode_config();
247 let claims = Claims::new().with_exp(now_secs() - 10);
248 let token = encoder.encode(&claims).unwrap();
249 let err = decoder.decode::<Claims>(&token).unwrap_err();
250 assert_eq!(err.error_code(), Some("jwt:expired"));
251 }
252
253 #[test]
254 fn respects_leeway_for_exp() {
255 let (encoder, decoder) = encode_decode_config();
257 let claims = Claims::new().with_exp(now_secs() - 10);
258 let token = encoder.encode(&claims).unwrap();
259 let err = decoder.decode::<Claims>(&token).unwrap_err();
261 assert_eq!(err.error_code(), Some("jwt:expired"));
262 }
263
264 #[test]
265 fn rejects_token_before_nbf() {
266 let (encoder, decoder) = encode_decode_config();
267 let claims = Claims::new()
268 .with_exp(now_secs() + 3600)
269 .with_nbf(now_secs() + 3600);
270 let token = encoder.encode(&claims).unwrap();
271 let err = decoder.decode::<Claims>(&token).unwrap_err();
272 assert_eq!(err.error_code(), Some("jwt:not_yet_valid"));
273 }
274
275 #[test]
276 fn rejects_wrong_issuer() {
277 let mut config = test_config();
278 config.issuer = Some("expected-app".into());
279 let encoder = JwtEncoder::from_config(&config);
280 let decoder = JwtDecoder::from_config(&config);
281 let claims = Claims::new()
282 .with_exp(now_secs() + 3600)
283 .with_iss("wrong-app");
284 let token = encoder.encode(&claims).unwrap();
285 let err = decoder.decode::<Claims>(&token).unwrap_err();
286 assert_eq!(err.error_code(), Some("jwt:invalid_issuer"));
287 }
288
289 #[test]
290 fn rejects_missing_issuer_when_required() {
291 let mut config = test_config();
292 config.issuer = Some("expected-app".into());
293 let encoder = JwtEncoder::from_config(&config);
294 let decoder = JwtDecoder::from_config(&config);
295 let claims = Claims::new().with_exp(now_secs() + 3600);
296 let token = encoder.encode(&claims).unwrap();
297 let err = decoder.decode::<Claims>(&token).unwrap_err();
298 assert_eq!(err.error_code(), Some("jwt:invalid_issuer"));
299 }
300
301 #[test]
302 fn accepts_when_no_issuer_policy() {
303 let (encoder, decoder) = encode_decode_config();
304 let claims = Claims::new()
305 .with_exp(now_secs() + 3600)
306 .with_iss("any-app");
307 let token = encoder.encode(&claims).unwrap();
308 assert!(decoder.decode::<Claims>(&token).is_ok());
309 }
310
311 #[test]
312 fn rejects_wrong_audience() {
313 let config = test_config();
314 let encoder = JwtEncoder::from_config(&config);
315 let signer = HmacSigner::new(config.signing_secret.as_bytes());
316 let validation = super::super::validation::ValidationConfig::default()
317 .with_audience("expected-audience");
318 let decoder = JwtDecoder::new(Arc::new(signer), validation);
319 let claims = Claims::new()
320 .with_exp(now_secs() + 3600)
321 .with_aud("wrong-audience");
322 let token = encoder.encode(&claims).unwrap();
323 let err = decoder.decode::<Claims>(&token).unwrap_err();
324 assert_eq!(err.error_code(), Some("jwt:invalid_audience"));
325 }
326
327 #[test]
328 fn rejects_tampered_signature() {
329 let (encoder, decoder) = encode_decode_config();
330 let claims = Claims::new().with_exp(now_secs() + 3600);
331 let mut token = encoder.encode(&claims).unwrap();
332 let idx = token.len() - 5;
334 let original = token.as_bytes()[idx];
335 let replacement = if original == b'A' { b'B' } else { b'A' };
336 unsafe { token.as_bytes_mut()[idx] = replacement };
338 let err = decoder.decode::<Claims>(&token).unwrap_err();
339 assert_eq!(err.error_code(), Some("jwt:invalid_signature"));
340 }
341
342 #[test]
343 fn rejects_malformed_token() {
344 let decoder = JwtDecoder::from_config(&test_config());
345 let err = decoder
346 .decode::<Claims>("not.a.valid.token.at.all")
347 .unwrap_err();
348 assert_eq!(err.error_code(), Some("jwt:malformed_token"));
349 }
350
351 #[test]
352 fn rejects_token_with_wrong_algorithm() {
353 let (encoder, _) = encode_decode_config();
354 let claims = Claims::new().with_exp(now_secs() + 3600);
355 let token = encoder.encode(&claims).unwrap();
356 let parts: Vec<&str> = token.splitn(3, '.').collect();
358 let header_bytes = base64url::decode(parts[0]).unwrap();
359 let header_str = String::from_utf8(header_bytes).unwrap();
360 let tampered_header = header_str.replace("HS256", "RS256");
361 let tampered_header_b64 = base64url::encode(tampered_header.as_bytes());
362 let tampered_token = format!("{}.{}.{}", tampered_header_b64, parts[1], parts[2]);
363 let decoder = JwtDecoder::from_config(&test_config());
364 let err = decoder.decode::<Claims>(&tampered_token).unwrap_err();
365 assert_eq!(err.error_code(), Some("jwt:algorithm_mismatch"));
366 }
367
368 #[test]
369 fn rejects_missing_exp() {
370 use super::super::signer::{HmacSigner, TokenSigner};
375 use crate::encoding::base64url;
376
377 let config = test_config();
378 let signer = HmacSigner::new(config.signing_secret.as_bytes());
379 let header = base64url::encode(br#"{"alg":"HS256","typ":"JWT"}"#);
380 let payload = base64url::encode(br#"{"sub":"user_1"}"#); let header_payload = format!("{header}.{payload}");
382 let sig = signer.sign(header_payload.as_bytes()).unwrap();
383 let sig_b64 = base64url::encode(&sig);
384 let token = format!("{header_payload}.{sig_b64}");
385
386 let decoder = JwtDecoder::from_config(&config);
387 let err = decoder.decode::<Claims>(&token).unwrap_err();
388 assert_eq!(err.error_code(), Some("jwt:expired"));
389 }
390
391 #[test]
392 fn from_encoder_shares_verifier() {
393 let config = test_config();
394 let encoder = JwtEncoder::from_config(&config);
395 let decoder = JwtDecoder::from(&encoder);
396 let claims = Claims::new().with_exp(now_secs() + 3600);
397 let token = encoder.encode(&claims).unwrap();
398 assert!(decoder.decode::<Claims>(&token).is_ok());
399 }
400
401 #[test]
402 fn decode_custom_struct_directly() {
403 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
404 struct CustomPayload {
405 sub: String,
406 role: String,
407 exp: u64,
408 }
409
410 let (encoder, decoder) = encode_decode_config();
411 let payload = CustomPayload {
412 sub: "user_1".into(),
413 role: "admin".into(),
414 exp: now_secs() + 3600,
415 };
416 let token = encoder.encode(&payload).unwrap();
417 let decoded: CustomPayload = decoder.decode(&token).unwrap();
418 assert_eq!(decoded.sub, "user_1");
419 assert_eq!(decoded.role, "admin");
420 }
421}