Skip to main content

rustauth_tokio_postgres/
adapter.rs

1use std::fmt;
2use std::sync::Arc;
3
4use rustauth_core::db::{
5    auth_schema, AdapterCapabilities, AdapterFuture, AuthSchemaOptions, Count, Create, DbAdapter,
6    DbRecord, DbSchema, Delete, DeleteMany, FindMany, FindOne, SchemaCreation, TransactionCallback,
7    Update, UpdateMany,
8};
9use rustauth_core::error::RustAuthError;
10use tokio_postgres::Client;
11
12use crate::connection::TokioPostgresConnection;
13use crate::driver::PostgresSqlState;
14use crate::errors::postgres_error;
15use crate::rate_limit::TokioPostgresRateLimitStore;
16use crate::schema::{
17    create_schema, execute_migration_plan, plan_migrations as plan_schema_migrations,
18};
19use crate::transaction::TokioPostgresTxAdapter;
20use crate::tx_guard::SharedClientRollbackGuard;
21use rustauth_core::db::SchemaMigrationPlan;
22
23#[derive(Clone)]
24pub struct TokioPostgresAdapter {
25    pub(crate) connection: TokioPostgresConnection,
26    pub(crate) schema: Arc<DbSchema>,
27}
28
29impl fmt::Debug for TokioPostgresAdapter {
30    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
31        formatter
32            .debug_struct("TokioPostgresAdapter")
33            .field("schema", &self.schema)
34            .finish_non_exhaustive()
35    }
36}
37
38impl TokioPostgresAdapter {
39    pub fn new(client: Client) -> Self {
40        Self::with_schema(client, auth_schema(AuthSchemaOptions::default()))
41    }
42
43    pub fn with_schema(client: Client, schema: DbSchema) -> Self {
44        Self::with_connection(TokioPostgresConnection::from_client(client), schema)
45    }
46
47    pub fn with_connection(connection: TokioPostgresConnection, schema: DbSchema) -> Self {
48        Self {
49            connection,
50            schema: Arc::new(schema),
51        }
52    }
53
54    /// Returns the shared client and transaction gate used by this adapter.
55    pub fn connection(&self) -> &TokioPostgresConnection {
56        &self.connection
57    }
58
59    /// Builds a SQL-backed rate-limit store that shares this adapter's client
60    /// and transaction gate.
61    pub fn rate_limit_store(&self) -> TokioPostgresRateLimitStore {
62        TokioPostgresRateLimitStore::from(self)
63    }
64
65    /// Connects to Postgres and spawns the `tokio-postgres` connection driver.
66    pub async fn connect(database_url: &str) -> Result<Self, RustAuthError> {
67        Self::connect_with_schema(database_url, auth_schema(AuthSchemaOptions::default())).await
68    }
69
70    /// Connects to Postgres with a custom RustAuth schema.
71    ///
72    /// The returned adapter owns the client handle and keeps the driver future
73    /// running in a background task as required by `tokio-postgres`.
74    pub async fn connect_with_schema(
75        database_url: &str,
76        schema: DbSchema,
77    ) -> Result<Self, RustAuthError> {
78        Ok(Self::with_connection(
79            TokioPostgresConnection::connect(database_url).await?,
80            schema,
81        ))
82    }
83
84    pub async fn plan_migrations(
85        &self,
86        schema: &DbSchema,
87    ) -> Result<SchemaMigrationPlan, RustAuthError> {
88        let _gate = self.connection.tx_gate.write().await;
89        plan_schema_migrations(self.connection.client.as_ref(), schema).await
90    }
91
92    pub async fn compile_migrations(&self, schema: &DbSchema) -> Result<String, RustAuthError> {
93        Ok(self.plan_migrations(schema).await?.compile())
94    }
95
96    async fn run_with_state<T>(
97        &self,
98        f: impl for<'a> FnOnce(PostgresSqlState<'a>) -> AdapterFuture<'a, T> + Send,
99    ) -> Result<T, RustAuthError>
100    where
101        T: Send + 'static,
102    {
103        let _gate = self.connection.tx_gate.read().await;
104        f(PostgresSqlState::new(
105            self.schema.as_ref(),
106            self.connection.client.as_ref(),
107        ))
108        .await
109    }
110}
111
112impl DbAdapter for TokioPostgresAdapter {
113    fn id(&self) -> &str {
114        "tokio-postgres"
115    }
116
117    fn capabilities(&self) -> AdapterCapabilities {
118        AdapterCapabilities::new(self.id())
119            .named("tokio-postgres")
120            .with_uuid_ids()
121            .with_json()
122            .with_arrays()
123            .with_native_joins()
124            .with_transactions()
125    }
126
127    fn create<'a>(&'a self, query: Create) -> AdapterFuture<'a, DbRecord> {
128        Box::pin(async move {
129            self.run_with_state(|state| Box::pin(state.create(query)))
130                .await
131        })
132    }
133
134    fn find_one<'a>(&'a self, query: FindOne) -> AdapterFuture<'a, Option<DbRecord>> {
135        Box::pin(async move {
136            self.run_with_state(|state| Box::pin(state.find_one(query)))
137                .await
138        })
139    }
140
141    fn find_many<'a>(&'a self, query: FindMany) -> AdapterFuture<'a, Vec<DbRecord>> {
142        Box::pin(async move {
143            self.run_with_state(|state| Box::pin(state.find_many(query)))
144                .await
145        })
146    }
147
148    fn count<'a>(&'a self, query: Count) -> AdapterFuture<'a, u64> {
149        Box::pin(async move {
150            self.run_with_state(|state| Box::pin(state.count(query)))
151                .await
152        })
153    }
154
155    fn update<'a>(&'a self, query: Update) -> AdapterFuture<'a, Option<DbRecord>> {
156        Box::pin(async move {
157            self.run_with_state(|state| Box::pin(state.update(query)))
158                .await
159        })
160    }
161
162    fn update_many<'a>(&'a self, query: UpdateMany) -> AdapterFuture<'a, u64> {
163        Box::pin(async move {
164            self.run_with_state(|state| Box::pin(state.update_many(query)))
165                .await
166        })
167    }
168
169    fn delete<'a>(&'a self, query: Delete) -> AdapterFuture<'a, ()> {
170        Box::pin(async move {
171            self.run_with_state(|state| Box::pin(state.delete(query)))
172                .await
173        })
174    }
175
176    fn delete_many<'a>(&'a self, query: DeleteMany) -> AdapterFuture<'a, u64> {
177        Box::pin(async move {
178            self.run_with_state(|state| Box::pin(state.delete_many(query)))
179                .await
180        })
181    }
182
183    fn transaction<'a>(&'a self, callback: TransactionCallback<'a>) -> AdapterFuture<'a, ()> {
184        Box::pin(async move {
185            let gate = Arc::clone(&self.connection.tx_gate).write_owned().await;
186            self.connection
187                .client
188                .batch_execute("BEGIN")
189                .await
190                .map_err(postgres_error)?;
191            let mut guard =
192                SharedClientRollbackGuard::new(Arc::clone(&self.connection.client), gate);
193
194            let adapter = TokioPostgresTxAdapter::new(
195                Arc::clone(&self.connection.client),
196                Arc::clone(&self.schema),
197            );
198            let result = callback(Box::new(adapter)).await;
199
200            match result {
201                Ok(()) => {
202                    if let Err(error) = self.connection.client.batch_execute("COMMIT").await {
203                        let _rollback_result =
204                            self.connection.client.batch_execute("ROLLBACK").await;
205                        guard.disarm();
206                        return Err(postgres_error(error));
207                    }
208                    guard.disarm();
209                    Ok(())
210                }
211                Err(error) => {
212                    let _rollback_result = self.connection.client.batch_execute("ROLLBACK").await;
213                    guard.disarm();
214                    Err(error)
215                }
216            }
217        })
218    }
219
220    fn create_schema<'a>(
221        &'a self,
222        schema: &'a DbSchema,
223        _file: Option<&'a str>,
224    ) -> AdapterFuture<'a, Option<SchemaCreation>> {
225        Box::pin(async move {
226            let _gate = self.connection.tx_gate.write().await;
227            create_schema(self.connection.client.as_ref(), schema).await?;
228            Ok(None)
229        })
230    }
231
232    fn run_migrations<'a>(&'a self, schema: &'a DbSchema) -> AdapterFuture<'a, ()> {
233        Box::pin(async move {
234            let _gate = self.connection.tx_gate.write().await;
235            execute_migration_plan(self.connection.client.as_ref(), schema).await
236        })
237    }
238}