atproto_oauth/
storage.rs

1//! OAuth request storage abstraction.
2//!
3//! Storage trait for OAuth request CRUD operations supporting multiple
4//! backends with state tracking and expiration handling.
5
6use anyhow::Result;
7
8use crate::workflow::OAuthRequest;
9
10/// Trait for implementing OAuth request CRUD operations across different storage backends.
11///
12/// This trait provides an abstraction layer for storing and retrieving OAuth authorization request
13/// state, allowing different implementations for various storage systems such as databases, file systems,
14/// in-memory stores, or cloud storage services.
15///
16/// All methods return `anyhow::Result` to allow implementations to use their own error types
17/// while providing a consistent interface for callers. Implementations should handle their
18/// specific error conditions and convert them to appropriate error messages.
19///
20/// ## Thread Safety
21///
22/// This trait requires implementations to be thread-safe (`Send + Sync`), meaning:
23/// - `Send`: The storage implementation can be moved between threads
24/// - `Sync`: The storage implementation can be safely accessed from multiple threads simultaneously
25///
26/// This is essential for async applications where the storage might be accessed from different
27/// async tasks running on different threads. Implementations should use appropriate
28/// synchronization primitives (like `Arc<Mutex<>>`, `RwLock`, or database connection pools)
29/// to ensure thread safety.
30///
31/// ## OAuth Request Lifecycle
32///
33/// OAuth requests have a natural lifecycle with expiration times. Implementations should:
34/// - Store requests with their creation and expiration timestamps
35/// - Support efficient lookup by OAuth state parameter
36/// - Provide cleanup mechanisms for expired requests
37/// - Handle concurrent access safely
38///
39/// ## Usage
40///
41/// Implementors of this trait can provide storage for OAuth requests in any backend:
42///
43/// ```rust,ignore
44/// use atproto_oauth::storage::OAuthRequestStorage;
45/// use atproto_oauth::workflow::OAuthRequest;
46/// use anyhow::Result;
47/// use std::sync::Arc;
48/// use tokio::sync::RwLock;
49/// use std::collections::HashMap;
50/// use chrono::{DateTime, Utc};
51///
52/// // Thread-safe in-memory storage using Arc<RwLock<>>
53/// #[derive(Clone)]
54/// struct InMemoryOAuthStorage {
55///     requests: Arc<RwLock<HashMap<String, OAuthRequest>>>, // state -> request mapping
56/// }
57///
58/// #[async_trait::async_trait]
59/// impl OAuthRequestStorage for InMemoryOAuthStorage {
60///     async fn get_oauth_request_by_state(&self, state: &str) -> Result<Option<OAuthRequest>> {
61///         let requests = self.requests.read().await;
62///         Ok(requests.get(state).cloned())
63///     }
64///     
65///     async fn insert_oauth_request(&self, request: OAuthRequest) -> Result<()> {
66///         let mut requests = self.requests.write().await;
67///         requests.insert(request.oauth_state.clone(), request);
68///         Ok(())
69///     }
70///     
71///     async fn delete_oauth_request_by_state(&self, state: &str) -> Result<()> {
72///         let mut requests = self.requests.write().await;
73///         requests.remove(state);
74///         Ok(())
75///     }
76///     
77///     async fn clear_expired_oauth_requests(&self) -> Result<u64> {
78///         let mut requests = self.requests.write().await;
79///         let now = Utc::now();
80///         let initial_count = requests.len();
81///         
82///         requests.retain(|_, req| req.expires_at > now);
83///         let final_count = requests.len();
84///         
85///         Ok((initial_count - final_count) as u64)
86///     }
87/// }
88///
89/// // Database storage with thread-safe connection pool
90/// struct DatabaseOAuthStorage {
91///     pool: sqlx::Pool<sqlx::Postgres>, // Thread-safe connection pool
92/// }
93///
94/// #[async_trait::async_trait]
95/// impl OAuthRequestStorage for DatabaseOAuthStorage {
96///     async fn get_oauth_request_by_state(&self, state: &str) -> Result<Option<OAuthRequest>> {
97///         let row: Option<_> = sqlx::query_as!(
98///             OAuthRequestRow,
99///             "SELECT oauth_state, issuer, did, nonce, pkce_verifier, signing_public_key,
100///              dpop_private_key, created_at, expires_at
101///              FROM oauth_requests WHERE oauth_state = $1 AND expires_at > NOW()"
102///         )
103///         .bind(state)
104///         .fetch_optional(&self.pool)
105///         .await?;
106///         
107///         Ok(row.map(|r| r.into_oauth_request()))
108///     }
109///     
110///     async fn insert_oauth_request(&self, request: OAuthRequest) -> Result<()> {
111///         sqlx::query!(
112///             "INSERT INTO oauth_requests
113///              (oauth_state, issuer, authorization_server, nonce, pkce_verifier, signing_public_key,
114///               dpop_private_key, created_at, expires_at)
115///              VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)",
116///             request.oauth_state,
117///             request.issuer,
118///             request.authorization_server,
119///             request.nonce,
120///             request.pkce_verifier,
121///             request.signing_public_key,
122///             request.dpop_private_key,
123///             request.created_at,
124///             request.expires_at
125///         )
126///         .execute(&self.pool)
127///         .await?;
128///         Ok(())
129///     }
130///     
131///     async fn delete_oauth_request_by_state(&self, state: &str) -> Result<()> {
132///         sqlx::query!("DELETE FROM oauth_requests WHERE oauth_state = $1", state)
133///             .execute(&self.pool)
134///             .await?;
135///         Ok(())
136///     }
137///     
138///     async fn clear_expired_oauth_requests(&self) -> Result<u64> {
139///         let result = sqlx::query!("DELETE FROM oauth_requests WHERE expires_at <= NOW()")
140///             .execute(&self.pool)
141///             .await?;
142///         Ok(result.rows_affected())
143///     }
144/// }
145/// ```
146#[async_trait::async_trait]
147pub trait OAuthRequestStorage: Send + Sync {
148    /// Retrieves an OAuth request by its state parameter.
149    ///
150    /// This method looks up an OAuth authorization request using the state parameter,
151    /// which is a unique identifier for each OAuth flow used to prevent CSRF attacks.
152    /// The state parameter is generated during the initial authorization request and
153    /// used to correlate the callback with the original request.
154    ///
155    /// Implementations should:
156    /// - Return only non-expired requests (check `expires_at` against current time)
157    /// - Handle the case where the state doesn't exist gracefully (return `None`)
158    /// - Ensure thread-safe access to the underlying storage
159    ///
160    /// # Arguments
161    /// * `state` - The OAuth state parameter to look up. This is a randomly generated
162    ///            string that uniquely identifies the OAuth authorization request.
163    ///
164    /// # Returns
165    /// * `Ok(Some(request))` - If a valid, non-expired OAuth request is found
166    /// * `Ok(None)` - If no request exists for the given state or if the request has expired
167    /// * `Err(error)` - If an error occurs during retrieval (storage failure, etc.)
168    ///
169    /// # Examples
170    ///
171    /// ```rust,ignore
172    /// let storage = MyOAuthStorage::new();
173    /// let request = storage.get_oauth_request_by_state("unique-state-value").await?;
174    /// match request {
175    ///     Some(req) => {
176    ///         println!("Found OAuth request for issuer: {}", req.issuer);
177    ///         // Continue with OAuth flow
178    ///     },
179    ///     None => {
180    ///         println!("No valid OAuth request found for this state");
181    ///         // Handle invalid/expired state
182    ///     }
183    /// }
184    /// ```
185    async fn get_oauth_request_by_state(&self, state: &str) -> Result<Option<OAuthRequest>>;
186
187    /// Deletes an OAuth request by its state parameter.
188    ///
189    /// This method removes an OAuth authorization request from storage using its state parameter.
190    /// This is typically called after the OAuth flow completes successfully or when cleaning up
191    /// failed/abandoned flows.
192    ///
193    /// Implementations should:
194    /// - Handle the case where the state doesn't exist gracefully (return `Ok(())`)
195    /// - Ensure the deletion is atomic
196    /// - Clean up any related data or indexes
197    /// - Be thread-safe for concurrent access
198    ///
199    /// # Arguments
200    /// * `state` - The OAuth state parameter identifying the request to delete.
201    ///
202    /// # Returns
203    /// * `Ok(())` - If the OAuth request was successfully deleted or didn't exist
204    /// * `Err(error)` - If an error occurs during deletion (storage failure, etc.)
205    ///
206    /// # Examples
207    ///
208    /// ```rust,ignore
209    /// let storage = MyOAuthStorage::new();
210    /// // After successful OAuth completion
211    /// storage.delete_oauth_request_by_state("completed-state-value").await?;
212    /// println!("OAuth request cleaned up successfully");
213    /// ```
214    async fn delete_oauth_request_by_state(&self, state: &str) -> Result<()>;
215
216    /// Inserts a new OAuth request into storage.
217    ///
218    /// This method stores a new OAuth authorization request, typically called at the beginning
219    /// of an OAuth flow when the authorization request is initiated. The request contains all
220    /// the necessary state information to complete the OAuth flow.
221    ///
222    /// Implementations should:
223    /// - Store all fields of the `OAuthRequest` struct
224    /// - Handle duplicate state parameters appropriately (either reject or replace)
225    /// - Ensure the insertion is atomic
226    /// - Maintain indexes for efficient lookups by state
227    /// - Be thread-safe for concurrent insertions
228    ///
229    /// # Arguments
230    /// * `request` - The complete OAuth request to store, including state, timing,
231    ///              cryptographic keys, and user information.
232    ///
233    /// # Returns
234    /// * `Ok(())` - If the OAuth request was successfully stored
235    /// * `Err(error)` - If an error occurs during insertion (storage failure,
236    ///                 constraint violation, etc.)
237    ///
238    /// # Examples
239    ///
240    /// ```rust,ignore
241    /// use chrono::{Utc, Duration};
242    ///
243    /// let storage = MyOAuthStorage::new();
244    /// let request = OAuthRequest {
245    ///     oauth_state: "unique-random-state".to_string(),
246    ///     issuer: "https://pds.example.com".to_string(),
247    ///     did: "did:plc:example123".to_string(),
248    ///     nonce: "random-nonce".to_string(),
249    ///     pkce_verifier: "code-verifier".to_string(),
250    ///     signing_public_key: "public-key-data".to_string(),
251    ///     dpop_private_key: "private-key-data".to_string(),
252    ///     created_at: Utc::now(),
253    ///     expires_at: Utc::now() + Duration::minutes(10),
254    /// };
255    ///
256    /// storage.insert_oauth_request(request).await?;
257    /// println!("OAuth request stored successfully");
258    /// ```
259    async fn insert_oauth_request(&self, request: OAuthRequest) -> Result<()>;
260
261    /// Clears all expired OAuth requests from storage.
262    ///
263    /// This method performs cleanup by removing OAuth requests that have passed their
264    /// expiration time. This is important for:
265    /// - Preventing storage bloat from abandoned OAuth flows
266    /// - Maintaining security by ensuring expired flows cannot be resumed
267    /// - Optimizing storage performance by removing stale data
268    ///
269    /// Implementations should:
270    /// - Compare `expires_at` against the current time (`Utc::now()`)
271    /// - Remove all requests where `expires_at <= current_time`
272    /// - Return the count of removed requests for monitoring/logging
273    /// - Be efficient for large datasets (use bulk operations when possible)
274    /// - Be thread-safe and handle concurrent access appropriately
275    ///
276    /// This method is typically called:
277    /// - Periodically by a background cleanup task
278    /// - Before inserting new requests to maintain storage hygiene
279    /// - During application startup to clean stale data
280    ///
281    /// # Returns
282    /// * `Ok(count)` - The number of expired requests that were successfully removed
283    /// * `Err(error)` - If an error occurs during cleanup (storage failure, etc.)
284    ///
285    /// # Examples
286    ///
287    /// ```rust,ignore
288    /// let storage = MyOAuthStorage::new();
289    ///
290    /// // Periodic cleanup
291    /// let removed_count = storage.clear_expired_oauth_requests().await?;
292    /// println!("Cleaned up {} expired OAuth requests", removed_count);
293    ///
294    /// // In a background task
295    /// tokio::spawn(async move {
296    ///     let mut interval = tokio::time::interval(Duration::from_secs(300)); // 5 minutes
297    ///     loop {
298    ///         interval.tick().await;
299    ///         if let Err(e) = storage.clear_expired_oauth_requests().await {
300    ///             eprintln!("Error during OAuth cleanup: {}", e);
301    ///         }
302    ///     }
303    /// });
304    /// ```
305    async fn clear_expired_oauth_requests(&self) -> Result<u64>;
306}