1use chrono::{Duration, Utc};
7use jsonwebtoken::{
8 decode, encode, errors::Error as JwtError, Algorithm, DecodingKey, EncodingKey, Header,
9 Validation,
10};
11use serde::{Deserialize, Serialize};
12use std::fmt;
13use thiserror::Error;
14use uuid::Uuid;
15
16#[derive(Debug, Clone)]
18pub struct JwtConfig {
19 pub secret: String,
21
22 pub expiration_seconds: i64,
24
25 pub refresh_expiration_seconds: i64,
27
28 pub issuer: String,
30
31 pub audience: String,
33
34 pub algorithm: Algorithm,
36}
37
38impl JwtConfig {
40 pub fn issuer(&self) -> &str {
42 &self.issuer
43 }
44
45 pub fn audience(&self) -> &str {
47 &self.audience
48 }
49
50 pub fn expiration_seconds(&self) -> i64 {
52 self.expiration_seconds
53 }
54}
55
56impl Default for JwtConfig {
57 fn default() -> Self {
58 Self {
59 secret: "change-me-in-production".to_string(),
60 expiration_seconds: 3600, refresh_expiration_seconds: 86400 * 7, issuer: "llm-registry".to_string(),
63 audience: "llm-registry-api".to_string(),
64 algorithm: Algorithm::HS256,
65 }
66 }
67}
68
69impl JwtConfig {
70 pub fn new(secret: impl Into<String>) -> Self {
72 Self {
73 secret: secret.into(),
74 ..Default::default()
75 }
76 }
77
78 pub fn with_expiration(mut self, seconds: i64) -> Self {
80 self.expiration_seconds = seconds;
81 self
82 }
83
84 pub fn with_refresh_expiration(mut self, seconds: i64) -> Self {
86 self.refresh_expiration_seconds = seconds;
87 self
88 }
89
90 pub fn with_issuer(mut self, issuer: impl Into<String>) -> Self {
92 self.issuer = issuer.into();
93 self
94 }
95
96 pub fn with_audience(mut self, audience: impl Into<String>) -> Self {
98 self.audience = audience.into();
99 self
100 }
101
102 pub fn with_algorithm(mut self, algorithm: Algorithm) -> Self {
104 self.algorithm = algorithm;
105 self
106 }
107
108 pub fn validate(&self) -> Result<(), JwtConfigError> {
110 if self.secret.is_empty() {
111 return Err(JwtConfigError::EmptySecret);
112 }
113
114 if self.secret == "change-me-in-production" {
115 tracing::warn!("Using default JWT secret - change this in production!");
116 }
117
118 if self.expiration_seconds <= 0 {
119 return Err(JwtConfigError::InvalidExpiration);
120 }
121
122 if self.refresh_expiration_seconds <= 0 {
123 return Err(JwtConfigError::InvalidExpiration);
124 }
125
126 if self.issuer.is_empty() {
127 return Err(JwtConfigError::EmptyIssuer);
128 }
129
130 if self.audience.is_empty() {
131 return Err(JwtConfigError::EmptyAudience);
132 }
133
134 Ok(())
135 }
136}
137
138#[derive(Debug, Error)]
140pub enum JwtConfigError {
141 #[error("JWT secret cannot be empty")]
142 EmptySecret,
143
144 #[error("JWT expiration must be positive")]
145 InvalidExpiration,
146
147 #[error("JWT issuer cannot be empty")]
148 EmptyIssuer,
149
150 #[error("JWT audience cannot be empty")]
151 EmptyAudience,
152}
153
154#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct Claims {
157 pub sub: String,
159
160 pub iss: String,
162
163 pub aud: String,
165
166 pub exp: i64,
168
169 pub iat: i64,
171
172 pub nbf: i64,
174
175 pub jti: String,
177
178 #[serde(skip_serializing_if = "Option::is_none")]
180 pub email: Option<String>,
181
182 #[serde(default, skip_serializing_if = "Vec::is_empty")]
184 pub roles: Vec<String>,
185
186 #[serde(flatten)]
188 pub custom: serde_json::Value,
189}
190
191impl Claims {
192 pub fn new(
194 user_id: impl Into<String>,
195 issuer: impl Into<String>,
196 audience: impl Into<String>,
197 expiration_seconds: i64,
198 ) -> Self {
199 let now = Utc::now();
200 let exp = now + Duration::seconds(expiration_seconds);
201
202 Self {
203 sub: user_id.into(),
204 iss: issuer.into(),
205 aud: audience.into(),
206 exp: exp.timestamp(),
207 iat: now.timestamp(),
208 nbf: now.timestamp(),
209 jti: Uuid::new_v4().to_string(),
210 email: None,
211 roles: Vec::new(),
212 custom: serde_json::json!({}),
213 }
214 }
215
216 pub fn with_email(mut self, email: impl Into<String>) -> Self {
218 self.email = Some(email.into());
219 self
220 }
221
222 pub fn with_roles(mut self, roles: Vec<String>) -> Self {
224 self.roles = roles;
225 self
226 }
227
228 pub fn with_role(mut self, role: impl Into<String>) -> Self {
230 self.roles.push(role.into());
231 self
232 }
233
234 pub fn with_custom(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
236 if let Some(obj) = self.custom.as_object_mut() {
237 obj.insert(key.into(), value);
238 }
239 self
240 }
241
242 pub fn is_expired(&self) -> bool {
244 let now = Utc::now().timestamp();
245 self.exp < now
246 }
247
248 pub fn is_not_yet_valid(&self) -> bool {
250 let now = Utc::now().timestamp();
251 self.nbf > now
252 }
253
254 pub fn validate(&self) -> Result<(), TokenError> {
256 if self.is_expired() {
257 return Err(TokenError::Expired);
258 }
259
260 if self.is_not_yet_valid() {
261 return Err(TokenError::NotYetValid);
262 }
263
264 if self.sub.is_empty() {
265 return Err(TokenError::InvalidClaims("Subject cannot be empty".to_string()));
266 }
267
268 Ok(())
269 }
270
271 pub fn has_role(&self, role: &str) -> bool {
273 self.roles.iter().any(|r| r == role)
274 }
275
276 pub fn has_any_role(&self, roles: &[&str]) -> bool {
278 roles.iter().any(|role| self.has_role(role))
279 }
280
281 pub fn has_all_roles(&self, roles: &[&str]) -> bool {
283 roles.iter().all(|role| self.has_role(role))
284 }
285}
286
287impl fmt::Display for Claims {
288 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
289 write!(f, "Claims(sub={}, jti={})", self.sub, self.jti)
290 }
291}
292
293#[derive(Debug, Error)]
295pub enum TokenError {
296 #[error("Token has expired")]
297 Expired,
298
299 #[error("Token is not yet valid")]
300 NotYetValid,
301
302 #[error("Invalid token claims: {0}")]
303 InvalidClaims(String),
304
305 #[error("JWT error: {0}")]
306 JwtError(#[from] JwtError),
307
308 #[error("Invalid token format")]
309 InvalidFormat,
310}
311
312#[derive(Debug, Clone, Serialize, Deserialize)]
314pub struct TokenPair {
315 pub access_token: String,
317
318 pub refresh_token: String,
320
321 pub token_type: String,
323
324 pub expires_in: i64,
326}
327
328impl TokenPair {
329 pub fn new(access_token: String, refresh_token: String, expires_in: i64) -> Self {
331 Self {
332 access_token,
333 refresh_token,
334 token_type: "Bearer".to_string(),
335 expires_in,
336 }
337 }
338}
339
340pub struct JwtManager {
342 pub config: JwtConfig,
343 encoding_key: EncodingKey,
344 decoding_key: DecodingKey,
345 validation: Validation,
346}
347
348impl JwtManager {
349 pub fn new(config: JwtConfig) -> Result<Self, JwtConfigError> {
351 config.validate()?;
352
353 let encoding_key = EncodingKey::from_secret(config.secret.as_bytes());
354 let decoding_key = DecodingKey::from_secret(config.secret.as_bytes());
355
356 let mut validation = Validation::new(config.algorithm);
357 validation.set_issuer(&[&config.issuer]);
358 validation.set_audience(&[&config.audience]);
359 validation.validate_exp = true;
360 validation.validate_nbf = true;
361
362 Ok(Self {
363 config,
364 encoding_key,
365 decoding_key,
366 validation,
367 })
368 }
369
370 pub fn generate_token(&self, user_id: impl Into<String>) -> Result<String, TokenError> {
372 let claims = Claims::new(
373 user_id,
374 &self.config.issuer,
375 &self.config.audience,
376 self.config.expiration_seconds,
377 );
378
379 let header = Header::new(self.config.algorithm);
380 encode(&header, &claims, &self.encoding_key).map_err(TokenError::from)
381 }
382
383 pub fn generate_token_with_claims(&self, claims: Claims) -> Result<String, TokenError> {
385 let header = Header::new(self.config.algorithm);
386 encode(&header, &claims, &self.encoding_key).map_err(TokenError::from)
387 }
388
389 pub fn generate_refresh_token(&self, user_id: impl Into<String>) -> Result<String, TokenError> {
391 let claims = Claims::new(
392 user_id,
393 &self.config.issuer,
394 &self.config.audience,
395 self.config.refresh_expiration_seconds,
396 )
397 .with_role("refresh");
398
399 let header = Header::new(self.config.algorithm);
400 encode(&header, &claims, &self.encoding_key).map_err(TokenError::from)
401 }
402
403 pub fn generate_token_pair(&self, user_id: impl Into<String>) -> Result<TokenPair, TokenError> {
405 let user_id = user_id.into();
406 let access_token = self.generate_token(&user_id)?;
407 let refresh_token = self.generate_refresh_token(&user_id)?;
408
409 Ok(TokenPair::new(
410 access_token,
411 refresh_token,
412 self.config.expiration_seconds,
413 ))
414 }
415
416 pub fn validate_token(&self, token: &str) -> Result<Claims, TokenError> {
418 let token_data = decode::<Claims>(token, &self.decoding_key, &self.validation)?;
419 let claims = token_data.claims;
420 claims.validate()?;
421 Ok(claims)
422 }
423
424 pub fn refresh_access_token(&self, refresh_token: &str) -> Result<TokenPair, TokenError> {
426 let claims = self.validate_token(refresh_token)?;
427
428 if !claims.has_role("refresh") {
430 return Err(TokenError::InvalidClaims(
431 "Not a refresh token".to_string(),
432 ));
433 }
434
435 self.generate_token_pair(&claims.sub)
437 }
438
439 pub fn decode_unverified(&self, token: &str) -> Result<Claims, TokenError> {
441 let token_data = decode::<Claims>(
442 token,
443 &self.decoding_key,
444 &Validation::new(self.config.algorithm),
445 )?;
446 Ok(token_data.claims)
447 }
448
449 pub fn extract_token_from_header(header_value: &str) -> Result<&str, TokenError> {
451 let parts: Vec<&str> = header_value.split_whitespace().collect();
452
453 if parts.len() != 2 {
454 return Err(TokenError::InvalidFormat);
455 }
456
457 if parts[0].to_lowercase() != "bearer" {
458 return Err(TokenError::InvalidFormat);
459 }
460
461 Ok(parts[1])
462 }
463}
464
465impl fmt::Debug for JwtManager {
466 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
467 f.debug_struct("JwtManager")
468 .field("issuer", &self.config.issuer)
469 .field("audience", &self.config.audience)
470 .field("algorithm", &self.config.algorithm)
471 .field("expiration_seconds", &self.config.expiration_seconds)
472 .finish()
473 }
474}
475
476#[cfg(test)]
477mod tests {
478 use super::*;
479
480 fn create_test_config() -> JwtConfig {
481 JwtConfig::new("test-secret-key-for-testing")
482 .with_issuer("test-issuer")
483 .with_audience("test-audience")
484 .with_expiration(3600)
485 }
486
487 #[test]
488 fn test_jwt_config_validation() {
489 let config = create_test_config();
490 assert!(config.validate().is_ok());
491 }
492
493 #[test]
494 fn test_jwt_config_empty_secret() {
495 let config = JwtConfig {
496 secret: String::new(),
497 ..create_test_config()
498 };
499 assert!(matches!(config.validate(), Err(JwtConfigError::EmptySecret)));
500 }
501
502 #[test]
503 fn test_claims_creation() {
504 let claims = Claims::new("user123", "test-issuer", "test-audience", 3600);
505
506 assert_eq!(claims.sub, "user123");
507 assert_eq!(claims.iss, "test-issuer");
508 assert_eq!(claims.aud, "test-audience");
509 assert!(!claims.is_expired());
510 }
511
512 #[test]
513 fn test_claims_with_roles() {
514 let claims = Claims::new("user123", "test", "test", 3600)
515 .with_role("admin")
516 .with_role("user");
517
518 assert!(claims.has_role("admin"));
519 assert!(claims.has_role("user"));
520 assert!(!claims.has_role("moderator"));
521 assert!(claims.has_any_role(&["admin", "moderator"]));
522 assert!(claims.has_all_roles(&["admin", "user"]));
523 assert!(!claims.has_all_roles(&["admin", "moderator"]));
524 }
525
526 #[test]
527 fn test_jwt_manager_creation() {
528 let config = create_test_config();
529 let manager = JwtManager::new(config);
530 assert!(manager.is_ok());
531 }
532
533 #[test]
534 fn test_generate_and_validate_token() {
535 let config = create_test_config();
536 let manager = JwtManager::new(config).unwrap();
537
538 let token = manager.generate_token("user123").unwrap();
539 let claims = manager.validate_token(&token).unwrap();
540
541 assert_eq!(claims.sub, "user123");
542 assert_eq!(claims.iss, "test-issuer");
543 assert_eq!(claims.aud, "test-audience");
544 }
545
546 #[test]
547 fn test_generate_token_pair() {
548 let config = create_test_config();
549 let manager = JwtManager::new(config).unwrap();
550
551 let pair = manager.generate_token_pair("user123").unwrap();
552
553 assert!(!pair.access_token.is_empty());
554 assert!(!pair.refresh_token.is_empty());
555 assert_eq!(pair.token_type, "Bearer");
556 assert_eq!(pair.expires_in, 3600);
557
558 let access_claims = manager.validate_token(&pair.access_token).unwrap();
560 assert_eq!(access_claims.sub, "user123");
561
562 let refresh_claims = manager.validate_token(&pair.refresh_token).unwrap();
564 assert_eq!(refresh_claims.sub, "user123");
565 assert!(refresh_claims.has_role("refresh"));
566 }
567
568 #[test]
569 fn test_refresh_access_token() {
570 let config = create_test_config();
571 let manager = JwtManager::new(config).unwrap();
572
573 let pair = manager.generate_token_pair("user123").unwrap();
574 let new_pair = manager.refresh_access_token(&pair.refresh_token).unwrap();
575
576 assert!(!new_pair.access_token.is_empty());
577 assert_ne!(pair.access_token, new_pair.access_token);
578 }
579
580 #[test]
581 fn test_extract_token_from_header() {
582 let header = "Bearer abc123xyz";
583 let token = JwtManager::extract_token_from_header(header).unwrap();
584 assert_eq!(token, "abc123xyz");
585 }
586
587 #[test]
588 fn test_extract_token_invalid_format() {
589 let header = "InvalidFormat";
590 assert!(JwtManager::extract_token_from_header(header).is_err());
591 }
592
593 #[test]
594 fn test_validate_invalid_token() {
595 let config = create_test_config();
596 let manager = JwtManager::new(config).unwrap();
597
598 let result = manager.validate_token("invalid.token.here");
599 assert!(result.is_err());
600 }
601
602 #[test]
603 fn test_claims_with_email_and_custom() {
604 let claims = Claims::new("user123", "test", "test", 3600)
605 .with_email("user@example.com")
606 .with_custom("org_id", serde_json::json!("org-456"));
607
608 assert_eq!(claims.email, Some("user@example.com".to_string()));
609 assert_eq!(claims.custom["org_id"], "org-456");
610 }
611}