atproto_oauth/
storage.rs

1//! OAuth request storage abstraction for AT Protocol OAuth operations.
2//!
3//! This module provides the `OAuthRequestStorage` trait for implementing OAuth request CRUD operations
4//! across different storage backends (database, file system, in-memory, etc.).
5//!
6//! The trait enables flexible storage implementations while maintaining a consistent interface
7//! for OAuth request management operations throughout the AT Protocol OAuth ecosystem, including
8//! state tracking, expiration handling, and cleanup operations.
9
10use anyhow::Result;
11
12use crate::workflow::OAuthRequest;
13
14/// Trait for implementing OAuth request CRUD operations across different storage backends.
15///
16/// This trait provides an abstraction layer for storing and retrieving OAuth authorization request
17/// state, allowing different implementations for various storage systems such as databases, file systems,
18/// in-memory stores, or cloud storage services.
19///
20/// All methods return `anyhow::Result` to allow implementations to use their own error types
21/// while providing a consistent interface for callers. Implementations should handle their
22/// specific error conditions and convert them to appropriate error messages.
23///
24/// ## Thread Safety
25///
26/// This trait requires implementations to be thread-safe (`Send + Sync`), meaning:
27/// - `Send`: The storage implementation can be moved between threads
28/// - `Sync`: The storage implementation can be safely accessed from multiple threads simultaneously
29///
30/// This is essential for async applications where the storage might be accessed from different
31/// async tasks running on different threads. Implementations should use appropriate
32/// synchronization primitives (like `Arc<Mutex<>>`, `RwLock`, or database connection pools)
33/// to ensure thread safety.
34///
35/// ## OAuth Request Lifecycle
36///
37/// OAuth requests have a natural lifecycle with expiration times. Implementations should:
38/// - Store requests with their creation and expiration timestamps
39/// - Support efficient lookup by OAuth state parameter
40/// - Provide cleanup mechanisms for expired requests
41/// - Handle concurrent access safely
42///
43/// ## Usage
44///
45/// Implementors of this trait can provide storage for OAuth requests in any backend:
46///
47/// ```rust,ignore
48/// use atproto_oauth::storage::OAuthRequestStorage;
49/// use atproto_oauth::workflow::OAuthRequest;
50/// use anyhow::Result;
51/// use std::sync::Arc;
52/// use tokio::sync::RwLock;
53/// use std::collections::HashMap;
54/// use chrono::{DateTime, Utc};
55///
56/// // Thread-safe in-memory storage using Arc<RwLock<>>
57/// #[derive(Clone)]
58/// struct InMemoryOAuthStorage {
59///     requests: Arc<RwLock<HashMap<String, OAuthRequest>>>, // state -> request mapping
60/// }
61///
62/// #[async_trait::async_trait]
63/// impl OAuthRequestStorage for InMemoryOAuthStorage {
64///     async fn get_oauth_request_by_state(&self, state: &str) -> Result<Option<OAuthRequest>> {
65///         let requests = self.requests.read().await;
66///         Ok(requests.get(state).cloned())
67///     }
68///     
69///     async fn insert_oauth_request(&self, request: OAuthRequest) -> Result<()> {
70///         let mut requests = self.requests.write().await;
71///         requests.insert(request.oauth_state.clone(), request);
72///         Ok(())
73///     }
74///     
75///     async fn delete_oauth_request_by_state(&self, state: &str) -> Result<()> {
76///         let mut requests = self.requests.write().await;
77///         requests.remove(state);
78///         Ok(())
79///     }
80///     
81///     async fn clear_expired_oauth_requests(&self) -> Result<u64> {
82///         let mut requests = self.requests.write().await;
83///         let now = Utc::now();
84///         let initial_count = requests.len();
85///         
86///         requests.retain(|_, req| req.expires_at > now);
87///         let final_count = requests.len();
88///         
89///         Ok((initial_count - final_count) as u64)
90///     }
91/// }
92///
93/// // Database storage with thread-safe connection pool
94/// struct DatabaseOAuthStorage {
95///     pool: sqlx::Pool<sqlx::Postgres>, // Thread-safe connection pool
96/// }
97///
98/// #[async_trait::async_trait]
99/// impl OAuthRequestStorage for DatabaseOAuthStorage {
100///     async fn get_oauth_request_by_state(&self, state: &str) -> Result<Option<OAuthRequest>> {
101///         let row: Option<_> = sqlx::query_as!(
102///             OAuthRequestRow,
103///             "SELECT oauth_state, issuer, did, nonce, pkce_verifier, signing_public_key,
104///              dpop_private_key, created_at, expires_at
105///              FROM oauth_requests WHERE oauth_state = $1 AND expires_at > NOW()"
106///         )
107///         .bind(state)
108///         .fetch_optional(&self.pool)
109///         .await?;
110///         
111///         Ok(row.map(|r| r.into_oauth_request()))
112///     }
113///     
114///     async fn insert_oauth_request(&self, request: OAuthRequest) -> Result<()> {
115///         sqlx::query!(
116///             "INSERT INTO oauth_requests
117///              (oauth_state, issuer, did, nonce, pkce_verifier, signing_public_key,
118///               dpop_private_key, created_at, expires_at)
119///              VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)",
120///             request.oauth_state,
121///             request.issuer,
122///             request.did,
123///             request.nonce,
124///             request.pkce_verifier,
125///             request.signing_public_key,
126///             request.dpop_private_key,
127///             request.created_at,
128///             request.expires_at
129///         )
130///         .execute(&self.pool)
131///         .await?;
132///         Ok(())
133///     }
134///     
135///     async fn delete_oauth_request_by_state(&self, state: &str) -> Result<()> {
136///         sqlx::query!("DELETE FROM oauth_requests WHERE oauth_state = $1", state)
137///             .execute(&self.pool)
138///             .await?;
139///         Ok(())
140///     }
141///     
142///     async fn clear_expired_oauth_requests(&self) -> Result<u64> {
143///         let result = sqlx::query!("DELETE FROM oauth_requests WHERE expires_at <= NOW()")
144///             .execute(&self.pool)
145///             .await?;
146///         Ok(result.rows_affected())
147///     }
148/// }
149/// ```
150#[async_trait::async_trait]
151pub trait OAuthRequestStorage: Send + Sync {
152    /// Retrieves an OAuth request by its state parameter.
153    ///
154    /// This method looks up an OAuth authorization request using the state parameter,
155    /// which is a unique identifier for each OAuth flow used to prevent CSRF attacks.
156    /// The state parameter is generated during the initial authorization request and
157    /// used to correlate the callback with the original request.
158    ///
159    /// Implementations should:
160    /// - Return only non-expired requests (check `expires_at` against current time)
161    /// - Handle the case where the state doesn't exist gracefully (return `None`)
162    /// - Ensure thread-safe access to the underlying storage
163    ///
164    /// # Arguments
165    /// * `state` - The OAuth state parameter to look up. This is a randomly generated
166    ///            string that uniquely identifies the OAuth authorization request.
167    ///
168    /// # Returns
169    /// * `Ok(Some(request))` - If a valid, non-expired OAuth request is found
170    /// * `Ok(None)` - If no request exists for the given state or if the request has expired
171    /// * `Err(error)` - If an error occurs during retrieval (storage failure, etc.)
172    ///
173    /// # Examples
174    ///
175    /// ```rust,ignore
176    /// let storage = MyOAuthStorage::new();
177    /// let request = storage.get_oauth_request_by_state("unique-state-value").await?;
178    /// match request {
179    ///     Some(req) => {
180    ///         println!("Found OAuth request for DID: {}", req.did);
181    ///         // Continue with OAuth flow
182    ///     },
183    ///     None => {
184    ///         println!("No valid OAuth request found for this state");
185    ///         // Handle invalid/expired state
186    ///     }
187    /// }
188    /// ```
189    async fn get_oauth_request_by_state(&self, state: &str) -> Result<Option<OAuthRequest>>;
190
191    /// Deletes an OAuth request by its state parameter.
192    ///
193    /// This method removes an OAuth authorization request from storage using its state parameter.
194    /// This is typically called after the OAuth flow completes successfully or when cleaning up
195    /// failed/abandoned flows.
196    ///
197    /// Implementations should:
198    /// - Handle the case where the state doesn't exist gracefully (return `Ok(())`)
199    /// - Ensure the deletion is atomic
200    /// - Clean up any related data or indexes
201    /// - Be thread-safe for concurrent access
202    ///
203    /// # Arguments
204    /// * `state` - The OAuth state parameter identifying the request to delete.
205    ///
206    /// # Returns
207    /// * `Ok(())` - If the OAuth request was successfully deleted or didn't exist
208    /// * `Err(error)` - If an error occurs during deletion (storage failure, etc.)
209    ///
210    /// # Examples
211    ///
212    /// ```rust,ignore
213    /// let storage = MyOAuthStorage::new();
214    /// // After successful OAuth completion
215    /// storage.delete_oauth_request_by_state("completed-state-value").await?;
216    /// println!("OAuth request cleaned up successfully");
217    /// ```
218    async fn delete_oauth_request_by_state(&self, state: &str) -> Result<()>;
219
220    /// Inserts a new OAuth request into storage.
221    ///
222    /// This method stores a new OAuth authorization request, typically called at the beginning
223    /// of an OAuth flow when the authorization request is initiated. The request contains all
224    /// the necessary state information to complete the OAuth flow.
225    ///
226    /// Implementations should:
227    /// - Store all fields of the `OAuthRequest` struct
228    /// - Handle duplicate state parameters appropriately (either reject or replace)
229    /// - Ensure the insertion is atomic
230    /// - Maintain indexes for efficient lookups by state
231    /// - Be thread-safe for concurrent insertions
232    ///
233    /// # Arguments
234    /// * `request` - The complete OAuth request to store, including state, timing,
235    ///              cryptographic keys, and user information.
236    ///
237    /// # Returns
238    /// * `Ok(())` - If the OAuth request was successfully stored
239    /// * `Err(error)` - If an error occurs during insertion (storage failure,
240    ///                 constraint violation, etc.)
241    ///
242    /// # Examples
243    ///
244    /// ```rust,ignore
245    /// use chrono::{Utc, Duration};
246    ///
247    /// let storage = MyOAuthStorage::new();
248    /// let request = OAuthRequest {
249    ///     oauth_state: "unique-random-state".to_string(),
250    ///     issuer: "https://pds.example.com".to_string(),
251    ///     did: "did:plc:example123".to_string(),
252    ///     nonce: "random-nonce".to_string(),
253    ///     pkce_verifier: "code-verifier".to_string(),
254    ///     signing_public_key: "public-key-data".to_string(),
255    ///     dpop_private_key: "private-key-data".to_string(),
256    ///     created_at: Utc::now(),
257    ///     expires_at: Utc::now() + Duration::minutes(10),
258    /// };
259    ///
260    /// storage.insert_oauth_request(request).await?;
261    /// println!("OAuth request stored successfully");
262    /// ```
263    async fn insert_oauth_request(&self, request: OAuthRequest) -> Result<()>;
264
265    /// Clears all expired OAuth requests from storage.
266    ///
267    /// This method performs cleanup by removing OAuth requests that have passed their
268    /// expiration time. This is important for:
269    /// - Preventing storage bloat from abandoned OAuth flows
270    /// - Maintaining security by ensuring expired flows cannot be resumed
271    /// - Optimizing storage performance by removing stale data
272    ///
273    /// Implementations should:
274    /// - Compare `expires_at` against the current time (`Utc::now()`)
275    /// - Remove all requests where `expires_at <= current_time`
276    /// - Return the count of removed requests for monitoring/logging
277    /// - Be efficient for large datasets (use bulk operations when possible)
278    /// - Be thread-safe and handle concurrent access appropriately
279    ///
280    /// This method is typically called:
281    /// - Periodically by a background cleanup task
282    /// - Before inserting new requests to maintain storage hygiene
283    /// - During application startup to clean stale data
284    ///
285    /// # Returns
286    /// * `Ok(count)` - The number of expired requests that were successfully removed
287    /// * `Err(error)` - If an error occurs during cleanup (storage failure, etc.)
288    ///
289    /// # Examples
290    ///
291    /// ```rust,ignore
292    /// let storage = MyOAuthStorage::new();
293    ///
294    /// // Periodic cleanup
295    /// let removed_count = storage.clear_expired_oauth_requests().await?;
296    /// println!("Cleaned up {} expired OAuth requests", removed_count);
297    ///
298    /// // In a background task
299    /// tokio::spawn(async move {
300    ///     let mut interval = tokio::time::interval(Duration::from_secs(300)); // 5 minutes
301    ///     loop {
302    ///         interval.tick().await;
303    ///         if let Err(e) = storage.clear_expired_oauth_requests().await {
304    ///             eprintln!("Error during OAuth cleanup: {}", e);
305    ///         }
306    ///     }
307    /// });
308    /// ```
309    async fn clear_expired_oauth_requests(&self) -> Result<u64>;
310}