1use std::{fs, path::Path};
2
3use chrono::{Duration, Utc};
4use jsonwebtoken::{
5 Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation, decode, encode,
6};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9
10use crate::{AccessTokenVerifier, Result, error::Error};
11
12#[derive(Debug, Deserialize)]
14pub struct JwtCfg {
15 #[serde(default)]
16 pub key_dir: Option<String>,
17 #[serde(default)]
18 pub access_private_key_pem: Option<String>,
19 #[serde(default)]
20 pub access_public_key_pem: Option<String>,
21 #[serde(default)]
22 pub refresh_private_key_pem: Option<String>,
23 #[serde(default)]
24 pub refresh_public_key_pem: Option<String>,
25 pub issuer: String,
26 pub audience: String,
27 pub access_token_duration: usize,
28 pub refresh_token_duration: usize,
29 pub access_key_validate_exp: bool,
30 pub refresh_key_validate_exp: bool,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct Claims {
36 pub iss: String,
37 pub aud: String,
38 pub sub: String,
39 pub exp: usize,
40 pub iat: usize,
41 #[serde(skip_serializing_if = "Option::is_none")]
42 pub ext: Option<Value>,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct TokenPair {
47 pub access_token: String,
48 pub refresh_token: String,
49}
50
51impl Claims {
52 pub fn new(iss: String, aud: String, sub: String, exp: usize, iat: usize) -> Self {
54 Self::new_with_ext(iss, aud, sub, exp, iat, None)
55 }
56
57 pub fn new_with_ext(
59 iss: String,
60 aud: String,
61 sub: String,
62 exp: usize,
63 iat: usize,
64 ext: Option<Value>,
65 ) -> Self {
66 Self {
67 iss,
68 aud,
69 sub,
70 exp,
71 iat,
72 ext,
73 }
74 }
75}
76
77enum TokenKind {
79 Access,
80 Refesh,
81}
82
83#[derive(Clone)]
85pub struct Jwt {
86 header: Header,
87 encoding_access_key: EncodingKey,
88 encoding_refresh_key: EncodingKey,
89 decoding_access_key: DecodingKey,
90 decoding_refresh_key: DecodingKey,
91 validation_access_key: Validation,
92 validation_refresh_key: Validation,
93 iss: String,
94 aud: String,
95 access_token_duration: usize,
96 refresh_token_duration: usize,
97}
98
99impl Jwt {
100 pub fn new(cfg: JwtCfg) -> Self {
102 Self::try_new(cfg).expect("invalid jwt config")
103 }
104
105 pub fn try_new(cfg: JwtCfg) -> Result<Self> {
107 let (
108 access_private_key_pem,
109 access_public_key_pem,
110 refresh_private_key_pem,
111 refresh_public_key_pem,
112 ) = resolve_key_material(&cfg)?;
113 let encoding_access_key = EncodingKey::from_ed_pem(access_private_key_pem.as_bytes())?;
114 let encoding_refresh_key = EncodingKey::from_ed_pem(refresh_private_key_pem.as_bytes())?;
115 let decoding_access_key = DecodingKey::from_ed_pem(access_public_key_pem.as_bytes())?;
116 let decoding_refresh_key = DecodingKey::from_ed_pem(refresh_public_key_pem.as_bytes())?;
117
118 let header = Header::new(Algorithm::EdDSA);
119 let mut validation_access_key = Validation::new(Algorithm::EdDSA);
120 validation_access_key.set_issuer(std::slice::from_ref(&cfg.issuer));
121 validation_access_key.set_audience(std::slice::from_ref(&cfg.audience));
122 let mut validation_refresh_key = validation_access_key.clone();
123 validation_access_key.validate_exp = cfg.access_key_validate_exp;
124 validation_refresh_key.validate_exp = cfg.refresh_key_validate_exp;
125 validation_refresh_key.required_spec_claims.clear();
126 Ok(Self {
127 header,
128 encoding_access_key,
129 encoding_refresh_key,
130 decoding_access_key,
131 decoding_refresh_key,
132 validation_access_key,
133 validation_refresh_key,
134 iss: cfg.issuer,
135 aud: cfg.audience,
136 access_token_duration: cfg.access_token_duration,
137 refresh_token_duration: cfg.refresh_token_duration,
138 })
139 }
140
141 pub fn generate_token_pair(&self, sub: String, ext: Option<Value>) -> Result<TokenPair> {
143 let access_token = self.generate_token(&TokenKind::Access, &sub, ext.clone())?;
144 let refresh_token = self.generate_token(&TokenKind::Refesh, &sub, ext)?;
145 Ok(TokenPair {
146 access_token,
147 refresh_token,
148 })
149 }
150
151 pub fn generate_token_pair_for_subject(&self, sub: String) -> Result<TokenPair> {
153 self.generate_token_pair(sub, None)
154 }
155
156 pub fn refresh_access_token(&self, refresh_token: &str) -> Result<String> {
158 let claims = self.validate_refresh_token(refresh_token)?;
159 self.generate_token(&TokenKind::Access, &claims.sub, claims.ext)
160 }
161
162 pub fn validate_access_token(&self, token: &str) -> Result<Claims> {
164 self.validate_token(&TokenKind::Access, token)
165 .map(|data| data.claims)
166 }
167
168 pub fn validate_refresh_token(&self, token: &str) -> Result<Claims> {
170 self.validate_token(&TokenKind::Refesh, token)
171 .map(|data| data.claims)
172 }
173
174 fn generate_token(&self, kind: &TokenKind, sub: &str, ext: Option<Value>) -> Result<String> {
175 let duration = self.get_token_duration(kind);
176 let (iat, exp) = self.generate_timestamps(duration);
177 let key = self.select_encoding_key(kind);
178 let claims = self.create_claims(sub, iat, exp, ext);
179 encode(&self.header, &claims, key).map_err(|e| Error::AuthError(e.to_string().into()))
180 }
181
182 fn validate_token(&self, kind: &TokenKind, token: &str) -> Result<TokenData<Claims>> {
183 let (key, validation) = self.select_decoding_key_and_validation(kind);
184 decode::<Claims>(token, key, validation).map_err(|e| Error::AuthError(e.to_string().into()))
185 }
186
187 fn get_token_duration(&self, kind: &TokenKind) -> usize {
188 match kind {
189 TokenKind::Access => self.access_token_duration,
190 TokenKind::Refesh => self.refresh_token_duration,
191 }
192 }
193
194 fn generate_timestamps(&self, duration: usize) -> (usize, usize) {
195 generate_expired_time(duration)
196 }
197
198 fn select_encoding_key(&self, kind: &TokenKind) -> &EncodingKey {
199 match kind {
200 TokenKind::Access => &self.encoding_access_key,
201 TokenKind::Refesh => &self.encoding_refresh_key,
202 }
203 }
204
205 fn create_claims(&self, sub: &str, iat: usize, exp: usize, ext: Option<Value>) -> Claims {
206 Claims::new_with_ext(
207 self.iss.clone(),
208 self.aud.clone(),
209 sub.to_string(),
210 exp,
211 iat,
212 ext,
213 )
214 }
215
216 fn select_decoding_key_and_validation(&self, kind: &TokenKind) -> (&DecodingKey, &Validation) {
217 match kind {
218 TokenKind::Access => (&self.decoding_access_key, &self.validation_access_key),
219 TokenKind::Refesh => (&self.decoding_refresh_key, &self.validation_refresh_key),
220 }
221 }
222}
223
224impl AccessTokenVerifier for Jwt {
225 fn validate_access_token(&self, token: &str) -> Result<Claims> {
226 Jwt::validate_access_token(self, token)
227 }
228}
229
230fn generate_expired_time(duration: usize) -> (usize, usize) {
231 let now = Utc::now();
232 let iat = now.timestamp() as usize;
233 let exp = (now
234 + Duration::try_seconds(i64::try_from(duration).expect("duration overflow"))
235 .expect("duration out of range"))
236 .timestamp() as usize;
237 (iat, exp)
238}
239
240fn resolve_key_material(cfg: &JwtCfg) -> Result<(String, String, String, String)> {
241 if let Some(dir) = cfg.key_dir.as_deref() {
242 let dir = Path::new(dir);
243 let access_private = read_key_file(dir, "access_private_key.pem")?;
244 let access_public = read_key_file(dir, "access_public_key.pem")?;
245 let refresh_private = read_key_file(dir, "refresh_private_key.pem")?;
246 let refresh_public = read_key_file(dir, "refresh_public_key.pem")?;
247 return Ok((
248 access_private,
249 access_public,
250 refresh_private,
251 refresh_public,
252 ));
253 }
254
255 Ok((
256 require_non_empty(
257 cfg.access_private_key_pem.as_deref(),
258 "access_private_key_pem",
259 )?
260 .to_string(),
261 require_non_empty(
262 cfg.access_public_key_pem.as_deref(),
263 "access_public_key_pem",
264 )?
265 .to_string(),
266 require_non_empty(
267 cfg.refresh_private_key_pem.as_deref(),
268 "refresh_private_key_pem",
269 )?
270 .to_string(),
271 require_non_empty(
272 cfg.refresh_public_key_pem.as_deref(),
273 "refresh_public_key_pem",
274 )?
275 .to_string(),
276 ))
277}
278
279fn read_key_file(dir: &Path, file_name: &str) -> Result<String> {
280 let path = dir.join(file_name);
281 fs::read_to_string(&path).map_err(|e| {
282 Error::ErrorMessage(format!("failed to read key file {}: {e}", path.display()).into())
283 })
284}
285
286fn require_non_empty<'a>(value: Option<&'a str>, field_name: &str) -> Result<&'a str> {
287 value
288 .filter(|s| !s.is_empty())
289 .ok_or_else(|| Error::ErrorMessage(format!("missing required field: {field_name}").into()))
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295
296 const ACCESS_PRIVATE_KEY_PEM: &str = "-----BEGIN PRIVATE KEY-----
297MC4CAQAwBQYDK2VwBCIEIGrD/e7uKYqSY4twDEsRfMMuLSrODf14dpTiTK6K1YI0
298-----END PRIVATE KEY-----";
299 const ACCESS_PUBLIC_KEY_PEM: &str = "-----BEGIN PUBLIC KEY-----
300MCowBQYDK2VwAyEA2+Jj2UvNCvQiUPNYRgSi0cJSPiJI6Rs6D0UTeEpQVj8=
301-----END PUBLIC KEY-----";
302 const REFRESH_PRIVATE_KEY_PEM: &str = "-----BEGIN PRIVATE KEY-----
303MC4CAQAwBQYDK2VwBCIEIGrD/e7uKYqSY4twDEsRfMMuLSrODf14dpTiTK6K1YI0
304-----END PRIVATE KEY-----";
305 const REFRESH_PUBLIC_KEY_PEM: &str = "-----BEGIN PUBLIC KEY-----
306MCowBQYDK2VwAyEA2+Jj2UvNCvQiUPNYRgSi0cJSPiJI6Rs6D0UTeEpQVj8=
307-----END PUBLIC KEY-----";
308
309 fn setup_jwt() -> Jwt {
310 Jwt::new(JwtCfg {
311 key_dir: None,
312 access_private_key_pem: Some(ACCESS_PRIVATE_KEY_PEM.to_string()),
313 access_public_key_pem: Some(ACCESS_PUBLIC_KEY_PEM.to_string()),
314 refresh_private_key_pem: Some(REFRESH_PRIVATE_KEY_PEM.to_string()),
315 refresh_public_key_pem: Some(REFRESH_PUBLIC_KEY_PEM.to_string()),
316 issuer: "test_issuer".to_string(),
317 audience: "test_audience".to_string(),
318 access_token_duration: 3600,
319 refresh_token_duration: 86400,
320 access_key_validate_exp: true,
321 refresh_key_validate_exp: true,
322 })
323 }
324
325 #[test]
326 fn test_generate_token_pair() {
327 let jwt = setup_jwt();
328 let token_pair = jwt
329 .generate_token_pair("test_sub".to_string(), None)
330 .unwrap();
331
332 assert!(!token_pair.access_token.is_empty());
333 assert!(!token_pair.refresh_token.is_empty());
334 }
335
336 #[test]
337 fn test_validate_access_token() {
338 let jwt = setup_jwt();
339 let token_pair = jwt
340 .generate_token_pair("test_sub".to_string(), None)
341 .unwrap();
342 let validation_result = jwt.validate_access_token(&token_pair.access_token);
343
344 assert!(validation_result.is_ok());
345 let claims = validation_result.unwrap();
346 assert_eq!(claims.iss, "test_issuer");
347 assert_eq!(claims.aud, "test_audience");
348 assert_eq!(claims.sub, "test_sub");
349 }
350
351 #[test]
352 fn test_validate_refresh_token() {
353 let jwt = setup_jwt();
354 let token_pair = jwt
355 .generate_token_pair("test_sub".to_string(), None)
356 .unwrap();
357 let validation_result = jwt.validate_refresh_token(&token_pair.refresh_token);
358
359 assert!(validation_result.is_ok());
360 let claims = validation_result.unwrap();
361 assert_eq!(claims.iss, "test_issuer");
362 assert_eq!(claims.aud, "test_audience");
363 assert_eq!(claims.sub, "test_sub");
364 }
365
366 #[test]
367 fn test_key_dir_config() {
368 use std::{
369 fs,
370 time::{SystemTime, UNIX_EPOCH},
371 };
372
373 let ts = SystemTime::now()
374 .duration_since(UNIX_EPOCH)
375 .unwrap()
376 .as_nanos();
377 let dir = std::env::temp_dir().join(format!("toolcraft_jwt_keys_{ts}"));
378 fs::create_dir_all(&dir).unwrap();
379 fs::write(dir.join("access_private_key.pem"), ACCESS_PRIVATE_KEY_PEM).unwrap();
380 fs::write(dir.join("access_public_key.pem"), ACCESS_PUBLIC_KEY_PEM).unwrap();
381 fs::write(dir.join("refresh_private_key.pem"), REFRESH_PRIVATE_KEY_PEM).unwrap();
382 fs::write(dir.join("refresh_public_key.pem"), REFRESH_PUBLIC_KEY_PEM).unwrap();
383
384 let jwt = Jwt::new(JwtCfg {
385 key_dir: Some(dir.to_string_lossy().to_string()),
386 access_private_key_pem: None,
387 access_public_key_pem: None,
388 refresh_private_key_pem: None,
389 refresh_public_key_pem: None,
390 issuer: "test_issuer".to_string(),
391 audience: "test_audience".to_string(),
392 access_token_duration: 3600,
393 refresh_token_duration: 86400,
394 access_key_validate_exp: true,
395 refresh_key_validate_exp: true,
396 });
397
398 let token_pair = jwt
399 .generate_token_pair("test_sub".to_string(), None)
400 .unwrap();
401 let claims = jwt.validate_access_token(&token_pair.access_token).unwrap();
402 assert_eq!(claims.sub, "test_sub");
403 }
404
405 #[test]
406 fn test_refresh_access_token_keeps_ext() {
407 let jwt = setup_jwt();
408 let token_pair = jwt
409 .generate_token_pair(
410 "test_sub".to_string(),
411 Some(serde_json::json!({"role":"admin"})),
412 )
413 .unwrap();
414 let access_token = jwt.refresh_access_token(&token_pair.refresh_token).unwrap();
415 let claims = jwt.validate_access_token(&access_token).unwrap();
416 assert_eq!(claims.ext.unwrap()["role"], "admin");
417 }
418}