1use std::sync::Arc;
2use std::time::Duration;
3
4use serde::de::DeserializeOwned;
5
6use crate::encoding::base64url;
7use crate::{Error, Result};
8
9use super::claims::Claims;
10use super::config::JwtConfig;
11use super::encoder::JwtEncoder;
12use super::error::JwtError;
13use super::signer::{HmacSigner, TokenVerifier};
14use super::validation::ValidationConfig;
15
16pub struct JwtDecoder {
21 inner: Arc<JwtDecoderInner>,
22}
23
24struct JwtDecoderInner {
25 verifier: Arc<dyn TokenVerifier>,
26 validation: ValidationConfig,
27}
28
29fn jwt_err(kind: JwtError) -> Error {
30 let status_fn = match kind {
31 JwtError::SigningFailed | JwtError::SerializationFailed => Error::internal,
32 _ => Error::unauthorized,
33 };
34 status_fn("unauthorized").chain(kind).with_code(kind.code())
35}
36
37impl JwtDecoder {
38 pub fn from_config(config: &JwtConfig) -> Self {
42 let signer = HmacSigner::new(config.secret.as_bytes());
43 Self {
44 inner: Arc::new(JwtDecoderInner {
45 verifier: Arc::new(signer),
46 validation: ValidationConfig {
47 leeway: Duration::from_secs(config.leeway),
48 require_issuer: config.issuer.clone(),
49 require_audience: config.audience.clone(),
50 },
51 }),
52 }
53 }
54
55 pub fn decode<T: DeserializeOwned>(&self, token: &str) -> Result<Claims<T>> {
76 let parts: Vec<&str> = token.splitn(4, '.').collect();
77 if parts.len() != 3 {
78 return Err(jwt_err(JwtError::MalformedToken));
79 }
80
81 let (header_b64, payload_b64, signature_b64) = (parts[0], parts[1], parts[2]);
82
83 let header_bytes =
85 base64url::decode(header_b64).map_err(|_| jwt_err(JwtError::InvalidHeader))?;
86 let header: serde_json::Value =
87 serde_json::from_slice(&header_bytes).map_err(|_| jwt_err(JwtError::InvalidHeader))?;
88
89 let alg = header["alg"]
90 .as_str()
91 .ok_or_else(|| jwt_err(JwtError::InvalidHeader))?;
92 if alg != self.inner.verifier.algorithm_name() {
93 return Err(jwt_err(JwtError::AlgorithmMismatch));
94 }
95
96 let signature =
98 base64url::decode(signature_b64).map_err(|_| jwt_err(JwtError::MalformedToken))?;
99 let header_payload = format!("{header_b64}.{payload_b64}");
100 self.inner
101 .verifier
102 .verify(header_payload.as_bytes(), &signature)?;
103
104 let payload_bytes =
106 base64url::decode(payload_b64).map_err(|_| jwt_err(JwtError::MalformedToken))?;
107 let claims: Claims<T> = serde_json::from_slice(&payload_bytes)
108 .map_err(|_| jwt_err(JwtError::DeserializationFailed))?;
109
110 let now = std::time::SystemTime::now()
112 .duration_since(std::time::UNIX_EPOCH)
113 .expect("system clock before UNIX epoch")
114 .as_secs();
115 let leeway = self.inner.validation.leeway.as_secs();
116
117 let exp = claims.exp.ok_or_else(|| jwt_err(JwtError::Expired))?;
118 if now > exp + leeway {
119 return Err(jwt_err(JwtError::Expired));
120 }
121
122 if let Some(nbf) = claims.nbf
124 && now + leeway < nbf
125 {
126 return Err(jwt_err(JwtError::NotYetValid));
127 }
128
129 if let Some(ref required_iss) = self.inner.validation.require_issuer {
131 match claims.iss.as_deref() {
132 Some(iss) if iss == required_iss => {}
133 _ => return Err(jwt_err(JwtError::InvalidIssuer)),
134 }
135 }
136
137 if let Some(ref required_aud) = self.inner.validation.require_audience {
139 match claims.aud.as_deref() {
140 Some(aud) if aud == required_aud => {}
141 _ => return Err(jwt_err(JwtError::InvalidAudience)),
142 }
143 }
144
145 Ok(claims)
146 }
147}
148
149impl From<&JwtEncoder> for JwtDecoder {
153 fn from(encoder: &JwtEncoder) -> Self {
154 Self {
155 inner: Arc::new(JwtDecoderInner {
156 verifier: encoder.verifier(),
157 validation: encoder.validation(),
158 }),
159 }
160 }
161}
162
163impl Clone for JwtDecoder {
164 fn clone(&self) -> Self {
165 Self {
166 inner: self.inner.clone(),
167 }
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174 use serde::{Deserialize, Serialize};
175
176 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
177 struct TestClaims {
178 role: String,
179 }
180
181 fn test_config() -> JwtConfig {
182 JwtConfig {
183 secret: "test-secret-key-at-least-32-bytes-long!".into(),
184 default_expiry: None,
185 leeway: 0,
186 issuer: None,
187 audience: None,
188 }
189 }
190
191 fn encode_decode_config() -> (JwtEncoder, JwtDecoder) {
192 let config = test_config();
193 let encoder = JwtEncoder::from_config(&config);
194 let decoder = JwtDecoder::from_config(&config);
195 (encoder, decoder)
196 }
197
198 fn make_token(encoder: &JwtEncoder, claims: &Claims<TestClaims>) -> String {
199 encoder.encode(claims).unwrap()
200 }
201
202 fn now_secs() -> u64 {
203 std::time::SystemTime::now()
204 .duration_since(std::time::UNIX_EPOCH)
205 .unwrap()
206 .as_secs()
207 }
208
209 #[test]
210 fn encode_decode_roundtrip() {
211 let (encoder, decoder) = encode_decode_config();
212 let claims = Claims::new(TestClaims {
213 role: "admin".into(),
214 })
215 .with_sub("user_1")
216 .with_exp(now_secs() + 3600);
217 let token = make_token(&encoder, &claims);
218 let decoded: Claims<TestClaims> = decoder.decode(&token).unwrap();
219 assert_eq!(decoded.sub, claims.sub);
220 assert_eq!(decoded.custom.role, "admin");
221 }
222
223 #[test]
224 fn rejects_expired_token() {
225 let (encoder, decoder) = encode_decode_config();
226 let claims = Claims::new(TestClaims {
227 role: "admin".into(),
228 })
229 .with_exp(now_secs() - 10);
230 let token = make_token(&encoder, &claims);
231 let err = decoder.decode::<TestClaims>(&token).unwrap_err();
232 assert_eq!(err.error_code(), Some("jwt:expired"));
233 }
234
235 #[test]
236 fn respects_leeway_for_exp() {
237 let mut config = test_config();
238 config.leeway = 30;
239 let encoder = JwtEncoder::from_config(&config);
240 let decoder = JwtDecoder::from_config(&config);
241 let claims = Claims::new(TestClaims {
243 role: "admin".into(),
244 })
245 .with_exp(now_secs() - 10);
246 let token = encoder.encode(&claims).unwrap();
247 assert!(decoder.decode::<TestClaims>(&token).is_ok());
248 }
249
250 #[test]
251 fn rejects_token_before_nbf() {
252 let (encoder, decoder) = encode_decode_config();
253 let claims = Claims::new(TestClaims {
254 role: "admin".into(),
255 })
256 .with_exp(now_secs() + 3600)
257 .with_nbf(now_secs() + 3600);
258 let token = make_token(&encoder, &claims);
259 let err = decoder.decode::<TestClaims>(&token).unwrap_err();
260 assert_eq!(err.error_code(), Some("jwt:not_yet_valid"));
261 }
262
263 #[test]
264 fn rejects_wrong_issuer() {
265 let mut config = test_config();
266 config.issuer = Some("expected-app".into());
267 let encoder = JwtEncoder::from_config(&config);
268 let decoder = JwtDecoder::from_config(&config);
269 let claims = Claims::new(TestClaims {
270 role: "admin".into(),
271 })
272 .with_exp(now_secs() + 3600)
273 .with_iss("wrong-app");
274 let token = encoder.encode(&claims).unwrap();
275 let err = decoder.decode::<TestClaims>(&token).unwrap_err();
276 assert_eq!(err.error_code(), Some("jwt:invalid_issuer"));
277 }
278
279 #[test]
280 fn rejects_missing_issuer_when_required() {
281 let mut config = test_config();
282 config.issuer = Some("expected-app".into());
283 let encoder = JwtEncoder::from_config(&config);
284 let decoder = JwtDecoder::from_config(&config);
285 let claims = Claims::new(TestClaims {
286 role: "admin".into(),
287 })
288 .with_exp(now_secs() + 3600);
289 let token = encoder.encode(&claims).unwrap();
290 let err = decoder.decode::<TestClaims>(&token).unwrap_err();
291 assert_eq!(err.error_code(), Some("jwt:invalid_issuer"));
292 }
293
294 #[test]
295 fn accepts_when_no_issuer_policy() {
296 let (encoder, decoder) = encode_decode_config();
297 let claims = Claims::new(TestClaims {
298 role: "admin".into(),
299 })
300 .with_exp(now_secs() + 3600)
301 .with_iss("any-app");
302 let token = make_token(&encoder, &claims);
303 assert!(decoder.decode::<TestClaims>(&token).is_ok());
304 }
305
306 #[test]
307 fn rejects_wrong_audience() {
308 let mut config = test_config();
309 config.audience = Some("expected-aud".into());
310 let encoder = JwtEncoder::from_config(&config);
311 let decoder = JwtDecoder::from_config(&config);
312 let claims = Claims::new(TestClaims {
313 role: "admin".into(),
314 })
315 .with_exp(now_secs() + 3600)
316 .with_aud("wrong-aud");
317 let token = encoder.encode(&claims).unwrap();
318 let err = decoder.decode::<TestClaims>(&token).unwrap_err();
319 assert_eq!(err.error_code(), Some("jwt:invalid_audience"));
320 }
321
322 #[test]
323 fn rejects_tampered_signature() {
324 let (encoder, decoder) = encode_decode_config();
325 let claims = Claims::new(TestClaims {
326 role: "admin".into(),
327 })
328 .with_exp(now_secs() + 3600);
329 let mut token = make_token(&encoder, &claims);
330 let idx = token.len() - 5;
332 let original = token.as_bytes()[idx];
333 let replacement = if original == b'A' { b'B' } else { b'A' };
334 unsafe { token.as_bytes_mut()[idx] = replacement };
336 let err = decoder.decode::<TestClaims>(&token).unwrap_err();
337 assert_eq!(err.error_code(), Some("jwt:invalid_signature"));
338 }
339
340 #[test]
341 fn rejects_malformed_token() {
342 let decoder = JwtDecoder::from_config(&test_config());
343 let err = decoder
344 .decode::<TestClaims>("not.a.valid.token.at.all")
345 .unwrap_err();
346 assert_eq!(err.error_code(), Some("jwt:malformed_token"));
347 }
348
349 #[test]
350 fn rejects_token_with_wrong_algorithm() {
351 let (encoder, _) = encode_decode_config();
352 let claims = Claims::new(TestClaims {
353 role: "admin".into(),
354 })
355 .with_exp(now_secs() + 3600);
356 let token = encoder.encode(&claims).unwrap();
357 let parts: Vec<&str> = token.splitn(3, '.').collect();
359 let header_bytes = base64url::decode(parts[0]).unwrap();
360 let header_str = String::from_utf8(header_bytes).unwrap();
361 let tampered_header = header_str.replace("HS256", "RS256");
362 let tampered_header_b64 = base64url::encode(tampered_header.as_bytes());
363 let tampered_token = format!("{}.{}.{}", tampered_header_b64, parts[1], parts[2]);
364 let decoder = JwtDecoder::from_config(&test_config());
365 let err = decoder.decode::<TestClaims>(&tampered_token).unwrap_err();
366 assert_eq!(err.error_code(), Some("jwt:algorithm_mismatch"));
367 }
368
369 #[test]
370 fn rejects_missing_exp() {
371 let (encoder, decoder) = encode_decode_config();
372 let claims = Claims::new(TestClaims {
373 role: "admin".into(),
374 });
375 let token = encoder.encode(&claims).unwrap();
376 let err = decoder.decode::<TestClaims>(&token).unwrap_err();
377 assert_eq!(err.error_code(), Some("jwt:expired"));
378 }
379
380 #[test]
381 fn from_encoder_shares_verifier() {
382 let config = test_config();
383 let encoder = JwtEncoder::from_config(&config);
384 let decoder = JwtDecoder::from(&encoder);
385 let claims = Claims::new(TestClaims {
386 role: "admin".into(),
387 })
388 .with_exp(now_secs() + 3600);
389 let token = encoder.encode(&claims).unwrap();
390 assert!(decoder.decode::<TestClaims>(&token).is_ok());
391 }
392}