modo/auth/session/jwt/
decoder.rs1use std::sync::Arc;
2use std::time::Duration;
3
4use serde::de::DeserializeOwned;
5
6use crate::encoding::base64url;
7use crate::{Error, Result};
8
9use super::config::JwtSessionsConfig;
10use super::encoder::JwtEncoder;
11use super::error::JwtError;
12use super::signer::{HmacSigner, TokenVerifier};
13use super::validation::ValidationConfig;
14
15pub struct JwtDecoder {
19 inner: Arc<JwtDecoderInner>,
20}
21
22struct JwtDecoderInner {
23 verifier: Arc<dyn TokenVerifier>,
24 validation: ValidationConfig,
25}
26
27pub(super) fn jwt_err(kind: JwtError) -> Error {
28 let status_fn = match kind {
29 JwtError::SigningFailed | JwtError::SerializationFailed => Error::internal,
30 _ => Error::unauthorized,
31 };
32 status_fn("unauthorized").chain(kind).with_code(kind.code())
33}
34
35pub(super) fn auth_err(code: &'static str) -> Error {
38 Error::unauthorized("unauthorized").with_code(code)
39}
40
41impl JwtDecoder {
42 pub fn new(verifier: Arc<dyn TokenVerifier>, validation: ValidationConfig) -> Self {
61 Self {
62 inner: Arc::new(JwtDecoderInner {
63 verifier,
64 validation,
65 }),
66 }
67 }
68
69 pub fn from_config(config: &JwtSessionsConfig) -> Self {
73 let signer = HmacSigner::new(config.signing_secret.as_bytes());
74 Self {
75 inner: Arc::new(JwtDecoderInner {
76 verifier: Arc::new(signer),
77 validation: ValidationConfig {
78 leeway: Duration::ZERO,
79 require_issuer: config.issuer.clone(),
80 require_audience: None,
81 },
82 }),
83 }
84 }
85
86 pub fn decode<T: DeserializeOwned>(&self, token: &str) -> Result<T> {
112 let mut iter = token.splitn(4, '.');
113 let (Some(header_b64), Some(payload_b64), Some(signature_b64), None) =
114 (iter.next(), iter.next(), iter.next(), iter.next())
115 else {
116 return Err(jwt_err(JwtError::MalformedToken));
117 };
118
119 let header_bytes =
120 base64url::decode(header_b64).map_err(|_| jwt_err(JwtError::InvalidHeader))?;
121 let header: serde_json::Value =
122 serde_json::from_slice(&header_bytes).map_err(|_| jwt_err(JwtError::InvalidHeader))?;
123
124 let alg = header["alg"]
125 .as_str()
126 .ok_or_else(|| jwt_err(JwtError::InvalidHeader))?;
127 if alg != self.inner.verifier.algorithm_name() {
128 return Err(jwt_err(JwtError::AlgorithmMismatch));
129 }
130
131 let signature =
132 base64url::decode(signature_b64).map_err(|_| jwt_err(JwtError::MalformedToken))?;
133 let mut header_payload = Vec::with_capacity(header_b64.len() + 1 + payload_b64.len());
134 header_payload.extend_from_slice(header_b64.as_bytes());
135 header_payload.push(b'.');
136 header_payload.extend_from_slice(payload_b64.as_bytes());
137 self.inner.verifier.verify(&header_payload, &signature)?;
138
139 let payload_bytes =
140 base64url::decode(payload_b64).map_err(|_| jwt_err(JwtError::MalformedToken))?;
141 let payload: serde_json::Value = serde_json::from_slice(&payload_bytes)
142 .map_err(|_| jwt_err(JwtError::DeserializationFailed))?;
143
144 let now = std::time::SystemTime::now()
145 .duration_since(std::time::UNIX_EPOCH)
146 .expect("system clock before UNIX epoch")
147 .as_secs();
148 let leeway = self.inner.validation.leeway.as_secs();
149
150 let exp = payload
151 .get("exp")
152 .and_then(|v| v.as_u64())
153 .ok_or_else(|| jwt_err(JwtError::Expired))?;
154 if now > exp + leeway {
155 return Err(jwt_err(JwtError::Expired));
156 }
157
158 if let Some(nbf) = payload.get("nbf").and_then(|v| v.as_u64())
159 && now + leeway < nbf
160 {
161 return Err(jwt_err(JwtError::NotYetValid));
162 }
163
164 if let Some(ref required_iss) = self.inner.validation.require_issuer {
165 match payload.get("iss").and_then(|v| v.as_str()) {
166 Some(iss) if iss == required_iss => {}
167 _ => return Err(jwt_err(JwtError::InvalidIssuer)),
168 }
169 }
170
171 if let Some(ref required_aud) = self.inner.validation.require_audience {
172 match payload.get("aud").and_then(|v| v.as_str()) {
173 Some(aud) if aud == required_aud => {}
174 _ => return Err(jwt_err(JwtError::InvalidAudience)),
175 }
176 }
177
178 serde_json::from_value(payload).map_err(|_| jwt_err(JwtError::DeserializationFailed))
179 }
180}
181
182impl From<&JwtEncoder> for JwtDecoder {
186 fn from(encoder: &JwtEncoder) -> Self {
187 Self {
188 inner: Arc::new(JwtDecoderInner {
189 verifier: encoder.verifier(),
190 validation: encoder.validation(),
191 }),
192 }
193 }
194}
195
196impl Clone for JwtDecoder {
197 fn clone(&self) -> Self {
198 Self {
199 inner: self.inner.clone(),
200 }
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207 use serde::{Deserialize, Serialize};
208
209 use super::super::claims::Claims;
210 use super::super::encoder::JwtEncoder;
211
212 fn test_config() -> JwtSessionsConfig {
213 JwtSessionsConfig {
214 signing_secret: "test-secret-key-at-least-32-bytes-long!".into(),
215 ..JwtSessionsConfig::default()
216 }
217 }
218
219 fn encode_decode_config() -> (JwtEncoder, JwtDecoder) {
220 let config = test_config();
221 let encoder = JwtEncoder::from_config(&config);
222 let decoder = JwtDecoder::from_config(&config);
223 (encoder, decoder)
224 }
225
226 fn now_secs() -> u64 {
227 std::time::SystemTime::now()
228 .duration_since(std::time::UNIX_EPOCH)
229 .unwrap()
230 .as_secs()
231 }
232
233 #[test]
234 fn encode_decode_roundtrip() {
235 let (encoder, decoder) = encode_decode_config();
236 let claims = Claims::new().with_sub("user_1").with_exp(now_secs() + 3600);
237 let token = encoder.encode(&claims).unwrap();
238 let decoded: Claims = decoder.decode(&token).unwrap();
239 assert_eq!(decoded.sub, claims.sub);
240 assert_eq!(decoded.exp, claims.exp);
241 }
242
243 #[test]
244 fn rejects_expired_token() {
245 let (encoder, decoder) = encode_decode_config();
246 let claims = Claims::new().with_exp(now_secs() - 10);
247 let token = encoder.encode(&claims).unwrap();
248 let err = decoder.decode::<Claims>(&token).unwrap_err();
249 assert_eq!(err.error_code(), Some("jwt:expired"));
250 }
251
252 #[test]
253 fn respects_leeway_for_exp() {
254 let (encoder, decoder) = encode_decode_config();
256 let claims = Claims::new().with_exp(now_secs() - 10);
257 let token = encoder.encode(&claims).unwrap();
258 let err = decoder.decode::<Claims>(&token).unwrap_err();
260 assert_eq!(err.error_code(), Some("jwt:expired"));
261 }
262
263 #[test]
264 fn rejects_token_before_nbf() {
265 let (encoder, decoder) = encode_decode_config();
266 let claims = Claims::new()
267 .with_exp(now_secs() + 3600)
268 .with_nbf(now_secs() + 3600);
269 let token = encoder.encode(&claims).unwrap();
270 let err = decoder.decode::<Claims>(&token).unwrap_err();
271 assert_eq!(err.error_code(), Some("jwt:not_yet_valid"));
272 }
273
274 #[test]
275 fn rejects_wrong_issuer() {
276 let mut config = test_config();
277 config.issuer = Some("expected-app".into());
278 let encoder = JwtEncoder::from_config(&config);
279 let decoder = JwtDecoder::from_config(&config);
280 let claims = Claims::new()
281 .with_exp(now_secs() + 3600)
282 .with_iss("wrong-app");
283 let token = encoder.encode(&claims).unwrap();
284 let err = decoder.decode::<Claims>(&token).unwrap_err();
285 assert_eq!(err.error_code(), Some("jwt:invalid_issuer"));
286 }
287
288 #[test]
289 fn rejects_missing_issuer_when_required() {
290 let mut config = test_config();
291 config.issuer = Some("expected-app".into());
292 let encoder = JwtEncoder::from_config(&config);
293 let decoder = JwtDecoder::from_config(&config);
294 let claims = Claims::new().with_exp(now_secs() + 3600);
295 let token = encoder.encode(&claims).unwrap();
296 let err = decoder.decode::<Claims>(&token).unwrap_err();
297 assert_eq!(err.error_code(), Some("jwt:invalid_issuer"));
298 }
299
300 #[test]
301 fn accepts_when_no_issuer_policy() {
302 let (encoder, decoder) = encode_decode_config();
303 let claims = Claims::new()
304 .with_exp(now_secs() + 3600)
305 .with_iss("any-app");
306 let token = encoder.encode(&claims).unwrap();
307 assert!(decoder.decode::<Claims>(&token).is_ok());
308 }
309
310 #[test]
311 fn rejects_wrong_audience() {
312 let config = test_config();
313 let encoder = JwtEncoder::from_config(&config);
314 let signer = HmacSigner::new(config.signing_secret.as_bytes());
315 let validation = super::super::validation::ValidationConfig::default()
316 .with_audience("expected-audience");
317 let decoder = JwtDecoder::new(Arc::new(signer), validation);
318 let claims = Claims::new()
319 .with_exp(now_secs() + 3600)
320 .with_aud("wrong-audience");
321 let token = encoder.encode(&claims).unwrap();
322 let err = decoder.decode::<Claims>(&token).unwrap_err();
323 assert_eq!(err.error_code(), Some("jwt:invalid_audience"));
324 }
325
326 #[test]
327 fn rejects_tampered_signature() {
328 let (encoder, decoder) = encode_decode_config();
329 let claims = Claims::new().with_exp(now_secs() + 3600);
330 let mut token = encoder.encode(&claims).unwrap();
331 let idx = token.len() - 5;
333 let original = token.as_bytes()[idx];
334 let replacement = if original == b'A' { b'B' } else { b'A' };
335 unsafe { token.as_bytes_mut()[idx] = replacement };
337 let err = decoder.decode::<Claims>(&token).unwrap_err();
338 assert_eq!(err.error_code(), Some("jwt:invalid_signature"));
339 }
340
341 #[test]
342 fn rejects_malformed_token() {
343 let decoder = JwtDecoder::from_config(&test_config());
344 let err = decoder
345 .decode::<Claims>("not.a.valid.token.at.all")
346 .unwrap_err();
347 assert_eq!(err.error_code(), Some("jwt:malformed_token"));
348 }
349
350 #[test]
351 fn rejects_token_with_wrong_algorithm() {
352 let (encoder, _) = encode_decode_config();
353 let claims = Claims::new().with_exp(now_secs() + 3600);
354 let token = encoder.encode(&claims).unwrap();
355 let parts: Vec<&str> = token.splitn(3, '.').collect();
357 let header_bytes = base64url::decode(parts[0]).unwrap();
358 let header_str = String::from_utf8(header_bytes).unwrap();
359 let tampered_header = header_str.replace("HS256", "RS256");
360 let tampered_header_b64 = base64url::encode(tampered_header.as_bytes());
361 let tampered_token = format!("{}.{}.{}", tampered_header_b64, parts[1], parts[2]);
362 let decoder = JwtDecoder::from_config(&test_config());
363 let err = decoder.decode::<Claims>(&tampered_token).unwrap_err();
364 assert_eq!(err.error_code(), Some("jwt:algorithm_mismatch"));
365 }
366
367 #[test]
368 fn rejects_missing_exp() {
369 use super::super::signer::{HmacSigner, TokenSigner};
374 use crate::encoding::base64url;
375
376 let config = test_config();
377 let signer = HmacSigner::new(config.signing_secret.as_bytes());
378 let header = base64url::encode(br#"{"alg":"HS256","typ":"JWT"}"#);
379 let payload = base64url::encode(br#"{"sub":"user_1"}"#); let header_payload = format!("{header}.{payload}");
381 let sig = signer.sign(header_payload.as_bytes()).unwrap();
382 let sig_b64 = base64url::encode(&sig);
383 let token = format!("{header_payload}.{sig_b64}");
384
385 let decoder = JwtDecoder::from_config(&config);
386 let err = decoder.decode::<Claims>(&token).unwrap_err();
387 assert_eq!(err.error_code(), Some("jwt:expired"));
388 }
389
390 #[test]
391 fn from_encoder_shares_verifier() {
392 let config = test_config();
393 let encoder = JwtEncoder::from_config(&config);
394 let decoder = JwtDecoder::from(&encoder);
395 let claims = Claims::new().with_exp(now_secs() + 3600);
396 let token = encoder.encode(&claims).unwrap();
397 assert!(decoder.decode::<Claims>(&token).is_ok());
398 }
399
400 #[test]
401 fn decode_custom_struct_directly() {
402 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
403 struct CustomPayload {
404 sub: String,
405 role: String,
406 exp: u64,
407 }
408
409 let (encoder, decoder) = encode_decode_config();
410 let payload = CustomPayload {
411 sub: "user_1".into(),
412 role: "admin".into(),
413 exp: now_secs() + 3600,
414 };
415 let token = encoder.encode(&payload).unwrap();
416 let decoded: CustomPayload = decoder.decode(&token).unwrap();
417 assert_eq!(decoded.sub, "user_1");
418 assert_eq!(decoded.role, "admin");
419 }
420}