api_tools/server/axum/security/jwt/
mod.rs

1//! JWT module
2
3pub mod access_token;
4pub mod payload;
5
6use crate::server::axum::response::ApiError;
7use crate::server::axum::security::jwt::access_token::AccessToken;
8use crate::value_objects::datetime::UtcDateTime;
9use jsonwebtoken::errors::ErrorKind::ExpiredSignature;
10use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Validation, decode, encode};
11use serde::{Deserialize, Serialize};
12use std::fmt::{Debug, Formatter};
13use thiserror::Error;
14
15const JWT_ACCESS_LIFETIME_IN_MINUTES: i64 = 15; // 15 minutes
16const JWT_REFRESH_LIFETIME_IN_HOURS: i64 = 7 * 24; // 7 days
17
18/// JWT errors
19#[derive(Debug, Clone, PartialEq, Error)]
20pub enum JwtError {
21    #[error("Parse token error: {0}")]
22    ParseError(String),
23
24    #[error("Generate token error: {0}")]
25    GenerateError(String),
26
27    #[error("Invalid or unsupported algorithm: {0}")]
28    InvalidAlgorithm(String),
29
30    #[error("Encoding key error: {0}")]
31    EncodingKeyError(String),
32
33    #[error("Decoding key error: {0}")]
34    DecodingKeyError(String),
35
36    #[error("Expired token")]
37    ExpiredToken,
38}
39
40/// JWT error
41impl From<JwtError> for ApiError {
42    fn from(value: JwtError) -> Self {
43        Self::InternalServerError(value.to_string())
44    }
45}
46
47/// JWT representation
48#[derive(Clone)]
49pub struct Jwt {
50    /// The algorithm supported for signing/verifying JWT
51    algorithm: Algorithm,
52
53    /// Access Token lifetime (in minute)
54    /// The default value is 15 minutes.
55    access_lifetime: i64,
56
57    /// Refresh Token lifetime (in hour)
58    /// The default value is 7 days.
59    refresh_lifetime: i64,
60
61    /// Encoding key
62    encoding_key: Option<EncodingKey>,
63
64    /// Decoding key
65    decoding_key: Option<DecodingKey>,
66}
67
68impl Default for Jwt {
69    fn default() -> Self {
70        Self {
71            algorithm: Algorithm::HS512,
72            access_lifetime: JWT_ACCESS_LIFETIME_IN_MINUTES,
73            refresh_lifetime: JWT_REFRESH_LIFETIME_IN_HOURS,
74            encoding_key: None,
75            decoding_key: None,
76        }
77    }
78}
79
80impl Debug for Jwt {
81    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
82        write!(
83            f,
84            "JWT => algo: {:?}, access_lifetime: {}, refresh_lifetime: {}",
85            self.algorithm, self.access_lifetime, self.refresh_lifetime
86        )
87    }
88}
89
90impl Jwt {
91    /// Initialize a new `Jwt`
92    pub fn init(
93        algorithm: &str,
94        access_lifetime: i64,
95        refresh_lifetime: i64,
96        secret: Option<&str>,
97        private_key: Option<&str>,
98        public_key: Option<&str>,
99    ) -> Result<Self, JwtError> {
100        let mut jwt = Jwt {
101            algorithm: Self::algorithm_from_str(algorithm)?,
102            access_lifetime,
103            refresh_lifetime,
104            ..Default::default()
105        };
106
107        // Encoding key
108        match (secret, private_key, jwt.use_secret()) {
109            (Some(secret), _, true) => jwt.set_encoding_key(secret.trim())?,
110            (_, Some(private_key), false) => jwt.set_encoding_key(private_key.trim())?,
111            _ => return Err(JwtError::EncodingKeyError("invalid JWT encoding key".to_owned())),
112        }
113
114        // Decoding key
115        match (secret, public_key, jwt.use_secret()) {
116            (Some(secret), _, true) => jwt.set_decoding_key(secret.trim())?,
117            (_, Some(public_key), false) => jwt.set_decoding_key(public_key.trim())?,
118            _ => return Err(JwtError::DecodingKeyError("invalid JWT decoding key".to_owned())),
119        }
120
121        Ok(jwt)
122    }
123
124    /// Get access token lifetime
125    pub fn access_lifetime(&self) -> i64 {
126        self.access_lifetime
127    }
128
129    /// Get refresh token lifetime
130    pub fn refresh_lifetime(&self) -> i64 {
131        self.refresh_lifetime
132    }
133
134    /// Update access token lifetime (in minute)
135    pub fn set_access_lifetime(&mut self, duration: i64) {
136        self.access_lifetime = duration;
137    }
138
139    /// Update refresh token lifetime (in day)
140    pub fn set_refresh_lifetime(&mut self, duration: i64) {
141        self.refresh_lifetime = duration;
142    }
143
144    /// Update encoding key
145    pub fn set_encoding_key(&mut self, secret: &str) -> Result<(), JwtError> {
146        let key = match self.algorithm {
147            Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => EncodingKey::from_secret(secret.as_bytes()),
148            Algorithm::ES256 | Algorithm::ES384 => EncodingKey::from_ec_pem(secret.as_bytes())
149                .map_err(|err| JwtError::EncodingKeyError(err.to_string()))?,
150            Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => EncodingKey::from_rsa_pem(secret.as_bytes())
151                .map_err(|err| JwtError::EncodingKeyError(err.to_string()))?,
152            Algorithm::PS256 | Algorithm::PS384 | Algorithm::PS512 => EncodingKey::from_rsa_pem(secret.as_bytes())
153                .map_err(|err| JwtError::EncodingKeyError(err.to_string()))?,
154            Algorithm::EdDSA => EncodingKey::from_ed_pem(secret.as_bytes())
155                .map_err(|err| JwtError::EncodingKeyError(err.to_string()))?,
156        };
157
158        self.encoding_key = Some(key);
159
160        Ok(())
161    }
162
163    /// Update decoding key
164    pub fn set_decoding_key(&mut self, secret: &str) -> Result<(), JwtError> {
165        let key = match self.algorithm {
166            Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => DecodingKey::from_secret(secret.as_bytes()),
167            Algorithm::ES256 | Algorithm::ES384 => DecodingKey::from_ec_pem(secret.as_bytes())
168                .map_err(|err| JwtError::DecodingKeyError(err.to_string()))?,
169            Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => DecodingKey::from_rsa_pem(secret.as_bytes())
170                .map_err(|err| JwtError::DecodingKeyError(err.to_string()))?,
171            Algorithm::PS256 | Algorithm::PS384 | Algorithm::PS512 => DecodingKey::from_rsa_pem(secret.as_bytes())
172                .map_err(|err| JwtError::DecodingKeyError(err.to_string()))?,
173            Algorithm::EdDSA => DecodingKey::from_ed_pem(secret.as_bytes())
174                .map_err(|err| JwtError::DecodingKeyError(err.to_string()))?,
175        };
176
177        self.decoding_key = Some(key);
178
179        Ok(())
180    }
181
182    /// Generate JWT
183    pub fn generate<P: Debug + Serialize>(&self, payload: P, expired_at: UtcDateTime) -> Result<AccessToken, JwtError> {
184        let header = jsonwebtoken::Header::new(self.algorithm);
185
186        match self.encoding_key.clone() {
187            Some(encoding_key) => {
188                let token = encode(&header, &payload, &encoding_key)
189                    .map_err(|err| JwtError::EncodingKeyError(err.to_string()))?;
190
191                Ok(AccessToken { token, expired_at })
192            }
193            _ => Err(JwtError::EncodingKeyError("empty key".to_owned())),
194        }
195    }
196
197    /// Parse JWT
198    pub fn parse<P: Debug + for<'de> Deserialize<'de>>(&self, token: &AccessToken) -> Result<P, JwtError> {
199        let validation = Validation::new(self.algorithm);
200
201        match self.decoding_key.clone() {
202            Some(decoding_key) => {
203                let token = decode::<P>(&token.token, &decoding_key, &validation).map_err(|err| match err.kind() {
204                    ExpiredSignature => JwtError::ExpiredToken,
205                    _ => JwtError::DecodingKeyError(err.to_string()),
206                })?;
207
208                Ok(token.claims)
209            }
210            _ => Err(JwtError::DecodingKeyError("empty key".to_owned())),
211        }
212    }
213
214    /// Return true if a secret key is used instead of a pair of keys
215    fn use_secret(&self) -> bool {
216        self.algorithm == Algorithm::HS256 || self.algorithm == Algorithm::HS384 || self.algorithm == Algorithm::HS512
217    }
218
219    /// Convert `&str` to `Algorithm`
220    fn algorithm_from_str(algo: &str) -> Result<Algorithm, JwtError> {
221        Ok(match algo {
222            "HS256" => Algorithm::HS256,
223            "HS384" => Algorithm::HS384,
224            "HS512" => Algorithm::HS512,
225            "ES256" => Algorithm::ES256,
226            "ES384" => Algorithm::ES384,
227            _ => {
228                return Err(JwtError::InvalidAlgorithm(algo.to_string()));
229            }
230        })
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    #[test]
239    fn test_jwt_use_secret() {
240        let jwt = Jwt::default();
241        assert!(jwt.use_secret());
242
243        let mut jwt = Jwt::default();
244        jwt.algorithm = Algorithm::ES256;
245        assert!(!jwt.use_secret());
246
247        jwt.algorithm = Algorithm::HS256;
248        assert!(jwt.use_secret());
249    }
250
251    #[test]
252    fn test_jwt_algorithm_from_str() {
253        assert_eq!(Jwt::algorithm_from_str("HS256").unwrap(), Algorithm::HS256);
254        assert_eq!(Jwt::algorithm_from_str("HS384").unwrap(), Algorithm::HS384);
255        assert_eq!(Jwt::algorithm_from_str("HS512").unwrap(), Algorithm::HS512);
256        assert_eq!(Jwt::algorithm_from_str("ES256").unwrap(), Algorithm::ES256);
257        assert_eq!(Jwt::algorithm_from_str("ES384").unwrap(), Algorithm::ES384);
258
259        let invalid_algo = Jwt::algorithm_from_str("ES512");
260        assert!(invalid_algo.is_err());
261        if let Err(e) = invalid_algo {
262            assert_eq!(e, JwtError::InvalidAlgorithm("ES512".to_string()));
263        }
264    }
265
266    #[test]
267    fn test_jwt_default() {
268        let jwt = Jwt::default();
269        assert_eq!(jwt.algorithm, Algorithm::HS512);
270        assert_eq!(jwt.access_lifetime, JWT_ACCESS_LIFETIME_IN_MINUTES);
271        assert_eq!(jwt.refresh_lifetime, JWT_REFRESH_LIFETIME_IN_HOURS);
272        assert!(jwt.encoding_key.is_none());
273        assert!(jwt.decoding_key.is_none());
274    }
275
276    #[test]
277    fn test_jwt_debug() {
278        let jwt = Jwt::default();
279        let debug_str = format!("{:?}", jwt);
280
281        assert_eq!(
282            debug_str,
283            format!("JWT => algo: HS512, access_lifetime: 15, refresh_lifetime: {}", 7 * 24)
284        );
285    }
286}