toolcraft_jwt/
lib.rs

1pub mod error;
2
3use chrono::{Duration, Utc};
4use jsonwebtoken::{DecodingKey, EncodingKey, Header, TokenData, Validation, decode, encode};
5use serde::{Deserialize, Serialize};
6
7use crate::error::Error;
8pub type Result<T> = std::result::Result<T, Error>;
9
10/// Struct representing the JWT configuration parameters.
11#[derive(Debug, Deserialize)]
12pub struct JwtCfg {
13    pub access_secret: String,
14    pub refresh_secret: String,
15    pub audience: String,
16    pub access_token_duration: usize,
17    pub refresh_token_duration: usize,
18    pub access_key_validate_exp: bool,
19    pub refresh_key_validate_exp: bool,
20}
21
22/// Represents the JWT claims.
23#[derive(Debug, Serialize, Deserialize)]
24pub struct Claims {
25    pub aud: String,
26    pub sub: String,
27    pub exp: usize,
28    pub iat: usize,
29}
30
31impl Claims {
32    /// Creates a new `Claims` instance.
33    pub fn new(aud: String, sub: String, exp: usize, iat: usize) -> Self {
34        Self { aud, sub, exp, iat }
35    }
36}
37
38/// Enum representing the type of token: ACCESS or REFRESH.
39enum TokenKind {
40    Access,
41    Refesh,
42}
43
44/// Struct representing the JWT configuration and operations.
45#[derive(Clone)]
46pub struct Jwt {
47    header: Header,
48    encoding_access_key: EncodingKey,
49    encoding_refresh_key: EncodingKey,
50    decoding_access_key: DecodingKey,
51    decoding_refresh_key: DecodingKey,
52    validation_access_key: Validation,
53    validation_refresh_key: Validation,
54    aud: String,
55    access_token_duration: usize,
56    refresh_token_duration: usize,
57}
58
59impl Jwt {
60    /// Creates a new `Jwt` instance from the given configuration.
61    ///
62    /// # Arguments
63    ///
64    /// * `cfg` - A `JwtCfg` struct containing the JWT configuration.
65    ///
66    /// # Returns
67    ///
68    /// * A new `Jwt` instance.
69    pub fn new(cfg: JwtCfg) -> Self {
70        let header = Header::default();
71        let encoding_access_key = EncodingKey::from_secret(cfg.access_secret.as_bytes());
72        let encoding_refresh_key = EncodingKey::from_secret(cfg.refresh_secret.as_bytes());
73        let decoding_access_key = DecodingKey::from_secret(cfg.access_secret.as_bytes());
74        let decoding_refresh_key = DecodingKey::from_secret(cfg.refresh_secret.as_bytes());
75        let mut validation_access_key = Validation::default();
76        validation_access_key.set_audience(std::slice::from_ref(&cfg.audience));
77        let mut validation_refresh_key = validation_access_key.clone();
78        validation_access_key.validate_exp = cfg.access_key_validate_exp;
79        validation_refresh_key.validate_exp = cfg.refresh_key_validate_exp;
80        validation_refresh_key.required_spec_claims.clear();
81        Self {
82            header,
83            encoding_access_key,
84            encoding_refresh_key,
85            decoding_access_key,
86            decoding_refresh_key,
87            validation_access_key,
88            validation_refresh_key,
89            aud: cfg.audience,
90            access_token_duration: cfg.access_token_duration,
91            refresh_token_duration: cfg.refresh_token_duration,
92        }
93    }
94
95    /// Generates a pair of access and refresh tokens.
96    ///
97    /// # Arguments
98    ///
99    /// * `sub` - The subject for which the tokens are generated.
100    ///
101    /// # Returns
102    ///
103    /// * A `Result` containing a tuple of the access token and the refresh token, or an `Error`.
104    pub fn generate_token_pair(&self, sub: String) -> Result<(String, String)> {
105        let access_token = self.generate_token(&TokenKind::Access, &sub)?;
106        let refresh_token = self.generate_token(&TokenKind::Refesh, &sub)?;
107        Ok((access_token, refresh_token))
108    }
109
110    /// Generates an access token.
111    ///
112    /// # Arguments
113    ///
114    /// * `sub` - The subject for which the access token is generated.
115    ///
116    /// # Returns
117    ///
118    /// * A `Result` containing the generated access token as a string, or an `Error`.
119    pub fn generate_access_token(&self, sub: String) -> Result<String> {
120        self.generate_token(&TokenKind::Access, &sub)
121    }
122
123    /// Refreshes an access token using a refresh token.
124    ///
125    /// # Arguments
126    ///
127    /// * `refresh_token` - The refresh token used to generate a new access token.
128    ///
129    /// # Returns
130    ///
131    /// * A `Result` containing the new access token, or an `Error`.
132    pub fn refresh_access_token(&self, refresh_token: &str) -> Result<String> {
133        let claims = self.validate_refresh_token(refresh_token)?;
134        self.generate_access_token(claims.sub)
135    }
136
137    /// Validates an access token.
138    ///
139    /// # Arguments
140    ///
141    /// * `token` - The access token to validate.
142    ///
143    /// # Returns
144    ///
145    /// * A `Result` containing the `Claims` if validation is successful, or an `Error`.
146    pub fn validate_access_token(&self, token: &str) -> Result<Claims> {
147        self.validate_token(&TokenKind::Access, token)
148            .map(|data| data.claims)
149    }
150
151    /// Validates a refresh token.
152    ///
153    /// # Arguments
154    ///
155    /// * `token` - The refresh token to validate.
156    ///
157    /// # Returns
158    ///
159    /// * A `Result` containing the `Claims` if validation is successful, or an `Error`.
160    pub fn validate_refresh_token(&self, token: &str) -> Result<Claims> {
161        self.validate_token(&TokenKind::Refesh, token)
162            .map(|data| data.claims)
163    }
164
165    /// Generates a token based on the token kind and subject.
166    ///
167    /// # Arguments
168    ///
169    /// * `kind` - The type of token (ACCESS or REFRESH).
170    /// * `sub` - The subject for which the token is generated.
171    ///
172    /// # Returns
173    ///
174    /// * A `Result` containing the generated token as a string, or an `Error`.
175    fn generate_token(&self, kind: &TokenKind, sub: &str) -> Result<String> {
176        let duration = self.get_token_duration(kind);
177        let (iat, exp) = self.generate_timestamps(duration);
178        let key = self.select_encoding_key(kind);
179        let claims = self.create_claims(sub, iat, exp);
180        encode(&self.header, &claims, key).map_err(|e| Error::AuthError(e.to_string().into()))
181    }
182
183    /// Validates a token based on the token kind.
184    ///
185    /// # Arguments
186    ///
187    /// * `kind` - The type of token (ACCESS or REFRESH).
188    /// * `token` - The token to validate.
189    ///
190    /// # Returns
191    ///
192    /// * A `Result` containing `TokenData<Claims>` if validation is successful, or an `Error`.
193    fn validate_token(&self, kind: &TokenKind, token: &str) -> Result<TokenData<Claims>> {
194        let (key, validation) = self.select_decoding_key_and_validation(kind);
195        decode::<Claims>(token, key, validation).map_err(|e| Error::AuthError(e.to_string().into()))
196    }
197
198    /// Selects the appropriate token duration based on the token kind.
199    ///
200    /// # Arguments
201    ///
202    /// * `kind` - The type of token (ACCESS or REFRESH).
203    ///
204    /// # Returns
205    ///
206    /// * The token duration in seconds.
207    fn get_token_duration(&self, kind: &TokenKind) -> usize {
208        match kind {
209            TokenKind::Access => self.access_token_duration,
210            TokenKind::Refesh => self.refresh_token_duration,
211        }
212    }
213
214    /// Generates the issued at (iat) and expiration (exp) times based on the provided duration.
215    ///
216    /// # Arguments
217    ///
218    /// * `duration` - The duration in seconds for which the token is valid.
219    ///
220    /// # Returns
221    ///
222    /// * A tuple containing the issued at time and expiration time as UNIX timestamps.
223    fn generate_timestamps(&self, duration: usize) -> (usize, usize) {
224        generate_expired_time(duration)
225    }
226
227    /// Selects the appropriate encoding key based on the token kind.
228    ///
229    /// # Arguments
230    ///
231    /// * `kind` - The type of token (ACCESS or REFRESH).
232    ///
233    /// # Returns
234    ///
235    /// * A reference to the selected `EncodingKey`.
236    fn select_encoding_key(&self, kind: &TokenKind) -> &EncodingKey {
237        match kind {
238            TokenKind::Access => &self.encoding_access_key,
239            TokenKind::Refesh => &self.encoding_refresh_key,
240        }
241    }
242
243    /// Creates a new `Claims` instance for the given subject, issued at time, and expiration time.
244    ///
245    /// # Arguments
246    ///
247    /// * `sub` - The subject for which the claims are generated.
248    /// * `iat` - The issued at time.
249    /// * `exp` - The expiration time.
250    ///
251    /// # Returns
252    ///
253    /// * A new `Claims` instance.
254    fn create_claims(&self, sub: &str, iat: usize, exp: usize) -> Claims {
255        Claims::new(self.aud.clone(), sub.to_string(), exp, iat)
256    }
257
258    /// Selects the appropriate decoding key and validation based on the token kind.
259    ///
260    /// # Arguments
261    ///
262    /// * `kind` - The type of token (ACCESS or REFRESH).
263    ///
264    /// # Returns
265    ///
266    /// * A tuple containing a reference to the selected `DecodingKey` and `Validation`.
267    fn select_decoding_key_and_validation(&self, kind: &TokenKind) -> (&DecodingKey, &Validation) {
268        match kind {
269            TokenKind::Access => (&self.decoding_access_key, &self.validation_access_key),
270            TokenKind::Refesh => (&self.decoding_refresh_key, &self.validation_refresh_key),
271        }
272    }
273}
274
275/// Generates the issued at (iat) and expiration (exp) times based on the provided duration.
276///
277/// # Arguments
278///
279/// * `duration` - The duration in seconds for which the token is valid.
280///
281/// # Returns
282///
283/// * A tuple containing the issued at time and expiration time as UNIX timestamps.
284fn generate_expired_time(duration: usize) -> (usize, usize) {
285    let now = Utc::now();
286    let iat = now.timestamp() as usize;
287    let exp = (now + Duration::seconds(duration as i64)).timestamp() as usize;
288    (iat, exp)
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294
295    /// Sets up a `Jwt` instance for testing.
296    ///
297    /// # Returns
298    ///
299    /// * A `Jwt` instance with test configuration.
300    fn setup_jwt() -> Jwt {
301        Jwt::new(JwtCfg {
302            access_secret: "access_secret".to_string(),
303            refresh_secret: "refresh_secret".to_string(),
304            audience: "test_audience".to_string(),
305            access_token_duration: 3600, // 1 hour
306            refresh_token_duration: 86400,
307            access_key_validate_exp: true,
308            refresh_key_validate_exp: true,
309        })
310    }
311
312    #[test]
313    fn test_generate_token_pair() {
314        let jwt = setup_jwt();
315        let (access_token, refresh_token) =
316            jwt.generate_token_pair("test_sub".to_string()).unwrap();
317
318        assert!(!access_token.is_empty());
319        assert!(!refresh_token.is_empty());
320    }
321
322    #[test]
323    fn test_generate_access_token() {
324        let jwt = setup_jwt();
325        let access_token = jwt.generate_access_token("test_sub".to_string()).unwrap();
326
327        assert!(!access_token.is_empty());
328    }
329
330    #[test]
331    fn test_validate_access_token() {
332        let jwt = setup_jwt();
333        let access_token = jwt.generate_access_token("test_sub".to_string()).unwrap();
334        let validation_result = jwt.validate_access_token(&access_token);
335
336        assert!(validation_result.is_ok());
337        let claims = validation_result.unwrap();
338        assert_eq!(claims.aud, "test_audience");
339        assert_eq!(claims.sub, "test_sub");
340    }
341
342    #[test]
343    fn test_validate_refresh_token() {
344        let jwt = setup_jwt();
345        let (_, refresh_token) = jwt.generate_token_pair("test_sub".to_string()).unwrap();
346        let validation_result = jwt.validate_refresh_token(&refresh_token);
347
348        assert!(validation_result.is_ok());
349        let claims = validation_result.unwrap();
350        assert_eq!(claims.aud, "test_audience");
351        assert_eq!(claims.sub, "test_sub");
352    }
353
354    #[test]
355    fn test_expired_access_token() {
356        use std::time::{Duration as StdDuration, SystemTime, UNIX_EPOCH};
357
358        let jwt = setup_jwt();
359        // Manually generate an expired token
360        let iat = (SystemTime::now() - StdDuration::from_secs(7200))
361            .duration_since(UNIX_EPOCH)
362            .unwrap()
363            .as_secs() as usize;
364        let exp = (SystemTime::now() - StdDuration::from_secs(3600))
365            .duration_since(UNIX_EPOCH)
366            .unwrap()
367            .as_secs() as usize;
368        let claims = Claims::new(
369            "test_audience".to_string(),
370            "test_sub".to_string(),
371            exp,
372            iat,
373        );
374        let access_token = encode(
375            &Header::default(),
376            &claims,
377            &EncodingKey::from_secret("access_secret".as_ref()),
378        )
379        .unwrap();
380
381        let validation_result = jwt.validate_access_token(&access_token);
382
383        assert!(validation_result.is_err());
384        match validation_result.unwrap_err() {
385            Error::AuthError(_) => (),
386            _ => panic!("Expected AuthError"),
387        }
388    }
389
390    #[test]
391    fn test_invalid_access_token() {
392        let jwt = setup_jwt();
393        let invalid_token = "invalid_token";
394
395        let validation_result = jwt.validate_access_token(invalid_token);
396
397        assert!(validation_result.is_err());
398        match validation_result.unwrap_err() {
399            Error::AuthError(_) => (),
400            _ => panic!("Expected AuthError"),
401        }
402    }
403
404    #[test]
405    fn test_refresh_access_token() {
406        let jwt = setup_jwt();
407        let (_, refresh_token) = jwt.generate_token_pair("test_sub".to_string()).unwrap();
408
409        let new_access_token = jwt.refresh_access_token(&refresh_token).unwrap();
410
411        assert!(!new_access_token.is_empty());
412    }
413}