1pub mod access_token;
4pub mod payload;
5
6use crate::server::axum::response::ApiError;
7use crate::{security::jwt::access_token::AccessToken, value_objects::datetime::UtcDateTime};
8use jsonwebtoken::errors::ErrorKind::ExpiredSignature;
9use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Validation, decode, encode};
10use serde::{Deserialize, Serialize};
11use std::fmt::{Debug, Formatter};
12use thiserror::Error;
13
14const JWT_ACCESS_LIFETIME_IN_MINUTES: i64 = 15; const JWT_REFRESH_LIFETIME_IN_HOURS: i64 = 7 * 24; #[derive(Debug, Clone, PartialEq, Error)]
19pub enum JwtError {
20 #[error("Parse token error: {0}")]
21 ParseError(String),
22
23 #[error("Generate token error: {0}")]
24 GenerateError(String),
25
26 #[error("Invalid or unsupported algorithm: {0}")]
27 InvalidAlgorithm(String),
28
29 #[error("Encoding key error: {0}")]
30 EncodingKeyError(String),
31
32 #[error("Decoding key error: {0}")]
33 DecodingKeyError(String),
34
35 #[error("Expired token")]
36 ExpiredToken,
37}
38
39impl From<JwtError> for ApiError {
41 fn from(value: JwtError) -> Self {
42 Self::InternalServerError(value.to_string())
43 }
44}
45
46#[derive(Clone)]
48pub struct Jwt {
49 algorithm: Algorithm,
51
52 access_lifetime: i64,
55
56 refresh_lifetime: i64,
59
60 encoding_key: Option<EncodingKey>,
62
63 decoding_key: Option<DecodingKey>,
65}
66
67impl Default for Jwt {
68 fn default() -> Self {
69 Self {
70 algorithm: Algorithm::HS512,
71 access_lifetime: JWT_ACCESS_LIFETIME_IN_MINUTES,
72 refresh_lifetime: JWT_REFRESH_LIFETIME_IN_HOURS,
73 encoding_key: None,
74 decoding_key: None,
75 }
76 }
77}
78
79impl Debug for Jwt {
80 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
81 write!(
82 f,
83 "JWT => algo: {:?}, access_lifetime: {}, refresh_lifetime: {}",
84 self.algorithm, self.access_lifetime, self.refresh_lifetime
85 )
86 }
87}
88
89impl Jwt {
90 pub fn init(
92 algorithm: &str,
93 access_lifetime: i64,
94 refresh_lifetime: i64,
95 secret: Option<&str>,
96 private_key: Option<&str>,
97 public_key: Option<&str>,
98 ) -> Result<Self, JwtError> {
99 let mut jwt = Jwt {
100 algorithm: Self::algorithm_from_str(algorithm)?,
101 access_lifetime,
102 refresh_lifetime,
103 ..Default::default()
104 };
105
106 match (secret, private_key, jwt.use_secret()) {
108 (Some(secret), _, true) => jwt.set_encoding_key(secret.trim())?,
109 (_, Some(private_key), false) => jwt.set_encoding_key(private_key.trim())?,
110 _ => return Err(JwtError::EncodingKeyError("invalid JWT encoding key".to_owned())),
111 }
112
113 match (secret, public_key, jwt.use_secret()) {
115 (Some(secret), _, true) => jwt.set_decoding_key(secret.trim())?,
116 (_, Some(public_key), false) => jwt.set_decoding_key(public_key.trim())?,
117 _ => return Err(JwtError::DecodingKeyError("invalid JWT decoding key".to_owned())),
118 }
119
120 Ok(jwt)
121 }
122
123 pub fn access_lifetime(&self) -> i64 {
125 self.access_lifetime
126 }
127
128 pub fn refresh_lifetime(&self) -> i64 {
130 self.refresh_lifetime
131 }
132
133 pub fn set_access_lifetime(&mut self, duration: i64) {
135 self.access_lifetime = duration;
136 }
137
138 pub fn set_refresh_lifetime(&mut self, duration: i64) {
140 self.refresh_lifetime = duration;
141 }
142
143 pub fn set_encoding_key(&mut self, secret: &str) -> Result<(), JwtError> {
145 let key = match self.algorithm {
146 Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => EncodingKey::from_secret(secret.as_bytes()),
147 Algorithm::ES256 | Algorithm::ES384 => EncodingKey::from_ec_pem(secret.as_bytes())
148 .map_err(|err| JwtError::EncodingKeyError(err.to_string()))?,
149 Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => EncodingKey::from_rsa_pem(secret.as_bytes())
150 .map_err(|err| JwtError::EncodingKeyError(err.to_string()))?,
151 Algorithm::PS256 | Algorithm::PS384 | Algorithm::PS512 => EncodingKey::from_rsa_pem(secret.as_bytes())
152 .map_err(|err| JwtError::EncodingKeyError(err.to_string()))?,
153 Algorithm::EdDSA => EncodingKey::from_ed_pem(secret.as_bytes())
154 .map_err(|err| JwtError::EncodingKeyError(err.to_string()))?,
155 };
156
157 self.encoding_key = Some(key);
158
159 Ok(())
160 }
161
162 pub fn set_decoding_key(&mut self, secret: &str) -> Result<(), JwtError> {
164 let key = match self.algorithm {
165 Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => DecodingKey::from_secret(secret.as_bytes()),
166 Algorithm::ES256 | Algorithm::ES384 => DecodingKey::from_ec_pem(secret.as_bytes())
167 .map_err(|err| JwtError::DecodingKeyError(err.to_string()))?,
168 Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => DecodingKey::from_rsa_pem(secret.as_bytes())
169 .map_err(|err| JwtError::DecodingKeyError(err.to_string()))?,
170 Algorithm::PS256 | Algorithm::PS384 | Algorithm::PS512 => DecodingKey::from_rsa_pem(secret.as_bytes())
171 .map_err(|err| JwtError::DecodingKeyError(err.to_string()))?,
172 Algorithm::EdDSA => DecodingKey::from_ed_pem(secret.as_bytes())
173 .map_err(|err| JwtError::DecodingKeyError(err.to_string()))?,
174 };
175
176 self.decoding_key = Some(key);
177
178 Ok(())
179 }
180
181 pub fn generate<P: Debug + Serialize>(&self, payload: P, expired_at: UtcDateTime) -> Result<AccessToken, JwtError> {
183 let header = jsonwebtoken::Header::new(self.algorithm);
184
185 match self.encoding_key.clone() {
186 Some(encoding_key) => {
187 let token = encode(&header, &payload, &encoding_key)
188 .map_err(|err| JwtError::EncodingKeyError(err.to_string()))?;
189
190 Ok(AccessToken { token, expired_at })
191 }
192 _ => Err(JwtError::EncodingKeyError("empty key".to_owned())),
193 }
194 }
195
196 pub fn parse<P: Debug + for<'de> Deserialize<'de>>(&self, token: &AccessToken) -> Result<P, JwtError> {
198 let validation = Validation::new(self.algorithm);
199
200 match self.decoding_key.clone() {
201 Some(decoding_key) => {
202 let token = decode::<P>(&token.token, &decoding_key, &validation).map_err(|err| match err.kind() {
203 ExpiredSignature => JwtError::ExpiredToken,
204 _ => JwtError::DecodingKeyError(err.to_string()),
205 })?;
206
207 Ok(token.claims)
208 }
209 _ => Err(JwtError::DecodingKeyError("empty key".to_owned())),
210 }
211 }
212
213 fn use_secret(&self) -> bool {
215 self.algorithm == Algorithm::HS256 || self.algorithm == Algorithm::HS384 || self.algorithm == Algorithm::HS512
216 }
217
218 fn algorithm_from_str(algo: &str) -> Result<Algorithm, JwtError> {
220 Ok(match algo {
221 "HS256" => Algorithm::HS256,
222 "HS384" => Algorithm::HS384,
223 "HS512" => Algorithm::HS512,
224 "ES256" => Algorithm::ES256,
225 "ES384" => Algorithm::ES384,
226 _ => {
227 return Err(JwtError::InvalidAlgorithm(algo.to_string()));
228 }
229 })
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 #[test]
238 fn test_jwt_use_secret() {
239 let jwt = Jwt::default();
240 assert!(jwt.use_secret());
241
242 let mut jwt = Jwt::default();
243 jwt.algorithm = Algorithm::ES256;
244 assert!(!jwt.use_secret());
245
246 jwt.algorithm = Algorithm::HS256;
247 assert!(jwt.use_secret());
248 }
249
250 #[test]
251 fn test_jwt_algorithm_from_str() {
252 assert_eq!(Jwt::algorithm_from_str("HS256").unwrap(), Algorithm::HS256);
253 assert_eq!(Jwt::algorithm_from_str("HS384").unwrap(), Algorithm::HS384);
254 assert_eq!(Jwt::algorithm_from_str("HS512").unwrap(), Algorithm::HS512);
255 assert_eq!(Jwt::algorithm_from_str("ES256").unwrap(), Algorithm::ES256);
256 assert_eq!(Jwt::algorithm_from_str("ES384").unwrap(), Algorithm::ES384);
257
258 let invalid_algo = Jwt::algorithm_from_str("ES512");
259 assert!(invalid_algo.is_err());
260 if let Err(e) = invalid_algo {
261 assert_eq!(e, JwtError::InvalidAlgorithm("ES512".to_string()));
262 }
263 }
264
265 #[test]
266 fn test_jwt_default() {
267 let jwt = Jwt::default();
268 assert_eq!(jwt.algorithm, Algorithm::HS512);
269 assert_eq!(jwt.access_lifetime, JWT_ACCESS_LIFETIME_IN_MINUTES);
270 assert_eq!(jwt.refresh_lifetime, JWT_REFRESH_LIFETIME_IN_HOURS);
271 assert!(jwt.encoding_key.is_none());
272 assert!(jwt.decoding_key.is_none());
273 }
274
275 #[test]
276 fn test_jwt_debug() {
277 let jwt = Jwt::default();
278 let debug_str = format!("{:?}", jwt);
279
280 assert_eq!(
281 debug_str,
282 format!("JWT => algo: HS512, access_lifetime: 15, refresh_lifetime: {}", 7 * 24)
283 );
284 }
285}