Skip to main content

toolcraft_jwt/
jwt.rs

1use std::{fs, path::Path};
2
3use chrono::{Duration, Utc};
4use jsonwebtoken::{
5    Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation, decode, encode,
6};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9
10use crate::{AccessTokenVerifier, Result, error::Error};
11
12/// Struct representing the JWT configuration parameters.
13#[derive(Debug, Deserialize)]
14pub struct JwtCfg {
15    #[serde(default)]
16    pub key_dir: Option<String>,
17    #[serde(default)]
18    pub access_private_key_pem: Option<String>,
19    #[serde(default)]
20    pub access_public_key_pem: Option<String>,
21    #[serde(default)]
22    pub refresh_private_key_pem: Option<String>,
23    #[serde(default)]
24    pub refresh_public_key_pem: Option<String>,
25    pub issuer: String,
26    pub audience: String,
27    pub access_token_duration: usize,
28    pub refresh_token_duration: usize,
29    pub access_key_validate_exp: bool,
30    pub refresh_key_validate_exp: bool,
31}
32
33/// Represents the JWT claims.
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct Claims {
36    pub iss: String,
37    pub aud: String,
38    pub sub: String,
39    pub exp: usize,
40    pub iat: usize,
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub ext: Option<Value>,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct TokenPair {
47    pub access_token: String,
48    pub refresh_token: String,
49}
50
51impl Claims {
52    /// Creates a new `Claims` instance.
53    pub fn new(iss: String, aud: String, sub: String, exp: usize, iat: usize) -> Self {
54        Self::new_with_ext(iss, aud, sub, exp, iat, None)
55    }
56
57    /// Creates a new `Claims` instance with custom extension payload.
58    pub fn new_with_ext(
59        iss: String,
60        aud: String,
61        sub: String,
62        exp: usize,
63        iat: usize,
64        ext: Option<Value>,
65    ) -> Self {
66        Self {
67            iss,
68            aud,
69            sub,
70            exp,
71            iat,
72            ext,
73        }
74    }
75}
76
77/// Enum representing the type of token: ACCESS or REFRESH.
78enum TokenKind {
79    Access,
80    Refesh,
81}
82
83/// Struct representing the JWT configuration and operations.
84#[derive(Clone)]
85pub struct Jwt {
86    header: Header,
87    encoding_access_key: EncodingKey,
88    encoding_refresh_key: EncodingKey,
89    decoding_access_key: DecodingKey,
90    decoding_refresh_key: DecodingKey,
91    validation_access_key: Validation,
92    validation_refresh_key: Validation,
93    iss: String,
94    aud: String,
95    access_token_duration: usize,
96    refresh_token_duration: usize,
97}
98
99impl Jwt {
100    /// Creates a new `Jwt` instance from the given configuration.
101    pub fn new(cfg: JwtCfg) -> Self {
102        Self::try_new(cfg).expect("invalid jwt config")
103    }
104
105    /// Creates a new `Jwt` instance from the given configuration.
106    pub fn try_new(cfg: JwtCfg) -> Result<Self> {
107        let (
108            access_private_key_pem,
109            access_public_key_pem,
110            refresh_private_key_pem,
111            refresh_public_key_pem,
112        ) = resolve_key_material(&cfg)?;
113        let encoding_access_key = EncodingKey::from_ed_pem(access_private_key_pem.as_bytes())?;
114        let encoding_refresh_key = EncodingKey::from_ed_pem(refresh_private_key_pem.as_bytes())?;
115        let decoding_access_key = DecodingKey::from_ed_pem(access_public_key_pem.as_bytes())?;
116        let decoding_refresh_key = DecodingKey::from_ed_pem(refresh_public_key_pem.as_bytes())?;
117
118        let header = Header::new(Algorithm::EdDSA);
119        let mut validation_access_key = Validation::new(Algorithm::EdDSA);
120        validation_access_key.set_issuer(std::slice::from_ref(&cfg.issuer));
121        validation_access_key.set_audience(std::slice::from_ref(&cfg.audience));
122        let mut validation_refresh_key = validation_access_key.clone();
123        validation_access_key.validate_exp = cfg.access_key_validate_exp;
124        validation_refresh_key.validate_exp = cfg.refresh_key_validate_exp;
125        validation_refresh_key.required_spec_claims.clear();
126        Ok(Self {
127            header,
128            encoding_access_key,
129            encoding_refresh_key,
130            decoding_access_key,
131            decoding_refresh_key,
132            validation_access_key,
133            validation_refresh_key,
134            iss: cfg.issuer,
135            aud: cfg.audience,
136            access_token_duration: cfg.access_token_duration,
137            refresh_token_duration: cfg.refresh_token_duration,
138        })
139    }
140
141    /// Generates a pair of access and refresh tokens.
142    pub fn generate_token_pair(&self, sub: String, ext: Option<Value>) -> Result<TokenPair> {
143        let access_token = self.generate_token(&TokenKind::Access, &sub, ext.clone())?;
144        let refresh_token = self.generate_token(&TokenKind::Refesh, &sub, ext)?;
145        Ok(TokenPair {
146            access_token,
147            refresh_token,
148        })
149    }
150
151    /// Generates a pair of access and refresh tokens for subject only (`ext = None`).
152    pub fn generate_token_pair_for_subject(&self, sub: String) -> Result<TokenPair> {
153        self.generate_token_pair(sub, None)
154    }
155
156    /// Refreshes an access token using a refresh token.
157    pub fn refresh_access_token(&self, refresh_token: &str) -> Result<String> {
158        let claims = self.validate_refresh_token(refresh_token)?;
159        self.generate_token(&TokenKind::Access, &claims.sub, claims.ext)
160    }
161
162    /// Validates an access token.
163    pub fn validate_access_token(&self, token: &str) -> Result<Claims> {
164        self.validate_token(&TokenKind::Access, token)
165            .map(|data| data.claims)
166    }
167
168    /// Validates a refresh token.
169    pub fn validate_refresh_token(&self, token: &str) -> Result<Claims> {
170        self.validate_token(&TokenKind::Refesh, token)
171            .map(|data| data.claims)
172    }
173
174    fn generate_token(&self, kind: &TokenKind, sub: &str, ext: Option<Value>) -> Result<String> {
175        let duration = self.get_token_duration(kind);
176        let (iat, exp) = self.generate_timestamps(duration);
177        let key = self.select_encoding_key(kind);
178        let claims = self.create_claims(sub, iat, exp, ext);
179        encode(&self.header, &claims, key).map_err(|e| Error::AuthError(e.to_string().into()))
180    }
181
182    fn validate_token(&self, kind: &TokenKind, token: &str) -> Result<TokenData<Claims>> {
183        let (key, validation) = self.select_decoding_key_and_validation(kind);
184        decode::<Claims>(token, key, validation).map_err(|e| Error::AuthError(e.to_string().into()))
185    }
186
187    fn get_token_duration(&self, kind: &TokenKind) -> usize {
188        match kind {
189            TokenKind::Access => self.access_token_duration,
190            TokenKind::Refesh => self.refresh_token_duration,
191        }
192    }
193
194    fn generate_timestamps(&self, duration: usize) -> (usize, usize) {
195        generate_expired_time(duration)
196    }
197
198    fn select_encoding_key(&self, kind: &TokenKind) -> &EncodingKey {
199        match kind {
200            TokenKind::Access => &self.encoding_access_key,
201            TokenKind::Refesh => &self.encoding_refresh_key,
202        }
203    }
204
205    fn create_claims(&self, sub: &str, iat: usize, exp: usize, ext: Option<Value>) -> Claims {
206        Claims::new_with_ext(
207            self.iss.clone(),
208            self.aud.clone(),
209            sub.to_string(),
210            exp,
211            iat,
212            ext,
213        )
214    }
215
216    fn select_decoding_key_and_validation(&self, kind: &TokenKind) -> (&DecodingKey, &Validation) {
217        match kind {
218            TokenKind::Access => (&self.decoding_access_key, &self.validation_access_key),
219            TokenKind::Refesh => (&self.decoding_refresh_key, &self.validation_refresh_key),
220        }
221    }
222}
223
224impl AccessTokenVerifier for Jwt {
225    fn validate_access_token(&self, token: &str) -> Result<Claims> {
226        Jwt::validate_access_token(self, token)
227    }
228}
229
230fn generate_expired_time(duration: usize) -> (usize, usize) {
231    let now = Utc::now();
232    let iat = now.timestamp() as usize;
233    let exp = (now
234        + Duration::try_seconds(i64::try_from(duration).expect("duration overflow"))
235            .expect("duration out of range"))
236    .timestamp() as usize;
237    (iat, exp)
238}
239
240fn resolve_key_material(cfg: &JwtCfg) -> Result<(String, String, String, String)> {
241    if let Some(dir) = cfg.key_dir.as_deref() {
242        let dir = Path::new(dir);
243        let access_private = read_key_file(dir, "access_private_key.pem")?;
244        let access_public = read_key_file(dir, "access_public_key.pem")?;
245        let refresh_private = read_key_file(dir, "refresh_private_key.pem")?;
246        let refresh_public = read_key_file(dir, "refresh_public_key.pem")?;
247        return Ok((
248            access_private,
249            access_public,
250            refresh_private,
251            refresh_public,
252        ));
253    }
254
255    Ok((
256        require_non_empty(
257            cfg.access_private_key_pem.as_deref(),
258            "access_private_key_pem",
259        )?
260        .to_string(),
261        require_non_empty(
262            cfg.access_public_key_pem.as_deref(),
263            "access_public_key_pem",
264        )?
265        .to_string(),
266        require_non_empty(
267            cfg.refresh_private_key_pem.as_deref(),
268            "refresh_private_key_pem",
269        )?
270        .to_string(),
271        require_non_empty(
272            cfg.refresh_public_key_pem.as_deref(),
273            "refresh_public_key_pem",
274        )?
275        .to_string(),
276    ))
277}
278
279fn read_key_file(dir: &Path, file_name: &str) -> Result<String> {
280    let path = dir.join(file_name);
281    fs::read_to_string(&path).map_err(|e| {
282        Error::ErrorMessage(format!("failed to read key file {}: {e}", path.display()).into())
283    })
284}
285
286fn require_non_empty<'a>(value: Option<&'a str>, field_name: &str) -> Result<&'a str> {
287    value
288        .filter(|s| !s.is_empty())
289        .ok_or_else(|| Error::ErrorMessage(format!("missing required field: {field_name}").into()))
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295
296    const ACCESS_PRIVATE_KEY_PEM: &str = "-----BEGIN PRIVATE KEY-----
297MC4CAQAwBQYDK2VwBCIEIGrD/e7uKYqSY4twDEsRfMMuLSrODf14dpTiTK6K1YI0
298-----END PRIVATE KEY-----";
299    const ACCESS_PUBLIC_KEY_PEM: &str = "-----BEGIN PUBLIC KEY-----
300MCowBQYDK2VwAyEA2+Jj2UvNCvQiUPNYRgSi0cJSPiJI6Rs6D0UTeEpQVj8=
301-----END PUBLIC KEY-----";
302    const REFRESH_PRIVATE_KEY_PEM: &str = "-----BEGIN PRIVATE KEY-----
303MC4CAQAwBQYDK2VwBCIEIGrD/e7uKYqSY4twDEsRfMMuLSrODf14dpTiTK6K1YI0
304-----END PRIVATE KEY-----";
305    const REFRESH_PUBLIC_KEY_PEM: &str = "-----BEGIN PUBLIC KEY-----
306MCowBQYDK2VwAyEA2+Jj2UvNCvQiUPNYRgSi0cJSPiJI6Rs6D0UTeEpQVj8=
307-----END PUBLIC KEY-----";
308
309    fn setup_jwt() -> Jwt {
310        Jwt::new(JwtCfg {
311            key_dir: None,
312            access_private_key_pem: Some(ACCESS_PRIVATE_KEY_PEM.to_string()),
313            access_public_key_pem: Some(ACCESS_PUBLIC_KEY_PEM.to_string()),
314            refresh_private_key_pem: Some(REFRESH_PRIVATE_KEY_PEM.to_string()),
315            refresh_public_key_pem: Some(REFRESH_PUBLIC_KEY_PEM.to_string()),
316            issuer: "test_issuer".to_string(),
317            audience: "test_audience".to_string(),
318            access_token_duration: 3600,
319            refresh_token_duration: 86400,
320            access_key_validate_exp: true,
321            refresh_key_validate_exp: true,
322        })
323    }
324
325    #[test]
326    fn test_generate_token_pair() {
327        let jwt = setup_jwt();
328        let token_pair = jwt
329            .generate_token_pair("test_sub".to_string(), None)
330            .unwrap();
331
332        assert!(!token_pair.access_token.is_empty());
333        assert!(!token_pair.refresh_token.is_empty());
334    }
335
336    #[test]
337    fn test_validate_access_token() {
338        let jwt = setup_jwt();
339        let token_pair = jwt
340            .generate_token_pair("test_sub".to_string(), None)
341            .unwrap();
342        let validation_result = jwt.validate_access_token(&token_pair.access_token);
343
344        assert!(validation_result.is_ok());
345        let claims = validation_result.unwrap();
346        assert_eq!(claims.iss, "test_issuer");
347        assert_eq!(claims.aud, "test_audience");
348        assert_eq!(claims.sub, "test_sub");
349    }
350
351    #[test]
352    fn test_validate_refresh_token() {
353        let jwt = setup_jwt();
354        let token_pair = jwt
355            .generate_token_pair("test_sub".to_string(), None)
356            .unwrap();
357        let validation_result = jwt.validate_refresh_token(&token_pair.refresh_token);
358
359        assert!(validation_result.is_ok());
360        let claims = validation_result.unwrap();
361        assert_eq!(claims.iss, "test_issuer");
362        assert_eq!(claims.aud, "test_audience");
363        assert_eq!(claims.sub, "test_sub");
364    }
365
366    #[test]
367    fn test_key_dir_config() {
368        use std::{
369            fs,
370            time::{SystemTime, UNIX_EPOCH},
371        };
372
373        let ts = SystemTime::now()
374            .duration_since(UNIX_EPOCH)
375            .unwrap()
376            .as_nanos();
377        let dir = std::env::temp_dir().join(format!("toolcraft_jwt_keys_{ts}"));
378        fs::create_dir_all(&dir).unwrap();
379        fs::write(dir.join("access_private_key.pem"), ACCESS_PRIVATE_KEY_PEM).unwrap();
380        fs::write(dir.join("access_public_key.pem"), ACCESS_PUBLIC_KEY_PEM).unwrap();
381        fs::write(dir.join("refresh_private_key.pem"), REFRESH_PRIVATE_KEY_PEM).unwrap();
382        fs::write(dir.join("refresh_public_key.pem"), REFRESH_PUBLIC_KEY_PEM).unwrap();
383
384        let jwt = Jwt::new(JwtCfg {
385            key_dir: Some(dir.to_string_lossy().to_string()),
386            access_private_key_pem: None,
387            access_public_key_pem: None,
388            refresh_private_key_pem: None,
389            refresh_public_key_pem: None,
390            issuer: "test_issuer".to_string(),
391            audience: "test_audience".to_string(),
392            access_token_duration: 3600,
393            refresh_token_duration: 86400,
394            access_key_validate_exp: true,
395            refresh_key_validate_exp: true,
396        });
397
398        let token_pair = jwt
399            .generate_token_pair("test_sub".to_string(), None)
400            .unwrap();
401        let claims = jwt.validate_access_token(&token_pair.access_token).unwrap();
402        assert_eq!(claims.sub, "test_sub");
403    }
404
405    #[test]
406    fn test_refresh_access_token_keeps_ext() {
407        let jwt = setup_jwt();
408        let token_pair = jwt
409            .generate_token_pair(
410                "test_sub".to_string(),
411                Some(serde_json::json!({"role":"admin"})),
412            )
413            .unwrap();
414        let access_token = jwt.refresh_access_token(&token_pair.refresh_token).unwrap();
415        let claims = jwt.validate_access_token(&access_token).unwrap();
416        assert_eq!(claims.ext.unwrap()["role"], "admin");
417    }
418}