Skip to main content

openauth_deadpool_postgres/
lib.rs

1//! Pooled Postgres database adapter for OpenAuth.
2//!
3//! This crate is the recommended Postgres adapter for production deployments.
4//! It keeps pooling in `deadpool-postgres` and reuses OpenAuth's shared SQL
5//! planning plus `openauth-tokio-postgres` driver helpers.
6
7pub mod migration;
8
9use std::fmt;
10use std::sync::Arc;
11
12use deadpool_postgres::{Config, Pool, PoolConfig, Runtime};
13use openauth_core::db::{
14    auth_schema, AdapterCapabilities, AdapterFuture, AuthSchemaOptions, Count, Create, DbAdapter,
15    DbRecord, DbSchema, Delete, DeleteMany, FindMany, FindOne, JoinAdapter, SchemaCreation,
16    SqlRateLimitNames, TransactionCallback, Update, UpdateMany,
17};
18use openauth_core::error::OpenAuthError;
19use openauth_core::options::{
20    RateLimitConsumeInput, RateLimitDecision, RateLimitFuture, RateLimitStore,
21};
22use openauth_tokio_postgres::driver::{
23    consume_postgres_rate_limit_in_tx, postgres_error, postgres_rate_limit_plan, PostgresSqlState,
24};
25use tokio::sync::Mutex;
26use tokio_postgres::{Client, NoTls};
27
28const DEFAULT_POOL_MAX_SIZE: usize = 16;
29
30/// Production-oriented Postgres adapter backed by a `deadpool-postgres` pool.
31#[derive(Clone)]
32pub struct DeadpoolPostgresAdapter {
33    pool: Pool,
34    schema: Arc<DbSchema>,
35}
36
37/// Database-backed rate-limit store backed by a `deadpool-postgres` pool.
38#[derive(Clone)]
39pub struct DeadpoolPostgresRateLimitStore {
40    pool: Pool,
41    names: SqlRateLimitNames,
42}
43
44impl fmt::Debug for DeadpoolPostgresAdapter {
45    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
46        formatter
47            .debug_struct("DeadpoolPostgresAdapter")
48            .field("schema", &self.schema)
49            .finish_non_exhaustive()
50    }
51}
52
53impl fmt::Debug for DeadpoolPostgresRateLimitStore {
54    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
55        formatter
56            .debug_struct("DeadpoolPostgresRateLimitStore")
57            .field("names", &self.names)
58            .finish_non_exhaustive()
59    }
60}
61
62impl DeadpoolPostgresAdapter {
63    pub fn new(pool: Pool) -> Self {
64        Self::with_schema(pool, auth_schema(AuthSchemaOptions::default()))
65    }
66
67    pub fn with_schema(pool: Pool, schema: DbSchema) -> Self {
68        Self {
69            pool,
70            schema: Arc::new(schema),
71        }
72    }
73
74    pub async fn connect(database_url: &str) -> Result<Self, OpenAuthError> {
75        Self::connect_with_schema(database_url, auth_schema(AuthSchemaOptions::default())).await
76    }
77
78    pub async fn connect_with_schema(
79        database_url: &str,
80        schema: DbSchema,
81    ) -> Result<Self, OpenAuthError> {
82        let mut config = Config::new();
83        config.url = Some(database_url.to_owned());
84        Self::from_config_with_schema(config, schema, DEFAULT_POOL_MAX_SIZE)
85    }
86
87    pub fn from_config(config: Config, max_size: usize) -> Result<Self, OpenAuthError> {
88        Self::from_config_with_schema(config, auth_schema(AuthSchemaOptions::default()), max_size)
89    }
90
91    pub fn from_config_with_schema(
92        mut config: Config,
93        schema: DbSchema,
94        max_size: usize,
95    ) -> Result<Self, OpenAuthError> {
96        config.pool = Some(PoolConfig::new(max_size));
97        let pool = config
98            .create_pool(Some(Runtime::Tokio1), NoTls)
99            .map_err(deadpool_error)?;
100        Ok(Self::with_schema(pool, schema))
101    }
102
103    pub async fn plan_migrations(
104        &self,
105        schema: &DbSchema,
106    ) -> Result<SchemaMigrationPlan, OpenAuthError> {
107        let client = self.pool.get().await.map_err(deadpool_error)?;
108        openauth_tokio_postgres::driver::plan_migrations(pg_client(&client), schema).await
109    }
110
111    pub async fn compile_migrations(&self, schema: &DbSchema) -> Result<String, OpenAuthError> {
112        Ok(self.plan_migrations(schema).await?.compile())
113    }
114
115    async fn run_with_state<T>(
116        &self,
117        f: impl for<'a> FnOnce(PostgresSqlState<'a>) -> AdapterFuture<'a, T> + Send,
118    ) -> Result<T, OpenAuthError>
119    where
120        T: Send + 'static,
121    {
122        let client = self.pool.get().await.map_err(deadpool_error)?;
123        f(PostgresSqlState::new(
124            self.schema.as_ref(),
125            pg_client(&client),
126        ))
127        .await
128    }
129}
130
131impl DeadpoolPostgresRateLimitStore {
132    pub fn new(pool: Pool) -> Self {
133        Self::with_table(pool, "rate_limits")
134    }
135
136    pub fn with_table(pool: Pool, table: impl Into<String>) -> Self {
137        Self {
138            pool,
139            names: SqlRateLimitNames::new(table),
140        }
141    }
142}
143
144impl From<&DeadpoolPostgresAdapter> for DeadpoolPostgresRateLimitStore {
145    fn from(adapter: &DeadpoolPostgresAdapter) -> Self {
146        Self {
147            pool: adapter.pool.clone(),
148            names: SqlRateLimitNames::from_schema(&adapter.schema),
149        }
150    }
151}
152
153impl RateLimitStore for DeadpoolPostgresRateLimitStore {
154    fn consume<'a>(&'a self, input: RateLimitConsumeInput) -> RateLimitFuture<'a> {
155        Box::pin(async move { consume_deadpool_rate_limit(self, input).await })
156    }
157}
158
159impl DbAdapter for DeadpoolPostgresAdapter {
160    fn id(&self) -> &str {
161        "deadpool-postgres"
162    }
163
164    fn capabilities(&self) -> AdapterCapabilities {
165        AdapterCapabilities::new(self.id())
166            .named("deadpool-postgres")
167            .with_json()
168            .with_arrays()
169            .with_joins()
170            .with_transactions()
171    }
172
173    fn create<'a>(&'a self, query: Create) -> AdapterFuture<'a, DbRecord> {
174        Box::pin(async move {
175            self.run_with_state(|state| Box::pin(state.create(query)))
176                .await
177        })
178    }
179
180    fn find_one<'a>(&'a self, query: FindOne) -> AdapterFuture<'a, Option<DbRecord>> {
181        Box::pin(async move {
182            self.run_with_state(|state| Box::pin(state.find_one(query)))
183                .await
184        })
185    }
186
187    fn find_many<'a>(&'a self, query: FindMany) -> AdapterFuture<'a, Vec<DbRecord>> {
188        Box::pin(async move {
189            if query.joins.len() <= 1 {
190                self.run_with_state(|state| Box::pin(state.find_many(query)))
191                    .await
192            } else {
193                let adapter =
194                    JoinAdapter::new(self.schema.as_ref().clone(), Arc::new(self.clone()), false);
195                adapter.find_many(query).await
196            }
197        })
198    }
199
200    fn count<'a>(&'a self, query: Count) -> AdapterFuture<'a, u64> {
201        Box::pin(async move {
202            self.run_with_state(|state| Box::pin(state.count(query)))
203                .await
204        })
205    }
206
207    fn update<'a>(&'a self, query: Update) -> AdapterFuture<'a, Option<DbRecord>> {
208        Box::pin(async move {
209            self.run_with_state(|state| Box::pin(state.update(query)))
210                .await
211        })
212    }
213
214    fn update_many<'a>(&'a self, query: UpdateMany) -> AdapterFuture<'a, u64> {
215        Box::pin(async move {
216            self.run_with_state(|state| Box::pin(state.update_many(query)))
217                .await
218        })
219    }
220
221    fn delete<'a>(&'a self, query: Delete) -> AdapterFuture<'a, ()> {
222        Box::pin(async move {
223            self.run_with_state(|state| Box::pin(state.delete(query)))
224                .await
225        })
226    }
227
228    fn delete_many<'a>(&'a self, query: DeleteMany) -> AdapterFuture<'a, u64> {
229        Box::pin(async move {
230            self.run_with_state(|state| Box::pin(state.delete_many(query)))
231                .await
232        })
233    }
234
235    fn transaction<'a>(&'a self, callback: TransactionCallback<'a>) -> AdapterFuture<'a, ()> {
236        Box::pin(async move {
237            let client = self.pool.get().await.map_err(deadpool_error)?;
238            client
239                .batch_execute("BEGIN")
240                .await
241                .map_err(postgres_error)?;
242            let client = Arc::new(Mutex::new(client));
243            let adapter = DeadpoolPostgresTxAdapter {
244                client: Arc::clone(&client),
245                schema: Arc::clone(&self.schema),
246            };
247            let result = callback(Box::new(adapter)).await;
248
249            let client = client.lock().await;
250            match result {
251                Ok(()) => client.batch_execute("COMMIT").await.map_err(postgres_error),
252                Err(error) => {
253                    let _rollback_result = client.batch_execute("ROLLBACK").await;
254                    Err(error)
255                }
256            }
257        })
258    }
259
260    fn create_schema<'a>(
261        &'a self,
262        schema: &'a DbSchema,
263        _file: Option<&'a str>,
264    ) -> AdapterFuture<'a, Option<SchemaCreation>> {
265        Box::pin(async move {
266            let client = self.pool.get().await.map_err(deadpool_error)?;
267            openauth_tokio_postgres::driver::create_schema(pg_client(&client), schema).await?;
268            Ok(None)
269        })
270    }
271
272    fn run_migrations<'a>(&'a self, schema: &'a DbSchema) -> AdapterFuture<'a, ()> {
273        Box::pin(async move {
274            let client = self.pool.get().await.map_err(deadpool_error)?;
275            openauth_tokio_postgres::driver::execute_migration_plan(pg_client(&client), schema)
276                .await
277        })
278    }
279}
280
281struct DeadpoolPostgresTxAdapter {
282    client: Arc<Mutex<deadpool_postgres::Client>>,
283    schema: Arc<DbSchema>,
284}
285
286impl DeadpoolPostgresTxAdapter {
287    async fn run_with_state<T>(
288        &self,
289        f: impl for<'a> FnOnce(PostgresSqlState<'a>) -> AdapterFuture<'a, T> + Send,
290    ) -> Result<T, OpenAuthError>
291    where
292        T: Send + 'static,
293    {
294        let client = self.client.lock().await;
295        f(PostgresSqlState::new(
296            self.schema.as_ref(),
297            pg_client(&client),
298        ))
299        .await
300    }
301}
302
303impl DbAdapter for DeadpoolPostgresTxAdapter {
304    fn id(&self) -> &str {
305        "deadpool-postgres-tx"
306    }
307
308    fn capabilities(&self) -> AdapterCapabilities {
309        AdapterCapabilities::new(self.id())
310            .named("deadpool-postgres transaction")
311            .with_json()
312            .with_arrays()
313            .with_transactions()
314    }
315
316    fn create<'a>(&'a self, query: Create) -> AdapterFuture<'a, DbRecord> {
317        Box::pin(async move {
318            self.run_with_state(|state| Box::pin(state.create(query)))
319                .await
320        })
321    }
322
323    fn find_one<'a>(&'a self, query: FindOne) -> AdapterFuture<'a, Option<DbRecord>> {
324        Box::pin(async move {
325            self.run_with_state(|state| Box::pin(state.find_one(query)))
326                .await
327        })
328    }
329
330    fn find_many<'a>(&'a self, query: FindMany) -> AdapterFuture<'a, Vec<DbRecord>> {
331        Box::pin(async move {
332            self.run_with_state(|state| Box::pin(state.find_many(query)))
333                .await
334        })
335    }
336
337    fn count<'a>(&'a self, query: Count) -> AdapterFuture<'a, u64> {
338        Box::pin(async move {
339            self.run_with_state(|state| Box::pin(state.count(query)))
340                .await
341        })
342    }
343
344    fn update<'a>(&'a self, query: Update) -> AdapterFuture<'a, Option<DbRecord>> {
345        Box::pin(async move {
346            self.run_with_state(|state| Box::pin(state.update(query)))
347                .await
348        })
349    }
350
351    fn update_many<'a>(&'a self, query: UpdateMany) -> AdapterFuture<'a, u64> {
352        Box::pin(async move {
353            self.run_with_state(|state| Box::pin(state.update_many(query)))
354                .await
355        })
356    }
357
358    fn delete<'a>(&'a self, query: Delete) -> AdapterFuture<'a, ()> {
359        Box::pin(async move {
360            self.run_with_state(|state| Box::pin(state.delete(query)))
361                .await
362        })
363    }
364
365    fn delete_many<'a>(&'a self, query: DeleteMany) -> AdapterFuture<'a, u64> {
366        Box::pin(async move {
367            self.run_with_state(|state| Box::pin(state.delete_many(query)))
368                .await
369        })
370    }
371
372    fn transaction<'a>(&'a self, _callback: TransactionCallback<'a>) -> AdapterFuture<'a, ()> {
373        Box::pin(async {
374            Err(OpenAuthError::Adapter(
375                "nested deadpool-postgres transactions are not supported".to_owned(),
376            ))
377        })
378    }
379}
380
381async fn consume_deadpool_rate_limit(
382    store: &DeadpoolPostgresRateLimitStore,
383    input: RateLimitConsumeInput,
384) -> Result<RateLimitDecision, OpenAuthError> {
385    let plan = postgres_rate_limit_plan(
386        &store.names.table,
387        &store.names.key,
388        &store.names.count,
389        &store.names.last_request,
390    )?;
391    let client = store.pool.get().await.map_err(deadpool_error)?;
392    client
393        .batch_execute("BEGIN")
394        .await
395        .map_err(postgres_error)?;
396    let result = consume_postgres_rate_limit_in_tx(pg_client(&client), &plan, input).await;
397    match result {
398        Ok(decision) => {
399            client
400                .batch_execute("COMMIT")
401                .await
402                .map_err(postgres_error)?;
403            Ok(decision)
404        }
405        Err(error) => {
406            let _rollback_result = client.batch_execute("ROLLBACK").await;
407            Err(error)
408        }
409    }
410}
411
412fn pg_client(client: &deadpool_postgres::Client) -> &Client {
413    client
414}
415
416fn deadpool_error(error: impl fmt::Display) -> OpenAuthError {
417    OpenAuthError::Adapter(format!("deadpool-postgres error: {error}"))
418}
419
420pub use self::migration::{
421    ColumnToAdd, IndexToCreate, MigrationStatement, MigrationStatementKind, SchemaMigrationPlan,
422    SchemaMigrationWarning, TableToCreate,
423};