Skip to main content

openclaw_gateway/auth/
jwt.rs

1//! JWT token management.
2
3use std::time::Duration;
4
5use chrono::{DateTime, Utc};
6use jsonwebtoken::{DecodingKey, EncodingKey, Header, TokenData, Validation, decode, encode};
7use rand::RngCore;
8use serde::{Deserialize, Serialize};
9
10use super::AuthError;
11use super::users::UserRole;
12
13/// JWT claims.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct Claims {
16    /// Subject (user ID).
17    pub sub: String,
18    /// Username.
19    pub username: String,
20    /// User role.
21    pub role: UserRole,
22    /// Issued at (Unix timestamp).
23    pub iat: i64,
24    /// Expiration (Unix timestamp).
25    pub exp: i64,
26    /// Token type (access or refresh).
27    #[serde(default)]
28    pub token_type: TokenType,
29    /// Token family ID (for refresh token rotation).
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub family_id: Option<String>,
32}
33
34/// Token type.
35#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
36#[serde(rename_all = "lowercase")]
37pub enum TokenType {
38    /// Access token for API calls.
39    #[default]
40    Access,
41    /// Refresh token for getting new access tokens.
42    Refresh,
43}
44
45/// A pair of access and refresh tokens.
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct TokenPair {
48    /// Access token.
49    pub access_token: String,
50    /// Refresh token.
51    pub refresh_token: String,
52    /// Access token expiration.
53    pub expires_at: DateTime<Utc>,
54    /// Refresh token expiration.
55    pub refresh_expires_at: DateTime<Utc>,
56    /// Token type (always "Bearer").
57    pub token_type: String,
58}
59
60/// JWT manager for creating and validating tokens.
61pub struct JwtManager {
62    encoding_key: EncodingKey,
63    decoding_key: DecodingKey,
64    access_expiry: Duration,
65    refresh_expiry: Duration,
66}
67
68impl JwtManager {
69    /// Create a new JWT manager with a secret key.
70    ///
71    /// The secret should be at least 32 bytes for security.
72    #[must_use]
73    pub fn new(secret: &[u8], access_expiry: Duration, refresh_expiry: Duration) -> Self {
74        Self {
75            encoding_key: EncodingKey::from_secret(secret),
76            decoding_key: DecodingKey::from_secret(secret),
77            access_expiry,
78            refresh_expiry,
79        }
80    }
81
82    /// Create a JWT manager from a hex-encoded secret.
83    ///
84    /// # Errors
85    ///
86    /// Returns error if hex decoding fails.
87    pub fn from_hex_secret(
88        hex_secret: &str,
89        access_expiry: Duration,
90        refresh_expiry: Duration,
91    ) -> Result<Self, AuthError> {
92        let secret = hex::decode(hex_secret)
93            .map_err(|e| AuthError::Config(format!("Invalid hex secret: {e}")))?;
94        Ok(Self::new(&secret, access_expiry, refresh_expiry))
95    }
96
97    /// Generate a random 256-bit secret key.
98    #[must_use]
99    pub fn generate_secret() -> [u8; 32] {
100        let mut bytes = [0u8; 32];
101        rand::thread_rng().fill_bytes(&mut bytes);
102        bytes
103    }
104
105    /// Generate a random secret as hex string.
106    #[must_use]
107    pub fn generate_hex_secret() -> String {
108        hex::encode(Self::generate_secret())
109    }
110
111    /// Create an access token for a user.
112    ///
113    /// # Errors
114    ///
115    /// Returns error if token encoding fails.
116    pub fn create_access_token(
117        &self,
118        user_id: &str,
119        username: &str,
120        role: UserRole,
121    ) -> Result<(String, DateTime<Utc>), AuthError> {
122        let now = Utc::now();
123        let exp = now + chrono::Duration::from_std(self.access_expiry).unwrap_or_default();
124
125        let claims = Claims {
126            sub: user_id.to_string(),
127            username: username.to_string(),
128            role,
129            iat: now.timestamp(),
130            exp: exp.timestamp(),
131            token_type: TokenType::Access,
132            family_id: None,
133        };
134
135        let token = encode(&Header::default(), &claims, &self.encoding_key)
136            .map_err(|e| AuthError::TokenError(format!("Encoding failed: {e}")))?;
137
138        Ok((token, exp))
139    }
140
141    /// Create a refresh token for a user.
142    ///
143    /// # Errors
144    ///
145    /// Returns error if token encoding fails.
146    pub fn create_refresh_token(
147        &self,
148        user_id: &str,
149        username: &str,
150        role: UserRole,
151        family_id: Option<String>,
152    ) -> Result<(String, DateTime<Utc>), AuthError> {
153        let now = Utc::now();
154        let exp = now + chrono::Duration::from_std(self.refresh_expiry).unwrap_or_default();
155
156        // Generate new family ID if not provided (new login)
157        let family_id = family_id.unwrap_or_else(|| {
158            let mut bytes = [0u8; 16];
159            rand::thread_rng().fill_bytes(&mut bytes);
160            hex::encode(bytes)
161        });
162
163        let claims = Claims {
164            sub: user_id.to_string(),
165            username: username.to_string(),
166            role,
167            iat: now.timestamp(),
168            exp: exp.timestamp(),
169            token_type: TokenType::Refresh,
170            family_id: Some(family_id),
171        };
172
173        let token = encode(&Header::default(), &claims, &self.encoding_key)
174            .map_err(|e| AuthError::TokenError(format!("Encoding failed: {e}")))?;
175
176        Ok((token, exp))
177    }
178
179    /// Create a token pair (access + refresh) for a user.
180    ///
181    /// # Errors
182    ///
183    /// Returns error if token creation fails.
184    pub fn create_token_pair(
185        &self,
186        user_id: &str,
187        username: &str,
188        role: UserRole,
189    ) -> Result<TokenPair, AuthError> {
190        let (access_token, expires_at) = self.create_access_token(user_id, username, role)?;
191        let (refresh_token, refresh_expires_at) =
192            self.create_refresh_token(user_id, username, role, None)?;
193
194        Ok(TokenPair {
195            access_token,
196            refresh_token,
197            expires_at,
198            refresh_expires_at,
199            token_type: "Bearer".to_string(),
200        })
201    }
202
203    /// Validate and decode a token.
204    ///
205    /// # Errors
206    ///
207    /// Returns error if token is invalid or expired.
208    pub fn validate_token(&self, token: &str) -> Result<Claims, AuthError> {
209        let validation = Validation::default();
210
211        let token_data: TokenData<Claims> = decode(token, &self.decoding_key, &validation)
212            .map_err(|e| AuthError::TokenError(format!("Validation failed: {e}")))?;
213
214        Ok(token_data.claims)
215    }
216
217    /// Validate an access token (must be access type).
218    ///
219    /// # Errors
220    ///
221    /// Returns error if token is invalid, expired, or not an access token.
222    pub fn validate_access_token(&self, token: &str) -> Result<Claims, AuthError> {
223        let claims = self.validate_token(token)?;
224
225        if claims.token_type != TokenType::Access {
226            return Err(AuthError::TokenError("Not an access token".to_string()));
227        }
228
229        Ok(claims)
230    }
231
232    /// Validate a refresh token and optionally create new tokens.
233    ///
234    /// # Errors
235    ///
236    /// Returns error if token is invalid, expired, or not a refresh token.
237    pub fn refresh_tokens(&self, refresh_token: &str) -> Result<TokenPair, AuthError> {
238        let claims = self.validate_token(refresh_token)?;
239
240        if claims.token_type != TokenType::Refresh {
241            return Err(AuthError::TokenError("Not a refresh token".to_string()));
242        }
243
244        // Create new tokens with the same family ID (for rotation tracking)
245        let (access_token, expires_at) =
246            self.create_access_token(&claims.sub, &claims.username, claims.role)?;
247        let (new_refresh_token, refresh_expires_at) = self.create_refresh_token(
248            &claims.sub,
249            &claims.username,
250            claims.role,
251            claims.family_id,
252        )?;
253
254        Ok(TokenPair {
255            access_token,
256            refresh_token: new_refresh_token,
257            expires_at,
258            refresh_expires_at,
259            token_type: "Bearer".to_string(),
260        })
261    }
262
263    /// Extract token from Authorization header.
264    ///
265    /// Expects format: "Bearer <token>"
266    #[must_use]
267    pub fn extract_from_header(header: &str) -> Option<&str> {
268        header
269            .strip_prefix("Bearer ")
270            .or_else(|| header.strip_prefix("bearer "))
271    }
272}
273
274impl std::fmt::Debug for JwtManager {
275    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
276        f.debug_struct("JwtManager")
277            .field("access_expiry", &self.access_expiry)
278            .field("refresh_expiry", &self.refresh_expiry)
279            .finish_non_exhaustive()
280    }
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286
287    fn create_manager() -> JwtManager {
288        let secret = JwtManager::generate_secret();
289        JwtManager::new(
290            &secret,
291            Duration::from_secs(3600),      // 1 hour
292            Duration::from_secs(7 * 86400), // 7 days
293        )
294    }
295
296    #[test]
297    fn test_generate_secret() {
298        let secret1 = JwtManager::generate_secret();
299        let secret2 = JwtManager::generate_secret();
300        assert_ne!(secret1, secret2);
301        assert_eq!(secret1.len(), 32);
302    }
303
304    #[test]
305    fn test_create_access_token() {
306        let manager = create_manager();
307        let (token, expires) = manager
308            .create_access_token("user_123", "testuser", UserRole::Admin)
309            .unwrap();
310
311        assert!(!token.is_empty());
312        assert!(expires > Utc::now());
313    }
314
315    #[test]
316    fn test_validate_token() {
317        let manager = create_manager();
318        let (token, _) = manager
319            .create_access_token("user_123", "testuser", UserRole::Operator)
320            .unwrap();
321
322        let claims = manager.validate_token(&token).unwrap();
323        assert_eq!(claims.sub, "user_123");
324        assert_eq!(claims.username, "testuser");
325        assert_eq!(claims.role, UserRole::Operator);
326        assert_eq!(claims.token_type, TokenType::Access);
327    }
328
329    #[test]
330    fn test_token_pair() {
331        let manager = create_manager();
332        let pair = manager
333            .create_token_pair("user_123", "admin", UserRole::Admin)
334            .unwrap();
335
336        assert!(!pair.access_token.is_empty());
337        assert!(!pair.refresh_token.is_empty());
338        assert_eq!(pair.token_type, "Bearer");
339
340        // Validate access token
341        let access_claims = manager.validate_access_token(&pair.access_token).unwrap();
342        assert_eq!(access_claims.token_type, TokenType::Access);
343
344        // Validate refresh token
345        let refresh_claims = manager.validate_token(&pair.refresh_token).unwrap();
346        assert_eq!(refresh_claims.token_type, TokenType::Refresh);
347    }
348
349    #[test]
350    fn test_refresh_tokens() {
351        let manager = create_manager();
352        let pair = manager
353            .create_token_pair("user_123", "admin", UserRole::Admin)
354            .unwrap();
355
356        // Refresh tokens should produce valid new tokens
357        let new_pair = manager.refresh_tokens(&pair.refresh_token).unwrap();
358
359        // Verify new tokens are valid
360        let access_claims = manager
361            .validate_access_token(&new_pair.access_token)
362            .unwrap();
363        assert_eq!(access_claims.sub, "user_123");
364        assert_eq!(access_claims.username, "admin");
365        assert_eq!(access_claims.role, UserRole::Admin);
366
367        let refresh_claims = manager.validate_token(&new_pair.refresh_token).unwrap();
368        assert_eq!(refresh_claims.token_type, TokenType::Refresh);
369
370        // The new refresh token should also be valid for refreshing
371        let third_pair = manager.refresh_tokens(&new_pair.refresh_token).unwrap();
372        assert!(
373            manager
374                .validate_access_token(&third_pair.access_token)
375                .is_ok()
376        );
377    }
378
379    #[test]
380    fn test_invalid_token() {
381        let manager = create_manager();
382        let result = manager.validate_token("invalid.token.here");
383        assert!(result.is_err());
384    }
385
386    #[test]
387    fn test_extract_from_header() {
388        assert_eq!(
389            JwtManager::extract_from_header("Bearer abc123"),
390            Some("abc123")
391        );
392        assert_eq!(
393            JwtManager::extract_from_header("bearer abc123"),
394            Some("abc123")
395        );
396        assert_eq!(JwtManager::extract_from_header("abc123"), None);
397    }
398
399    #[test]
400    fn test_hex_secret() {
401        let hex_secret = JwtManager::generate_hex_secret();
402        assert_eq!(hex_secret.len(), 64); // 32 bytes = 64 hex chars
403
404        let manager = JwtManager::from_hex_secret(
405            &hex_secret,
406            Duration::from_secs(3600),
407            Duration::from_secs(86400),
408        )
409        .unwrap();
410
411        let (token, _) = manager
412            .create_access_token("user_123", "test", UserRole::Viewer)
413            .unwrap();
414        assert!(manager.validate_token(&token).is_ok());
415    }
416}