llm_registry_api/
jwt.rs

1//! JWT token management
2//!
3//! This module provides JWT token generation, validation, and refresh functionality
4//! for API authentication.
5
6use chrono::{Duration, Utc};
7use jsonwebtoken::{
8    decode, encode, errors::Error as JwtError, Algorithm, DecodingKey, EncodingKey, Header,
9    Validation,
10};
11use serde::{Deserialize, Serialize};
12use std::fmt;
13use thiserror::Error;
14use uuid::Uuid;
15
16/// JWT configuration
17#[derive(Debug, Clone)]
18pub struct JwtConfig {
19    /// Secret key for signing tokens
20    pub secret: String,
21
22    /// Token expiration in seconds
23    pub expiration_seconds: i64,
24
25    /// Refresh token expiration in seconds
26    pub refresh_expiration_seconds: i64,
27
28    /// Token issuer
29    pub issuer: String,
30
31    /// Token audience
32    pub audience: String,
33
34    /// Algorithm for signing
35    pub algorithm: Algorithm,
36}
37
38// Make config accessible through getters
39impl JwtConfig {
40    /// Get the issuer
41    pub fn issuer(&self) -> &str {
42        &self.issuer
43    }
44
45    /// Get the audience
46    pub fn audience(&self) -> &str {
47        &self.audience
48    }
49
50    /// Get expiration seconds
51    pub fn expiration_seconds(&self) -> i64 {
52        self.expiration_seconds
53    }
54}
55
56impl Default for JwtConfig {
57    fn default() -> Self {
58        Self {
59            secret: "change-me-in-production".to_string(),
60            expiration_seconds: 3600,          // 1 hour
61            refresh_expiration_seconds: 86400 * 7, // 7 days
62            issuer: "llm-registry".to_string(),
63            audience: "llm-registry-api".to_string(),
64            algorithm: Algorithm::HS256,
65        }
66    }
67}
68
69impl JwtConfig {
70    /// Create new JWT configuration
71    pub fn new(secret: impl Into<String>) -> Self {
72        Self {
73            secret: secret.into(),
74            ..Default::default()
75        }
76    }
77
78    /// Set token expiration in seconds
79    pub fn with_expiration(mut self, seconds: i64) -> Self {
80        self.expiration_seconds = seconds;
81        self
82    }
83
84    /// Set refresh token expiration in seconds
85    pub fn with_refresh_expiration(mut self, seconds: i64) -> Self {
86        self.refresh_expiration_seconds = seconds;
87        self
88    }
89
90    /// Set issuer
91    pub fn with_issuer(mut self, issuer: impl Into<String>) -> Self {
92        self.issuer = issuer.into();
93        self
94    }
95
96    /// Set audience
97    pub fn with_audience(mut self, audience: impl Into<String>) -> Self {
98        self.audience = audience.into();
99        self
100    }
101
102    /// Set signing algorithm
103    pub fn with_algorithm(mut self, algorithm: Algorithm) -> Self {
104        self.algorithm = algorithm;
105        self
106    }
107
108    /// Validate configuration
109    pub fn validate(&self) -> Result<(), JwtConfigError> {
110        if self.secret.is_empty() {
111            return Err(JwtConfigError::EmptySecret);
112        }
113
114        if self.secret == "change-me-in-production" {
115            tracing::warn!("Using default JWT secret - change this in production!");
116        }
117
118        if self.expiration_seconds <= 0 {
119            return Err(JwtConfigError::InvalidExpiration);
120        }
121
122        if self.refresh_expiration_seconds <= 0 {
123            return Err(JwtConfigError::InvalidExpiration);
124        }
125
126        if self.issuer.is_empty() {
127            return Err(JwtConfigError::EmptyIssuer);
128        }
129
130        if self.audience.is_empty() {
131            return Err(JwtConfigError::EmptyAudience);
132        }
133
134        Ok(())
135    }
136}
137
138/// JWT configuration errors
139#[derive(Debug, Error)]
140pub enum JwtConfigError {
141    #[error("JWT secret cannot be empty")]
142    EmptySecret,
143
144    #[error("JWT expiration must be positive")]
145    InvalidExpiration,
146
147    #[error("JWT issuer cannot be empty")]
148    EmptyIssuer,
149
150    #[error("JWT audience cannot be empty")]
151    EmptyAudience,
152}
153
154/// JWT claims structure
155#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct Claims {
157    /// Subject (user ID)
158    pub sub: String,
159
160    /// Issuer
161    pub iss: String,
162
163    /// Audience
164    pub aud: String,
165
166    /// Expiration time (Unix timestamp)
167    pub exp: i64,
168
169    /// Issued at (Unix timestamp)
170    pub iat: i64,
171
172    /// Not before (Unix timestamp)
173    pub nbf: i64,
174
175    /// JWT ID (unique token identifier)
176    pub jti: String,
177
178    /// User email
179    #[serde(skip_serializing_if = "Option::is_none")]
180    pub email: Option<String>,
181
182    /// User roles
183    #[serde(default, skip_serializing_if = "Vec::is_empty")]
184    pub roles: Vec<String>,
185
186    /// Custom claims
187    #[serde(flatten)]
188    pub custom: serde_json::Value,
189}
190
191impl Claims {
192    /// Create new claims with default values
193    pub fn new(
194        user_id: impl Into<String>,
195        issuer: impl Into<String>,
196        audience: impl Into<String>,
197        expiration_seconds: i64,
198    ) -> Self {
199        let now = Utc::now();
200        let exp = now + Duration::seconds(expiration_seconds);
201
202        Self {
203            sub: user_id.into(),
204            iss: issuer.into(),
205            aud: audience.into(),
206            exp: exp.timestamp(),
207            iat: now.timestamp(),
208            nbf: now.timestamp(),
209            jti: Uuid::new_v4().to_string(),
210            email: None,
211            roles: Vec::new(),
212            custom: serde_json::json!({}),
213        }
214    }
215
216    /// Add email to claims
217    pub fn with_email(mut self, email: impl Into<String>) -> Self {
218        self.email = Some(email.into());
219        self
220    }
221
222    /// Add roles to claims
223    pub fn with_roles(mut self, roles: Vec<String>) -> Self {
224        self.roles = roles;
225        self
226    }
227
228    /// Add a single role
229    pub fn with_role(mut self, role: impl Into<String>) -> Self {
230        self.roles.push(role.into());
231        self
232    }
233
234    /// Add custom claims
235    pub fn with_custom(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
236        if let Some(obj) = self.custom.as_object_mut() {
237            obj.insert(key.into(), value);
238        }
239        self
240    }
241
242    /// Check if token is expired
243    pub fn is_expired(&self) -> bool {
244        let now = Utc::now().timestamp();
245        self.exp < now
246    }
247
248    /// Check if token is not yet valid
249    pub fn is_not_yet_valid(&self) -> bool {
250        let now = Utc::now().timestamp();
251        self.nbf > now
252    }
253
254    /// Check if claims are valid
255    pub fn validate(&self) -> Result<(), TokenError> {
256        if self.is_expired() {
257            return Err(TokenError::Expired);
258        }
259
260        if self.is_not_yet_valid() {
261            return Err(TokenError::NotYetValid);
262        }
263
264        if self.sub.is_empty() {
265            return Err(TokenError::InvalidClaims("Subject cannot be empty".to_string()));
266        }
267
268        Ok(())
269    }
270
271    /// Check if user has a specific role
272    pub fn has_role(&self, role: &str) -> bool {
273        self.roles.iter().any(|r| r == role)
274    }
275
276    /// Check if user has any of the specified roles
277    pub fn has_any_role(&self, roles: &[&str]) -> bool {
278        roles.iter().any(|role| self.has_role(role))
279    }
280
281    /// Check if user has all of the specified roles
282    pub fn has_all_roles(&self, roles: &[&str]) -> bool {
283        roles.iter().all(|role| self.has_role(role))
284    }
285}
286
287impl fmt::Display for Claims {
288    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
289        write!(f, "Claims(sub={}, jti={})", self.sub, self.jti)
290    }
291}
292
293/// Token errors
294#[derive(Debug, Error)]
295pub enum TokenError {
296    #[error("Token has expired")]
297    Expired,
298
299    #[error("Token is not yet valid")]
300    NotYetValid,
301
302    #[error("Invalid token claims: {0}")]
303    InvalidClaims(String),
304
305    #[error("JWT error: {0}")]
306    JwtError(#[from] JwtError),
307
308    #[error("Invalid token format")]
309    InvalidFormat,
310}
311
312/// JWT token pair (access + refresh)
313#[derive(Debug, Clone, Serialize, Deserialize)]
314pub struct TokenPair {
315    /// Access token
316    pub access_token: String,
317
318    /// Refresh token
319    pub refresh_token: String,
320
321    /// Token type (always "Bearer")
322    pub token_type: String,
323
324    /// Expiration in seconds
325    pub expires_in: i64,
326}
327
328impl TokenPair {
329    /// Create a new token pair
330    pub fn new(access_token: String, refresh_token: String, expires_in: i64) -> Self {
331        Self {
332            access_token,
333            refresh_token,
334            token_type: "Bearer".to_string(),
335            expires_in,
336        }
337    }
338}
339
340/// JWT token manager
341pub struct JwtManager {
342    pub config: JwtConfig,
343    encoding_key: EncodingKey,
344    decoding_key: DecodingKey,
345    validation: Validation,
346}
347
348impl JwtManager {
349    /// Create a new JWT manager
350    pub fn new(config: JwtConfig) -> Result<Self, JwtConfigError> {
351        config.validate()?;
352
353        let encoding_key = EncodingKey::from_secret(config.secret.as_bytes());
354        let decoding_key = DecodingKey::from_secret(config.secret.as_bytes());
355
356        let mut validation = Validation::new(config.algorithm);
357        validation.set_issuer(&[&config.issuer]);
358        validation.set_audience(&[&config.audience]);
359        validation.validate_exp = true;
360        validation.validate_nbf = true;
361
362        Ok(Self {
363            config,
364            encoding_key,
365            decoding_key,
366            validation,
367        })
368    }
369
370    /// Generate a new access token
371    pub fn generate_token(&self, user_id: impl Into<String>) -> Result<String, TokenError> {
372        let claims = Claims::new(
373            user_id,
374            &self.config.issuer,
375            &self.config.audience,
376            self.config.expiration_seconds,
377        );
378
379        let header = Header::new(self.config.algorithm);
380        encode(&header, &claims, &self.encoding_key).map_err(TokenError::from)
381    }
382
383    /// Generate a new access token with custom claims
384    pub fn generate_token_with_claims(&self, claims: Claims) -> Result<String, TokenError> {
385        let header = Header::new(self.config.algorithm);
386        encode(&header, &claims, &self.encoding_key).map_err(TokenError::from)
387    }
388
389    /// Generate a new refresh token
390    pub fn generate_refresh_token(&self, user_id: impl Into<String>) -> Result<String, TokenError> {
391        let claims = Claims::new(
392            user_id,
393            &self.config.issuer,
394            &self.config.audience,
395            self.config.refresh_expiration_seconds,
396        )
397        .with_role("refresh");
398
399        let header = Header::new(self.config.algorithm);
400        encode(&header, &claims, &self.encoding_key).map_err(TokenError::from)
401    }
402
403    /// Generate a token pair (access + refresh)
404    pub fn generate_token_pair(&self, user_id: impl Into<String>) -> Result<TokenPair, TokenError> {
405        let user_id = user_id.into();
406        let access_token = self.generate_token(&user_id)?;
407        let refresh_token = self.generate_refresh_token(&user_id)?;
408
409        Ok(TokenPair::new(
410            access_token,
411            refresh_token,
412            self.config.expiration_seconds,
413        ))
414    }
415
416    /// Validate and decode a token
417    pub fn validate_token(&self, token: &str) -> Result<Claims, TokenError> {
418        let token_data = decode::<Claims>(token, &self.decoding_key, &self.validation)?;
419        let claims = token_data.claims;
420        claims.validate()?;
421        Ok(claims)
422    }
423
424    /// Refresh an access token using a refresh token
425    pub fn refresh_access_token(&self, refresh_token: &str) -> Result<TokenPair, TokenError> {
426        let claims = self.validate_token(refresh_token)?;
427
428        // Verify it's a refresh token
429        if !claims.has_role("refresh") {
430            return Err(TokenError::InvalidClaims(
431                "Not a refresh token".to_string(),
432            ));
433        }
434
435        // Generate new token pair
436        self.generate_token_pair(&claims.sub)
437    }
438
439    /// Decode token without validation (use with caution)
440    pub fn decode_unverified(&self, token: &str) -> Result<Claims, TokenError> {
441        let token_data = decode::<Claims>(
442            token,
443            &self.decoding_key,
444            &Validation::new(self.config.algorithm),
445        )?;
446        Ok(token_data.claims)
447    }
448
449    /// Extract token from Authorization header value
450    pub fn extract_token_from_header(header_value: &str) -> Result<&str, TokenError> {
451        let parts: Vec<&str> = header_value.split_whitespace().collect();
452
453        if parts.len() != 2 {
454            return Err(TokenError::InvalidFormat);
455        }
456
457        if parts[0].to_lowercase() != "bearer" {
458            return Err(TokenError::InvalidFormat);
459        }
460
461        Ok(parts[1])
462    }
463}
464
465impl fmt::Debug for JwtManager {
466    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
467        f.debug_struct("JwtManager")
468            .field("issuer", &self.config.issuer)
469            .field("audience", &self.config.audience)
470            .field("algorithm", &self.config.algorithm)
471            .field("expiration_seconds", &self.config.expiration_seconds)
472            .finish()
473    }
474}
475
476#[cfg(test)]
477mod tests {
478    use super::*;
479
480    fn create_test_config() -> JwtConfig {
481        JwtConfig::new("test-secret-key-for-testing")
482            .with_issuer("test-issuer")
483            .with_audience("test-audience")
484            .with_expiration(3600)
485    }
486
487    #[test]
488    fn test_jwt_config_validation() {
489        let config = create_test_config();
490        assert!(config.validate().is_ok());
491    }
492
493    #[test]
494    fn test_jwt_config_empty_secret() {
495        let config = JwtConfig {
496            secret: String::new(),
497            ..create_test_config()
498        };
499        assert!(matches!(config.validate(), Err(JwtConfigError::EmptySecret)));
500    }
501
502    #[test]
503    fn test_claims_creation() {
504        let claims = Claims::new("user123", "test-issuer", "test-audience", 3600);
505
506        assert_eq!(claims.sub, "user123");
507        assert_eq!(claims.iss, "test-issuer");
508        assert_eq!(claims.aud, "test-audience");
509        assert!(!claims.is_expired());
510    }
511
512    #[test]
513    fn test_claims_with_roles() {
514        let claims = Claims::new("user123", "test", "test", 3600)
515            .with_role("admin")
516            .with_role("user");
517
518        assert!(claims.has_role("admin"));
519        assert!(claims.has_role("user"));
520        assert!(!claims.has_role("moderator"));
521        assert!(claims.has_any_role(&["admin", "moderator"]));
522        assert!(claims.has_all_roles(&["admin", "user"]));
523        assert!(!claims.has_all_roles(&["admin", "moderator"]));
524    }
525
526    #[test]
527    fn test_jwt_manager_creation() {
528        let config = create_test_config();
529        let manager = JwtManager::new(config);
530        assert!(manager.is_ok());
531    }
532
533    #[test]
534    fn test_generate_and_validate_token() {
535        let config = create_test_config();
536        let manager = JwtManager::new(config).unwrap();
537
538        let token = manager.generate_token("user123").unwrap();
539        let claims = manager.validate_token(&token).unwrap();
540
541        assert_eq!(claims.sub, "user123");
542        assert_eq!(claims.iss, "test-issuer");
543        assert_eq!(claims.aud, "test-audience");
544    }
545
546    #[test]
547    fn test_generate_token_pair() {
548        let config = create_test_config();
549        let manager = JwtManager::new(config).unwrap();
550
551        let pair = manager.generate_token_pair("user123").unwrap();
552
553        assert!(!pair.access_token.is_empty());
554        assert!(!pair.refresh_token.is_empty());
555        assert_eq!(pair.token_type, "Bearer");
556        assert_eq!(pair.expires_in, 3600);
557
558        // Validate access token
559        let access_claims = manager.validate_token(&pair.access_token).unwrap();
560        assert_eq!(access_claims.sub, "user123");
561
562        // Validate refresh token
563        let refresh_claims = manager.validate_token(&pair.refresh_token).unwrap();
564        assert_eq!(refresh_claims.sub, "user123");
565        assert!(refresh_claims.has_role("refresh"));
566    }
567
568    #[test]
569    fn test_refresh_access_token() {
570        let config = create_test_config();
571        let manager = JwtManager::new(config).unwrap();
572
573        let pair = manager.generate_token_pair("user123").unwrap();
574        let new_pair = manager.refresh_access_token(&pair.refresh_token).unwrap();
575
576        assert!(!new_pair.access_token.is_empty());
577        assert_ne!(pair.access_token, new_pair.access_token);
578    }
579
580    #[test]
581    fn test_extract_token_from_header() {
582        let header = "Bearer abc123xyz";
583        let token = JwtManager::extract_token_from_header(header).unwrap();
584        assert_eq!(token, "abc123xyz");
585    }
586
587    #[test]
588    fn test_extract_token_invalid_format() {
589        let header = "InvalidFormat";
590        assert!(JwtManager::extract_token_from_header(header).is_err());
591    }
592
593    #[test]
594    fn test_validate_invalid_token() {
595        let config = create_test_config();
596        let manager = JwtManager::new(config).unwrap();
597
598        let result = manager.validate_token("invalid.token.here");
599        assert!(result.is_err());
600    }
601
602    #[test]
603    fn test_claims_with_email_and_custom() {
604        let claims = Claims::new("user123", "test", "test", 3600)
605            .with_email("user@example.com")
606            .with_custom("org_id", serde_json::json!("org-456"));
607
608        assert_eq!(claims.email, Some("user@example.com".to_string()));
609        assert_eq!(claims.custom["org_id"], "org-456");
610    }
611}