atproto_oauth/storage_lru.rs
1//! LRU cache implementation for OAuth request storage.
2//!
3//! Thread-safe in-memory storage with automatic eviction of least recently used
4//! OAuth requests when capacity is reached.
5
6use std::num::NonZeroUsize;
7use std::sync::{Arc, Mutex};
8
9use anyhow::Result;
10use chrono::Utc;
11use lru::LruCache;
12
13use crate::errors::OAuthStorageError;
14use crate::storage::OAuthRequestStorage;
15use crate::workflow::OAuthRequest;
16
17/// An LRU-based implementation of `OAuthRequestStorage` that maintains a fixed-size cache of OAuth requests.
18///
19/// This storage implementation uses an LRU (Least Recently Used) cache to store OAuth requests
20/// in memory with automatic eviction of the least recently accessed entries when the cache reaches
21/// its capacity. This is ideal for scenarios where you want to cache frequently accessed OAuth requests
22/// while keeping memory usage bounded.
23///
24/// ## Thread Safety
25///
26/// This implementation is thread-safe through the use of `Arc<Mutex<LruCache<String, OAuthRequest>>>`.
27/// All operations are protected by a mutex, ensuring safe concurrent access from multiple threads
28/// or async tasks.
29///
30/// ## Cache Behavior
31///
32/// - **Get operations**: Move accessed entries to the front of the LRU order
33/// - **Insert operations**: Add new entries at the front, evicting the least recently used if at capacity
34/// - **Delete operations**: Remove entries from the cache entirely
35/// - **Capacity management**: Automatically evicts least recently used entries when capacity is exceeded
36/// - **Expiration handling**: Returns only non-expired OAuth requests based on `expires_at` timestamp
37///
38/// ## Use Cases
39///
40/// This implementation is particularly suitable for:
41/// - Caching OAuth authorization requests during the OAuth flow
42/// - Scenarios with bounded memory requirements for OAuth state management
43/// - Applications where some OAuth request lookup misses are acceptable
44/// - High-performance applications requiring fast in-memory OAuth state access
45/// - Stateless OAuth servers that need temporary request storage
46///
47/// ## Limitations
48///
49/// - **Persistence**: Data is lost when the application restarts
50/// - **Capacity**: Limited to the configured cache size
51/// - **Cache misses**: Older entries may be evicted and need OAuth flow restart
52/// - **Memory usage**: All cached OAuth data is kept in memory
53/// - **Request size**: OAuth requests with large key data consume more memory per entry
54///
55/// ## Examples
56///
57/// ```rust
58/// use atproto_oauth::storage_lru::LruOAuthRequestStorage;
59/// use atproto_oauth::storage::OAuthRequestStorage;
60/// use atproto_oauth::workflow::OAuthRequest;
61/// use std::num::NonZeroUsize;
62/// use chrono::{Utc, Duration};
63///
64/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
65/// // Create an LRU cache with capacity for 1000 OAuth requests
66/// let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(1000).unwrap());
67///
68/// // Create a sample OAuth request
69/// let request = OAuthRequest {
70/// oauth_state: "unique-state-123".to_string(),
71/// issuer: "https://pds.example.com".to_string(),
72/// authorization_server: "https://pds.example.com".to_string(),
73/// nonce: "secure-nonce".to_string(),
74/// pkce_verifier: "code-verifier".to_string(),
75/// signing_public_key: "public-key-data".to_string(),
76/// dpop_private_key: "private-key-data".to_string(),
77/// created_at: Utc::now(),
78/// expires_at: Utc::now() + Duration::minutes(10),
79/// };
80///
81/// // Store the OAuth request
82/// storage.insert_oauth_request(request.clone()).await?;
83///
84/// // Retrieve the OAuth request
85/// let retrieved = storage.get_oauth_request_by_state("unique-state-123").await?;
86/// assert_eq!(retrieved.as_ref().map(|r| &r.oauth_state), Some(&request.oauth_state));
87///
88/// // Delete the OAuth request
89/// storage.delete_oauth_request_by_state("unique-state-123").await?;
90/// let retrieved = storage.get_oauth_request_by_state("unique-state-123").await?;
91/// assert_eq!(retrieved, None);
92/// # Ok::<(), anyhow::Error>(())
93/// # }).unwrap();
94/// ```
95///
96/// ## Capacity Planning
97///
98/// When choosing the cache capacity, consider:
99/// - **Expected concurrent OAuth flows**: Size cache to hold active OAuth requests
100/// - **Memory constraints**: Each entry uses approximately (request size + state length + overhead) bytes
101/// - **OAuth request complexity**: Requests with large key data use more memory
102/// - **Access patterns**: Higher capacity reduces cache misses for concurrent OAuth flows
103/// - **Performance requirements**: Larger caches may have slightly higher lookup times
104///
105/// ```rust
106/// use atproto_oauth::storage_lru::LruOAuthRequestStorage;
107/// use std::num::NonZeroUsize;
108///
109/// // Small cache for testing or low-traffic scenarios
110/// let small_cache = LruOAuthRequestStorage::new(NonZeroUsize::new(100).unwrap());
111///
112/// // Medium cache for typical applications
113/// let medium_cache = LruOAuthRequestStorage::new(NonZeroUsize::new(10_000).unwrap());
114///
115/// // Large cache for high-traffic OAuth services
116/// let large_cache = LruOAuthRequestStorage::new(NonZeroUsize::new(100_000).unwrap());
117/// ```
118#[derive(Clone)]
119pub struct LruOAuthRequestStorage {
120 /// The LRU cache storing state -> OAuthRequest mappings, protected by a mutex for thread safety.
121 ///
122 /// We use the OAuth state parameter as the key since the primary operation is looking up
123 /// requests by state during OAuth callback processing. The cache is wrapped in Arc<Mutex<>>
124 /// to ensure thread-safe access across multiple async tasks and threads.
125 cache: Arc<Mutex<LruCache<String, OAuthRequest>>>,
126}
127
128impl LruOAuthRequestStorage {
129 /// Creates a new `LruOAuthRequestStorage` with the specified capacity.
130 ///
131 /// The capacity determines the maximum number of OAuth requests that can be stored
132 /// in the cache. When the cache reaches this capacity, the least recently used
133 /// entries will be automatically evicted to make room for new entries.
134 ///
135 /// # Arguments
136 /// * `capacity` - The maximum number of OAuth requests to store. Must be greater than 0.
137 ///
138 /// # Examples
139 ///
140 /// ```rust
141 /// use atproto_oauth::storage_lru::LruOAuthRequestStorage;
142 /// use std::num::NonZeroUsize;
143 ///
144 /// // Create a cache that can hold up to 5000 OAuth requests
145 /// let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(5000).unwrap());
146 /// ```
147 ///
148 /// # Performance Considerations
149 ///
150 /// - Larger capacities provide better cache hit rates but use more memory
151 /// - The underlying LRU implementation has O(1) access time for all operations
152 /// - Memory usage is approximately: capacity * (average_request_size + state_size + overhead)
153 /// - OAuth request size varies based on key data and metadata
154 pub fn new(capacity: NonZeroUsize) -> Self {
155 Self {
156 cache: Arc::new(Mutex::new(LruCache::new(capacity))),
157 }
158 }
159
160 /// Returns the current number of entries in the cache.
161 ///
162 /// This method provides visibility into cache usage for monitoring and debugging purposes.
163 /// The count represents the current number of OAuth requests stored in the cache.
164 ///
165 /// # Returns
166 /// The number of entries currently stored in the cache.
167 ///
168 /// # Examples
169 ///
170 /// ```rust
171 /// use atproto_oauth::storage_lru::LruOAuthRequestStorage;
172 /// use atproto_oauth::storage::OAuthRequestStorage;
173 /// use atproto_oauth::workflow::OAuthRequest;
174 /// use std::num::NonZeroUsize;
175 /// use chrono::{Utc, Duration};
176 ///
177 /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
178 /// let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(100).unwrap());
179 /// assert_eq!(storage.len(), 0);
180 ///
181 /// let request = OAuthRequest {
182 /// oauth_state: "state1".to_string(),
183 /// issuer: "https://pds.example.com".to_string(),
184 /// authorization_server: "https://pds.example.com".to_string(),
185 /// nonce: "nonce1".to_string(),
186 /// pkce_verifier: "verifier1".to_string(),
187 /// signing_public_key: "pubkey1".to_string(),
188 /// dpop_private_key: "privkey1".to_string(),
189 /// created_at: Utc::now(),
190 /// expires_at: Utc::now() + Duration::minutes(10),
191 /// };
192 /// storage.insert_oauth_request(request).await?;
193 /// assert_eq!(storage.len(), 1);
194 /// # Ok::<(), anyhow::Error>(())
195 /// # }).unwrap();
196 /// ```
197 pub fn len(&self) -> usize {
198 self.cache.lock().unwrap().len()
199 }
200
201 /// Returns whether the cache is empty.
202 ///
203 /// # Returns
204 /// `true` if the cache contains no entries, `false` otherwise.
205 ///
206 /// # Examples
207 ///
208 /// ```rust
209 /// use atproto_oauth::storage_lru::LruOAuthRequestStorage;
210 /// use std::num::NonZeroUsize;
211 ///
212 /// let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(100).unwrap());
213 /// assert!(storage.is_empty());
214 /// ```
215 pub fn is_empty(&self) -> bool {
216 self.cache.lock().unwrap().is_empty()
217 }
218
219 /// Returns the maximum capacity of the cache.
220 ///
221 /// This returns the capacity that was set when the cache was created and represents
222 /// the maximum number of OAuth requests that can be stored before eviction occurs.
223 ///
224 /// # Returns
225 /// The maximum capacity of the cache.
226 ///
227 /// # Examples
228 ///
229 /// ```rust
230 /// use atproto_oauth::storage_lru::LruOAuthRequestStorage;
231 /// use std::num::NonZeroUsize;
232 ///
233 /// let capacity = NonZeroUsize::new(500).unwrap();
234 /// let storage = LruOAuthRequestStorage::new(capacity);
235 /// assert_eq!(storage.capacity().get(), 500);
236 /// ```
237 pub fn capacity(&self) -> NonZeroUsize {
238 self.cache.lock().unwrap().cap()
239 }
240
241 /// Clears all entries from the cache.
242 ///
243 /// This method removes all OAuth requests from the cache, effectively resetting
244 /// it to an empty state. This can be useful for testing or when you need to
245 /// invalidate all cached OAuth data.
246 ///
247 /// # Examples
248 ///
249 /// ```rust
250 /// use atproto_oauth::storage_lru::LruOAuthRequestStorage;
251 /// use atproto_oauth::storage::OAuthRequestStorage;
252 /// use atproto_oauth::workflow::OAuthRequest;
253 /// use std::num::NonZeroUsize;
254 /// use chrono::{Utc, Duration};
255 ///
256 /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
257 /// let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(100).unwrap());
258 /// let request = OAuthRequest {
259 /// oauth_state: "test-state".to_string(),
260 /// issuer: "https://pds.example.com".to_string(),
261 /// authorization_server: "https://pds.example.com".to_string(),
262 /// nonce: "test-nonce".to_string(),
263 /// pkce_verifier: "test-verifier".to_string(),
264 /// signing_public_key: "test-pubkey".to_string(),
265 /// dpop_private_key: "test-privkey".to_string(),
266 /// created_at: Utc::now(),
267 /// expires_at: Utc::now() + Duration::minutes(10),
268 /// };
269 /// storage.insert_oauth_request(request).await?;
270 /// assert_eq!(storage.len(), 1);
271 ///
272 /// storage.clear();
273 /// assert_eq!(storage.len(), 0);
274 /// assert!(storage.is_empty());
275 /// # Ok::<(), anyhow::Error>(())
276 /// # }).unwrap();
277 /// ```
278 pub fn clear(&self) {
279 self.cache.lock().unwrap().clear();
280 }
281}
282
283#[async_trait::async_trait]
284impl OAuthRequestStorage for LruOAuthRequestStorage {
285 /// Retrieves an OAuth request by its state parameter from the LRU cache.
286 ///
287 /// This method looks up an OAuth authorization request using the state parameter
288 /// that is currently cached. If the state is found in the cache, the entry is moved
289 /// to the front of the LRU order (marking it as recently used) and the request is returned.
290 ///
291 /// Expired requests (where `expires_at < current_time`) are automatically filtered out
292 /// and not returned, ensuring only valid OAuth requests are accessible.
293 ///
294 /// # Arguments
295 /// * `state` - The OAuth state parameter to look up in the cache
296 ///
297 /// # Returns
298 /// * `Ok(Some(request))` - If the state is found in the cache and not expired
299 /// * `Ok(None)` - If the state is not found in the cache or the request has expired
300 /// * `Err(error)` - If an error occurs (primarily mutex poisoning, which is very rare)
301 ///
302 /// # Cache Behavior
303 ///
304 /// When a request is successfully retrieved, it's marked as recently used in the LRU order,
305 /// making it less likely to be evicted in future operations. Expired requests are treated
306 /// as cache misses.
307 ///
308 /// # Examples
309 ///
310 /// ```rust
311 /// use atproto_oauth::storage_lru::LruOAuthRequestStorage;
312 /// use atproto_oauth::storage::OAuthRequestStorage;
313 /// use atproto_oauth::workflow::OAuthRequest;
314 /// use std::num::NonZeroUsize;
315 /// use chrono::{Utc, Duration};
316 ///
317 /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
318 /// let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(100).unwrap());
319 ///
320 /// // Cache miss - state not in cache
321 /// let request = storage.get_oauth_request_by_state("unknown-state").await?;
322 /// assert_eq!(request, None);
323 ///
324 /// // Add request to cache
325 /// let oauth_req = OAuthRequest {
326 /// oauth_state: "valid-state-123".to_string(),
327 /// issuer: "https://pds.example.com".to_string(),
328 /// authorization_server: "https://pds.example.com".to_string(),
329 /// nonce: "secure-nonce".to_string(),
330 /// pkce_verifier: "code-verifier".to_string(),
331 /// signing_public_key: "public-key-data".to_string(),
332 /// dpop_private_key: "private-key-data".to_string(),
333 /// created_at: Utc::now(),
334 /// expires_at: Utc::now() + Duration::minutes(10),
335 /// };
336 /// storage.insert_oauth_request(oauth_req.clone()).await?;
337 ///
338 /// // Cache hit - state found in cache
339 /// let request = storage.get_oauth_request_by_state("valid-state-123").await?;
340 /// assert_eq!(request.as_ref().map(|r| &r.oauth_state), Some(&oauth_req.oauth_state));
341 /// # Ok::<(), anyhow::Error>(())
342 /// # }).unwrap();
343 /// ```
344 async fn get_oauth_request_by_state(&self, state: &str) -> Result<Option<OAuthRequest>> {
345 let mut cache = self
346 .cache
347 .lock()
348 .map_err(|e| OAuthStorageError::CacheLockFailedGet {
349 details: e.to_string(),
350 })?;
351
352 if let Some(request) = cache.get(state) {
353 // Check if the request has expired
354 let now = Utc::now();
355 if request.expires_at > now {
356 Ok(Some(request.clone()))
357 } else {
358 // Request has expired, remove it from cache and return None
359 cache.pop(state);
360 Ok(None)
361 }
362 } else {
363 Ok(None)
364 }
365 }
366
367 /// Deletes an OAuth request from the LRU cache by its state parameter.
368 ///
369 /// This method removes an OAuth authorization request from the cache using its state parameter.
370 /// If the state exists in the cache, it is removed entirely, freeing up space for new entries.
371 ///
372 /// # Arguments
373 /// * `state` - The OAuth state parameter identifying the request to delete
374 ///
375 /// # Returns
376 /// * `Ok(())` - If the OAuth request was successfully deleted or didn't exist
377 /// * `Err(error)` - If an error occurs (primarily mutex poisoning, which is very rare)
378 ///
379 /// # Cache Behavior
380 ///
381 /// - If the state exists in the cache, it is removed completely
382 /// - If the state doesn't exist, the operation succeeds without error
383 /// - Removing entries frees up capacity for new entries
384 ///
385 /// # Examples
386 ///
387 /// ```rust
388 /// use atproto_oauth::storage_lru::LruOAuthRequestStorage;
389 /// use atproto_oauth::storage::OAuthRequestStorage;
390 /// use atproto_oauth::workflow::OAuthRequest;
391 /// use std::num::NonZeroUsize;
392 /// use chrono::{Utc, Duration};
393 ///
394 /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
395 /// let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(100).unwrap());
396 ///
397 /// // Add an OAuth request
398 /// let request = OAuthRequest {
399 /// oauth_state: "deletable-state".to_string(),
400 /// issuer: "https://pds.example.com".to_string(),
401 /// authorization_server: "https://pds.example.com".to_string(),
402 /// nonce: "test-nonce".to_string(),
403 /// pkce_verifier: "test-verifier".to_string(),
404 /// signing_public_key: "test-pubkey".to_string(),
405 /// dpop_private_key: "test-privkey".to_string(),
406 /// created_at: Utc::now(),
407 /// expires_at: Utc::now() + Duration::minutes(10),
408 /// };
409 /// storage.insert_oauth_request(request).await?;
410 /// let retrieved = storage.get_oauth_request_by_state("deletable-state").await?;
411 /// assert!(retrieved.is_some());
412 ///
413 /// // Delete the request
414 /// storage.delete_oauth_request_by_state("deletable-state").await?;
415 /// let retrieved = storage.get_oauth_request_by_state("deletable-state").await?;
416 /// assert_eq!(retrieved, None);
417 ///
418 /// // Deleting non-existent entry is safe
419 /// storage.delete_oauth_request_by_state("non-existent-state").await?;
420 /// # Ok::<(), anyhow::Error>(())
421 /// # }).unwrap();
422 /// ```
423 async fn delete_oauth_request_by_state(&self, state: &str) -> Result<()> {
424 let mut cache =
425 self.cache
426 .lock()
427 .map_err(|e| OAuthStorageError::CacheLockFailedDelete {
428 details: e.to_string(),
429 })?;
430
431 cache.pop(state);
432 Ok(())
433 }
434
435 /// Inserts a new OAuth request into the LRU cache.
436 ///
437 /// This method stores an OAuth authorization request in the cache. If a request with the
438 /// same state already exists in the cache, it is replaced and the entry is moved to the front
439 /// of the LRU order. If the state is new and the cache is at capacity, the least recently
440 /// used entry is evicted to make room.
441 ///
442 /// # Arguments
443 /// * `request` - The complete OAuth request to store. The request's `oauth_state` field
444 /// will be used as the storage key.
445 ///
446 /// # Returns
447 /// * `Ok(())` - If the OAuth request was successfully stored
448 /// * `Err(error)` - If an error occurs (primarily mutex poisoning, which is very rare)
449 ///
450 /// # Cache Behavior
451 ///
452 /// - If the cache is at capacity and this is a new state, the least recently used entry is evicted
453 /// - The new or updated entry is placed at the front of the LRU order
454 /// - Existing entries with the same state are replaced in place
455 ///
456 /// # Examples
457 ///
458 /// ```rust
459 /// use atproto_oauth::storage_lru::LruOAuthRequestStorage;
460 /// use atproto_oauth::storage::OAuthRequestStorage;
461 /// use atproto_oauth::workflow::OAuthRequest;
462 /// use std::num::NonZeroUsize;
463 /// use chrono::{Utc, Duration};
464 ///
465 /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
466 /// let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(2).unwrap()); // Small cache for demo
467 ///
468 /// // Add first request
469 /// let req1 = OAuthRequest {
470 /// oauth_state: "state1".to_string(),
471 /// issuer: "https://pds.example.com".to_string(),
472 /// authorization_server: "https://pds.example.com".to_string(),
473 /// nonce: "nonce1".to_string(),
474 /// pkce_verifier: "verifier1".to_string(),
475 /// signing_public_key: "pubkey1".to_string(),
476 /// dpop_private_key: "privkey1".to_string(),
477 /// created_at: Utc::now(),
478 /// expires_at: Utc::now() + Duration::minutes(10),
479 /// };
480 /// storage.insert_oauth_request(req1).await?;
481 /// assert_eq!(storage.len(), 1);
482 ///
483 /// // Add second request
484 /// let req2 = OAuthRequest {
485 /// oauth_state: "state2".to_string(),
486 /// issuer: "https://pds.example.com".to_string(),
487 /// authorization_server: "https://pds.example.com".to_string(),
488 /// nonce: "nonce2".to_string(),
489 /// pkce_verifier: "verifier2".to_string(),
490 /// signing_public_key: "pubkey2".to_string(),
491 /// dpop_private_key: "privkey2".to_string(),
492 /// created_at: Utc::now(),
493 /// expires_at: Utc::now() + Duration::minutes(10),
494 /// };
495 /// storage.insert_oauth_request(req2).await?;
496 /// assert_eq!(storage.len(), 2);
497 ///
498 /// // Add third request - this will evict the least recently used entry (state1)
499 /// let req3 = OAuthRequest {
500 /// oauth_state: "state3".to_string(),
501 /// issuer: "https://pds.example.com".to_string(),
502 /// authorization_server: "https://pds.example.com".to_string(),
503 /// nonce: "nonce3".to_string(),
504 /// pkce_verifier: "verifier3".to_string(),
505 /// signing_public_key: "pubkey3".to_string(),
506 /// dpop_private_key: "privkey3".to_string(),
507 /// created_at: Utc::now(),
508 /// expires_at: Utc::now() + Duration::minutes(10),
509 /// };
510 /// storage.insert_oauth_request(req3).await?;
511 /// assert_eq!(storage.len(), 2); // Still at capacity
512 ///
513 /// // state1 should be evicted
514 /// let request = storage.get_oauth_request_by_state("state1").await?;
515 /// assert_eq!(request, None);
516 ///
517 /// // state2 and state3 should still be present
518 /// let req2_retrieved = storage.get_oauth_request_by_state("state2").await?;
519 /// let req3_retrieved = storage.get_oauth_request_by_state("state3").await?;
520 /// assert!(req2_retrieved.is_some());
521 /// assert!(req3_retrieved.is_some());
522 /// # Ok::<(), anyhow::Error>(())
523 /// # }).unwrap();
524 /// ```
525 async fn insert_oauth_request(&self, request: OAuthRequest) -> Result<()> {
526 let mut cache =
527 self.cache
528 .lock()
529 .map_err(|e| OAuthStorageError::CacheLockFailedInsert {
530 details: e.to_string(),
531 })?;
532
533 cache.put(request.oauth_state.clone(), request);
534 Ok(())
535 }
536
537 /// Clears all expired OAuth requests from the LRU cache.
538 ///
539 /// This method performs cleanup by removing OAuth requests that have passed their
540 /// expiration time (`expires_at <= current_time`). This is important for maintaining
541 /// security and preventing storage bloat from abandoned OAuth flows.
542 ///
543 /// # Returns
544 /// * `Ok(count)` - The number of expired requests that were successfully removed
545 /// * `Err(error)` - If an error occurs (primarily mutex poisoning, which is very rare)
546 ///
547 /// # Cache Behavior
548 ///
549 /// - Compares each request's `expires_at` against the current time (`Utc::now()`)
550 /// - Removes all requests where `expires_at <= current_time`
551 /// - Maintains LRU order for remaining valid requests
552 /// - Frees up capacity for new OAuth requests
553 ///
554 /// # Examples
555 ///
556 /// ```rust
557 /// use atproto_oauth::storage_lru::LruOAuthRequestStorage;
558 /// use atproto_oauth::storage::OAuthRequestStorage;
559 /// use atproto_oauth::workflow::OAuthRequest;
560 /// use std::num::NonZeroUsize;
561 /// use chrono::{Utc, Duration};
562 ///
563 /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
564 /// let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(100).unwrap());
565 ///
566 /// // Add a request that will expire soon
567 /// let expired_request = OAuthRequest {
568 /// oauth_state: "soon-expired".to_string(),
569 /// issuer: "https://pds.example.com".to_string(),
570 /// authorization_server: "https://pds.example.com".to_string(),
571 /// nonce: "nonce1".to_string(),
572 /// pkce_verifier: "verifier1".to_string(),
573 /// signing_public_key: "pubkey1".to_string(),
574 /// dpop_private_key: "privkey1".to_string(),
575 /// created_at: Utc::now() - Duration::minutes(20),
576 /// expires_at: Utc::now() - Duration::minutes(10), // Already expired
577 /// };
578 /// storage.insert_oauth_request(expired_request).await?;
579 ///
580 /// // Add a valid request
581 /// let valid_request = OAuthRequest {
582 /// oauth_state: "still-valid".to_string(),
583 /// issuer: "https://pds.example.com".to_string(),
584 /// authorization_server: "https://pds.example.com".to_string(),
585 /// nonce: "nonce2".to_string(),
586 /// pkce_verifier: "verifier2".to_string(),
587 /// signing_public_key: "pubkey2".to_string(),
588 /// dpop_private_key: "privkey2".to_string(),
589 /// created_at: Utc::now(),
590 /// expires_at: Utc::now() + Duration::minutes(10), // Still valid
591 /// };
592 /// storage.insert_oauth_request(valid_request).await?;
593 ///
594 /// assert_eq!(storage.len(), 2);
595 ///
596 /// // Clean up expired requests
597 /// let removed_count = storage.clear_expired_oauth_requests().await?;
598 /// assert_eq!(removed_count, 1); // One expired request removed
599 /// assert_eq!(storage.len(), 1); // One valid request remains
600 ///
601 /// // Verify the valid request is still accessible
602 /// let remaining = storage.get_oauth_request_by_state("still-valid").await?;
603 /// assert!(remaining.is_some());
604 /// # Ok::<(), anyhow::Error>(())
605 /// # }).unwrap();
606 /// ```
607 async fn clear_expired_oauth_requests(&self) -> Result<u64> {
608 let mut cache =
609 self.cache
610 .lock()
611 .map_err(|e| OAuthStorageError::CacheLockFailedCleanup {
612 details: e.to_string(),
613 })?;
614
615 let now = Utc::now();
616
617 // Collect keys of expired requests
618 let expired_keys: Vec<String> = cache
619 .iter()
620 .filter_map(|(key, request)| {
621 if request.expires_at <= now {
622 Some(key.clone())
623 } else {
624 None
625 }
626 })
627 .collect();
628
629 // Remove expired requests
630 for key in &expired_keys {
631 cache.pop(key);
632 }
633
634 Ok(expired_keys.len() as u64)
635 }
636}
637
638#[cfg(test)]
639mod tests {
640 use super::*;
641 use chrono::{Duration, Utc};
642 use std::num::NonZeroUsize;
643
644 fn create_test_oauth_request(state: &str, issuer: &str, _did: &str) -> OAuthRequest {
645 OAuthRequest {
646 oauth_state: state.to_string(),
647 issuer: issuer.to_string(),
648 authorization_server: issuer.to_string(),
649 nonce: format!("nonce-{}", state),
650 pkce_verifier: format!("verifier-{}", state),
651 signing_public_key: format!("pubkey-{}", state),
652 dpop_private_key: format!("privkey-{}", state),
653 created_at: Utc::now(),
654 expires_at: Utc::now() + Duration::minutes(10),
655 }
656 }
657
658 fn create_expired_oauth_request(state: &str, issuer: &str, _did: &str) -> OAuthRequest {
659 OAuthRequest {
660 oauth_state: state.to_string(),
661 issuer: issuer.to_string(),
662 authorization_server: issuer.to_string(),
663 nonce: format!("nonce-{}", state),
664 pkce_verifier: format!("verifier-{}", state),
665 signing_public_key: format!("pubkey-{}", state),
666 dpop_private_key: format!("privkey-{}", state),
667 created_at: Utc::now() - Duration::minutes(20),
668 expires_at: Utc::now() - Duration::minutes(10), // Already expired
669 }
670 }
671
672 #[tokio::test]
673 async fn test_new_storage() {
674 let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(100).unwrap());
675 assert_eq!(storage.len(), 0);
676 assert!(storage.is_empty());
677 assert_eq!(storage.capacity().get(), 100);
678 }
679
680 #[tokio::test]
681 async fn test_basic_operations() -> Result<()> {
682 let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(10).unwrap());
683
684 // Test get on empty cache
685 let result = storage.get_oauth_request_by_state("unknown-state").await?;
686 assert_eq!(result, None);
687
688 // Test insert and get
689 let request =
690 create_test_oauth_request("test-state", "https://pds.example.com", "did:plc:test");
691 storage.insert_oauth_request(request.clone()).await?;
692 let result = storage.get_oauth_request_by_state("test-state").await?;
693 assert!(result.is_some());
694 assert_eq!(result.as_ref().unwrap().oauth_state, request.oauth_state);
695 assert_eq!(storage.len(), 1);
696
697 // Test update existing
698 let updated_request = create_test_oauth_request(
699 "test-state",
700 "https://updated.example.com",
701 "did:plc:updated",
702 );
703 storage
704 .insert_oauth_request(updated_request.clone())
705 .await?;
706 let result = storage.get_oauth_request_by_state("test-state").await?;
707 assert!(result.is_some());
708 assert_eq!(result.as_ref().unwrap().issuer, updated_request.issuer);
709 assert_eq!(storage.len(), 1); // Should still be 1
710
711 // Test delete
712 storage.delete_oauth_request_by_state("test-state").await?;
713 let result = storage.get_oauth_request_by_state("test-state").await?;
714 assert_eq!(result, None);
715 assert_eq!(storage.len(), 0);
716
717 Ok(())
718 }
719
720 #[tokio::test]
721 async fn test_expiration_handling() -> Result<()> {
722 let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(10).unwrap());
723
724 // Insert an expired request
725 let expired_request = create_expired_oauth_request(
726 "expired-state",
727 "https://pds.example.com",
728 "did:plc:expired",
729 );
730 storage.insert_oauth_request(expired_request).await?;
731 assert_eq!(storage.len(), 1);
732
733 // Try to get the expired request - should return None and remove from cache
734 let result = storage.get_oauth_request_by_state("expired-state").await?;
735 assert_eq!(result, None);
736 assert_eq!(storage.len(), 0); // Should be removed from cache
737
738 Ok(())
739 }
740
741 #[tokio::test]
742 async fn test_lru_eviction() -> Result<()> {
743 let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(2).unwrap());
744
745 // Fill cache to capacity
746 let req1 = create_test_oauth_request("state1", "https://pds.example.com", "did:plc:user1");
747 let req2 = create_test_oauth_request("state2", "https://pds.example.com", "did:plc:user2");
748 storage.insert_oauth_request(req1.clone()).await?;
749 storage.insert_oauth_request(req2).await?;
750 assert_eq!(storage.len(), 2);
751
752 // Access req1 to make it recently used
753 let _ = storage.get_oauth_request_by_state("state1").await?;
754
755 // Add req3, which should evict req2 (least recently used)
756 let req3 = create_test_oauth_request("state3", "https://pds.example.com", "did:plc:user3");
757 storage.insert_oauth_request(req3.clone()).await?;
758 assert_eq!(storage.len(), 2);
759
760 // req1 and req3 should be present, req2 should be evicted
761 let result1 = storage.get_oauth_request_by_state("state1").await?;
762 assert!(result1.is_some());
763 assert_eq!(result1.unwrap().oauth_state, req1.oauth_state);
764
765 let result3 = storage.get_oauth_request_by_state("state3").await?;
766 assert!(result3.is_some());
767 assert_eq!(result3.unwrap().oauth_state, req3.oauth_state);
768
769 assert_eq!(storage.get_oauth_request_by_state("state2").await?, None);
770
771 Ok(())
772 }
773
774 #[tokio::test]
775 async fn test_clear() -> Result<()> {
776 let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(10).unwrap());
777
778 // Add some entries
779 let req1 = create_test_oauth_request("state1", "https://pds.example.com", "did:plc:user1");
780 let req2 = create_test_oauth_request("state2", "https://pds.example.com", "did:plc:user2");
781 storage.insert_oauth_request(req1).await?;
782 storage.insert_oauth_request(req2).await?;
783 assert_eq!(storage.len(), 2);
784
785 // Clear cache
786 storage.clear();
787 assert_eq!(storage.len(), 0);
788 assert!(storage.is_empty());
789
790 // Verify entries are gone
791 assert_eq!(storage.get_oauth_request_by_state("state1").await?, None);
792 assert_eq!(storage.get_oauth_request_by_state("state2").await?, None);
793
794 Ok(())
795 }
796
797 #[tokio::test]
798 async fn test_clear_expired_requests() -> Result<()> {
799 let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(10).unwrap());
800
801 // Add expired and valid requests
802 let expired1 =
803 create_expired_oauth_request("expired1", "https://pds.example.com", "did:plc:expired1");
804 let expired2 =
805 create_expired_oauth_request("expired2", "https://pds.example.com", "did:plc:expired2");
806 let valid1 =
807 create_test_oauth_request("valid1", "https://pds.example.com", "did:plc:valid1");
808 let valid2 =
809 create_test_oauth_request("valid2", "https://pds.example.com", "did:plc:valid2");
810
811 storage.insert_oauth_request(expired1).await?;
812 storage.insert_oauth_request(valid1).await?;
813 storage.insert_oauth_request(expired2).await?;
814 storage.insert_oauth_request(valid2).await?;
815 assert_eq!(storage.len(), 4);
816
817 // Clear expired requests
818 let removed_count = storage.clear_expired_oauth_requests().await?;
819 assert_eq!(removed_count, 2); // Two expired requests removed
820 assert_eq!(storage.len(), 2); // Two valid requests remain
821
822 // Verify only valid requests remain
823 assert!(
824 storage
825 .get_oauth_request_by_state("valid1")
826 .await?
827 .is_some()
828 );
829 assert!(
830 storage
831 .get_oauth_request_by_state("valid2")
832 .await?
833 .is_some()
834 );
835 assert_eq!(storage.get_oauth_request_by_state("expired1").await?, None);
836 assert_eq!(storage.get_oauth_request_by_state("expired2").await?, None);
837
838 Ok(())
839 }
840
841 #[tokio::test]
842 async fn test_delete_nonexistent() -> Result<()> {
843 let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(10).unwrap());
844
845 // Deleting non-existent entry should not error
846 storage
847 .delete_oauth_request_by_state("non-existent-state")
848 .await?;
849 assert_eq!(storage.len(), 0);
850
851 Ok(())
852 }
853
854 #[tokio::test]
855 async fn test_thread_safety() -> Result<()> {
856 let storage = Arc::new(LruOAuthRequestStorage::new(NonZeroUsize::new(100).unwrap()));
857 let mut handles = Vec::new();
858
859 // Spawn multiple tasks that concurrently access the storage
860 for i in 0..10 {
861 let storage_clone = Arc::clone(&storage);
862 let handle = tokio::spawn(async move {
863 let state = format!("state{}", i);
864 let issuer = format!("https://pds{}.example.com", i);
865 let did = format!("did:plc:user{}", i);
866 let request = create_test_oauth_request(&state, &issuer, &did);
867
868 // Insert a request
869 storage_clone.insert_oauth_request(request.clone()).await?;
870
871 // Get the request back
872 let result = storage_clone.get_oauth_request_by_state(&state).await?;
873 assert!(result.is_some());
874 assert_eq!(result.unwrap().oauth_state, request.oauth_state);
875
876 // Delete the request
877 storage_clone.delete_oauth_request_by_state(&state).await?;
878 let result = storage_clone.get_oauth_request_by_state(&state).await?;
879 assert_eq!(result, None);
880
881 Ok::<(), anyhow::Error>(())
882 });
883 handles.push(handle);
884 }
885
886 // Wait for all tasks to complete
887 for handle in handles {
888 handle.await??;
889 }
890
891 // Storage should be empty after all deletions
892 assert_eq!(storage.len(), 0);
893 Ok(())
894 }
895}