1pub use jsonwebtoken;
2
3use std::{error, fmt};
4
5use jsonwebkey::Algorithm;
6use jsonwebtoken::{self as jwt, errors::Error as JwtError, Header as JwtHeader};
7use serde::de::DeserializeOwned;
8
9use crate::JsonWebKeySet;
10
11pub trait DecryptExt {
12 fn decrypt<JC>(
13 &self,
14 token: impl AsRef<str>,
15 skip_validate_exp: impl Into<Option<bool>>,
16 algorithms_supported: impl Into<Option<Vec<Algorithm>>>,
17 ) -> Result<(JwtHeader, JC), DecryptError>
18 where
19 JC: DeserializeOwned;
20}
21
22#[derive(Debug)]
23pub enum DecryptError {
24 DecodeHeaderFailed(JwtError),
25 KidMissing,
26 KidNotFound,
27 DecodeFailed(JwtError),
28}
29impl fmt::Display for DecryptError {
30 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31 write!(f, "{:?}", self)
32 }
33}
34impl error::Error for DecryptError {}
35
36impl DecryptExt for JsonWebKeySet {
40 fn decrypt<JC>(
41 &self,
42 token: impl AsRef<str>,
43 skip_validate_exp: impl Into<Option<bool>>,
44 algorithms_supported: impl Into<Option<Vec<Algorithm>>>,
45 ) -> Result<(JwtHeader, JC), DecryptError>
46 where
47 JC: DeserializeOwned,
48 {
49 let token = token.as_ref();
50
51 let jwt_header = jwt::decode_header(token).map_err(DecryptError::DecodeHeaderFailed)?;
52
53 let kid = jwt_header.kid.ok_or(DecryptError::KidMissing)?;
54 let jwt_alg = jwt_header.alg;
55
56 let jwk = self
57 .keys
58 .iter()
59 .find(|jwk| {
60 jwk.key_id == Some(kid.to_owned()) && jwk.algorithm.map(Into::into) == Some(jwt_alg)
61 })
62 .or_else(|| {
63 self.keys
64 .iter()
65 .find(|jwk| jwk.key_id == Some(kid.to_owned()))
66 })
67 .ok_or(DecryptError::KidNotFound)?;
68
69 let jwt_key = jwk.key.to_decoding_key();
70
71 let mut jwt_validation = jwt::Validation::default();
72
73 if let Some(skip_validate_exp) = skip_validate_exp.into() {
74 if skip_validate_exp {
75 jwt_validation.validate_exp = false;
76 }
77 }
78
79 if let Some(jwk_alg) = jwk.algorithm {
80 jwt_validation.algorithms = vec![jwk_alg.into()];
81 }
82 if let Some(algorithms_supported) = algorithms_supported.into() {
83 jwt_validation.algorithms = algorithms_supported.into_iter().map(Into::into).collect()
84 }
85
86 let jwt::TokenData { header, claims } =
87 jwt::decode(token, &jwt_key, &jwt_validation).map_err(DecryptError::DecodeFailed)?;
88
89 Ok((header, claims))
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96
97 use serde_json::{Map, Value};
98
99 #[test]
100 fn test_decrypt_with_apple_oidc() {
101 let jwk_set = serde_json::from_str::<JsonWebKeySet>(include_str!(
102 "../tests/oidc_keys_json_files/apple.json"
103 ))
104 .unwrap();
105
106 let id_token = include_str!("../tests/oidc_id_token_files/apple.txt");
107
108 let (header, claims): (_, Map<String, Value>) =
109 jwk_set.decrypt(id_token, true, None).unwrap();
110
111 assert_eq!(header.kid, Some("eXaunmL".to_owned()));
112 assert_eq!(header.alg, jwt::Algorithm::RS256);
113
114 assert_eq!(
115 claims.get("iss").unwrap().as_str().unwrap(),
116 "https://appleid.apple.com"
117 );
118 }
119
120 #[test]
121 fn test_decrypt_with_microsoft_oidc() {
122 let jwk_set = serde_json::from_str::<JsonWebKeySet>(include_str!(
123 "../tests/oidc_keys_json_files/microsoft.json"
124 ))
125 .unwrap();
126
127 let id_token = include_str!("../tests/oidc_id_token_files/microsoft.txt");
128
129 let (header, claims): (_, Map<String, Value>) = jwk_set
130 .decrypt(id_token, true, vec![Algorithm::RS256])
131 .unwrap();
132
133 assert_eq!(header.kid, Some("bW8ZcMjBCnJZS-ibX5UQDNStvx4".to_owned()));
134 assert_eq!(header.alg, jwt::Algorithm::RS256);
135
136 assert_eq!(
137 claims.get("iss").unwrap().as_str().unwrap(),
138 "https://login.microsoftonline.com/9188040d-6c67-4c5b-b112-36a304b66dad/v2.0"
139 );
140 }
141}