1pub mod error;
2
3use chrono::{Duration, Utc};
4use jsonwebtoken::{DecodingKey, EncodingKey, Header, TokenData, Validation, decode, encode};
5use serde::{Deserialize, Serialize};
6
7use crate::error::Error;
8pub type Result<T> = std::result::Result<T, Error>;
9
10#[derive(Debug, Deserialize)]
12pub struct JwtCfg {
13 pub access_secret: String,
14 pub refresh_secret: String,
15 pub audience: String,
16 pub access_token_duration: usize,
17 pub refresh_token_duration: usize,
18 pub access_key_validate_exp: bool,
19 pub refresh_key_validate_exp: bool,
20}
21
22#[derive(Debug, Serialize, Deserialize)]
24pub struct Claims {
25 pub aud: String,
26 pub sub: String,
27 pub exp: usize,
28 pub iat: usize,
29}
30
31impl Claims {
32 pub fn new(aud: String, sub: String, exp: usize, iat: usize) -> Self {
34 Self { aud, sub, exp, iat }
35 }
36}
37
38enum TokenKind {
40 Access,
41 Refesh,
42}
43
44#[derive(Clone)]
46pub struct Jwt {
47 header: Header,
48 encoding_access_key: EncodingKey,
49 encoding_refresh_key: EncodingKey,
50 decoding_access_key: DecodingKey,
51 decoding_refresh_key: DecodingKey,
52 validation_access_key: Validation,
53 validation_refresh_key: Validation,
54 aud: String,
55 access_token_duration: usize,
56 refresh_token_duration: usize,
57}
58
59impl Jwt {
60 pub fn new(cfg: JwtCfg) -> Self {
70 let header = Header::default();
71 let encoding_access_key = EncodingKey::from_secret(cfg.access_secret.as_bytes());
72 let encoding_refresh_key = EncodingKey::from_secret(cfg.refresh_secret.as_bytes());
73 let decoding_access_key = DecodingKey::from_secret(cfg.access_secret.as_bytes());
74 let decoding_refresh_key = DecodingKey::from_secret(cfg.refresh_secret.as_bytes());
75 let mut validation_access_key = Validation::default();
76 validation_access_key.set_audience(std::slice::from_ref(&cfg.audience));
77 let mut validation_refresh_key = validation_access_key.clone();
78 validation_access_key.validate_exp = cfg.access_key_validate_exp;
79 validation_refresh_key.validate_exp = cfg.refresh_key_validate_exp;
80 validation_refresh_key.required_spec_claims.clear();
81 Self {
82 header,
83 encoding_access_key,
84 encoding_refresh_key,
85 decoding_access_key,
86 decoding_refresh_key,
87 validation_access_key,
88 validation_refresh_key,
89 aud: cfg.audience,
90 access_token_duration: cfg.access_token_duration,
91 refresh_token_duration: cfg.refresh_token_duration,
92 }
93 }
94
95 pub fn generate_token_pair(&self, sub: String) -> Result<(String, String)> {
105 let access_token = self.generate_token(&TokenKind::Access, &sub)?;
106 let refresh_token = self.generate_token(&TokenKind::Refesh, &sub)?;
107 Ok((access_token, refresh_token))
108 }
109
110 pub fn generate_access_token(&self, sub: String) -> Result<String> {
120 self.generate_token(&TokenKind::Access, &sub)
121 }
122
123 pub fn refresh_access_token(&self, refresh_token: &str) -> Result<String> {
133 let claims = self.validate_refresh_token(refresh_token)?;
134 self.generate_access_token(claims.sub)
135 }
136
137 pub fn validate_access_token(&self, token: &str) -> Result<Claims> {
147 self.validate_token(&TokenKind::Access, token)
148 .map(|data| data.claims)
149 }
150
151 pub fn validate_refresh_token(&self, token: &str) -> Result<Claims> {
161 self.validate_token(&TokenKind::Refesh, token)
162 .map(|data| data.claims)
163 }
164
165 fn generate_token(&self, kind: &TokenKind, sub: &str) -> Result<String> {
176 let duration = self.get_token_duration(kind);
177 let (iat, exp) = self.generate_timestamps(duration);
178 let key = self.select_encoding_key(kind);
179 let claims = self.create_claims(sub, iat, exp);
180 encode(&self.header, &claims, key).map_err(|e| Error::AuthError(e.to_string().into()))
181 }
182
183 fn validate_token(&self, kind: &TokenKind, token: &str) -> Result<TokenData<Claims>> {
194 let (key, validation) = self.select_decoding_key_and_validation(kind);
195 decode::<Claims>(token, key, validation).map_err(|e| Error::AuthError(e.to_string().into()))
196 }
197
198 fn get_token_duration(&self, kind: &TokenKind) -> usize {
208 match kind {
209 TokenKind::Access => self.access_token_duration,
210 TokenKind::Refesh => self.refresh_token_duration,
211 }
212 }
213
214 fn generate_timestamps(&self, duration: usize) -> (usize, usize) {
224 generate_expired_time(duration)
225 }
226
227 fn select_encoding_key(&self, kind: &TokenKind) -> &EncodingKey {
237 match kind {
238 TokenKind::Access => &self.encoding_access_key,
239 TokenKind::Refesh => &self.encoding_refresh_key,
240 }
241 }
242
243 fn create_claims(&self, sub: &str, iat: usize, exp: usize) -> Claims {
255 Claims::new(self.aud.clone(), sub.to_string(), exp, iat)
256 }
257
258 fn select_decoding_key_and_validation(&self, kind: &TokenKind) -> (&DecodingKey, &Validation) {
268 match kind {
269 TokenKind::Access => (&self.decoding_access_key, &self.validation_access_key),
270 TokenKind::Refesh => (&self.decoding_refresh_key, &self.validation_refresh_key),
271 }
272 }
273}
274
275fn generate_expired_time(duration: usize) -> (usize, usize) {
285 let now = Utc::now();
286 let iat = now.timestamp() as usize;
287 let exp = (now
288 + Duration::try_seconds(i64::try_from(duration).expect("duration overflow"))
289 .expect("duration out of range"))
290 .timestamp() as usize;
291 (iat, exp)
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297
298 fn setup_jwt() -> Jwt {
304 Jwt::new(JwtCfg {
305 access_secret: "access_secret".to_string(),
306 refresh_secret: "refresh_secret".to_string(),
307 audience: "test_audience".to_string(),
308 access_token_duration: 3600, refresh_token_duration: 86400,
310 access_key_validate_exp: true,
311 refresh_key_validate_exp: true,
312 })
313 }
314
315 #[test]
316 fn test_generate_token_pair() {
317 let jwt = setup_jwt();
318 let (access_token, refresh_token) =
319 jwt.generate_token_pair("test_sub".to_string()).unwrap();
320
321 assert!(!access_token.is_empty());
322 assert!(!refresh_token.is_empty());
323 }
324
325 #[test]
326 fn test_generate_access_token() {
327 let jwt = setup_jwt();
328 let access_token = jwt.generate_access_token("test_sub".to_string()).unwrap();
329
330 assert!(!access_token.is_empty());
331 }
332
333 #[test]
334 fn test_validate_access_token() {
335 let jwt = setup_jwt();
336 let access_token = jwt.generate_access_token("test_sub".to_string()).unwrap();
337 let validation_result = jwt.validate_access_token(&access_token);
338
339 assert!(validation_result.is_ok());
340 let claims = validation_result.unwrap();
341 assert_eq!(claims.aud, "test_audience");
342 assert_eq!(claims.sub, "test_sub");
343 }
344
345 #[test]
346 fn test_validate_refresh_token() {
347 let jwt = setup_jwt();
348 let (_, refresh_token) = jwt.generate_token_pair("test_sub".to_string()).unwrap();
349 let validation_result = jwt.validate_refresh_token(&refresh_token);
350
351 assert!(validation_result.is_ok());
352 let claims = validation_result.unwrap();
353 assert_eq!(claims.aud, "test_audience");
354 assert_eq!(claims.sub, "test_sub");
355 }
356
357 #[test]
358 fn test_expired_access_token() {
359 use std::time::{Duration as StdDuration, SystemTime, UNIX_EPOCH};
360
361 let jwt = setup_jwt();
362 let iat = (SystemTime::now() - StdDuration::from_secs(7200))
364 .duration_since(UNIX_EPOCH)
365 .unwrap()
366 .as_secs() as usize;
367 let exp = (SystemTime::now() - StdDuration::from_secs(3600))
368 .duration_since(UNIX_EPOCH)
369 .unwrap()
370 .as_secs() as usize;
371 let claims = Claims::new(
372 "test_audience".to_string(),
373 "test_sub".to_string(),
374 exp,
375 iat,
376 );
377 let access_token = encode(
378 &Header::default(),
379 &claims,
380 &EncodingKey::from_secret("access_secret".as_ref()),
381 )
382 .unwrap();
383
384 let validation_result = jwt.validate_access_token(&access_token);
385
386 assert!(validation_result.is_err());
387 match validation_result.unwrap_err() {
388 Error::AuthError(_) => (),
389 _ => panic!("Expected AuthError"),
390 }
391 }
392
393 #[test]
394 fn test_invalid_access_token() {
395 let jwt = setup_jwt();
396 let invalid_token = "invalid_token";
397
398 let validation_result = jwt.validate_access_token(invalid_token);
399
400 assert!(validation_result.is_err());
401 match validation_result.unwrap_err() {
402 Error::AuthError(_) => (),
403 _ => panic!("Expected AuthError"),
404 }
405 }
406
407 #[test]
408 fn test_refresh_access_token() {
409 let jwt = setup_jwt();
410 let (_, refresh_token) = jwt.generate_token_pair("test_sub".to_string()).unwrap();
411
412 let new_access_token = jwt.refresh_access_token(&refresh_token).unwrap();
413
414 assert!(!new_access_token.is_empty());
415 }
416}