1use std::sync::Arc;
7
8use async_trait::async_trait;
9use dashmap::DashMap;
10
11use crate::error::Result;
12
13#[async_trait]
42pub trait StateStore: Send + Sync {
43 async fn store(&self, state: String, provider: String, expiry_secs: u64) -> Result<()>;
50
51 async fn retrieve(&self, state: &str) -> Result<(String, u64)>;
56}
57
58#[derive(Debug)]
68pub struct InMemoryStateStore {
69 states: Arc<DashMap<String, (String, u64)>>,
71 max_states: usize,
73}
74
75impl InMemoryStateStore {
76 const MAX_STATES: usize = 10_000;
79
80 pub fn new() -> Self {
82 Self {
83 states: Arc::new(DashMap::new()),
84 max_states: Self::MAX_STATES,
85 }
86 }
87
88 pub fn with_max_states(max_states: usize) -> Self {
93 Self {
94 states: Arc::new(DashMap::new()),
95 max_states: max_states.max(1), }
97 }
98
99 fn cleanup_expired(&self) -> bool {
107 let now = std::time::SystemTime::now()
108 .duration_since(std::time::UNIX_EPOCH)
109 .unwrap_or_default()
110 .as_secs();
111
112 self.states.retain(|_key, (_provider, expiry)| *expiry > now);
114
115 self.states.len() >= self.max_states
117 }
118}
119
120impl Default for InMemoryStateStore {
121 fn default() -> Self {
122 Self::new()
123 }
124}
125
126#[async_trait]
130impl StateStore for InMemoryStateStore {
131 async fn store(&self, state: String, provider: String, expiry_secs: u64) -> Result<()> {
132 if self.cleanup_expired() {
134 return Err(crate::error::AuthError::ConfigError {
136 message: "State store at capacity, cannot store new state".to_string(),
137 });
138 }
139
140 self.states.insert(state, (provider, expiry_secs));
141 Ok(())
142 }
143
144 async fn retrieve(&self, state: &str) -> Result<(String, u64)> {
145 let (_key, value) =
146 self.states.remove(state).ok_or_else(|| crate::error::AuthError::InvalidState)?;
147 Ok(value)
148 }
149}
150
151#[cfg(feature = "redis-rate-limiting")]
156#[derive(Clone)]
157pub struct RedisStateStore {
158 client: redis::aio::ConnectionManager,
159}
160
161#[cfg(feature = "redis-rate-limiting")]
162impl RedisStateStore {
163 pub async fn new(redis_url: &str) -> Result<Self> {
182 let client =
183 redis::Client::open(redis_url).map_err(|e| crate::error::AuthError::ConfigError {
184 message: e.to_string(),
185 })?;
186
187 let connection_manager = client.get_connection_manager().await.map_err(|e| {
188 crate::error::AuthError::ConfigError {
189 message: e.to_string(),
190 }
191 })?;
192
193 Ok(Self {
194 client: connection_manager,
195 })
196 }
197
198 fn state_key(state: &str) -> String {
200 format!("oauth:state:{}", state)
201 }
202}
203
204#[cfg(feature = "redis-rate-limiting")]
205#[async_trait]
209impl StateStore for RedisStateStore {
210 async fn store(&self, state: String, provider: String, expiry_secs: u64) -> Result<()> {
211 use redis::AsyncCommands;
212
213 let key = Self::state_key(&state);
214 let ttl = expiry_secs
215 .saturating_sub(
216 std::time::SystemTime::now()
217 .duration_since(std::time::UNIX_EPOCH)
218 .unwrap_or_default()
219 .as_secs(),
220 )
221 .max(1); let mut conn = self.client.clone();
224 let _: () = conn.set_ex(&key, &provider, ttl).await.map_err(|e| {
225 crate::error::AuthError::ConfigError {
226 message: e.to_string(),
227 }
228 })?;
229
230 Ok(())
231 }
232
233 async fn retrieve(&self, state: &str) -> Result<(String, u64)> {
234 use redis::AsyncCommands;
235
236 let key = Self::state_key(state);
237 let mut conn = self.client.clone();
238
239 let provider: Option<String> =
241 conn.get(&key).await.map_err(|e| crate::error::AuthError::ConfigError {
242 message: e.to_string(),
243 })?;
244
245 let provider = provider.ok_or(crate::error::AuthError::InvalidState)?;
246
247 let _: () = conn.del(&key).await.map_err(|e| crate::error::AuthError::ConfigError {
249 message: e.to_string(),
250 })?;
251
252 let expiry_secs = std::time::SystemTime::now()
254 .duration_since(std::time::UNIX_EPOCH)
255 .unwrap_or_default()
256 .as_secs();
257
258 Ok((provider, expiry_secs))
259 }
260}
261
262#[allow(clippy::unwrap_used)] #[cfg(test)]
264mod tests {
265 #[allow(clippy::wildcard_imports)]
266 use super::*;
268
269 #[tokio::test]
270 async fn test_in_memory_state_store() {
271 let store = InMemoryStateStore::new();
272
273 store
275 .store(
276 "state123".to_string(),
277 "google".to_string(),
278 std::time::SystemTime::now()
279 .duration_since(std::time::UNIX_EPOCH)
280 .unwrap()
281 .as_secs()
282 + 600,
283 )
284 .await
285 .unwrap();
286
287 let (provider, _expiry) = store.retrieve("state123").await.unwrap();
289 assert_eq!(provider, "google");
290
291 let result = store.retrieve("state123").await;
293 assert!(
294 matches!(result, Err(crate::error::AuthError::InvalidState)),
295 "expected InvalidState for consumed state, got: {result:?}"
296 );
297 }
298
299 #[tokio::test]
300 async fn test_state_not_found() {
301 let store = InMemoryStateStore::new();
302 let result = store.retrieve("nonexistent").await;
303 assert!(
304 matches!(result, Err(crate::error::AuthError::InvalidState)),
305 "expected InvalidState for nonexistent state, got: {result:?}"
306 );
307 }
308
309 #[tokio::test]
310 async fn test_in_memory_state_replay_prevention() {
311 let store = InMemoryStateStore::new();
312 let expiry = std::time::SystemTime::now()
313 .duration_since(std::time::UNIX_EPOCH)
314 .unwrap()
315 .as_secs()
316 + 600;
317
318 store.store("state_abc".to_string(), "auth0".to_string(), expiry).await.unwrap();
319
320 let result1 = store.retrieve("state_abc").await;
322 assert!(result1.is_ok(), "first retrieval should succeed: {result1:?}");
323
324 let result2 = store.retrieve("state_abc").await;
326 assert!(
327 matches!(result2, Err(crate::error::AuthError::InvalidState)),
328 "replay attempt should return InvalidState, got: {result2:?}"
329 );
330 }
331
332 #[tokio::test]
333 async fn test_in_memory_multiple_states() {
334 let store = InMemoryStateStore::new();
335 let expiry = std::time::SystemTime::now()
336 .duration_since(std::time::UNIX_EPOCH)
337 .unwrap()
338 .as_secs()
339 + 600;
340
341 store.store("state1".to_string(), "google".to_string(), expiry).await.unwrap();
343 store.store("state2".to_string(), "auth0".to_string(), expiry).await.unwrap();
344 store.store("state3".to_string(), "okta".to_string(), expiry).await.unwrap();
345
346 let (p1, _) = store.retrieve("state1").await.unwrap();
348 assert_eq!(p1, "google");
349
350 let (p2, _) = store.retrieve("state2").await.unwrap();
351 assert_eq!(p2, "auth0");
352
353 let (p3, _) = store.retrieve("state3").await.unwrap();
354 assert_eq!(p3, "okta");
355 }
356
357 #[tokio::test]
358 async fn test_in_memory_state_store_trait_object() {
359 let store: Arc<dyn StateStore> = Arc::new(InMemoryStateStore::new());
360 let expiry = std::time::SystemTime::now()
361 .duration_since(std::time::UNIX_EPOCH)
362 .unwrap()
363 .as_secs()
364 + 600;
365
366 store
367 .store("state_trait".to_string(), "test_provider".to_string(), expiry)
368 .await
369 .unwrap();
370
371 let (provider, _) = store.retrieve("state_trait").await.unwrap();
372 assert_eq!(provider, "test_provider");
373 }
374
375 #[tokio::test]
376 async fn test_in_memory_state_store_bounded() {
377 let store = InMemoryStateStore::with_max_states(5);
379 let expiry = std::time::SystemTime::now()
380 .duration_since(std::time::UNIX_EPOCH)
381 .unwrap()
382 .as_secs()
383 + 600;
384
385 for i in 0..5 {
387 let state = format!("state_{}", i);
388 store.store(state, "google".to_string(), expiry).await.unwrap();
389 }
390
391 let result = store.store("state_5".to_string(), "google".to_string(), expiry).await;
393 assert!(
394 matches!(result, Err(crate::error::AuthError::ConfigError { .. })),
395 "expected ConfigError when store at capacity, got: {result:?}"
396 );
397 }
398
399 #[tokio::test]
400 async fn test_in_memory_state_store_cleanup_expired() {
401 let store = InMemoryStateStore::with_max_states(3);
403
404 let now = std::time::SystemTime::now()
405 .duration_since(std::time::UNIX_EPOCH)
406 .unwrap()
407 .as_secs();
408
409 for i in 0..3 {
411 let state = format!("expired_{}", i);
412 store.store(state, "google".to_string(), now - 100).await.unwrap();
413 }
414
415 let expiry = now + 600;
417 let result = store.store("valid_state".to_string(), "auth0".to_string(), expiry).await;
418 assert!(result.is_ok(), "Should succeed after cleaning up expired states");
419
420 store
422 .store("valid_state_2".to_string(), "google".to_string(), expiry)
423 .await
424 .unwrap();
425 store
426 .store("valid_state_3".to_string(), "okta".to_string(), expiry)
427 .await
428 .unwrap();
429
430 let result = store.store("valid_state_4".to_string(), "auth0".to_string(), expiry).await;
432 assert!(
433 matches!(result, Err(crate::error::AuthError::ConfigError { .. })),
434 "expected ConfigError when at capacity with valid states, got: {result:?}"
435 );
436 }
437
438 #[tokio::test]
439 async fn test_in_memory_state_store_custom_max_size() {
440 let store_small = InMemoryStateStore::with_max_states(1);
442 let store_large = InMemoryStateStore::with_max_states(100);
443
444 let expiry = std::time::SystemTime::now()
445 .duration_since(std::time::UNIX_EPOCH)
446 .unwrap()
447 .as_secs()
448 + 600;
449
450 store_small.store("s1".to_string(), "p1".to_string(), expiry).await.unwrap();
452 let result = store_small.store("s2".to_string(), "p2".to_string(), expiry).await;
453 assert!(
454 matches!(result, Err(crate::error::AuthError::ConfigError { .. })),
455 "expected ConfigError when small store at capacity, got: {result:?}"
456 );
457
458 for i in 0..50 {
460 let state = format!("state_{}", i);
461 store_large.store(state, "provider".to_string(), expiry).await.unwrap();
462 }
463 assert_eq!(store_large.states.len(), 50);
464 }
465
466 #[tokio::test]
467 async fn test_in_memory_state_store_zero_max_enforced() {
468 let store = InMemoryStateStore::with_max_states(0);
470 let expiry = std::time::SystemTime::now()
471 .duration_since(std::time::UNIX_EPOCH)
472 .unwrap()
473 .as_secs()
474 + 600;
475
476 let result = store.store("state1".to_string(), "google".to_string(), expiry).await;
478 assert!(result.is_ok(), "Should allow at least 1 state minimum");
479 }
480
481 #[cfg(feature = "redis-rate-limiting")]
482 #[tokio::test]
483 async fn test_redis_state_store_basic() {
484 let redis_url = "redis://localhost:6379";
487
488 match RedisStateStore::new(redis_url).await {
489 Ok(store) => {
490 let expiry = std::time::SystemTime::now()
491 .duration_since(std::time::UNIX_EPOCH)
492 .unwrap()
493 .as_secs()
494 + 600;
495
496 store
498 .store("redis_state_1".to_string(), "google".to_string(), expiry)
499 .await
500 .unwrap();
501
502 let (provider, _) = store.retrieve("redis_state_1").await.unwrap();
504 assert_eq!(provider, "google");
505
506 let result = store.retrieve("redis_state_1").await;
508 assert!(
509 matches!(result, Err(crate::error::AuthError::InvalidState)),
510 "expected InvalidState for consumed redis state, got: {result:?}"
511 );
512 },
513 Err(_) => {
514 eprintln!("Skipping Redis tests - Redis server not available");
516 },
517 }
518 }
519
520 #[cfg(feature = "redis-rate-limiting")]
521 #[tokio::test]
522 async fn test_redis_state_replay_prevention() {
523 let redis_url = "redis://localhost:6379";
524
525 if let Ok(store) = RedisStateStore::new(redis_url).await {
526 let expiry = std::time::SystemTime::now()
527 .duration_since(std::time::UNIX_EPOCH)
528 .unwrap()
529 .as_secs()
530 + 600;
531
532 store
533 .store("redis_replay_test".to_string(), "auth0".to_string(), expiry)
534 .await
535 .unwrap();
536
537 let result1 = store.retrieve("redis_replay_test").await;
539 assert!(result1.is_ok(), "first redis retrieval should succeed: {result1:?}");
540
541 let result2 = store.retrieve("redis_replay_test").await;
543 assert!(
544 matches!(result2, Err(crate::error::AuthError::InvalidState)),
545 "redis replay attempt should return InvalidState, got: {result2:?}"
546 );
547 }
548 }
549
550 #[cfg(feature = "redis-rate-limiting")]
551 #[tokio::test]
552 async fn test_redis_multiple_states() {
553 let redis_url = "redis://localhost:6379";
554
555 if let Ok(store) = RedisStateStore::new(redis_url).await {
556 let expiry = std::time::SystemTime::now()
557 .duration_since(std::time::UNIX_EPOCH)
558 .unwrap()
559 .as_secs()
560 + 600;
561
562 store
564 .store("redis_state_a".to_string(), "google".to_string(), expiry)
565 .await
566 .unwrap();
567 store
568 .store("redis_state_b".to_string(), "okta".to_string(), expiry)
569 .await
570 .unwrap();
571
572 let (p1, _) = store.retrieve("redis_state_a").await.unwrap();
574 assert_eq!(p1, "google");
575
576 let (p2, _) = store.retrieve("redis_state_b").await.unwrap();
577 assert_eq!(p2, "okta");
578 }
579 }
580}