api_tools/server/axum/security/jwt/
mod.rs1pub mod access_token;
4pub mod payload;
5
6use crate::server::axum::response::ApiError;
7use crate::server::axum::security::jwt::access_token::AccessToken;
8use crate::value_objects::datetime::UtcDateTime;
9use jsonwebtoken::errors::ErrorKind::ExpiredSignature;
10use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Validation, decode, encode};
11use serde::{Deserialize, Serialize};
12use std::fmt::{Debug, Formatter};
13use thiserror::Error;
14
15const JWT_ACCESS_LIFETIME_IN_MINUTES: i64 = 15; const JWT_REFRESH_LIFETIME_IN_HOURS: i64 = 7 * 24; #[derive(Debug, Clone, PartialEq, Error)]
20pub enum JwtError {
21 #[error("Parse token error: {0}")]
22 ParseError(String),
23
24 #[error("Generate token error: {0}")]
25 GenerateError(String),
26
27 #[error("Invalid or unsupported algorithm: {0}")]
28 InvalidAlgorithm(String),
29
30 #[error("Encoding key error: {0}")]
31 EncodingKeyError(String),
32
33 #[error("Decoding key error: {0}")]
34 DecodingKeyError(String),
35
36 #[error("Expired token")]
37 ExpiredToken,
38}
39
40impl From<JwtError> for ApiError {
42 fn from(value: JwtError) -> Self {
43 Self::InternalServerError(value.to_string())
44 }
45}
46
47#[derive(Clone)]
49pub struct Jwt {
50 algorithm: Algorithm,
52
53 access_lifetime: i64,
56
57 refresh_lifetime: i64,
60
61 encoding_key: Option<EncodingKey>,
63
64 decoding_key: Option<DecodingKey>,
66}
67
68impl Default for Jwt {
69 fn default() -> Self {
70 Self {
71 algorithm: Algorithm::HS512,
72 access_lifetime: JWT_ACCESS_LIFETIME_IN_MINUTES,
73 refresh_lifetime: JWT_REFRESH_LIFETIME_IN_HOURS,
74 encoding_key: None,
75 decoding_key: None,
76 }
77 }
78}
79
80impl Debug for Jwt {
81 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
82 write!(
83 f,
84 "JWT => algo: {:?}, access_lifetime: {}, refresh_lifetime: {}",
85 self.algorithm, self.access_lifetime, self.refresh_lifetime
86 )
87 }
88}
89
90impl Jwt {
91 pub fn init(
93 algorithm: &str,
94 access_lifetime: i64,
95 refresh_lifetime: i64,
96 secret: Option<&str>,
97 private_key: Option<&str>,
98 public_key: Option<&str>,
99 ) -> Result<Self, JwtError> {
100 let mut jwt = Jwt {
101 algorithm: Self::algorithm_from_str(algorithm)?,
102 access_lifetime,
103 refresh_lifetime,
104 ..Default::default()
105 };
106
107 match (secret, private_key, jwt.use_secret()) {
109 (Some(secret), _, true) => jwt.set_encoding_key(secret.trim())?,
110 (_, Some(private_key), false) => jwt.set_encoding_key(private_key.trim())?,
111 _ => return Err(JwtError::EncodingKeyError("invalid JWT encoding key".to_owned())),
112 }
113
114 match (secret, public_key, jwt.use_secret()) {
116 (Some(secret), _, true) => jwt.set_decoding_key(secret.trim())?,
117 (_, Some(public_key), false) => jwt.set_decoding_key(public_key.trim())?,
118 _ => return Err(JwtError::DecodingKeyError("invalid JWT decoding key".to_owned())),
119 }
120
121 Ok(jwt)
122 }
123
124 pub fn access_lifetime(&self) -> i64 {
126 self.access_lifetime
127 }
128
129 pub fn refresh_lifetime(&self) -> i64 {
131 self.refresh_lifetime
132 }
133
134 pub fn set_access_lifetime(&mut self, duration: i64) {
136 self.access_lifetime = duration;
137 }
138
139 pub fn set_refresh_lifetime(&mut self, duration: i64) {
141 self.refresh_lifetime = duration;
142 }
143
144 pub fn set_encoding_key(&mut self, secret: &str) -> Result<(), JwtError> {
146 let key = match self.algorithm {
147 Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => EncodingKey::from_secret(secret.as_bytes()),
148 Algorithm::ES256 | Algorithm::ES384 => EncodingKey::from_ec_pem(secret.as_bytes())
149 .map_err(|err| JwtError::EncodingKeyError(err.to_string()))?,
150 Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => EncodingKey::from_rsa_pem(secret.as_bytes())
151 .map_err(|err| JwtError::EncodingKeyError(err.to_string()))?,
152 Algorithm::PS256 | Algorithm::PS384 | Algorithm::PS512 => EncodingKey::from_rsa_pem(secret.as_bytes())
153 .map_err(|err| JwtError::EncodingKeyError(err.to_string()))?,
154 Algorithm::EdDSA => EncodingKey::from_ed_pem(secret.as_bytes())
155 .map_err(|err| JwtError::EncodingKeyError(err.to_string()))?,
156 };
157
158 self.encoding_key = Some(key);
159
160 Ok(())
161 }
162
163 pub fn set_decoding_key(&mut self, secret: &str) -> Result<(), JwtError> {
165 let key = match self.algorithm {
166 Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => DecodingKey::from_secret(secret.as_bytes()),
167 Algorithm::ES256 | Algorithm::ES384 => DecodingKey::from_ec_pem(secret.as_bytes())
168 .map_err(|err| JwtError::DecodingKeyError(err.to_string()))?,
169 Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => DecodingKey::from_rsa_pem(secret.as_bytes())
170 .map_err(|err| JwtError::DecodingKeyError(err.to_string()))?,
171 Algorithm::PS256 | Algorithm::PS384 | Algorithm::PS512 => DecodingKey::from_rsa_pem(secret.as_bytes())
172 .map_err(|err| JwtError::DecodingKeyError(err.to_string()))?,
173 Algorithm::EdDSA => DecodingKey::from_ed_pem(secret.as_bytes())
174 .map_err(|err| JwtError::DecodingKeyError(err.to_string()))?,
175 };
176
177 self.decoding_key = Some(key);
178
179 Ok(())
180 }
181
182 pub fn generate<P: Debug + Serialize>(&self, payload: P, expired_at: UtcDateTime) -> Result<AccessToken, JwtError> {
184 let header = jsonwebtoken::Header::new(self.algorithm);
185
186 match self.encoding_key.clone() {
187 Some(encoding_key) => {
188 let token = encode(&header, &payload, &encoding_key)
189 .map_err(|err| JwtError::EncodingKeyError(err.to_string()))?;
190
191 Ok(AccessToken { token, expired_at })
192 }
193 _ => Err(JwtError::EncodingKeyError("empty key".to_owned())),
194 }
195 }
196
197 pub fn parse<P: Debug + for<'de> Deserialize<'de>>(&self, token: &AccessToken) -> Result<P, JwtError> {
199 let validation = Validation::new(self.algorithm);
200
201 match self.decoding_key.clone() {
202 Some(decoding_key) => {
203 let token = decode::<P>(&token.token, &decoding_key, &validation).map_err(|err| match err.kind() {
204 ExpiredSignature => JwtError::ExpiredToken,
205 _ => JwtError::DecodingKeyError(err.to_string()),
206 })?;
207
208 Ok(token.claims)
209 }
210 _ => Err(JwtError::DecodingKeyError("empty key".to_owned())),
211 }
212 }
213
214 fn use_secret(&self) -> bool {
216 self.algorithm == Algorithm::HS256 || self.algorithm == Algorithm::HS384 || self.algorithm == Algorithm::HS512
217 }
218
219 fn algorithm_from_str(algo: &str) -> Result<Algorithm, JwtError> {
221 Ok(match algo {
222 "HS256" => Algorithm::HS256,
223 "HS384" => Algorithm::HS384,
224 "HS512" => Algorithm::HS512,
225 "ES256" => Algorithm::ES256,
226 "ES384" => Algorithm::ES384,
227 _ => {
228 return Err(JwtError::InvalidAlgorithm(algo.to_string()));
229 }
230 })
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 #[test]
239 fn test_jwt_use_secret() {
240 let jwt = Jwt::default();
241 assert!(jwt.use_secret());
242
243 let mut jwt = Jwt::default();
244 jwt.algorithm = Algorithm::ES256;
245 assert!(!jwt.use_secret());
246
247 jwt.algorithm = Algorithm::HS256;
248 assert!(jwt.use_secret());
249 }
250
251 #[test]
252 fn test_jwt_algorithm_from_str() {
253 assert_eq!(Jwt::algorithm_from_str("HS256").unwrap(), Algorithm::HS256);
254 assert_eq!(Jwt::algorithm_from_str("HS384").unwrap(), Algorithm::HS384);
255 assert_eq!(Jwt::algorithm_from_str("HS512").unwrap(), Algorithm::HS512);
256 assert_eq!(Jwt::algorithm_from_str("ES256").unwrap(), Algorithm::ES256);
257 assert_eq!(Jwt::algorithm_from_str("ES384").unwrap(), Algorithm::ES384);
258
259 let invalid_algo = Jwt::algorithm_from_str("ES512");
260 assert!(invalid_algo.is_err());
261 if let Err(e) = invalid_algo {
262 assert_eq!(e, JwtError::InvalidAlgorithm("ES512".to_string()));
263 }
264 }
265
266 #[test]
267 fn test_jwt_default() {
268 let jwt = Jwt::default();
269 assert_eq!(jwt.algorithm, Algorithm::HS512);
270 assert_eq!(jwt.access_lifetime, JWT_ACCESS_LIFETIME_IN_MINUTES);
271 assert_eq!(jwt.refresh_lifetime, JWT_REFRESH_LIFETIME_IN_HOURS);
272 assert!(jwt.encoding_key.is_none());
273 assert!(jwt.decoding_key.is_none());
274 }
275
276 #[test]
277 fn test_jwt_debug() {
278 let jwt = Jwt::default();
279 let debug_str = format!("{:?}", jwt);
280
281 assert_eq!(
282 debug_str,
283 format!("JWT => algo: HS512, access_lifetime: 15, refresh_lifetime: {}", 7 * 24)
284 );
285 }
286}