atproto_oauth/storage.rs
1//! OAuth request storage abstraction.
2//!
3//! Storage trait for OAuth request CRUD operations supporting multiple
4//! backends with state tracking and expiration handling.
5
6use anyhow::Result;
7
8use crate::workflow::OAuthRequest;
9
10/// Trait for implementing OAuth request CRUD operations across different storage backends.
11///
12/// This trait provides an abstraction layer for storing and retrieving OAuth authorization request
13/// state, allowing different implementations for various storage systems such as databases, file systems,
14/// in-memory stores, or cloud storage services.
15///
16/// All methods return `anyhow::Result` to allow implementations to use their own error types
17/// while providing a consistent interface for callers. Implementations should handle their
18/// specific error conditions and convert them to appropriate error messages.
19///
20/// ## Thread Safety
21///
22/// This trait requires implementations to be thread-safe (`Send + Sync`), meaning:
23/// - `Send`: The storage implementation can be moved between threads
24/// - `Sync`: The storage implementation can be safely accessed from multiple threads simultaneously
25///
26/// This is essential for async applications where the storage might be accessed from different
27/// async tasks running on different threads. Implementations should use appropriate
28/// synchronization primitives (like `Arc<Mutex<>>`, `RwLock`, or database connection pools)
29/// to ensure thread safety.
30///
31/// ## OAuth Request Lifecycle
32///
33/// OAuth requests have a natural lifecycle with expiration times. Implementations should:
34/// - Store requests with their creation and expiration timestamps
35/// - Support efficient lookup by OAuth state parameter
36/// - Provide cleanup mechanisms for expired requests
37/// - Handle concurrent access safely
38///
39/// ## Usage
40///
41/// Implementors of this trait can provide storage for OAuth requests in any backend:
42///
43/// ```rust,ignore
44/// use atproto_oauth::storage::OAuthRequestStorage;
45/// use atproto_oauth::workflow::OAuthRequest;
46/// use anyhow::Result;
47/// use std::sync::Arc;
48/// use tokio::sync::RwLock;
49/// use std::collections::HashMap;
50/// use chrono::{DateTime, Utc};
51///
52/// // Thread-safe in-memory storage using Arc<RwLock<>>
53/// #[derive(Clone)]
54/// struct InMemoryOAuthStorage {
55/// requests: Arc<RwLock<HashMap<String, OAuthRequest>>>, // state -> request mapping
56/// }
57///
58/// #[async_trait::async_trait]
59/// impl OAuthRequestStorage for InMemoryOAuthStorage {
60/// async fn get_oauth_request_by_state(&self, state: &str) -> Result<Option<OAuthRequest>> {
61/// let requests = self.requests.read().await;
62/// Ok(requests.get(state).cloned())
63/// }
64///
65/// async fn insert_oauth_request(&self, request: OAuthRequest) -> Result<()> {
66/// let mut requests = self.requests.write().await;
67/// requests.insert(request.oauth_state.clone(), request);
68/// Ok(())
69/// }
70///
71/// async fn delete_oauth_request_by_state(&self, state: &str) -> Result<()> {
72/// let mut requests = self.requests.write().await;
73/// requests.remove(state);
74/// Ok(())
75/// }
76///
77/// async fn clear_expired_oauth_requests(&self) -> Result<u64> {
78/// let mut requests = self.requests.write().await;
79/// let now = Utc::now();
80/// let initial_count = requests.len();
81///
82/// requests.retain(|_, req| req.expires_at > now);
83/// let final_count = requests.len();
84///
85/// Ok((initial_count - final_count) as u64)
86/// }
87/// }
88///
89/// // Database storage with thread-safe connection pool
90/// struct DatabaseOAuthStorage {
91/// pool: sqlx::Pool<sqlx::Postgres>, // Thread-safe connection pool
92/// }
93///
94/// #[async_trait::async_trait]
95/// impl OAuthRequestStorage for DatabaseOAuthStorage {
96/// async fn get_oauth_request_by_state(&self, state: &str) -> Result<Option<OAuthRequest>> {
97/// let row: Option<_> = sqlx::query_as!(
98/// OAuthRequestRow,
99/// "SELECT oauth_state, issuer, did, nonce, pkce_verifier, signing_public_key,
100/// dpop_private_key, created_at, expires_at
101/// FROM oauth_requests WHERE oauth_state = $1 AND expires_at > NOW()"
102/// )
103/// .bind(state)
104/// .fetch_optional(&self.pool)
105/// .await?;
106///
107/// Ok(row.map(|r| r.into_oauth_request()))
108/// }
109///
110/// async fn insert_oauth_request(&self, request: OAuthRequest) -> Result<()> {
111/// sqlx::query!(
112/// "INSERT INTO oauth_requests
113/// (oauth_state, issuer, authorization_server, nonce, pkce_verifier, signing_public_key,
114/// dpop_private_key, created_at, expires_at)
115/// VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)",
116/// request.oauth_state,
117/// request.issuer,
118/// request.authorization_server,
119/// request.nonce,
120/// request.pkce_verifier,
121/// request.signing_public_key,
122/// request.dpop_private_key,
123/// request.created_at,
124/// request.expires_at
125/// )
126/// .execute(&self.pool)
127/// .await?;
128/// Ok(())
129/// }
130///
131/// async fn delete_oauth_request_by_state(&self, state: &str) -> Result<()> {
132/// sqlx::query!("DELETE FROM oauth_requests WHERE oauth_state = $1", state)
133/// .execute(&self.pool)
134/// .await?;
135/// Ok(())
136/// }
137///
138/// async fn clear_expired_oauth_requests(&self) -> Result<u64> {
139/// let result = sqlx::query!("DELETE FROM oauth_requests WHERE expires_at <= NOW()")
140/// .execute(&self.pool)
141/// .await?;
142/// Ok(result.rows_affected())
143/// }
144/// }
145/// ```
146#[async_trait::async_trait]
147pub trait OAuthRequestStorage: Send + Sync {
148 /// Retrieves an OAuth request by its state parameter.
149 ///
150 /// This method looks up an OAuth authorization request using the state parameter,
151 /// which is a unique identifier for each OAuth flow used to prevent CSRF attacks.
152 /// The state parameter is generated during the initial authorization request and
153 /// used to correlate the callback with the original request.
154 ///
155 /// Implementations should:
156 /// - Return only non-expired requests (check `expires_at` against current time)
157 /// - Handle the case where the state doesn't exist gracefully (return `None`)
158 /// - Ensure thread-safe access to the underlying storage
159 ///
160 /// # Arguments
161 /// * `state` - The OAuth state parameter to look up. This is a randomly generated
162 /// string that uniquely identifies the OAuth authorization request.
163 ///
164 /// # Returns
165 /// * `Ok(Some(request))` - If a valid, non-expired OAuth request is found
166 /// * `Ok(None)` - If no request exists for the given state or if the request has expired
167 /// * `Err(error)` - If an error occurs during retrieval (storage failure, etc.)
168 ///
169 /// # Examples
170 ///
171 /// ```rust,ignore
172 /// let storage = MyOAuthStorage::new();
173 /// let request = storage.get_oauth_request_by_state("unique-state-value").await?;
174 /// match request {
175 /// Some(req) => {
176 /// println!("Found OAuth request for issuer: {}", req.issuer);
177 /// // Continue with OAuth flow
178 /// },
179 /// None => {
180 /// println!("No valid OAuth request found for this state");
181 /// // Handle invalid/expired state
182 /// }
183 /// }
184 /// ```
185 async fn get_oauth_request_by_state(&self, state: &str) -> Result<Option<OAuthRequest>>;
186
187 /// Deletes an OAuth request by its state parameter.
188 ///
189 /// This method removes an OAuth authorization request from storage using its state parameter.
190 /// This is typically called after the OAuth flow completes successfully or when cleaning up
191 /// failed/abandoned flows.
192 ///
193 /// Implementations should:
194 /// - Handle the case where the state doesn't exist gracefully (return `Ok(())`)
195 /// - Ensure the deletion is atomic
196 /// - Clean up any related data or indexes
197 /// - Be thread-safe for concurrent access
198 ///
199 /// # Arguments
200 /// * `state` - The OAuth state parameter identifying the request to delete.
201 ///
202 /// # Returns
203 /// * `Ok(())` - If the OAuth request was successfully deleted or didn't exist
204 /// * `Err(error)` - If an error occurs during deletion (storage failure, etc.)
205 ///
206 /// # Examples
207 ///
208 /// ```rust,ignore
209 /// let storage = MyOAuthStorage::new();
210 /// // After successful OAuth completion
211 /// storage.delete_oauth_request_by_state("completed-state-value").await?;
212 /// println!("OAuth request cleaned up successfully");
213 /// ```
214 async fn delete_oauth_request_by_state(&self, state: &str) -> Result<()>;
215
216 /// Inserts a new OAuth request into storage.
217 ///
218 /// This method stores a new OAuth authorization request, typically called at the beginning
219 /// of an OAuth flow when the authorization request is initiated. The request contains all
220 /// the necessary state information to complete the OAuth flow.
221 ///
222 /// Implementations should:
223 /// - Store all fields of the `OAuthRequest` struct
224 /// - Handle duplicate state parameters appropriately (either reject or replace)
225 /// - Ensure the insertion is atomic
226 /// - Maintain indexes for efficient lookups by state
227 /// - Be thread-safe for concurrent insertions
228 ///
229 /// # Arguments
230 /// * `request` - The complete OAuth request to store, including state, timing,
231 /// cryptographic keys, and user information.
232 ///
233 /// # Returns
234 /// * `Ok(())` - If the OAuth request was successfully stored
235 /// * `Err(error)` - If an error occurs during insertion (storage failure,
236 /// constraint violation, etc.)
237 ///
238 /// # Examples
239 ///
240 /// ```rust,ignore
241 /// use chrono::{Utc, Duration};
242 ///
243 /// let storage = MyOAuthStorage::new();
244 /// let request = OAuthRequest {
245 /// oauth_state: "unique-random-state".to_string(),
246 /// issuer: "https://pds.example.com".to_string(),
247 /// did: "did:plc:example123".to_string(),
248 /// nonce: "random-nonce".to_string(),
249 /// pkce_verifier: "code-verifier".to_string(),
250 /// signing_public_key: "public-key-data".to_string(),
251 /// dpop_private_key: "private-key-data".to_string(),
252 /// created_at: Utc::now(),
253 /// expires_at: Utc::now() + Duration::minutes(10),
254 /// };
255 ///
256 /// storage.insert_oauth_request(request).await?;
257 /// println!("OAuth request stored successfully");
258 /// ```
259 async fn insert_oauth_request(&self, request: OAuthRequest) -> Result<()>;
260
261 /// Clears all expired OAuth requests from storage.
262 ///
263 /// This method performs cleanup by removing OAuth requests that have passed their
264 /// expiration time. This is important for:
265 /// - Preventing storage bloat from abandoned OAuth flows
266 /// - Maintaining security by ensuring expired flows cannot be resumed
267 /// - Optimizing storage performance by removing stale data
268 ///
269 /// Implementations should:
270 /// - Compare `expires_at` against the current time (`Utc::now()`)
271 /// - Remove all requests where `expires_at <= current_time`
272 /// - Return the count of removed requests for monitoring/logging
273 /// - Be efficient for large datasets (use bulk operations when possible)
274 /// - Be thread-safe and handle concurrent access appropriately
275 ///
276 /// This method is typically called:
277 /// - Periodically by a background cleanup task
278 /// - Before inserting new requests to maintain storage hygiene
279 /// - During application startup to clean stale data
280 ///
281 /// # Returns
282 /// * `Ok(count)` - The number of expired requests that were successfully removed
283 /// * `Err(error)` - If an error occurs during cleanup (storage failure, etc.)
284 ///
285 /// # Examples
286 ///
287 /// ```rust,ignore
288 /// let storage = MyOAuthStorage::new();
289 ///
290 /// // Periodic cleanup
291 /// let removed_count = storage.clear_expired_oauth_requests().await?;
292 /// println!("Cleaned up {} expired OAuth requests", removed_count);
293 ///
294 /// // In a background task
295 /// tokio::spawn(async move {
296 /// let mut interval = tokio::time::interval(Duration::from_secs(300)); // 5 minutes
297 /// loop {
298 /// interval.tick().await;
299 /// if let Err(e) = storage.clear_expired_oauth_requests().await {
300 /// eprintln!("Error during OAuth cleanup: {}", e);
301 /// }
302 /// }
303 /// });
304 /// ```
305 async fn clear_expired_oauth_requests(&self) -> Result<u64>;
306}