jwk_set/
decrypt.rs

1pub use jsonwebtoken;
2
3use std::{error, fmt};
4
5use jsonwebkey::Algorithm;
6use jsonwebtoken::{self as jwt, errors::Error as JwtError, Header as JwtHeader};
7use serde::de::DeserializeOwned;
8
9use crate::JsonWebKeySet;
10
11pub trait DecryptExt {
12    fn decrypt<JC>(
13        &self,
14        token: impl AsRef<str>,
15        skip_validate_exp: impl Into<Option<bool>>,
16        algorithms_supported: impl Into<Option<Vec<Algorithm>>>,
17    ) -> Result<(JwtHeader, JC), DecryptError>
18    where
19        JC: DeserializeOwned;
20}
21
22#[derive(Debug)]
23pub enum DecryptError {
24    DecodeHeaderFailed(JwtError),
25    KidMissing,
26    KidNotFound,
27    DecodeFailed(JwtError),
28}
29impl fmt::Display for DecryptError {
30    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31        write!(f, "{:?}", self)
32    }
33}
34impl error::Error for DecryptError {}
35
36//
37//
38//
39impl DecryptExt for JsonWebKeySet {
40    fn decrypt<JC>(
41        &self,
42        token: impl AsRef<str>,
43        skip_validate_exp: impl Into<Option<bool>>,
44        algorithms_supported: impl Into<Option<Vec<Algorithm>>>,
45    ) -> Result<(JwtHeader, JC), DecryptError>
46    where
47        JC: DeserializeOwned,
48    {
49        let token = token.as_ref();
50
51        let jwt_header = jwt::decode_header(token).map_err(DecryptError::DecodeHeaderFailed)?;
52
53        let kid = jwt_header.kid.ok_or(DecryptError::KidMissing)?;
54        let jwt_alg = jwt_header.alg;
55
56        let jwk = self
57            .keys
58            .iter()
59            .find(|jwk| {
60                jwk.key_id == Some(kid.to_owned()) && jwk.algorithm.map(Into::into) == Some(jwt_alg)
61            })
62            .or_else(|| {
63                self.keys
64                    .iter()
65                    .find(|jwk| jwk.key_id == Some(kid.to_owned()))
66            })
67            .ok_or(DecryptError::KidNotFound)?;
68
69        let jwt_key = jwk.key.to_decoding_key();
70
71        let mut jwt_validation = jwt::Validation::default();
72
73        if let Some(skip_validate_exp) = skip_validate_exp.into() {
74            if skip_validate_exp {
75                jwt_validation.validate_exp = false;
76            }
77        }
78
79        if let Some(jwk_alg) = jwk.algorithm {
80            jwt_validation.algorithms = vec![jwk_alg.into()];
81        }
82        if let Some(algorithms_supported) = algorithms_supported.into() {
83            jwt_validation.algorithms = algorithms_supported.into_iter().map(Into::into).collect()
84        }
85
86        let jwt::TokenData { header, claims } =
87            jwt::decode(token, &jwt_key, &jwt_validation).map_err(DecryptError::DecodeFailed)?;
88
89        Ok((header, claims))
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96
97    use serde_json::{Map, Value};
98
99    #[test]
100    fn test_decrypt_with_apple_oidc() {
101        let jwk_set = serde_json::from_str::<JsonWebKeySet>(include_str!(
102            "../tests/oidc_keys_json_files/apple.json"
103        ))
104        .unwrap();
105
106        let id_token = include_str!("../tests/oidc_id_token_files/apple.txt");
107
108        let (header, claims): (_, Map<String, Value>) =
109            jwk_set.decrypt(id_token, true, None).unwrap();
110
111        assert_eq!(header.kid, Some("eXaunmL".to_owned()));
112        assert_eq!(header.alg, jwt::Algorithm::RS256);
113
114        assert_eq!(
115            claims.get("iss").unwrap().as_str().unwrap(),
116            "https://appleid.apple.com"
117        );
118    }
119
120    #[test]
121    fn test_decrypt_with_microsoft_oidc() {
122        let jwk_set = serde_json::from_str::<JsonWebKeySet>(include_str!(
123            "../tests/oidc_keys_json_files/microsoft.json"
124        ))
125        .unwrap();
126
127        let id_token = include_str!("../tests/oidc_id_token_files/microsoft.txt");
128
129        let (header, claims): (_, Map<String, Value>) = jwk_set
130            .decrypt(id_token, true, vec![Algorithm::RS256])
131            .unwrap();
132
133        assert_eq!(header.kid, Some("bW8ZcMjBCnJZS-ibX5UQDNStvx4".to_owned()));
134        assert_eq!(header.alg, jwt::Algorithm::RS256);
135
136        assert_eq!(
137            claims.get("iss").unwrap().as_str().unwrap(),
138            "https://login.microsoftonline.com/9188040d-6c67-4c5b-b112-36a304b66dad/v2.0"
139        );
140    }
141}