Skip to main content

fraiseql_server/auth/
state_store.rs

1// CSRF state store - trait definition and implementations
2// Prevents OAuth state parameter reuse in distributed systems
3
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use dashmap::DashMap;
8
9use crate::auth::error::Result;
10
11/// StateStore trait - implement this for different storage backends
12///
13/// Stores OAuth state parameters with expiration for CSRF protection.
14/// In distributed deployments, use a persistent backend (Redis) instead of in-memory.
15///
16/// # Examples
17///
18/// Use in-memory store for single-instance deployments:
19/// ```ignore
20/// let state_store = Arc::new(InMemoryStateStore::new());
21/// ```
22///
23/// Use Redis for distributed deployments:
24/// ```ignore
25/// let state_store = Arc::new(RedisStateStore::new(redis_client).await?);
26/// ```
27#[async_trait]
28pub trait StateStore: Send + Sync {
29    /// Store a state value with provider and expiration
30    ///
31    /// # Arguments
32    /// * `state` - The state parameter value
33    /// * `provider` - OAuth provider name
34    /// * `expiry_secs` - Unix timestamp when this state expires
35    async fn store(&self, state: String, provider: String, expiry_secs: u64) -> Result<()>;
36
37    /// Retrieve and remove a state value
38    ///
39    /// Returns (provider, expiry_secs) if state exists and is valid
40    /// Returns error if state doesn't exist or is invalid
41    async fn retrieve(&self, state: &str) -> Result<(String, u64)>;
42}
43
44/// In-memory state store using DashMap
45///
46/// **Warning**: Only suitable for single-instance deployments!
47/// For distributed systems, use RedisStateStore instead.
48#[derive(Debug)]
49pub struct InMemoryStateStore {
50    // Map of state -> (provider, expiry_secs)
51    states: Arc<DashMap<String, (String, u64)>>,
52}
53
54impl InMemoryStateStore {
55    /// Create a new in-memory state store
56    pub fn new() -> Self {
57        Self {
58            states: Arc::new(DashMap::new()),
59        }
60    }
61}
62
63impl Default for InMemoryStateStore {
64    fn default() -> Self {
65        Self::new()
66    }
67}
68
69#[async_trait]
70impl StateStore for InMemoryStateStore {
71    async fn store(&self, state: String, provider: String, expiry_secs: u64) -> Result<()> {
72        self.states.insert(state, (provider, expiry_secs));
73        Ok(())
74    }
75
76    async fn retrieve(&self, state: &str) -> Result<(String, u64)> {
77        let (_key, value) = self
78            .states
79            .remove(state)
80            .ok_or_else(|| crate::auth::error::AuthError::InvalidState)?;
81        Ok(value)
82    }
83}
84
85/// Redis-backed state store for distributed deployments
86///
87/// Uses Redis to store OAuth state parameters, allowing state validation
88/// across multiple server instances. Automatically expires states after TTL.
89#[cfg(feature = "redis-rate-limiting")]
90#[derive(Clone)]
91pub struct RedisStateStore {
92    client: redis::aio::ConnectionManager,
93}
94
95#[cfg(feature = "redis-rate-limiting")]
96impl RedisStateStore {
97    /// Create a new Redis state store
98    ///
99    /// # Arguments
100    /// * `redis_url` - Connection string (e.g., "redis://localhost:6379")
101    ///
102    /// # Example
103    /// ```ignore
104    /// let store = RedisStateStore::new("redis://localhost:6379").await?;
105    /// ```
106    pub async fn new(redis_url: &str) -> Result<Self> {
107        let client = redis::Client::open(redis_url).map_err(|e| {
108            crate::auth::error::AuthError::ConfigError {
109                message: e.to_string(),
110            }
111        })?;
112
113        let connection_manager = client.get_connection_manager().await.map_err(|e| {
114            crate::auth::error::AuthError::ConfigError {
115                message: e.to_string(),
116            }
117        })?;
118
119        Ok(Self {
120            client: connection_manager,
121        })
122    }
123
124    /// Get Redis key for state
125    fn state_key(state: &str) -> String {
126        format!("oauth:state:{}", state)
127    }
128}
129
130#[cfg(feature = "redis-rate-limiting")]
131#[async_trait]
132impl StateStore for RedisStateStore {
133    async fn store(&self, state: String, provider: String, expiry_secs: u64) -> Result<()> {
134        use redis::AsyncCommands;
135
136        let key = Self::state_key(&state);
137        let ttl = expiry_secs
138            .saturating_sub(
139                std::time::SystemTime::now()
140                    .duration_since(std::time::UNIX_EPOCH)
141                    .unwrap_or_default()
142                    .as_secs(),
143            )
144            .max(1); // Minimum 1 second TTL
145
146        let mut conn = self.client.clone();
147        let _: () = conn.set_ex(&key, &provider, ttl).await.map_err(|e| {
148            crate::auth::error::AuthError::ConfigError {
149                message: e.to_string(),
150            }
151        })?;
152
153        Ok(())
154    }
155
156    async fn retrieve(&self, state: &str) -> Result<(String, u64)> {
157        use redis::AsyncCommands;
158
159        let key = Self::state_key(state);
160        let mut conn = self.client.clone();
161
162        // Get the value and delete it atomically
163        let provider: Option<String> =
164            conn.get(&key).await.map_err(|e| crate::auth::error::AuthError::ConfigError {
165                message: e.to_string(),
166            })?;
167
168        let provider = provider.ok_or(crate::auth::error::AuthError::InvalidState)?;
169
170        // Delete the state to prevent replay
171        let _: () =
172            conn.del(&key).await.map_err(|e| crate::auth::error::AuthError::ConfigError {
173                message: e.to_string(),
174            })?;
175
176        // Return current time as expiry (it was already validated by Redis TTL)
177        let expiry_secs = std::time::SystemTime::now()
178            .duration_since(std::time::UNIX_EPOCH)
179            .unwrap_or_default()
180            .as_secs();
181
182        Ok((provider, expiry_secs))
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189
190    #[tokio::test]
191    async fn test_in_memory_state_store() {
192        let store = InMemoryStateStore::new();
193
194        // Store a state
195        store
196            .store(
197                "state123".to_string(),
198                "google".to_string(),
199                std::time::SystemTime::now()
200                    .duration_since(std::time::UNIX_EPOCH)
201                    .unwrap()
202                    .as_secs()
203                    + 600,
204            )
205            .await
206            .unwrap();
207
208        // Retrieve it
209        let (provider, _expiry) = store.retrieve("state123").await.unwrap();
210        assert_eq!(provider, "google");
211
212        // Should be gone now (consumed)
213        let result = store.retrieve("state123").await;
214        assert!(result.is_err());
215    }
216
217    #[tokio::test]
218    async fn test_state_not_found() {
219        let store = InMemoryStateStore::new();
220        let result = store.retrieve("nonexistent").await;
221        assert!(result.is_err());
222    }
223
224    #[tokio::test]
225    async fn test_in_memory_state_replay_prevention() {
226        let store = InMemoryStateStore::new();
227        let expiry = std::time::SystemTime::now()
228            .duration_since(std::time::UNIX_EPOCH)
229            .unwrap()
230            .as_secs()
231            + 600;
232
233        store.store("state_abc".to_string(), "auth0".to_string(), expiry).await.unwrap();
234
235        // First retrieval succeeds
236        let result1 = store.retrieve("state_abc").await;
237        assert!(result1.is_ok());
238
239        // Replay attempt fails
240        let result2 = store.retrieve("state_abc").await;
241        assert!(result2.is_err());
242    }
243
244    #[tokio::test]
245    async fn test_in_memory_multiple_states() {
246        let store = InMemoryStateStore::new();
247        let expiry = std::time::SystemTime::now()
248            .duration_since(std::time::UNIX_EPOCH)
249            .unwrap()
250            .as_secs()
251            + 600;
252
253        // Store multiple states
254        store.store("state1".to_string(), "google".to_string(), expiry).await.unwrap();
255        store.store("state2".to_string(), "auth0".to_string(), expiry).await.unwrap();
256        store.store("state3".to_string(), "okta".to_string(), expiry).await.unwrap();
257
258        // Retrieve each independently
259        let (p1, _) = store.retrieve("state1").await.unwrap();
260        assert_eq!(p1, "google");
261
262        let (p2, _) = store.retrieve("state2").await.unwrap();
263        assert_eq!(p2, "auth0");
264
265        let (p3, _) = store.retrieve("state3").await.unwrap();
266        assert_eq!(p3, "okta");
267    }
268
269    #[tokio::test]
270    async fn test_in_memory_state_store_trait_object() {
271        let store: Arc<dyn StateStore> = Arc::new(InMemoryStateStore::new());
272        let expiry = std::time::SystemTime::now()
273            .duration_since(std::time::UNIX_EPOCH)
274            .unwrap()
275            .as_secs()
276            + 600;
277
278        store
279            .store("state_trait".to_string(), "test_provider".to_string(), expiry)
280            .await
281            .unwrap();
282
283        let (provider, _) = store.retrieve("state_trait").await.unwrap();
284        assert_eq!(provider, "test_provider");
285    }
286
287    #[cfg(feature = "redis-rate-limiting")]
288    #[tokio::test]
289    async fn test_redis_state_store_basic() {
290        // This test requires Redis to be running
291        // Skip if Redis is unavailable
292        let redis_url = "redis://localhost:6379";
293
294        match RedisStateStore::new(redis_url).await {
295            Ok(store) => {
296                let expiry = std::time::SystemTime::now()
297                    .duration_since(std::time::UNIX_EPOCH)
298                    .unwrap()
299                    .as_secs()
300                    + 600;
301
302                // Store a state
303                store
304                    .store("redis_state_1".to_string(), "google".to_string(), expiry)
305                    .await
306                    .unwrap();
307
308                // Retrieve it
309                let (provider, _) = store.retrieve("redis_state_1").await.unwrap();
310                assert_eq!(provider, "google");
311
312                // Should not be retrievable again (consumed)
313                let result = store.retrieve("redis_state_1").await;
314                assert!(result.is_err());
315            },
316            Err(_) => {
317                // Skip test if Redis is unavailable
318                eprintln!("Skipping Redis tests - Redis server not available");
319            },
320        }
321    }
322
323    #[cfg(feature = "redis-rate-limiting")]
324    #[tokio::test]
325    async fn test_redis_state_replay_prevention() {
326        let redis_url = "redis://localhost:6379";
327
328        if let Ok(store) = RedisStateStore::new(redis_url).await {
329            let expiry = std::time::SystemTime::now()
330                .duration_since(std::time::UNIX_EPOCH)
331                .unwrap()
332                .as_secs()
333                + 600;
334
335            store
336                .store("redis_replay_test".to_string(), "auth0".to_string(), expiry)
337                .await
338                .unwrap();
339
340            // First retrieval succeeds
341            let result1 = store.retrieve("redis_replay_test").await;
342            assert!(result1.is_ok());
343
344            // Replay attempt fails
345            let result2 = store.retrieve("redis_replay_test").await;
346            assert!(result2.is_err());
347        }
348    }
349
350    #[cfg(feature = "redis-rate-limiting")]
351    #[tokio::test]
352    async fn test_redis_multiple_states() {
353        let redis_url = "redis://localhost:6379";
354
355        if let Ok(store) = RedisStateStore::new(redis_url).await {
356            let expiry = std::time::SystemTime::now()
357                .duration_since(std::time::UNIX_EPOCH)
358                .unwrap()
359                .as_secs()
360                + 600;
361
362            // Store multiple states
363            store
364                .store("redis_state_a".to_string(), "google".to_string(), expiry)
365                .await
366                .unwrap();
367            store
368                .store("redis_state_b".to_string(), "okta".to_string(), expiry)
369                .await
370                .unwrap();
371
372            // Retrieve each independently
373            let (p1, _) = store.retrieve("redis_state_a").await.unwrap();
374            assert_eq!(p1, "google");
375
376            let (p2, _) = store.retrieve("redis_state_b").await.unwrap();
377            assert_eq!(p2, "okta");
378        }
379    }
380}