Skip to main content

fraiseql_auth/
state_store.rs

1//! CSRF state store — trait definition and backends.
2//!
3//! Stores OAuth `state` parameters for the duration of an authorization flow and
4//! removes them on first retrieval, preventing state-replay attacks.
5
6use std::sync::Arc;
7
8use async_trait::async_trait;
9use dashmap::DashMap;
10
11use crate::error::Result;
12
13/// StateStore trait - implement this for different storage backends
14///
15/// Stores OAuth state parameters with expiration for CSRF protection.
16/// In distributed deployments, use a persistent backend (Redis) instead of in-memory.
17///
18/// # Examples
19///
20/// Use in-memory store for single-instance deployments:
21/// ```rust
22/// use std::sync::Arc;
23/// use fraiseql_auth::state_store::InMemoryStateStore;
24/// let state_store = Arc::new(InMemoryStateStore::new());
25/// ```
26///
27/// Use Redis for distributed deployments:
28/// ```no_run
29/// // Requires: live Redis server.
30/// use std::sync::Arc;
31/// # async fn example() -> fraiseql_auth::error::Result<()> {
32/// # #[cfg(feature = "redis-rate-limiting")] {
33/// use fraiseql_auth::state_store::RedisStateStore;
34/// let state_store = Arc::new(RedisStateStore::new("redis://localhost:6379").await?);
35/// # }
36/// # Ok(())
37/// # }
38/// ```
39// Reason: used as dyn Trait (Arc<dyn StateStore>); async_trait ensures Send bounds and
40// dyn-compatibility async_trait: dyn-dispatch required; remove when RTN + Send is stable (RFC 3425)
41#[async_trait]
42pub trait StateStore: Send + Sync {
43    /// Store a state value with provider and expiration
44    ///
45    /// # Arguments
46    /// * `state` - The state parameter value
47    /// * `provider` - OAuth provider name
48    /// * `expiry_secs` - Unix timestamp when this state expires
49    async fn store(&self, state: String, provider: String, expiry_secs: u64) -> Result<()>;
50
51    /// Retrieve and remove a state value
52    ///
53    /// Returns (provider, expiry_secs) if state exists and is valid
54    /// Returns error if state doesn't exist or is invalid
55    async fn retrieve(&self, state: &str) -> Result<(String, u64)>;
56}
57
58/// In-memory state store using DashMap
59///
60/// **Warning**: Only suitable for single-instance deployments!
61/// For distributed systems, use RedisStateStore instead.
62///
63/// # SECURITY
64/// - Bounded to MAX_STATES entries to prevent unbounded memory growth
65/// - Expired states are automatically cleaned up on store operations
66/// - Implements LRU-like eviction when max capacity is reached
67#[derive(Debug)]
68pub struct InMemoryStateStore {
69    // Map of state -> (provider, expiry_secs)
70    states:     Arc<DashMap<String, (String, u64)>>,
71    // Maximum number of states to store (prevents memory exhaustion)
72    max_states: usize,
73}
74
75impl InMemoryStateStore {
76    /// Default maximum number of states to store (10,000 states)
77    /// At ~100 bytes per state, this limits memory to ~1 MB
78    const MAX_STATES: usize = 10_000;
79
80    /// Create a new in-memory state store with default limits
81    pub fn new() -> Self {
82        Self {
83            states:     Arc::new(DashMap::new()),
84            max_states: Self::MAX_STATES,
85        }
86    }
87
88    /// Create a new in-memory state store with custom max size
89    ///
90    /// # Arguments
91    /// * `max_states` - Maximum number of states to store
92    pub fn with_max_states(max_states: usize) -> Self {
93        Self {
94            states:     Arc::new(DashMap::new()),
95            max_states: max_states.max(1), // Ensure at least 1 state
96        }
97    }
98
99    /// Clean up expired states and check capacity
100    ///
101    /// # SECURITY
102    /// Called before inserting new states to:
103    /// 1. Remove expired states (automatic cleanup)
104    /// 2. Check if store is at capacity
105    /// 3. Return eviction needed flag if cleanup doesn't free space
106    fn cleanup_expired(&self) -> bool {
107        let now = std::time::SystemTime::now()
108            .duration_since(std::time::UNIX_EPOCH)
109            .unwrap_or_default()
110            .as_secs();
111
112        // Remove all expired states
113        self.states.retain(|_key, (_provider, expiry)| *expiry > now);
114
115        // Return true if we're still over capacity after cleanup
116        self.states.len() >= self.max_states
117    }
118}
119
120impl Default for InMemoryStateStore {
121    fn default() -> Self {
122        Self::new()
123    }
124}
125
126// Reason: StateStore is defined with #[async_trait]; all implementations must match
127// its transformed method signatures to satisfy the trait contract
128// async_trait: dyn-dispatch required; remove when RTN + Send is stable (RFC 3425)
129#[async_trait]
130impl StateStore for InMemoryStateStore {
131    async fn store(&self, state: String, provider: String, expiry_secs: u64) -> Result<()> {
132        // SECURITY: Clean up expired states before inserting new one
133        if self.cleanup_expired() {
134            // Still over capacity after cleanup - reject to prevent memory exhaustion
135            return Err(crate::error::AuthError::ConfigError {
136                message: "State store at capacity, cannot store new state".to_string(),
137            });
138        }
139
140        self.states.insert(state, (provider, expiry_secs));
141        Ok(())
142    }
143
144    async fn retrieve(&self, state: &str) -> Result<(String, u64)> {
145        let (_key, value) =
146            self.states.remove(state).ok_or_else(|| crate::error::AuthError::InvalidState)?;
147        Ok(value)
148    }
149}
150
151/// Redis-backed state store for distributed deployments
152///
153/// Uses Redis to store OAuth state parameters, allowing state validation
154/// across multiple server instances. Automatically expires states after TTL.
155#[cfg(feature = "redis-rate-limiting")]
156#[derive(Clone)]
157pub struct RedisStateStore {
158    client: redis::aio::ConnectionManager,
159}
160
161#[cfg(feature = "redis-rate-limiting")]
162impl RedisStateStore {
163    /// Create a new Redis state store
164    ///
165    /// # Arguments
166    /// * `redis_url` - Connection string (e.g., "redis://localhost:6379")
167    ///
168    /// # Example
169    /// ```no_run
170    /// // Requires: live Redis server.
171    /// # async fn example() -> fraiseql_auth::error::Result<()> {
172    /// use fraiseql_auth::state_store::RedisStateStore;
173    /// let store = RedisStateStore::new("redis://localhost:6379").await?;
174    /// # Ok(())
175    /// # }
176    /// ```
177    /// # Errors
178    ///
179    /// Returns [`AuthError::ConfigError`](crate::error::AuthError::ConfigError) if the Redis URL is
180    /// invalid or if the connection manager cannot be established.
181    pub async fn new(redis_url: &str) -> Result<Self> {
182        let client =
183            redis::Client::open(redis_url).map_err(|e| crate::error::AuthError::ConfigError {
184                message: e.to_string(),
185            })?;
186
187        let connection_manager = client.get_connection_manager().await.map_err(|e| {
188            crate::error::AuthError::ConfigError {
189                message: e.to_string(),
190            }
191        })?;
192
193        Ok(Self {
194            client: connection_manager,
195        })
196    }
197
198    /// Get Redis key for state
199    fn state_key(state: &str) -> String {
200        format!("oauth:state:{}", state)
201    }
202}
203
204#[cfg(feature = "redis-rate-limiting")]
205// Reason: StateStore is defined with #[async_trait]; all implementations must match
206// its transformed method signatures to satisfy the trait contract
207// async_trait: dyn-dispatch required; remove when RTN + Send is stable (RFC 3425)
208#[async_trait]
209impl StateStore for RedisStateStore {
210    async fn store(&self, state: String, provider: String, expiry_secs: u64) -> Result<()> {
211        use redis::AsyncCommands;
212
213        let key = Self::state_key(&state);
214        let ttl = expiry_secs
215            .saturating_sub(
216                std::time::SystemTime::now()
217                    .duration_since(std::time::UNIX_EPOCH)
218                    .unwrap_or_default()
219                    .as_secs(),
220            )
221            .max(1); // Minimum 1 second TTL
222
223        let mut conn = self.client.clone();
224        let _: () = conn.set_ex(&key, &provider, ttl).await.map_err(|e| {
225            crate::error::AuthError::ConfigError {
226                message: e.to_string(),
227            }
228        })?;
229
230        Ok(())
231    }
232
233    async fn retrieve(&self, state: &str) -> Result<(String, u64)> {
234        use redis::AsyncCommands;
235
236        let key = Self::state_key(state);
237        let mut conn = self.client.clone();
238
239        // Get the value and delete it atomically
240        let provider: Option<String> =
241            conn.get(&key).await.map_err(|e| crate::error::AuthError::ConfigError {
242                message: e.to_string(),
243            })?;
244
245        let provider = provider.ok_or(crate::error::AuthError::InvalidState)?;
246
247        // Delete the state to prevent replay
248        let _: () = conn.del(&key).await.map_err(|e| crate::error::AuthError::ConfigError {
249            message: e.to_string(),
250        })?;
251
252        // Return current time as expiry (it was already validated by Redis TTL)
253        let expiry_secs = std::time::SystemTime::now()
254            .duration_since(std::time::UNIX_EPOCH)
255            .unwrap_or_default()
256            .as_secs();
257
258        Ok((provider, expiry_secs))
259    }
260}
261
262#[allow(clippy::unwrap_used)] // Reason: test code, panics are acceptable
263#[cfg(test)]
264mod tests {
265    #[allow(clippy::wildcard_imports)]
266    // Reason: test module — wildcard keeps test boilerplate minimal
267    use super::*;
268
269    #[tokio::test]
270    async fn test_in_memory_state_store() {
271        let store = InMemoryStateStore::new();
272
273        // Store a state
274        store
275            .store(
276                "state123".to_string(),
277                "google".to_string(),
278                std::time::SystemTime::now()
279                    .duration_since(std::time::UNIX_EPOCH)
280                    .unwrap()
281                    .as_secs()
282                    + 600,
283            )
284            .await
285            .unwrap();
286
287        // Retrieve it
288        let (provider, _expiry) = store.retrieve("state123").await.unwrap();
289        assert_eq!(provider, "google");
290
291        // Should be gone now (consumed)
292        let result = store.retrieve("state123").await;
293        assert!(
294            matches!(result, Err(crate::error::AuthError::InvalidState)),
295            "expected InvalidState for consumed state, got: {result:?}"
296        );
297    }
298
299    #[tokio::test]
300    async fn test_state_not_found() {
301        let store = InMemoryStateStore::new();
302        let result = store.retrieve("nonexistent").await;
303        assert!(
304            matches!(result, Err(crate::error::AuthError::InvalidState)),
305            "expected InvalidState for nonexistent state, got: {result:?}"
306        );
307    }
308
309    #[tokio::test]
310    async fn test_in_memory_state_replay_prevention() {
311        let store = InMemoryStateStore::new();
312        let expiry = std::time::SystemTime::now()
313            .duration_since(std::time::UNIX_EPOCH)
314            .unwrap()
315            .as_secs()
316            + 600;
317
318        store.store("state_abc".to_string(), "auth0".to_string(), expiry).await.unwrap();
319
320        // First retrieval succeeds
321        let result1 = store.retrieve("state_abc").await;
322        assert!(result1.is_ok(), "first retrieval should succeed: {result1:?}");
323
324        // Replay attempt fails
325        let result2 = store.retrieve("state_abc").await;
326        assert!(
327            matches!(result2, Err(crate::error::AuthError::InvalidState)),
328            "replay attempt should return InvalidState, got: {result2:?}"
329        );
330    }
331
332    #[tokio::test]
333    async fn test_in_memory_multiple_states() {
334        let store = InMemoryStateStore::new();
335        let expiry = std::time::SystemTime::now()
336            .duration_since(std::time::UNIX_EPOCH)
337            .unwrap()
338            .as_secs()
339            + 600;
340
341        // Store multiple states
342        store.store("state1".to_string(), "google".to_string(), expiry).await.unwrap();
343        store.store("state2".to_string(), "auth0".to_string(), expiry).await.unwrap();
344        store.store("state3".to_string(), "okta".to_string(), expiry).await.unwrap();
345
346        // Retrieve each independently
347        let (p1, _) = store.retrieve("state1").await.unwrap();
348        assert_eq!(p1, "google");
349
350        let (p2, _) = store.retrieve("state2").await.unwrap();
351        assert_eq!(p2, "auth0");
352
353        let (p3, _) = store.retrieve("state3").await.unwrap();
354        assert_eq!(p3, "okta");
355    }
356
357    #[tokio::test]
358    async fn test_in_memory_state_store_trait_object() {
359        let store: Arc<dyn StateStore> = Arc::new(InMemoryStateStore::new());
360        let expiry = std::time::SystemTime::now()
361            .duration_since(std::time::UNIX_EPOCH)
362            .unwrap()
363            .as_secs()
364            + 600;
365
366        store
367            .store("state_trait".to_string(), "test_provider".to_string(), expiry)
368            .await
369            .unwrap();
370
371        let (provider, _) = store.retrieve("state_trait").await.unwrap();
372        assert_eq!(provider, "test_provider");
373    }
374
375    #[tokio::test]
376    async fn test_in_memory_state_store_bounded() {
377        // SECURITY: Test that store respects max size limit
378        let store = InMemoryStateStore::with_max_states(5);
379        let expiry = std::time::SystemTime::now()
380            .duration_since(std::time::UNIX_EPOCH)
381            .unwrap()
382            .as_secs()
383            + 600;
384
385        // Store 5 states (at capacity)
386        for i in 0..5 {
387            let state = format!("state_{}", i);
388            store.store(state, "google".to_string(), expiry).await.unwrap();
389        }
390
391        // 6th state should be rejected when store is at capacity
392        let result = store.store("state_5".to_string(), "google".to_string(), expiry).await;
393        assert!(
394            matches!(result, Err(crate::error::AuthError::ConfigError { .. })),
395            "expected ConfigError when store at capacity, got: {result:?}"
396        );
397    }
398
399    #[tokio::test]
400    async fn test_in_memory_state_store_cleanup_expired() {
401        // SECURITY: Test that expired states are cleaned up automatically
402        let store = InMemoryStateStore::with_max_states(3);
403
404        let now = std::time::SystemTime::now()
405            .duration_since(std::time::UNIX_EPOCH)
406            .unwrap()
407            .as_secs();
408
409        // Store 3 expired states
410        for i in 0..3 {
411            let state = format!("expired_{}", i);
412            store.store(state, "google".to_string(), now - 100).await.unwrap();
413        }
414
415        // Store 1 valid state - should succeed because expired states are cleaned up
416        let expiry = now + 600;
417        let result = store.store("valid_state".to_string(), "auth0".to_string(), expiry).await;
418        assert!(result.is_ok(), "Should succeed after cleaning up expired states");
419
420        // Store 2 more valid states
421        store
422            .store("valid_state_2".to_string(), "google".to_string(), expiry)
423            .await
424            .unwrap();
425        store
426            .store("valid_state_3".to_string(), "okta".to_string(), expiry)
427            .await
428            .unwrap();
429
430        // Now at capacity with valid states
431        let result = store.store("valid_state_4".to_string(), "auth0".to_string(), expiry).await;
432        assert!(
433            matches!(result, Err(crate::error::AuthError::ConfigError { .. })),
434            "expected ConfigError when at capacity with valid states, got: {result:?}"
435        );
436    }
437
438    #[tokio::test]
439    async fn test_in_memory_state_store_custom_max_size() {
440        // Test with different max sizes
441        let store_small = InMemoryStateStore::with_max_states(1);
442        let store_large = InMemoryStateStore::with_max_states(100);
443
444        let expiry = std::time::SystemTime::now()
445            .duration_since(std::time::UNIX_EPOCH)
446            .unwrap()
447            .as_secs()
448            + 600;
449
450        // Small store should reject after 1 state
451        store_small.store("s1".to_string(), "p1".to_string(), expiry).await.unwrap();
452        let result = store_small.store("s2".to_string(), "p2".to_string(), expiry).await;
453        assert!(
454            matches!(result, Err(crate::error::AuthError::ConfigError { .. })),
455            "expected ConfigError when small store at capacity, got: {result:?}"
456        );
457
458        // Large store should allow more states
459        for i in 0..50 {
460            let state = format!("state_{}", i);
461            store_large.store(state, "provider".to_string(), expiry).await.unwrap();
462        }
463        assert_eq!(store_large.states.len(), 50);
464    }
465
466    #[tokio::test]
467    async fn test_in_memory_state_store_zero_max_enforced() {
468        // Edge case: verify that min(1) is enforced
469        let store = InMemoryStateStore::with_max_states(0);
470        let expiry = std::time::SystemTime::now()
471            .duration_since(std::time::UNIX_EPOCH)
472            .unwrap()
473            .as_secs()
474            + 600;
475
476        // Even with max_states=0, should allow at least 1
477        let result = store.store("state1".to_string(), "google".to_string(), expiry).await;
478        assert!(result.is_ok(), "Should allow at least 1 state minimum");
479    }
480
481    #[cfg(feature = "redis-rate-limiting")]
482    #[tokio::test]
483    async fn test_redis_state_store_basic() {
484        // This test requires Redis to be running
485        // Skip if Redis is unavailable
486        let redis_url = "redis://localhost:6379";
487
488        match RedisStateStore::new(redis_url).await {
489            Ok(store) => {
490                let expiry = std::time::SystemTime::now()
491                    .duration_since(std::time::UNIX_EPOCH)
492                    .unwrap()
493                    .as_secs()
494                    + 600;
495
496                // Store a state
497                store
498                    .store("redis_state_1".to_string(), "google".to_string(), expiry)
499                    .await
500                    .unwrap();
501
502                // Retrieve it
503                let (provider, _) = store.retrieve("redis_state_1").await.unwrap();
504                assert_eq!(provider, "google");
505
506                // Should not be retrievable again (consumed)
507                let result = store.retrieve("redis_state_1").await;
508                assert!(
509                    matches!(result, Err(crate::error::AuthError::InvalidState)),
510                    "expected InvalidState for consumed redis state, got: {result:?}"
511                );
512            },
513            Err(_) => {
514                // Skip test if Redis is unavailable
515                eprintln!("Skipping Redis tests - Redis server not available");
516            },
517        }
518    }
519
520    #[cfg(feature = "redis-rate-limiting")]
521    #[tokio::test]
522    async fn test_redis_state_replay_prevention() {
523        let redis_url = "redis://localhost:6379";
524
525        if let Ok(store) = RedisStateStore::new(redis_url).await {
526            let expiry = std::time::SystemTime::now()
527                .duration_since(std::time::UNIX_EPOCH)
528                .unwrap()
529                .as_secs()
530                + 600;
531
532            store
533                .store("redis_replay_test".to_string(), "auth0".to_string(), expiry)
534                .await
535                .unwrap();
536
537            // First retrieval succeeds
538            let result1 = store.retrieve("redis_replay_test").await;
539            assert!(result1.is_ok(), "first redis retrieval should succeed: {result1:?}");
540
541            // Replay attempt fails
542            let result2 = store.retrieve("redis_replay_test").await;
543            assert!(
544                matches!(result2, Err(crate::error::AuthError::InvalidState)),
545                "redis replay attempt should return InvalidState, got: {result2:?}"
546            );
547        }
548    }
549
550    #[cfg(feature = "redis-rate-limiting")]
551    #[tokio::test]
552    async fn test_redis_multiple_states() {
553        let redis_url = "redis://localhost:6379";
554
555        if let Ok(store) = RedisStateStore::new(redis_url).await {
556            let expiry = std::time::SystemTime::now()
557                .duration_since(std::time::UNIX_EPOCH)
558                .unwrap()
559                .as_secs()
560                + 600;
561
562            // Store multiple states
563            store
564                .store("redis_state_a".to_string(), "google".to_string(), expiry)
565                .await
566                .unwrap();
567            store
568                .store("redis_state_b".to_string(), "okta".to_string(), expiry)
569                .await
570                .unwrap();
571
572            // Retrieve each independently
573            let (p1, _) = store.retrieve("redis_state_a").await.unwrap();
574            assert_eq!(p1, "google");
575
576            let (p2, _) = store.retrieve("redis_state_b").await.unwrap();
577            assert_eq!(p2, "okta");
578        }
579    }
580}