Skip to main content

fraiseql_server/auth/
session.rs

1// Session management - trait definition and implementations
2#[cfg(test)]
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use rand::Rng;
7use serde::{Deserialize, Serialize};
8use sha2::{Digest, Sha256};
9
10#[cfg(test)]
11use crate::auth::error::AuthError;
12use crate::auth::error::Result;
13
14/// Session data stored in the backend
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct SessionData {
17    /// User ID (unique per user)
18    pub user_id:            String,
19    /// Session issued timestamp (Unix seconds)
20    pub issued_at:          u64,
21    /// Session expiration timestamp (Unix seconds)
22    pub expires_at:         u64,
23    /// Hash of the refresh token (stored securely)
24    pub refresh_token_hash: String,
25}
26
27impl SessionData {
28    /// Check if session is expired
29    pub fn is_expired(&self) -> bool {
30        let now = std::time::SystemTime::now()
31            .duration_since(std::time::UNIX_EPOCH)
32            .unwrap_or_default()
33            .as_secs();
34        self.expires_at <= now
35    }
36}
37
38/// Token pair returned after successful authentication
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct TokenPair {
41    /// JWT access token (short-lived, typically 15 min - 1 hour)
42    pub access_token:  String,
43    /// Refresh token (long-lived, typically 7-30 days)
44    pub refresh_token: String,
45    /// Time in seconds until access token expires
46    pub expires_in:    u64,
47}
48
49/// SessionStore trait - implement this for your storage backend
50///
51/// # Examples
52///
53/// Implement for PostgreSQL:
54/// ```ignore
55/// pub struct PostgresSessionStore {
56///     pool: PgPool,
57/// }
58///
59/// #[async_trait]
60/// impl SessionStore for PostgresSessionStore {
61///     async fn create_session(...) -> Result<TokenPair> { ... }
62///     // ... other methods
63/// }
64/// ```
65///
66/// Implement for Redis:
67/// ```ignore
68/// pub struct RedisSessionStore {
69///     client: redis::Client,
70/// }
71///
72/// #[async_trait]
73/// impl SessionStore for RedisSessionStore {
74///     async fn create_session(...) -> Result<TokenPair> { ... }
75///     // ... other methods
76/// }
77/// ```
78#[async_trait]
79pub trait SessionStore: Send + Sync {
80    /// Create a new session and return token pair
81    ///
82    /// # Arguments
83    /// * `user_id` - The user identifier
84    /// * `expires_at` - When the session should expire (Unix seconds)
85    ///
86    /// # Returns
87    /// TokenPair with access_token and refresh_token
88    ///
89    /// # Errors
90    /// Returns error if session creation fails
91    async fn create_session(&self, user_id: &str, expires_at: u64) -> Result<TokenPair>;
92
93    /// Get session data by refresh token hash
94    ///
95    /// # Arguments
96    /// * `refresh_token_hash` - Hash of the refresh token
97    ///
98    /// # Returns
99    /// SessionData if session exists and is not revoked
100    ///
101    /// # Errors
102    /// Returns SessionError if session not found or revoked
103    async fn get_session(&self, refresh_token_hash: &str) -> Result<SessionData>;
104
105    /// Revoke a single session
106    ///
107    /// # Arguments
108    /// * `refresh_token_hash` - Hash of the refresh token to revoke
109    ///
110    /// # Errors
111    /// Returns error if revocation fails
112    async fn revoke_session(&self, refresh_token_hash: &str) -> Result<()>;
113
114    /// Revoke all sessions for a user
115    ///
116    /// # Arguments
117    /// * `user_id` - The user identifier
118    ///
119    /// # Errors
120    /// Returns error if revocation fails
121    async fn revoke_all_sessions(&self, user_id: &str) -> Result<()>;
122}
123
124/// Hash a refresh token for secure storage
125pub fn hash_token(token: &str) -> String {
126    let mut hasher = Sha256::new();
127    hasher.update(token.as_bytes());
128    format!("{:x}", hasher.finalize())
129}
130
131/// Generate a cryptographically secure refresh token
132pub fn generate_refresh_token() -> String {
133    use base64::Engine;
134    let mut rng = rand::thread_rng();
135    let random_bytes: Vec<u8> = (0..32).map(|_| rng.gen()).collect();
136    base64::engine::general_purpose::STANDARD.encode(&random_bytes)
137}
138
139/// In-memory session store for testing
140#[cfg(test)]
141pub struct InMemorySessionStore {
142    sessions: Arc<dashmap::DashMap<String, SessionData>>,
143}
144
145#[cfg(test)]
146impl InMemorySessionStore {
147    /// Create a new in-memory session store
148    pub fn new() -> Self {
149        Self {
150            sessions: Arc::new(dashmap::DashMap::new()),
151        }
152    }
153
154    /// Clear all sessions (useful for tests)
155    pub fn clear(&self) {
156        self.sessions.clear();
157    }
158
159    /// Get number of sessions (useful for tests)
160    pub fn len(&self) -> usize {
161        self.sessions.len()
162    }
163
164    /// Check if there are no sessions
165    pub fn is_empty(&self) -> bool {
166        self.sessions.is_empty()
167    }
168}
169
170#[cfg(test)]
171impl Default for InMemorySessionStore {
172    fn default() -> Self {
173        Self::new()
174    }
175}
176
177#[cfg(test)]
178#[async_trait]
179impl SessionStore for InMemorySessionStore {
180    async fn create_session(&self, user_id: &str, expires_at: u64) -> Result<TokenPair> {
181        let refresh_token = generate_refresh_token();
182        let refresh_token_hash = hash_token(&refresh_token);
183
184        let now = std::time::SystemTime::now()
185            .duration_since(std::time::UNIX_EPOCH)
186            .unwrap_or_default()
187            .as_secs();
188
189        let session = SessionData {
190            user_id: user_id.to_string(),
191            issued_at: now,
192            expires_at,
193            refresh_token_hash: refresh_token_hash.clone(),
194        };
195
196        self.sessions.insert(refresh_token_hash, session);
197
198        let expires_in = expires_at.saturating_sub(now);
199
200        // For testing, generate a dummy JWT (in real impl, would come from claims)
201        let access_token = format!("access_token_{}", refresh_token);
202
203        Ok(TokenPair {
204            access_token,
205            refresh_token,
206            expires_in,
207        })
208    }
209
210    async fn get_session(&self, refresh_token_hash: &str) -> Result<SessionData> {
211        self.sessions
212            .get(refresh_token_hash)
213            .map(|entry| entry.clone())
214            .ok_or(AuthError::TokenNotFound)
215    }
216
217    async fn revoke_session(&self, refresh_token_hash: &str) -> Result<()> {
218        self.sessions.remove(refresh_token_hash).ok_or(AuthError::SessionError {
219            message: "Session not found".to_string(),
220        })?;
221        Ok(())
222    }
223
224    async fn revoke_all_sessions(&self, user_id: &str) -> Result<()> {
225        let mut to_remove = Vec::new();
226        for entry in self.sessions.iter() {
227            if entry.user_id == user_id {
228                to_remove.push(entry.key().clone());
229            }
230        }
231
232        for key in to_remove {
233            self.sessions.remove(&key);
234        }
235
236        Ok(())
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243
244    #[test]
245    fn test_hash_token() {
246        let token = "my_secret_token";
247        let hash1 = hash_token(token);
248        let hash2 = hash_token(token);
249
250        // Same token should produce same hash
251        assert_eq!(hash1, hash2);
252
253        // Different token should produce different hash
254        let different_hash = hash_token("different_token");
255        assert_ne!(hash1, different_hash);
256    }
257
258    #[test]
259    fn test_generate_refresh_token() {
260        let token1 = generate_refresh_token();
261        let token2 = generate_refresh_token();
262
263        // Tokens should be random and different
264        assert_ne!(token1, token2);
265        // Should be non-empty
266        assert!(!token1.is_empty());
267        assert!(!token2.is_empty());
268    }
269
270    #[test]
271    fn test_session_data_not_expired() {
272        let now = std::time::SystemTime::now()
273            .duration_since(std::time::UNIX_EPOCH)
274            .unwrap_or_default()
275            .as_secs();
276
277        let session = SessionData {
278            user_id:            "user123".to_string(),
279            issued_at:          now,
280            expires_at:         now + 3600,
281            refresh_token_hash: "hash".to_string(),
282        };
283
284        assert!(!session.is_expired());
285    }
286
287    #[test]
288    fn test_session_data_expired() {
289        let now = std::time::SystemTime::now()
290            .duration_since(std::time::UNIX_EPOCH)
291            .unwrap_or_default()
292            .as_secs();
293
294        let session = SessionData {
295            user_id:            "user123".to_string(),
296            issued_at:          now - 3600,
297            expires_at:         now - 100,
298            refresh_token_hash: "hash".to_string(),
299        };
300
301        assert!(session.is_expired());
302    }
303
304    #[tokio::test]
305    async fn test_in_memory_store_create_session() {
306        let store = InMemorySessionStore::new();
307        let now = std::time::SystemTime::now()
308            .duration_since(std::time::UNIX_EPOCH)
309            .unwrap_or_default()
310            .as_secs();
311
312        let result = store.create_session("user123", now + 3600).await;
313        assert!(result.is_ok());
314
315        let tokens = result.unwrap();
316        assert!(!tokens.access_token.is_empty());
317        assert!(!tokens.refresh_token.is_empty());
318        assert!(tokens.expires_in > 0);
319    }
320
321    #[tokio::test]
322    async fn test_in_memory_store_get_session() {
323        let store = InMemorySessionStore::new();
324        let now = std::time::SystemTime::now()
325            .duration_since(std::time::UNIX_EPOCH)
326            .unwrap_or_default()
327            .as_secs();
328
329        let tokens = store.create_session("user123", now + 3600).await.unwrap();
330        let refresh_token_hash = hash_token(&tokens.refresh_token);
331
332        let session = store.get_session(&refresh_token_hash).await;
333        assert!(session.is_ok());
334        assert_eq!(session.unwrap().user_id, "user123");
335    }
336
337    #[tokio::test]
338    async fn test_in_memory_store_revoke_session() {
339        let store = InMemorySessionStore::new();
340        let now = std::time::SystemTime::now()
341            .duration_since(std::time::UNIX_EPOCH)
342            .unwrap_or_default()
343            .as_secs();
344
345        let tokens = store.create_session("user123", now + 3600).await.unwrap();
346        let refresh_token_hash = hash_token(&tokens.refresh_token);
347
348        assert!(store.revoke_session(&refresh_token_hash).await.is_ok());
349
350        let session = store.get_session(&refresh_token_hash).await;
351        assert!(session.is_err());
352    }
353
354    #[tokio::test]
355    async fn test_in_memory_store_revoke_all_sessions() {
356        let store = InMemorySessionStore::new();
357        let now = std::time::SystemTime::now()
358            .duration_since(std::time::UNIX_EPOCH)
359            .unwrap_or_default()
360            .as_secs();
361
362        // Create multiple sessions for same user
363        let tokens1 = store.create_session("user123", now + 3600).await.unwrap();
364        let tokens2 = store.create_session("user123", now + 3600).await.unwrap();
365
366        // Create session for different user
367        let tokens3 = store.create_session("user456", now + 3600).await.unwrap();
368
369        assert_eq!(store.len(), 3);
370
371        // Revoke all for user123
372        assert!(store.revoke_all_sessions("user123").await.is_ok());
373
374        // user456 session should still exist
375        let hash3 = hash_token(&tokens3.refresh_token);
376        assert!(store.get_session(&hash3).await.is_ok());
377
378        // user123 sessions should be gone
379        let hash1 = hash_token(&tokens1.refresh_token);
380        let hash2 = hash_token(&tokens2.refresh_token);
381        assert!(store.get_session(&hash1).await.is_err());
382        assert!(store.get_session(&hash2).await.is_err());
383    }
384}