api_tools/security/jwt/
mod.rs

1//! JWT module
2
3pub mod access_token;
4pub mod payload;
5
6use crate::server::axum::response::ApiError;
7use crate::{security::jwt::access_token::AccessToken, value_objects::datetime::UtcDateTime};
8use jsonwebtoken::errors::ErrorKind::ExpiredSignature;
9use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Validation, decode, encode};
10use serde::{Deserialize, Serialize};
11use std::fmt::{Debug, Formatter};
12use thiserror::Error;
13
14const JWT_ACCESS_LIFETIME_IN_MINUTES: i64 = 15; // 15 minutes
15const JWT_REFRESH_LIFETIME_IN_HOURS: i64 = 7 * 24; // 7 days
16
17/// JWT errors
18#[derive(Debug, Clone, PartialEq, Error)]
19pub enum JwtError {
20    #[error("Parse token error: {0}")]
21    ParseError(String),
22
23    #[error("Generate token error: {0}")]
24    GenerateError(String),
25
26    #[error("Invalid or unsupported algorithm: {0}")]
27    InvalidAlgorithm(String),
28
29    #[error("Encoding key error: {0}")]
30    EncodingKeyError(String),
31
32    #[error("Decoding key error: {0}")]
33    DecodingKeyError(String),
34
35    #[error("Expired token")]
36    ExpiredToken,
37}
38
39/// JWT error
40impl From<JwtError> for ApiError {
41    fn from(value: JwtError) -> Self {
42        Self::InternalServerError(value.to_string())
43    }
44}
45
46/// JWT representation
47#[derive(Clone)]
48pub struct Jwt {
49    /// The algorithm supported for signing/verifying JWT
50    algorithm: Algorithm,
51
52    /// Access Token lifetime (in minute)
53    /// The default value is 15 minutes.
54    access_lifetime: i64,
55
56    /// Refresh Token lifetime (in hour)
57    /// The default value is 7 days.
58    refresh_lifetime: i64,
59
60    /// Encoding key
61    encoding_key: Option<EncodingKey>,
62
63    /// Decoding key
64    decoding_key: Option<DecodingKey>,
65}
66
67impl Default for Jwt {
68    fn default() -> Self {
69        Self {
70            algorithm: Algorithm::HS512,
71            access_lifetime: JWT_ACCESS_LIFETIME_IN_MINUTES,
72            refresh_lifetime: JWT_REFRESH_LIFETIME_IN_HOURS,
73            encoding_key: None,
74            decoding_key: None,
75        }
76    }
77}
78
79impl Debug for Jwt {
80    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
81        write!(
82            f,
83            "JWT => algo: {:?}, access_lifetime: {}, refresh_lifetime: {}",
84            self.algorithm, self.access_lifetime, self.refresh_lifetime
85        )
86    }
87}
88
89impl Jwt {
90    /// Initialize a new `Jwt`
91    pub fn init(
92        algorithm: &str,
93        access_lifetime: i64,
94        refresh_lifetime: i64,
95        secret: Option<&str>,
96        private_key: Option<&str>,
97        public_key: Option<&str>,
98    ) -> Result<Self, JwtError> {
99        let mut jwt = Jwt {
100            algorithm: Self::algorithm_from_str(algorithm)?,
101            access_lifetime,
102            refresh_lifetime,
103            ..Default::default()
104        };
105
106        // Encoding key
107        match (secret, private_key, jwt.use_secret()) {
108            (Some(secret), _, true) => jwt.set_encoding_key(secret.trim())?,
109            (_, Some(private_key), false) => jwt.set_encoding_key(private_key.trim())?,
110            _ => return Err(JwtError::EncodingKeyError("invalid JWT encoding key".to_owned())),
111        }
112
113        // Decoding key
114        match (secret, public_key, jwt.use_secret()) {
115            (Some(secret), _, true) => jwt.set_decoding_key(secret.trim())?,
116            (_, Some(public_key), false) => jwt.set_decoding_key(public_key.trim())?,
117            _ => return Err(JwtError::DecodingKeyError("invalid JWT decoding key".to_owned())),
118        }
119
120        Ok(jwt)
121    }
122
123    /// Get access token lifetime
124    pub fn access_lifetime(&self) -> i64 {
125        self.access_lifetime
126    }
127
128    /// Get refresh token lifetime
129    pub fn refresh_lifetime(&self) -> i64 {
130        self.refresh_lifetime
131    }
132
133    /// Update access token lifetime (in minute)
134    pub fn set_access_lifetime(&mut self, duration: i64) {
135        self.access_lifetime = duration;
136    }
137
138    /// Update refresh token lifetime (in day)
139    pub fn set_refresh_lifetime(&mut self, duration: i64) {
140        self.refresh_lifetime = duration;
141    }
142
143    /// Update encoding key
144    pub fn set_encoding_key(&mut self, secret: &str) -> Result<(), JwtError> {
145        let key = match self.algorithm {
146            Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => EncodingKey::from_secret(secret.as_bytes()),
147            Algorithm::ES256 | Algorithm::ES384 => EncodingKey::from_ec_pem(secret.as_bytes())
148                .map_err(|err| JwtError::EncodingKeyError(err.to_string()))?,
149            Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => EncodingKey::from_rsa_pem(secret.as_bytes())
150                .map_err(|err| JwtError::EncodingKeyError(err.to_string()))?,
151            Algorithm::PS256 | Algorithm::PS384 | Algorithm::PS512 => EncodingKey::from_rsa_pem(secret.as_bytes())
152                .map_err(|err| JwtError::EncodingKeyError(err.to_string()))?,
153            Algorithm::EdDSA => EncodingKey::from_ed_pem(secret.as_bytes())
154                .map_err(|err| JwtError::EncodingKeyError(err.to_string()))?,
155        };
156
157        self.encoding_key = Some(key);
158
159        Ok(())
160    }
161
162    /// Update decoding key
163    pub fn set_decoding_key(&mut self, secret: &str) -> Result<(), JwtError> {
164        let key = match self.algorithm {
165            Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => DecodingKey::from_secret(secret.as_bytes()),
166            Algorithm::ES256 | Algorithm::ES384 => DecodingKey::from_ec_pem(secret.as_bytes())
167                .map_err(|err| JwtError::DecodingKeyError(err.to_string()))?,
168            Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => DecodingKey::from_rsa_pem(secret.as_bytes())
169                .map_err(|err| JwtError::DecodingKeyError(err.to_string()))?,
170            Algorithm::PS256 | Algorithm::PS384 | Algorithm::PS512 => DecodingKey::from_rsa_pem(secret.as_bytes())
171                .map_err(|err| JwtError::DecodingKeyError(err.to_string()))?,
172            Algorithm::EdDSA => DecodingKey::from_ed_pem(secret.as_bytes())
173                .map_err(|err| JwtError::DecodingKeyError(err.to_string()))?,
174        };
175
176        self.decoding_key = Some(key);
177
178        Ok(())
179    }
180
181    /// Generate JWT
182    pub fn generate<P: Debug + Serialize>(&self, payload: P, expired_at: UtcDateTime) -> Result<AccessToken, JwtError> {
183        let header = jsonwebtoken::Header::new(self.algorithm);
184
185        match self.encoding_key.clone() {
186            Some(encoding_key) => {
187                let token = encode(&header, &payload, &encoding_key)
188                    .map_err(|err| JwtError::EncodingKeyError(err.to_string()))?;
189
190                Ok(AccessToken { token, expired_at })
191            }
192            _ => Err(JwtError::EncodingKeyError("empty key".to_owned())),
193        }
194    }
195
196    /// Parse JWT
197    pub fn parse<P: Debug + for<'de> Deserialize<'de>>(&self, token: &AccessToken) -> Result<P, JwtError> {
198        let validation = Validation::new(self.algorithm);
199
200        match self.decoding_key.clone() {
201            Some(decoding_key) => {
202                let token = decode::<P>(&token.token, &decoding_key, &validation).map_err(|err| match err.kind() {
203                    ExpiredSignature => JwtError::ExpiredToken,
204                    _ => JwtError::DecodingKeyError(err.to_string()),
205                })?;
206
207                Ok(token.claims)
208            }
209            _ => Err(JwtError::DecodingKeyError("empty key".to_owned())),
210        }
211    }
212
213    /// Return true if a secret key is used instead of a pair of keys
214    fn use_secret(&self) -> bool {
215        self.algorithm == Algorithm::HS256 || self.algorithm == Algorithm::HS384 || self.algorithm == Algorithm::HS512
216    }
217
218    /// Convert `&str` to `Algorithm`
219    fn algorithm_from_str(algo: &str) -> Result<Algorithm, JwtError> {
220        Ok(match algo {
221            "HS256" => Algorithm::HS256,
222            "HS384" => Algorithm::HS384,
223            "HS512" => Algorithm::HS512,
224            "ES256" => Algorithm::ES256,
225            "ES384" => Algorithm::ES384,
226            _ => {
227                return Err(JwtError::InvalidAlgorithm(algo.to_string()));
228            }
229        })
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236
237    #[test]
238    fn test_jwt_use_secret() {
239        let jwt = Jwt::default();
240        assert!(jwt.use_secret());
241
242        let mut jwt = Jwt::default();
243        jwt.algorithm = Algorithm::ES256;
244        assert!(!jwt.use_secret());
245
246        jwt.algorithm = Algorithm::HS256;
247        assert!(jwt.use_secret());
248    }
249
250    #[test]
251    fn test_jwt_algorithm_from_str() {
252        assert_eq!(Jwt::algorithm_from_str("HS256").unwrap(), Algorithm::HS256);
253        assert_eq!(Jwt::algorithm_from_str("HS384").unwrap(), Algorithm::HS384);
254        assert_eq!(Jwt::algorithm_from_str("HS512").unwrap(), Algorithm::HS512);
255        assert_eq!(Jwt::algorithm_from_str("ES256").unwrap(), Algorithm::ES256);
256        assert_eq!(Jwt::algorithm_from_str("ES384").unwrap(), Algorithm::ES384);
257
258        let invalid_algo = Jwt::algorithm_from_str("ES512");
259        assert!(invalid_algo.is_err());
260        if let Err(e) = invalid_algo {
261            assert_eq!(e, JwtError::InvalidAlgorithm("ES512".to_string()));
262        }
263    }
264
265    #[test]
266    fn test_jwt_default() {
267        let jwt = Jwt::default();
268        assert_eq!(jwt.algorithm, Algorithm::HS512);
269        assert_eq!(jwt.access_lifetime, JWT_ACCESS_LIFETIME_IN_MINUTES);
270        assert_eq!(jwt.refresh_lifetime, JWT_REFRESH_LIFETIME_IN_HOURS);
271        assert!(jwt.encoding_key.is_none());
272        assert!(jwt.decoding_key.is_none());
273    }
274
275    #[test]
276    fn test_jwt_debug() {
277        let jwt = Jwt::default();
278        let debug_str = format!("{:?}", jwt);
279
280        assert_eq!(
281            debug_str,
282            format!("JWT => algo: HS512, access_lifetime: 15, refresh_lifetime: {}", 7 * 24)
283        );
284    }
285}