atproto_oauth/storage_lru.rs
1//! LRU-based implementation of the `OAuthRequestStorage` trait.
2//!
3//! This module provides an in-memory implementation of OAuth request storage using an LRU (Least Recently Used)
4//! cache. This is useful for scenarios where you want to cache OAuth requests with automatic
5//! eviction of least recently used entries when the cache reaches its capacity limit.
6//!
7//! The implementation is thread-safe and can be used in concurrent environments where multiple
8//! async tasks need to access the OAuth request storage simultaneously.
9
10use std::num::NonZeroUsize;
11use std::sync::{Arc, Mutex};
12
13use anyhow::Result;
14use chrono::Utc;
15use lru::LruCache;
16
17use crate::errors::OAuthStorageError;
18use crate::storage::OAuthRequestStorage;
19use crate::workflow::OAuthRequest;
20
21/// An LRU-based implementation of `OAuthRequestStorage` that maintains a fixed-size cache of OAuth requests.
22///
23/// This storage implementation uses an LRU (Least Recently Used) cache to store OAuth requests
24/// in memory with automatic eviction of the least recently accessed entries when the cache reaches
25/// its capacity. This is ideal for scenarios where you want to cache frequently accessed OAuth requests
26/// while keeping memory usage bounded.
27///
28/// ## Thread Safety
29///
30/// This implementation is thread-safe through the use of `Arc<Mutex<LruCache<String, OAuthRequest>>>`.
31/// All operations are protected by a mutex, ensuring safe concurrent access from multiple threads
32/// or async tasks.
33///
34/// ## Cache Behavior
35///
36/// - **Get operations**: Move accessed entries to the front of the LRU order
37/// - **Insert operations**: Add new entries at the front, evicting the least recently used if at capacity
38/// - **Delete operations**: Remove entries from the cache entirely
39/// - **Capacity management**: Automatically evicts least recently used entries when capacity is exceeded
40/// - **Expiration handling**: Returns only non-expired OAuth requests based on `expires_at` timestamp
41///
42/// ## Use Cases
43///
44/// This implementation is particularly suitable for:
45/// - Caching OAuth authorization requests during the OAuth flow
46/// - Scenarios with bounded memory requirements for OAuth state management
47/// - Applications where some OAuth request lookup misses are acceptable
48/// - High-performance applications requiring fast in-memory OAuth state access
49/// - Stateless OAuth servers that need temporary request storage
50///
51/// ## Limitations
52///
53/// - **Persistence**: Data is lost when the application restarts
54/// - **Capacity**: Limited to the configured cache size
55/// - **Cache misses**: Older entries may be evicted and need OAuth flow restart
56/// - **Memory usage**: All cached OAuth data is kept in memory
57/// - **Request size**: OAuth requests with large key data consume more memory per entry
58///
59/// ## Examples
60///
61/// ```rust
62/// use atproto_oauth::storage_lru::LruOAuthRequestStorage;
63/// use atproto_oauth::storage::OAuthRequestStorage;
64/// use atproto_oauth::workflow::OAuthRequest;
65/// use std::num::NonZeroUsize;
66/// use chrono::{Utc, Duration};
67///
68/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
69/// // Create an LRU cache with capacity for 1000 OAuth requests
70/// let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(1000).unwrap());
71///
72/// // Create a sample OAuth request
73/// let request = OAuthRequest {
74/// oauth_state: "unique-state-123".to_string(),
75/// issuer: "https://pds.example.com".to_string(),
76/// did: "did:plc:bv6ggog3tya2z3vxsub7hnal".to_string(),
77/// nonce: "secure-nonce".to_string(),
78/// pkce_verifier: "code-verifier".to_string(),
79/// signing_public_key: "public-key-data".to_string(),
80/// dpop_private_key: "private-key-data".to_string(),
81/// created_at: Utc::now(),
82/// expires_at: Utc::now() + Duration::minutes(10),
83/// };
84///
85/// // Store the OAuth request
86/// storage.insert_oauth_request(request.clone()).await?;
87///
88/// // Retrieve the OAuth request
89/// let retrieved = storage.get_oauth_request_by_state("unique-state-123").await?;
90/// assert_eq!(retrieved.as_ref().map(|r| &r.oauth_state), Some(&request.oauth_state));
91///
92/// // Delete the OAuth request
93/// storage.delete_oauth_request_by_state("unique-state-123").await?;
94/// let retrieved = storage.get_oauth_request_by_state("unique-state-123").await?;
95/// assert_eq!(retrieved, None);
96/// # Ok::<(), anyhow::Error>(())
97/// # }).unwrap();
98/// ```
99///
100/// ## Capacity Planning
101///
102/// When choosing the cache capacity, consider:
103/// - **Expected concurrent OAuth flows**: Size cache to hold active OAuth requests
104/// - **Memory constraints**: Each entry uses approximately (request size + state length + overhead) bytes
105/// - **OAuth request complexity**: Requests with large key data use more memory
106/// - **Access patterns**: Higher capacity reduces cache misses for concurrent OAuth flows
107/// - **Performance requirements**: Larger caches may have slightly higher lookup times
108///
109/// ```rust
110/// use atproto_oauth::storage_lru::LruOAuthRequestStorage;
111/// use std::num::NonZeroUsize;
112///
113/// // Small cache for testing or low-traffic scenarios
114/// let small_cache = LruOAuthRequestStorage::new(NonZeroUsize::new(100).unwrap());
115///
116/// // Medium cache for typical applications
117/// let medium_cache = LruOAuthRequestStorage::new(NonZeroUsize::new(10_000).unwrap());
118///
119/// // Large cache for high-traffic OAuth services
120/// let large_cache = LruOAuthRequestStorage::new(NonZeroUsize::new(100_000).unwrap());
121/// ```
122#[derive(Clone)]
123pub struct LruOAuthRequestStorage {
124 /// The LRU cache storing state -> OAuthRequest mappings, protected by a mutex for thread safety.
125 ///
126 /// We use the OAuth state parameter as the key since the primary operation is looking up
127 /// requests by state during OAuth callback processing. The cache is wrapped in Arc<Mutex<>>
128 /// to ensure thread-safe access across multiple async tasks and threads.
129 cache: Arc<Mutex<LruCache<String, OAuthRequest>>>,
130}
131
132impl LruOAuthRequestStorage {
133 /// Creates a new `LruOAuthRequestStorage` with the specified capacity.
134 ///
135 /// The capacity determines the maximum number of OAuth requests that can be stored
136 /// in the cache. When the cache reaches this capacity, the least recently used
137 /// entries will be automatically evicted to make room for new entries.
138 ///
139 /// # Arguments
140 /// * `capacity` - The maximum number of OAuth requests to store. Must be greater than 0.
141 ///
142 /// # Examples
143 ///
144 /// ```rust
145 /// use atproto_oauth::storage_lru::LruOAuthRequestStorage;
146 /// use std::num::NonZeroUsize;
147 ///
148 /// // Create a cache that can hold up to 5000 OAuth requests
149 /// let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(5000).unwrap());
150 /// ```
151 ///
152 /// # Performance Considerations
153 ///
154 /// - Larger capacities provide better cache hit rates but use more memory
155 /// - The underlying LRU implementation has O(1) access time for all operations
156 /// - Memory usage is approximately: capacity * (average_request_size + state_size + overhead)
157 /// - OAuth request size varies based on key data and metadata
158 pub fn new(capacity: NonZeroUsize) -> Self {
159 Self {
160 cache: Arc::new(Mutex::new(LruCache::new(capacity))),
161 }
162 }
163
164 /// Returns the current number of entries in the cache.
165 ///
166 /// This method provides visibility into cache usage for monitoring and debugging purposes.
167 /// The count represents the current number of OAuth requests stored in the cache.
168 ///
169 /// # Returns
170 /// The number of entries currently stored in the cache.
171 ///
172 /// # Examples
173 ///
174 /// ```rust
175 /// use atproto_oauth::storage_lru::LruOAuthRequestStorage;
176 /// use atproto_oauth::storage::OAuthRequestStorage;
177 /// use atproto_oauth::workflow::OAuthRequest;
178 /// use std::num::NonZeroUsize;
179 /// use chrono::{Utc, Duration};
180 ///
181 /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
182 /// let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(100).unwrap());
183 /// assert_eq!(storage.len(), 0);
184 ///
185 /// let request = OAuthRequest {
186 /// oauth_state: "state1".to_string(),
187 /// issuer: "https://pds.example.com".to_string(),
188 /// did: "did:plc:example1".to_string(),
189 /// nonce: "nonce1".to_string(),
190 /// pkce_verifier: "verifier1".to_string(),
191 /// signing_public_key: "pubkey1".to_string(),
192 /// dpop_private_key: "privkey1".to_string(),
193 /// created_at: Utc::now(),
194 /// expires_at: Utc::now() + Duration::minutes(10),
195 /// };
196 /// storage.insert_oauth_request(request).await?;
197 /// assert_eq!(storage.len(), 1);
198 /// # Ok::<(), anyhow::Error>(())
199 /// # }).unwrap();
200 /// ```
201 pub fn len(&self) -> usize {
202 self.cache.lock().unwrap().len()
203 }
204
205 /// Returns whether the cache is empty.
206 ///
207 /// # Returns
208 /// `true` if the cache contains no entries, `false` otherwise.
209 ///
210 /// # Examples
211 ///
212 /// ```rust
213 /// use atproto_oauth::storage_lru::LruOAuthRequestStorage;
214 /// use std::num::NonZeroUsize;
215 ///
216 /// let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(100).unwrap());
217 /// assert!(storage.is_empty());
218 /// ```
219 pub fn is_empty(&self) -> bool {
220 self.cache.lock().unwrap().is_empty()
221 }
222
223 /// Returns the maximum capacity of the cache.
224 ///
225 /// This returns the capacity that was set when the cache was created and represents
226 /// the maximum number of OAuth requests that can be stored before eviction occurs.
227 ///
228 /// # Returns
229 /// The maximum capacity of the cache.
230 ///
231 /// # Examples
232 ///
233 /// ```rust
234 /// use atproto_oauth::storage_lru::LruOAuthRequestStorage;
235 /// use std::num::NonZeroUsize;
236 ///
237 /// let capacity = NonZeroUsize::new(500).unwrap();
238 /// let storage = LruOAuthRequestStorage::new(capacity);
239 /// assert_eq!(storage.capacity().get(), 500);
240 /// ```
241 pub fn capacity(&self) -> NonZeroUsize {
242 self.cache.lock().unwrap().cap()
243 }
244
245 /// Clears all entries from the cache.
246 ///
247 /// This method removes all OAuth requests from the cache, effectively resetting
248 /// it to an empty state. This can be useful for testing or when you need to
249 /// invalidate all cached OAuth data.
250 ///
251 /// # Examples
252 ///
253 /// ```rust
254 /// use atproto_oauth::storage_lru::LruOAuthRequestStorage;
255 /// use atproto_oauth::storage::OAuthRequestStorage;
256 /// use atproto_oauth::workflow::OAuthRequest;
257 /// use std::num::NonZeroUsize;
258 /// use chrono::{Utc, Duration};
259 ///
260 /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
261 /// let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(100).unwrap());
262 /// let request = OAuthRequest {
263 /// oauth_state: "test-state".to_string(),
264 /// issuer: "https://pds.example.com".to_string(),
265 /// did: "did:plc:example".to_string(),
266 /// nonce: "test-nonce".to_string(),
267 /// pkce_verifier: "test-verifier".to_string(),
268 /// signing_public_key: "test-pubkey".to_string(),
269 /// dpop_private_key: "test-privkey".to_string(),
270 /// created_at: Utc::now(),
271 /// expires_at: Utc::now() + Duration::minutes(10),
272 /// };
273 /// storage.insert_oauth_request(request).await?;
274 /// assert_eq!(storage.len(), 1);
275 ///
276 /// storage.clear();
277 /// assert_eq!(storage.len(), 0);
278 /// assert!(storage.is_empty());
279 /// # Ok::<(), anyhow::Error>(())
280 /// # }).unwrap();
281 /// ```
282 pub fn clear(&self) {
283 self.cache.lock().unwrap().clear();
284 }
285}
286
287#[async_trait::async_trait]
288impl OAuthRequestStorage for LruOAuthRequestStorage {
289 /// Retrieves an OAuth request by its state parameter from the LRU cache.
290 ///
291 /// This method looks up an OAuth authorization request using the state parameter
292 /// that is currently cached. If the state is found in the cache, the entry is moved
293 /// to the front of the LRU order (marking it as recently used) and the request is returned.
294 ///
295 /// Expired requests (where `expires_at < current_time`) are automatically filtered out
296 /// and not returned, ensuring only valid OAuth requests are accessible.
297 ///
298 /// # Arguments
299 /// * `state` - The OAuth state parameter to look up in the cache
300 ///
301 /// # Returns
302 /// * `Ok(Some(request))` - If the state is found in the cache and not expired
303 /// * `Ok(None)` - If the state is not found in the cache or the request has expired
304 /// * `Err(error)` - If an error occurs (primarily mutex poisoning, which is very rare)
305 ///
306 /// # Cache Behavior
307 ///
308 /// When a request is successfully retrieved, it's marked as recently used in the LRU order,
309 /// making it less likely to be evicted in future operations. Expired requests are treated
310 /// as cache misses.
311 ///
312 /// # Examples
313 ///
314 /// ```rust
315 /// use atproto_oauth::storage_lru::LruOAuthRequestStorage;
316 /// use atproto_oauth::storage::OAuthRequestStorage;
317 /// use atproto_oauth::workflow::OAuthRequest;
318 /// use std::num::NonZeroUsize;
319 /// use chrono::{Utc, Duration};
320 ///
321 /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
322 /// let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(100).unwrap());
323 ///
324 /// // Cache miss - state not in cache
325 /// let request = storage.get_oauth_request_by_state("unknown-state").await?;
326 /// assert_eq!(request, None);
327 ///
328 /// // Add request to cache
329 /// let oauth_req = OAuthRequest {
330 /// oauth_state: "valid-state-123".to_string(),
331 /// issuer: "https://pds.example.com".to_string(),
332 /// did: "did:plc:bv6ggog3tya2z3vxsub7hnal".to_string(),
333 /// nonce: "secure-nonce".to_string(),
334 /// pkce_verifier: "code-verifier".to_string(),
335 /// signing_public_key: "public-key-data".to_string(),
336 /// dpop_private_key: "private-key-data".to_string(),
337 /// created_at: Utc::now(),
338 /// expires_at: Utc::now() + Duration::minutes(10),
339 /// };
340 /// storage.insert_oauth_request(oauth_req.clone()).await?;
341 ///
342 /// // Cache hit - state found in cache
343 /// let request = storage.get_oauth_request_by_state("valid-state-123").await?;
344 /// assert_eq!(request.as_ref().map(|r| &r.oauth_state), Some(&oauth_req.oauth_state));
345 /// # Ok::<(), anyhow::Error>(())
346 /// # }).unwrap();
347 /// ```
348 async fn get_oauth_request_by_state(&self, state: &str) -> Result<Option<OAuthRequest>> {
349 let mut cache = self
350 .cache
351 .lock()
352 .map_err(|e| OAuthStorageError::CacheLockFailedGet {
353 details: e.to_string(),
354 })?;
355
356 if let Some(request) = cache.get(state) {
357 // Check if the request has expired
358 let now = Utc::now();
359 if request.expires_at > now {
360 Ok(Some(request.clone()))
361 } else {
362 // Request has expired, remove it from cache and return None
363 cache.pop(state);
364 Ok(None)
365 }
366 } else {
367 Ok(None)
368 }
369 }
370
371 /// Deletes an OAuth request from the LRU cache by its state parameter.
372 ///
373 /// This method removes an OAuth authorization request from the cache using its state parameter.
374 /// If the state exists in the cache, it is removed entirely, freeing up space for new entries.
375 ///
376 /// # Arguments
377 /// * `state` - The OAuth state parameter identifying the request to delete
378 ///
379 /// # Returns
380 /// * `Ok(())` - If the OAuth request was successfully deleted or didn't exist
381 /// * `Err(error)` - If an error occurs (primarily mutex poisoning, which is very rare)
382 ///
383 /// # Cache Behavior
384 ///
385 /// - If the state exists in the cache, it is removed completely
386 /// - If the state doesn't exist, the operation succeeds without error
387 /// - Removing entries frees up capacity for new entries
388 ///
389 /// # Examples
390 ///
391 /// ```rust
392 /// use atproto_oauth::storage_lru::LruOAuthRequestStorage;
393 /// use atproto_oauth::storage::OAuthRequestStorage;
394 /// use atproto_oauth::workflow::OAuthRequest;
395 /// use std::num::NonZeroUsize;
396 /// use chrono::{Utc, Duration};
397 ///
398 /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
399 /// let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(100).unwrap());
400 ///
401 /// // Add an OAuth request
402 /// let request = OAuthRequest {
403 /// oauth_state: "deletable-state".to_string(),
404 /// issuer: "https://pds.example.com".to_string(),
405 /// did: "did:plc:bv6ggog3tya2z3vxsub7hnal".to_string(),
406 /// nonce: "test-nonce".to_string(),
407 /// pkce_verifier: "test-verifier".to_string(),
408 /// signing_public_key: "test-pubkey".to_string(),
409 /// dpop_private_key: "test-privkey".to_string(),
410 /// created_at: Utc::now(),
411 /// expires_at: Utc::now() + Duration::minutes(10),
412 /// };
413 /// storage.insert_oauth_request(request).await?;
414 /// let retrieved = storage.get_oauth_request_by_state("deletable-state").await?;
415 /// assert!(retrieved.is_some());
416 ///
417 /// // Delete the request
418 /// storage.delete_oauth_request_by_state("deletable-state").await?;
419 /// let retrieved = storage.get_oauth_request_by_state("deletable-state").await?;
420 /// assert_eq!(retrieved, None);
421 ///
422 /// // Deleting non-existent entry is safe
423 /// storage.delete_oauth_request_by_state("non-existent-state").await?;
424 /// # Ok::<(), anyhow::Error>(())
425 /// # }).unwrap();
426 /// ```
427 async fn delete_oauth_request_by_state(&self, state: &str) -> Result<()> {
428 let mut cache =
429 self.cache
430 .lock()
431 .map_err(|e| OAuthStorageError::CacheLockFailedDelete {
432 details: e.to_string(),
433 })?;
434
435 cache.pop(state);
436 Ok(())
437 }
438
439 /// Inserts a new OAuth request into the LRU cache.
440 ///
441 /// This method stores an OAuth authorization request in the cache. If a request with the
442 /// same state already exists in the cache, it is replaced and the entry is moved to the front
443 /// of the LRU order. If the state is new and the cache is at capacity, the least recently
444 /// used entry is evicted to make room.
445 ///
446 /// # Arguments
447 /// * `request` - The complete OAuth request to store. The request's `oauth_state` field
448 /// will be used as the storage key.
449 ///
450 /// # Returns
451 /// * `Ok(())` - If the OAuth request was successfully stored
452 /// * `Err(error)` - If an error occurs (primarily mutex poisoning, which is very rare)
453 ///
454 /// # Cache Behavior
455 ///
456 /// - If the cache is at capacity and this is a new state, the least recently used entry is evicted
457 /// - The new or updated entry is placed at the front of the LRU order
458 /// - Existing entries with the same state are replaced in place
459 ///
460 /// # Examples
461 ///
462 /// ```rust
463 /// use atproto_oauth::storage_lru::LruOAuthRequestStorage;
464 /// use atproto_oauth::storage::OAuthRequestStorage;
465 /// use atproto_oauth::workflow::OAuthRequest;
466 /// use std::num::NonZeroUsize;
467 /// use chrono::{Utc, Duration};
468 ///
469 /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
470 /// let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(2).unwrap()); // Small cache for demo
471 ///
472 /// // Add first request
473 /// let req1 = OAuthRequest {
474 /// oauth_state: "state1".to_string(),
475 /// issuer: "https://pds.example.com".to_string(),
476 /// did: "did:plc:user1".to_string(),
477 /// nonce: "nonce1".to_string(),
478 /// pkce_verifier: "verifier1".to_string(),
479 /// signing_public_key: "pubkey1".to_string(),
480 /// dpop_private_key: "privkey1".to_string(),
481 /// created_at: Utc::now(),
482 /// expires_at: Utc::now() + Duration::minutes(10),
483 /// };
484 /// storage.insert_oauth_request(req1).await?;
485 /// assert_eq!(storage.len(), 1);
486 ///
487 /// // Add second request
488 /// let req2 = OAuthRequest {
489 /// oauth_state: "state2".to_string(),
490 /// issuer: "https://pds.example.com".to_string(),
491 /// did: "did:plc:user2".to_string(),
492 /// nonce: "nonce2".to_string(),
493 /// pkce_verifier: "verifier2".to_string(),
494 /// signing_public_key: "pubkey2".to_string(),
495 /// dpop_private_key: "privkey2".to_string(),
496 /// created_at: Utc::now(),
497 /// expires_at: Utc::now() + Duration::minutes(10),
498 /// };
499 /// storage.insert_oauth_request(req2).await?;
500 /// assert_eq!(storage.len(), 2);
501 ///
502 /// // Add third request - this will evict the least recently used entry (state1)
503 /// let req3 = OAuthRequest {
504 /// oauth_state: "state3".to_string(),
505 /// issuer: "https://pds.example.com".to_string(),
506 /// did: "did:plc:user3".to_string(),
507 /// nonce: "nonce3".to_string(),
508 /// pkce_verifier: "verifier3".to_string(),
509 /// signing_public_key: "pubkey3".to_string(),
510 /// dpop_private_key: "privkey3".to_string(),
511 /// created_at: Utc::now(),
512 /// expires_at: Utc::now() + Duration::minutes(10),
513 /// };
514 /// storage.insert_oauth_request(req3).await?;
515 /// assert_eq!(storage.len(), 2); // Still at capacity
516 ///
517 /// // state1 should be evicted
518 /// let request = storage.get_oauth_request_by_state("state1").await?;
519 /// assert_eq!(request, None);
520 ///
521 /// // state2 and state3 should still be present
522 /// let req2_retrieved = storage.get_oauth_request_by_state("state2").await?;
523 /// let req3_retrieved = storage.get_oauth_request_by_state("state3").await?;
524 /// assert!(req2_retrieved.is_some());
525 /// assert!(req3_retrieved.is_some());
526 /// # Ok::<(), anyhow::Error>(())
527 /// # }).unwrap();
528 /// ```
529 async fn insert_oauth_request(&self, request: OAuthRequest) -> Result<()> {
530 let mut cache =
531 self.cache
532 .lock()
533 .map_err(|e| OAuthStorageError::CacheLockFailedInsert {
534 details: e.to_string(),
535 })?;
536
537 cache.put(request.oauth_state.clone(), request);
538 Ok(())
539 }
540
541 /// Clears all expired OAuth requests from the LRU cache.
542 ///
543 /// This method performs cleanup by removing OAuth requests that have passed their
544 /// expiration time (`expires_at <= current_time`). This is important for maintaining
545 /// security and preventing storage bloat from abandoned OAuth flows.
546 ///
547 /// # Returns
548 /// * `Ok(count)` - The number of expired requests that were successfully removed
549 /// * `Err(error)` - If an error occurs (primarily mutex poisoning, which is very rare)
550 ///
551 /// # Cache Behavior
552 ///
553 /// - Compares each request's `expires_at` against the current time (`Utc::now()`)
554 /// - Removes all requests where `expires_at <= current_time`
555 /// - Maintains LRU order for remaining valid requests
556 /// - Frees up capacity for new OAuth requests
557 ///
558 /// # Examples
559 ///
560 /// ```rust
561 /// use atproto_oauth::storage_lru::LruOAuthRequestStorage;
562 /// use atproto_oauth::storage::OAuthRequestStorage;
563 /// use atproto_oauth::workflow::OAuthRequest;
564 /// use std::num::NonZeroUsize;
565 /// use chrono::{Utc, Duration};
566 ///
567 /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
568 /// let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(100).unwrap());
569 ///
570 /// // Add a request that will expire soon
571 /// let expired_request = OAuthRequest {
572 /// oauth_state: "soon-expired".to_string(),
573 /// issuer: "https://pds.example.com".to_string(),
574 /// did: "did:plc:user1".to_string(),
575 /// nonce: "nonce1".to_string(),
576 /// pkce_verifier: "verifier1".to_string(),
577 /// signing_public_key: "pubkey1".to_string(),
578 /// dpop_private_key: "privkey1".to_string(),
579 /// created_at: Utc::now() - Duration::minutes(20),
580 /// expires_at: Utc::now() - Duration::minutes(10), // Already expired
581 /// };
582 /// storage.insert_oauth_request(expired_request).await?;
583 ///
584 /// // Add a valid request
585 /// let valid_request = OAuthRequest {
586 /// oauth_state: "still-valid".to_string(),
587 /// issuer: "https://pds.example.com".to_string(),
588 /// did: "did:plc:user2".to_string(),
589 /// nonce: "nonce2".to_string(),
590 /// pkce_verifier: "verifier2".to_string(),
591 /// signing_public_key: "pubkey2".to_string(),
592 /// dpop_private_key: "privkey2".to_string(),
593 /// created_at: Utc::now(),
594 /// expires_at: Utc::now() + Duration::minutes(10), // Still valid
595 /// };
596 /// storage.insert_oauth_request(valid_request).await?;
597 ///
598 /// assert_eq!(storage.len(), 2);
599 ///
600 /// // Clean up expired requests
601 /// let removed_count = storage.clear_expired_oauth_requests().await?;
602 /// assert_eq!(removed_count, 1); // One expired request removed
603 /// assert_eq!(storage.len(), 1); // One valid request remains
604 ///
605 /// // Verify the valid request is still accessible
606 /// let remaining = storage.get_oauth_request_by_state("still-valid").await?;
607 /// assert!(remaining.is_some());
608 /// # Ok::<(), anyhow::Error>(())
609 /// # }).unwrap();
610 /// ```
611 async fn clear_expired_oauth_requests(&self) -> Result<u64> {
612 let mut cache =
613 self.cache
614 .lock()
615 .map_err(|e| OAuthStorageError::CacheLockFailedCleanup {
616 details: e.to_string(),
617 })?;
618
619 let now = Utc::now();
620
621 // Collect keys of expired requests
622 let expired_keys: Vec<String> = cache
623 .iter()
624 .filter_map(|(key, request)| {
625 if request.expires_at <= now {
626 Some(key.clone())
627 } else {
628 None
629 }
630 })
631 .collect();
632
633 // Remove expired requests
634 for key in &expired_keys {
635 cache.pop(key);
636 }
637
638 Ok(expired_keys.len() as u64)
639 }
640}
641
642#[cfg(test)]
643mod tests {
644 use super::*;
645 use chrono::{Duration, Utc};
646 use std::num::NonZeroUsize;
647
648 fn create_test_oauth_request(state: &str, issuer: &str, did: &str) -> OAuthRequest {
649 OAuthRequest {
650 oauth_state: state.to_string(),
651 issuer: issuer.to_string(),
652 did: did.to_string(),
653 nonce: format!("nonce-{}", state),
654 pkce_verifier: format!("verifier-{}", state),
655 signing_public_key: format!("pubkey-{}", state),
656 dpop_private_key: format!("privkey-{}", state),
657 created_at: Utc::now(),
658 expires_at: Utc::now() + Duration::minutes(10),
659 }
660 }
661
662 fn create_expired_oauth_request(state: &str, issuer: &str, did: &str) -> OAuthRequest {
663 OAuthRequest {
664 oauth_state: state.to_string(),
665 issuer: issuer.to_string(),
666 did: did.to_string(),
667 nonce: format!("nonce-{}", state),
668 pkce_verifier: format!("verifier-{}", state),
669 signing_public_key: format!("pubkey-{}", state),
670 dpop_private_key: format!("privkey-{}", state),
671 created_at: Utc::now() - Duration::minutes(20),
672 expires_at: Utc::now() - Duration::minutes(10), // Already expired
673 }
674 }
675
676 #[tokio::test]
677 async fn test_new_storage() {
678 let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(100).unwrap());
679 assert_eq!(storage.len(), 0);
680 assert!(storage.is_empty());
681 assert_eq!(storage.capacity().get(), 100);
682 }
683
684 #[tokio::test]
685 async fn test_basic_operations() -> Result<()> {
686 let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(10).unwrap());
687
688 // Test get on empty cache
689 let result = storage.get_oauth_request_by_state("unknown-state").await?;
690 assert_eq!(result, None);
691
692 // Test insert and get
693 let request =
694 create_test_oauth_request("test-state", "https://pds.example.com", "did:plc:test");
695 storage.insert_oauth_request(request.clone()).await?;
696 let result = storage.get_oauth_request_by_state("test-state").await?;
697 assert!(result.is_some());
698 assert_eq!(result.as_ref().unwrap().oauth_state, request.oauth_state);
699 assert_eq!(storage.len(), 1);
700
701 // Test update existing
702 let updated_request = create_test_oauth_request(
703 "test-state",
704 "https://updated.example.com",
705 "did:plc:updated",
706 );
707 storage
708 .insert_oauth_request(updated_request.clone())
709 .await?;
710 let result = storage.get_oauth_request_by_state("test-state").await?;
711 assert!(result.is_some());
712 assert_eq!(result.as_ref().unwrap().issuer, updated_request.issuer);
713 assert_eq!(storage.len(), 1); // Should still be 1
714
715 // Test delete
716 storage.delete_oauth_request_by_state("test-state").await?;
717 let result = storage.get_oauth_request_by_state("test-state").await?;
718 assert_eq!(result, None);
719 assert_eq!(storage.len(), 0);
720
721 Ok(())
722 }
723
724 #[tokio::test]
725 async fn test_expiration_handling() -> Result<()> {
726 let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(10).unwrap());
727
728 // Insert an expired request
729 let expired_request = create_expired_oauth_request(
730 "expired-state",
731 "https://pds.example.com",
732 "did:plc:expired",
733 );
734 storage.insert_oauth_request(expired_request).await?;
735 assert_eq!(storage.len(), 1);
736
737 // Try to get the expired request - should return None and remove from cache
738 let result = storage.get_oauth_request_by_state("expired-state").await?;
739 assert_eq!(result, None);
740 assert_eq!(storage.len(), 0); // Should be removed from cache
741
742 Ok(())
743 }
744
745 #[tokio::test]
746 async fn test_lru_eviction() -> Result<()> {
747 let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(2).unwrap());
748
749 // Fill cache to capacity
750 let req1 = create_test_oauth_request("state1", "https://pds.example.com", "did:plc:user1");
751 let req2 = create_test_oauth_request("state2", "https://pds.example.com", "did:plc:user2");
752 storage.insert_oauth_request(req1.clone()).await?;
753 storage.insert_oauth_request(req2).await?;
754 assert_eq!(storage.len(), 2);
755
756 // Access req1 to make it recently used
757 let _ = storage.get_oauth_request_by_state("state1").await?;
758
759 // Add req3, which should evict req2 (least recently used)
760 let req3 = create_test_oauth_request("state3", "https://pds.example.com", "did:plc:user3");
761 storage.insert_oauth_request(req3.clone()).await?;
762 assert_eq!(storage.len(), 2);
763
764 // req1 and req3 should be present, req2 should be evicted
765 let result1 = storage.get_oauth_request_by_state("state1").await?;
766 assert!(result1.is_some());
767 assert_eq!(result1.unwrap().did, req1.did);
768
769 let result3 = storage.get_oauth_request_by_state("state3").await?;
770 assert!(result3.is_some());
771 assert_eq!(result3.unwrap().did, req3.did);
772
773 assert_eq!(storage.get_oauth_request_by_state("state2").await?, None);
774
775 Ok(())
776 }
777
778 #[tokio::test]
779 async fn test_clear() -> Result<()> {
780 let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(10).unwrap());
781
782 // Add some entries
783 let req1 = create_test_oauth_request("state1", "https://pds.example.com", "did:plc:user1");
784 let req2 = create_test_oauth_request("state2", "https://pds.example.com", "did:plc:user2");
785 storage.insert_oauth_request(req1).await?;
786 storage.insert_oauth_request(req2).await?;
787 assert_eq!(storage.len(), 2);
788
789 // Clear cache
790 storage.clear();
791 assert_eq!(storage.len(), 0);
792 assert!(storage.is_empty());
793
794 // Verify entries are gone
795 assert_eq!(storage.get_oauth_request_by_state("state1").await?, None);
796 assert_eq!(storage.get_oauth_request_by_state("state2").await?, None);
797
798 Ok(())
799 }
800
801 #[tokio::test]
802 async fn test_clear_expired_requests() -> Result<()> {
803 let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(10).unwrap());
804
805 // Add expired and valid requests
806 let expired1 =
807 create_expired_oauth_request("expired1", "https://pds.example.com", "did:plc:expired1");
808 let expired2 =
809 create_expired_oauth_request("expired2", "https://pds.example.com", "did:plc:expired2");
810 let valid1 =
811 create_test_oauth_request("valid1", "https://pds.example.com", "did:plc:valid1");
812 let valid2 =
813 create_test_oauth_request("valid2", "https://pds.example.com", "did:plc:valid2");
814
815 storage.insert_oauth_request(expired1).await?;
816 storage.insert_oauth_request(valid1).await?;
817 storage.insert_oauth_request(expired2).await?;
818 storage.insert_oauth_request(valid2).await?;
819 assert_eq!(storage.len(), 4);
820
821 // Clear expired requests
822 let removed_count = storage.clear_expired_oauth_requests().await?;
823 assert_eq!(removed_count, 2); // Two expired requests removed
824 assert_eq!(storage.len(), 2); // Two valid requests remain
825
826 // Verify only valid requests remain
827 assert!(
828 storage
829 .get_oauth_request_by_state("valid1")
830 .await?
831 .is_some()
832 );
833 assert!(
834 storage
835 .get_oauth_request_by_state("valid2")
836 .await?
837 .is_some()
838 );
839 assert_eq!(storage.get_oauth_request_by_state("expired1").await?, None);
840 assert_eq!(storage.get_oauth_request_by_state("expired2").await?, None);
841
842 Ok(())
843 }
844
845 #[tokio::test]
846 async fn test_delete_nonexistent() -> Result<()> {
847 let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(10).unwrap());
848
849 // Deleting non-existent entry should not error
850 storage
851 .delete_oauth_request_by_state("non-existent-state")
852 .await?;
853 assert_eq!(storage.len(), 0);
854
855 Ok(())
856 }
857
858 #[tokio::test]
859 async fn test_thread_safety() -> Result<()> {
860 let storage = Arc::new(LruOAuthRequestStorage::new(NonZeroUsize::new(100).unwrap()));
861 let mut handles = Vec::new();
862
863 // Spawn multiple tasks that concurrently access the storage
864 for i in 0..10 {
865 let storage_clone = Arc::clone(&storage);
866 let handle = tokio::spawn(async move {
867 let state = format!("state{}", i);
868 let issuer = format!("https://pds{}.example.com", i);
869 let did = format!("did:plc:user{}", i);
870 let request = create_test_oauth_request(&state, &issuer, &did);
871
872 // Insert a request
873 storage_clone.insert_oauth_request(request.clone()).await?;
874
875 // Get the request back
876 let result = storage_clone.get_oauth_request_by_state(&state).await?;
877 assert!(result.is_some());
878 assert_eq!(result.unwrap().oauth_state, request.oauth_state);
879
880 // Delete the request
881 storage_clone.delete_oauth_request_by_state(&state).await?;
882 let result = storage_clone.get_oauth_request_by_state(&state).await?;
883 assert_eq!(result, None);
884
885 Ok::<(), anyhow::Error>(())
886 });
887 handles.push(handle);
888 }
889
890 // Wait for all tasks to complete
891 for handle in handles {
892 handle.await??;
893 }
894
895 // Storage should be empty after all deletions
896 assert_eq!(storage.len(), 0);
897 Ok(())
898 }
899}