Skip to main content

rustauth_deadpool_postgres/
adapter.rs

1use std::sync::Arc;
2
3use deadpool_postgres::Pool;
4use rustauth_core::db::SchemaMigrationPlan;
5use rustauth_core::db::{
6    auth_schema, AdapterCapabilities, AdapterFuture, AuthSchemaOptions, Count, Create, DbAdapter,
7    DbRecord, DbSchema, Delete, DeleteMany, FindMany, FindOne, SchemaCreation, TransactionCallback,
8    Update, UpdateMany,
9};
10use rustauth_core::error::RustAuthError;
11use rustauth_tokio_postgres::driver::{postgres_error, PostgresSqlState};
12use tokio::sync::Mutex;
13
14use crate::builder::DeadpoolPostgresBuilder;
15use crate::config::{deadpool_error, pg_client};
16use crate::transaction::DeadpoolPostgresTxAdapter;
17use crate::tx_guard::PooledClientRollbackGuard;
18
19/// Production-oriented Postgres adapter backed by a `deadpool-postgres` pool.
20#[derive(Clone)]
21pub struct DeadpoolPostgresAdapter {
22    pub(crate) pool: Pool,
23    pub(crate) schema: Arc<DbSchema>,
24}
25
26impl std::fmt::Debug for DeadpoolPostgresAdapter {
27    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        formatter
29            .debug_struct("DeadpoolPostgresAdapter")
30            .field("schema", &self.schema)
31            .finish_non_exhaustive()
32    }
33}
34
35impl DeadpoolPostgresAdapter {
36    pub fn new(pool: Pool) -> Self {
37        Self::with_schema(pool, auth_schema(AuthSchemaOptions::default()))
38    }
39
40    pub fn with_schema(pool: Pool, schema: DbSchema) -> Self {
41        Self {
42            pool,
43            schema: Arc::new(schema),
44        }
45    }
46
47    pub fn pool(&self) -> &Pool {
48        &self.pool
49    }
50
51    pub fn builder() -> DeadpoolPostgresBuilder {
52        DeadpoolPostgresBuilder::new()
53    }
54
55    pub async fn plan_migrations(
56        &self,
57        schema: &DbSchema,
58    ) -> Result<SchemaMigrationPlan, RustAuthError> {
59        let client = self.pool.get().await.map_err(deadpool_error)?;
60        rustauth_tokio_postgres::driver::plan_migrations(pg_client(&client), schema).await
61    }
62
63    pub async fn validate_connection(&self) -> Result<(), RustAuthError> {
64        let client = self.pool.get().await.map_err(deadpool_error)?;
65        client
66            .simple_query("SELECT 1")
67            .await
68            .map_err(postgres_error)?;
69        Ok(())
70    }
71
72    pub async fn compile_migrations(&self, schema: &DbSchema) -> Result<String, RustAuthError> {
73        Ok(self.plan_migrations(schema).await?.compile())
74    }
75
76    async fn run_with_state<T>(
77        &self,
78        f: impl for<'a> FnOnce(PostgresSqlState<'a>) -> AdapterFuture<'a, T> + Send,
79    ) -> Result<T, RustAuthError>
80    where
81        T: Send + 'static,
82    {
83        let client = self.pool.get().await.map_err(deadpool_error)?;
84        f(PostgresSqlState::new(
85            self.schema.as_ref(),
86            pg_client(&client),
87        ))
88        .await
89    }
90}
91
92impl DbAdapter for DeadpoolPostgresAdapter {
93    fn id(&self) -> &str {
94        "deadpool-postgres"
95    }
96
97    fn capabilities(&self) -> AdapterCapabilities {
98        AdapterCapabilities::new(self.id())
99            .named("deadpool-postgres")
100            .with_uuid_ids()
101            .with_json()
102            .with_arrays()
103            .with_native_joins()
104            .with_transactions()
105    }
106
107    fn create<'a>(&'a self, query: Create) -> AdapterFuture<'a, DbRecord> {
108        Box::pin(async move {
109            self.run_with_state(|state| Box::pin(state.create(query)))
110                .await
111        })
112    }
113
114    fn find_one<'a>(&'a self, query: FindOne) -> AdapterFuture<'a, Option<DbRecord>> {
115        Box::pin(async move {
116            self.run_with_state(|state| Box::pin(state.find_one(query)))
117                .await
118        })
119    }
120
121    fn find_many<'a>(&'a self, query: FindMany) -> AdapterFuture<'a, Vec<DbRecord>> {
122        Box::pin(async move {
123            self.run_with_state(|state| Box::pin(state.find_many(query)))
124                .await
125        })
126    }
127
128    fn count<'a>(&'a self, query: Count) -> AdapterFuture<'a, u64> {
129        Box::pin(async move {
130            self.run_with_state(|state| Box::pin(state.count(query)))
131                .await
132        })
133    }
134
135    fn update<'a>(&'a self, query: Update) -> AdapterFuture<'a, Option<DbRecord>> {
136        Box::pin(async move {
137            self.run_with_state(|state| Box::pin(state.update(query)))
138                .await
139        })
140    }
141
142    fn update_many<'a>(&'a self, query: UpdateMany) -> AdapterFuture<'a, u64> {
143        Box::pin(async move {
144            self.run_with_state(|state| Box::pin(state.update_many(query)))
145                .await
146        })
147    }
148
149    fn delete<'a>(&'a self, query: Delete) -> AdapterFuture<'a, ()> {
150        Box::pin(async move {
151            self.run_with_state(|state| Box::pin(state.delete(query)))
152                .await
153        })
154    }
155
156    fn delete_many<'a>(&'a self, query: DeleteMany) -> AdapterFuture<'a, u64> {
157        Box::pin(async move {
158            self.run_with_state(|state| Box::pin(state.delete_many(query)))
159                .await
160        })
161    }
162
163    fn transaction<'a>(&'a self, callback: TransactionCallback<'a>) -> AdapterFuture<'a, ()> {
164        Box::pin(async move {
165            let client = self.pool.get().await.map_err(deadpool_error)?;
166            client
167                .batch_execute("BEGIN")
168                .await
169                .map_err(postgres_error)?;
170            let client = Arc::new(Mutex::new(client));
171            let mut guard = PooledClientRollbackGuard::new(Arc::clone(&client));
172            let adapter = DeadpoolPostgresTxAdapter {
173                client: Arc::clone(&client),
174                schema: Arc::clone(&self.schema),
175            };
176            let result = callback(Box::new(adapter)).await;
177
178            let locked = client.lock().await;
179            match result {
180                Ok(()) => {
181                    if let Err(error) = locked.batch_execute("COMMIT").await {
182                        let _rollback_result = locked.batch_execute("ROLLBACK").await;
183                        guard.disarm();
184                        return Err(postgres_error(error));
185                    }
186                    guard.disarm();
187                    Ok(())
188                }
189                Err(error) => {
190                    let _rollback_result = locked.batch_execute("ROLLBACK").await;
191                    guard.disarm();
192                    Err(error)
193                }
194            }
195        })
196    }
197
198    fn create_schema<'a>(
199        &'a self,
200        schema: &'a DbSchema,
201        _file: Option<&'a str>,
202    ) -> AdapterFuture<'a, Option<SchemaCreation>> {
203        Box::pin(async move {
204            let client = self.pool.get().await.map_err(deadpool_error)?;
205            rustauth_tokio_postgres::driver::create_schema(pg_client(&client), schema).await?;
206            Ok(None)
207        })
208    }
209
210    fn run_migrations<'a>(&'a self, schema: &'a DbSchema) -> AdapterFuture<'a, ()> {
211        Box::pin(async move {
212            let client = self.pool.get().await.map_err(deadpool_error)?;
213            rustauth_tokio_postgres::driver::execute_migration_plan(pg_client(&client), schema)
214                .await
215        })
216    }
217}