1use std::time::Duration;
4
5use chrono::{DateTime, Utc};
6use jsonwebtoken::{DecodingKey, EncodingKey, Header, TokenData, Validation, decode, encode};
7use rand::RngCore;
8use serde::{Deserialize, Serialize};
9
10use super::AuthError;
11use super::users::UserRole;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct Claims {
16 pub sub: String,
18 pub username: String,
20 pub role: UserRole,
22 pub iat: i64,
24 pub exp: i64,
26 #[serde(default)]
28 pub token_type: TokenType,
29 #[serde(skip_serializing_if = "Option::is_none")]
31 pub family_id: Option<String>,
32}
33
34#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
36#[serde(rename_all = "lowercase")]
37pub enum TokenType {
38 #[default]
40 Access,
41 Refresh,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct TokenPair {
48 pub access_token: String,
50 pub refresh_token: String,
52 pub expires_at: DateTime<Utc>,
54 pub refresh_expires_at: DateTime<Utc>,
56 pub token_type: String,
58}
59
60pub struct JwtManager {
62 encoding_key: EncodingKey,
63 decoding_key: DecodingKey,
64 access_expiry: Duration,
65 refresh_expiry: Duration,
66}
67
68impl JwtManager {
69 #[must_use]
73 pub fn new(secret: &[u8], access_expiry: Duration, refresh_expiry: Duration) -> Self {
74 Self {
75 encoding_key: EncodingKey::from_secret(secret),
76 decoding_key: DecodingKey::from_secret(secret),
77 access_expiry,
78 refresh_expiry,
79 }
80 }
81
82 pub fn from_hex_secret(
88 hex_secret: &str,
89 access_expiry: Duration,
90 refresh_expiry: Duration,
91 ) -> Result<Self, AuthError> {
92 let secret = hex::decode(hex_secret)
93 .map_err(|e| AuthError::Config(format!("Invalid hex secret: {e}")))?;
94 Ok(Self::new(&secret, access_expiry, refresh_expiry))
95 }
96
97 #[must_use]
99 pub fn generate_secret() -> [u8; 32] {
100 let mut bytes = [0u8; 32];
101 rand::thread_rng().fill_bytes(&mut bytes);
102 bytes
103 }
104
105 #[must_use]
107 pub fn generate_hex_secret() -> String {
108 hex::encode(Self::generate_secret())
109 }
110
111 pub fn create_access_token(
117 &self,
118 user_id: &str,
119 username: &str,
120 role: UserRole,
121 ) -> Result<(String, DateTime<Utc>), AuthError> {
122 let now = Utc::now();
123 let exp = now + chrono::Duration::from_std(self.access_expiry).unwrap_or_default();
124
125 let claims = Claims {
126 sub: user_id.to_string(),
127 username: username.to_string(),
128 role,
129 iat: now.timestamp(),
130 exp: exp.timestamp(),
131 token_type: TokenType::Access,
132 family_id: None,
133 };
134
135 let token = encode(&Header::default(), &claims, &self.encoding_key)
136 .map_err(|e| AuthError::TokenError(format!("Encoding failed: {e}")))?;
137
138 Ok((token, exp))
139 }
140
141 pub fn create_refresh_token(
147 &self,
148 user_id: &str,
149 username: &str,
150 role: UserRole,
151 family_id: Option<String>,
152 ) -> Result<(String, DateTime<Utc>), AuthError> {
153 let now = Utc::now();
154 let exp = now + chrono::Duration::from_std(self.refresh_expiry).unwrap_or_default();
155
156 let family_id = family_id.unwrap_or_else(|| {
158 let mut bytes = [0u8; 16];
159 rand::thread_rng().fill_bytes(&mut bytes);
160 hex::encode(bytes)
161 });
162
163 let claims = Claims {
164 sub: user_id.to_string(),
165 username: username.to_string(),
166 role,
167 iat: now.timestamp(),
168 exp: exp.timestamp(),
169 token_type: TokenType::Refresh,
170 family_id: Some(family_id),
171 };
172
173 let token = encode(&Header::default(), &claims, &self.encoding_key)
174 .map_err(|e| AuthError::TokenError(format!("Encoding failed: {e}")))?;
175
176 Ok((token, exp))
177 }
178
179 pub fn create_token_pair(
185 &self,
186 user_id: &str,
187 username: &str,
188 role: UserRole,
189 ) -> Result<TokenPair, AuthError> {
190 let (access_token, expires_at) = self.create_access_token(user_id, username, role)?;
191 let (refresh_token, refresh_expires_at) =
192 self.create_refresh_token(user_id, username, role, None)?;
193
194 Ok(TokenPair {
195 access_token,
196 refresh_token,
197 expires_at,
198 refresh_expires_at,
199 token_type: "Bearer".to_string(),
200 })
201 }
202
203 pub fn validate_token(&self, token: &str) -> Result<Claims, AuthError> {
209 let validation = Validation::default();
210
211 let token_data: TokenData<Claims> = decode(token, &self.decoding_key, &validation)
212 .map_err(|e| AuthError::TokenError(format!("Validation failed: {e}")))?;
213
214 Ok(token_data.claims)
215 }
216
217 pub fn validate_access_token(&self, token: &str) -> Result<Claims, AuthError> {
223 let claims = self.validate_token(token)?;
224
225 if claims.token_type != TokenType::Access {
226 return Err(AuthError::TokenError("Not an access token".to_string()));
227 }
228
229 Ok(claims)
230 }
231
232 pub fn refresh_tokens(&self, refresh_token: &str) -> Result<TokenPair, AuthError> {
238 let claims = self.validate_token(refresh_token)?;
239
240 if claims.token_type != TokenType::Refresh {
241 return Err(AuthError::TokenError("Not a refresh token".to_string()));
242 }
243
244 let (access_token, expires_at) =
246 self.create_access_token(&claims.sub, &claims.username, claims.role)?;
247 let (new_refresh_token, refresh_expires_at) = self.create_refresh_token(
248 &claims.sub,
249 &claims.username,
250 claims.role,
251 claims.family_id,
252 )?;
253
254 Ok(TokenPair {
255 access_token,
256 refresh_token: new_refresh_token,
257 expires_at,
258 refresh_expires_at,
259 token_type: "Bearer".to_string(),
260 })
261 }
262
263 #[must_use]
267 pub fn extract_from_header(header: &str) -> Option<&str> {
268 header
269 .strip_prefix("Bearer ")
270 .or_else(|| header.strip_prefix("bearer "))
271 }
272}
273
274impl std::fmt::Debug for JwtManager {
275 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
276 f.debug_struct("JwtManager")
277 .field("access_expiry", &self.access_expiry)
278 .field("refresh_expiry", &self.refresh_expiry)
279 .finish_non_exhaustive()
280 }
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286
287 fn create_manager() -> JwtManager {
288 let secret = JwtManager::generate_secret();
289 JwtManager::new(
290 &secret,
291 Duration::from_secs(3600), Duration::from_secs(7 * 86400), )
294 }
295
296 #[test]
297 fn test_generate_secret() {
298 let secret1 = JwtManager::generate_secret();
299 let secret2 = JwtManager::generate_secret();
300 assert_ne!(secret1, secret2);
301 assert_eq!(secret1.len(), 32);
302 }
303
304 #[test]
305 fn test_create_access_token() {
306 let manager = create_manager();
307 let (token, expires) = manager
308 .create_access_token("user_123", "testuser", UserRole::Admin)
309 .unwrap();
310
311 assert!(!token.is_empty());
312 assert!(expires > Utc::now());
313 }
314
315 #[test]
316 fn test_validate_token() {
317 let manager = create_manager();
318 let (token, _) = manager
319 .create_access_token("user_123", "testuser", UserRole::Operator)
320 .unwrap();
321
322 let claims = manager.validate_token(&token).unwrap();
323 assert_eq!(claims.sub, "user_123");
324 assert_eq!(claims.username, "testuser");
325 assert_eq!(claims.role, UserRole::Operator);
326 assert_eq!(claims.token_type, TokenType::Access);
327 }
328
329 #[test]
330 fn test_token_pair() {
331 let manager = create_manager();
332 let pair = manager
333 .create_token_pair("user_123", "admin", UserRole::Admin)
334 .unwrap();
335
336 assert!(!pair.access_token.is_empty());
337 assert!(!pair.refresh_token.is_empty());
338 assert_eq!(pair.token_type, "Bearer");
339
340 let access_claims = manager.validate_access_token(&pair.access_token).unwrap();
342 assert_eq!(access_claims.token_type, TokenType::Access);
343
344 let refresh_claims = manager.validate_token(&pair.refresh_token).unwrap();
346 assert_eq!(refresh_claims.token_type, TokenType::Refresh);
347 }
348
349 #[test]
350 fn test_refresh_tokens() {
351 let manager = create_manager();
352 let pair = manager
353 .create_token_pair("user_123", "admin", UserRole::Admin)
354 .unwrap();
355
356 let new_pair = manager.refresh_tokens(&pair.refresh_token).unwrap();
358
359 let access_claims = manager
361 .validate_access_token(&new_pair.access_token)
362 .unwrap();
363 assert_eq!(access_claims.sub, "user_123");
364 assert_eq!(access_claims.username, "admin");
365 assert_eq!(access_claims.role, UserRole::Admin);
366
367 let refresh_claims = manager.validate_token(&new_pair.refresh_token).unwrap();
368 assert_eq!(refresh_claims.token_type, TokenType::Refresh);
369
370 let third_pair = manager.refresh_tokens(&new_pair.refresh_token).unwrap();
372 assert!(
373 manager
374 .validate_access_token(&third_pair.access_token)
375 .is_ok()
376 );
377 }
378
379 #[test]
380 fn test_invalid_token() {
381 let manager = create_manager();
382 let result = manager.validate_token("invalid.token.here");
383 assert!(result.is_err());
384 }
385
386 #[test]
387 fn test_extract_from_header() {
388 assert_eq!(
389 JwtManager::extract_from_header("Bearer abc123"),
390 Some("abc123")
391 );
392 assert_eq!(
393 JwtManager::extract_from_header("bearer abc123"),
394 Some("abc123")
395 );
396 assert_eq!(JwtManager::extract_from_header("abc123"), None);
397 }
398
399 #[test]
400 fn test_hex_secret() {
401 let hex_secret = JwtManager::generate_hex_secret();
402 assert_eq!(hex_secret.len(), 64); let manager = JwtManager::from_hex_secret(
405 &hex_secret,
406 Duration::from_secs(3600),
407 Duration::from_secs(86400),
408 )
409 .unwrap();
410
411 let (token, _) = manager
412 .create_access_token("user_123", "test", UserRole::Viewer)
413 .unwrap();
414 assert!(manager.validate_token(&token).is_ok());
415 }
416}