fraiseql_server/auth/
state_store.rs1use std::sync::Arc;
5
6use async_trait::async_trait;
7use dashmap::DashMap;
8
9use crate::auth::error::Result;
10
11#[async_trait]
28pub trait StateStore: Send + Sync {
29 async fn store(&self, state: String, provider: String, expiry_secs: u64) -> Result<()>;
36
37 async fn retrieve(&self, state: &str) -> Result<(String, u64)>;
42}
43
44#[derive(Debug)]
49pub struct InMemoryStateStore {
50 states: Arc<DashMap<String, (String, u64)>>,
52}
53
54impl InMemoryStateStore {
55 pub fn new() -> Self {
57 Self {
58 states: Arc::new(DashMap::new()),
59 }
60 }
61}
62
63impl Default for InMemoryStateStore {
64 fn default() -> Self {
65 Self::new()
66 }
67}
68
69#[async_trait]
70impl StateStore for InMemoryStateStore {
71 async fn store(&self, state: String, provider: String, expiry_secs: u64) -> Result<()> {
72 self.states.insert(state, (provider, expiry_secs));
73 Ok(())
74 }
75
76 async fn retrieve(&self, state: &str) -> Result<(String, u64)> {
77 let (_key, value) = self
78 .states
79 .remove(state)
80 .ok_or_else(|| crate::auth::error::AuthError::InvalidState)?;
81 Ok(value)
82 }
83}
84
85#[cfg(feature = "redis-rate-limiting")]
90#[derive(Clone)]
91pub struct RedisStateStore {
92 client: redis::aio::ConnectionManager,
93}
94
95#[cfg(feature = "redis-rate-limiting")]
96impl RedisStateStore {
97 pub async fn new(redis_url: &str) -> Result<Self> {
107 let client = redis::Client::open(redis_url).map_err(|e| {
108 crate::auth::error::AuthError::ConfigError {
109 message: e.to_string(),
110 }
111 })?;
112
113 let connection_manager = client.get_connection_manager().await.map_err(|e| {
114 crate::auth::error::AuthError::ConfigError {
115 message: e.to_string(),
116 }
117 })?;
118
119 Ok(Self {
120 client: connection_manager,
121 })
122 }
123
124 fn state_key(state: &str) -> String {
126 format!("oauth:state:{}", state)
127 }
128}
129
130#[cfg(feature = "redis-rate-limiting")]
131#[async_trait]
132impl StateStore for RedisStateStore {
133 async fn store(&self, state: String, provider: String, expiry_secs: u64) -> Result<()> {
134 use redis::AsyncCommands;
135
136 let key = Self::state_key(&state);
137 let ttl = expiry_secs
138 .saturating_sub(
139 std::time::SystemTime::now()
140 .duration_since(std::time::UNIX_EPOCH)
141 .unwrap_or_default()
142 .as_secs(),
143 )
144 .max(1); let mut conn = self.client.clone();
147 let _: () = conn.set_ex(&key, &provider, ttl).await.map_err(|e| {
148 crate::auth::error::AuthError::ConfigError {
149 message: e.to_string(),
150 }
151 })?;
152
153 Ok(())
154 }
155
156 async fn retrieve(&self, state: &str) -> Result<(String, u64)> {
157 use redis::AsyncCommands;
158
159 let key = Self::state_key(state);
160 let mut conn = self.client.clone();
161
162 let provider: Option<String> =
164 conn.get(&key).await.map_err(|e| crate::auth::error::AuthError::ConfigError {
165 message: e.to_string(),
166 })?;
167
168 let provider = provider.ok_or(crate::auth::error::AuthError::InvalidState)?;
169
170 let _: () =
172 conn.del(&key).await.map_err(|e| crate::auth::error::AuthError::ConfigError {
173 message: e.to_string(),
174 })?;
175
176 let expiry_secs = std::time::SystemTime::now()
178 .duration_since(std::time::UNIX_EPOCH)
179 .unwrap_or_default()
180 .as_secs();
181
182 Ok((provider, expiry_secs))
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189
190 #[tokio::test]
191 async fn test_in_memory_state_store() {
192 let store = InMemoryStateStore::new();
193
194 store
196 .store(
197 "state123".to_string(),
198 "google".to_string(),
199 std::time::SystemTime::now()
200 .duration_since(std::time::UNIX_EPOCH)
201 .unwrap()
202 .as_secs()
203 + 600,
204 )
205 .await
206 .unwrap();
207
208 let (provider, _expiry) = store.retrieve("state123").await.unwrap();
210 assert_eq!(provider, "google");
211
212 let result = store.retrieve("state123").await;
214 assert!(result.is_err());
215 }
216
217 #[tokio::test]
218 async fn test_state_not_found() {
219 let store = InMemoryStateStore::new();
220 let result = store.retrieve("nonexistent").await;
221 assert!(result.is_err());
222 }
223
224 #[tokio::test]
225 async fn test_in_memory_state_replay_prevention() {
226 let store = InMemoryStateStore::new();
227 let expiry = std::time::SystemTime::now()
228 .duration_since(std::time::UNIX_EPOCH)
229 .unwrap()
230 .as_secs()
231 + 600;
232
233 store.store("state_abc".to_string(), "auth0".to_string(), expiry).await.unwrap();
234
235 let result1 = store.retrieve("state_abc").await;
237 assert!(result1.is_ok());
238
239 let result2 = store.retrieve("state_abc").await;
241 assert!(result2.is_err());
242 }
243
244 #[tokio::test]
245 async fn test_in_memory_multiple_states() {
246 let store = InMemoryStateStore::new();
247 let expiry = std::time::SystemTime::now()
248 .duration_since(std::time::UNIX_EPOCH)
249 .unwrap()
250 .as_secs()
251 + 600;
252
253 store.store("state1".to_string(), "google".to_string(), expiry).await.unwrap();
255 store.store("state2".to_string(), "auth0".to_string(), expiry).await.unwrap();
256 store.store("state3".to_string(), "okta".to_string(), expiry).await.unwrap();
257
258 let (p1, _) = store.retrieve("state1").await.unwrap();
260 assert_eq!(p1, "google");
261
262 let (p2, _) = store.retrieve("state2").await.unwrap();
263 assert_eq!(p2, "auth0");
264
265 let (p3, _) = store.retrieve("state3").await.unwrap();
266 assert_eq!(p3, "okta");
267 }
268
269 #[tokio::test]
270 async fn test_in_memory_state_store_trait_object() {
271 let store: Arc<dyn StateStore> = Arc::new(InMemoryStateStore::new());
272 let expiry = std::time::SystemTime::now()
273 .duration_since(std::time::UNIX_EPOCH)
274 .unwrap()
275 .as_secs()
276 + 600;
277
278 store
279 .store("state_trait".to_string(), "test_provider".to_string(), expiry)
280 .await
281 .unwrap();
282
283 let (provider, _) = store.retrieve("state_trait").await.unwrap();
284 assert_eq!(provider, "test_provider");
285 }
286
287 #[cfg(feature = "redis-rate-limiting")]
288 #[tokio::test]
289 async fn test_redis_state_store_basic() {
290 let redis_url = "redis://localhost:6379";
293
294 match RedisStateStore::new(redis_url).await {
295 Ok(store) => {
296 let expiry = std::time::SystemTime::now()
297 .duration_since(std::time::UNIX_EPOCH)
298 .unwrap()
299 .as_secs()
300 + 600;
301
302 store
304 .store("redis_state_1".to_string(), "google".to_string(), expiry)
305 .await
306 .unwrap();
307
308 let (provider, _) = store.retrieve("redis_state_1").await.unwrap();
310 assert_eq!(provider, "google");
311
312 let result = store.retrieve("redis_state_1").await;
314 assert!(result.is_err());
315 },
316 Err(_) => {
317 eprintln!("Skipping Redis tests - Redis server not available");
319 },
320 }
321 }
322
323 #[cfg(feature = "redis-rate-limiting")]
324 #[tokio::test]
325 async fn test_redis_state_replay_prevention() {
326 let redis_url = "redis://localhost:6379";
327
328 if let Ok(store) = RedisStateStore::new(redis_url).await {
329 let expiry = std::time::SystemTime::now()
330 .duration_since(std::time::UNIX_EPOCH)
331 .unwrap()
332 .as_secs()
333 + 600;
334
335 store
336 .store("redis_replay_test".to_string(), "auth0".to_string(), expiry)
337 .await
338 .unwrap();
339
340 let result1 = store.retrieve("redis_replay_test").await;
342 assert!(result1.is_ok());
343
344 let result2 = store.retrieve("redis_replay_test").await;
346 assert!(result2.is_err());
347 }
348 }
349
350 #[cfg(feature = "redis-rate-limiting")]
351 #[tokio::test]
352 async fn test_redis_multiple_states() {
353 let redis_url = "redis://localhost:6379";
354
355 if let Ok(store) = RedisStateStore::new(redis_url).await {
356 let expiry = std::time::SystemTime::now()
357 .duration_since(std::time::UNIX_EPOCH)
358 .unwrap()
359 .as_secs()
360 + 600;
361
362 store
364 .store("redis_state_a".to_string(), "google".to_string(), expiry)
365 .await
366 .unwrap();
367 store
368 .store("redis_state_b".to_string(), "okta".to_string(), expiry)
369 .await
370 .unwrap();
371
372 let (p1, _) = store.retrieve("redis_state_a").await.unwrap();
374 assert_eq!(p1, "google");
375
376 let (p2, _) = store.retrieve("redis_state_b").await.unwrap();
377 assert_eq!(p2, "okta");
378 }
379 }
380}