Skip to main content

openauth_tokio_postgres/
lib.rs

1//! Minimal `tokio-postgres` database adapter for OpenAuth.
2//!
3//! This crate is useful when the application already owns a single
4//! `tokio_postgres::Client` or wants the smallest async Postgres adapter.
5//! Production applications that need pooling should prefer
6//! `openauth-deadpool-postgres`.
7
8pub mod driver;
9mod errors;
10pub mod migration;
11mod query;
12mod row;
13mod schema;
14
15use std::fmt;
16use std::sync::Arc;
17
18use openauth_core::db::{
19    auth_schema, AdapterCapabilities, AdapterFuture, AuthSchemaOptions, Count, Create, DbAdapter,
20    DbRecord, DbSchema, Delete, DeleteMany, FindMany, FindOne, JoinAdapter, SchemaCreation,
21    SqlRateLimitNames, TransactionCallback, Update, UpdateMany,
22};
23use openauth_core::error::OpenAuthError;
24use openauth_core::options::{
25    RateLimitConsumeInput, RateLimitDecision, RateLimitFuture, RateLimitStore,
26};
27use tokio::sync::Mutex;
28use tokio_postgres::{Client, NoTls};
29
30use self::driver::{consume_postgres_rate_limit_in_tx, postgres_rate_limit_plan, PostgresSqlState};
31use self::errors::postgres_error;
32use self::schema::{
33    create_schema, execute_migration_plan, plan_migrations as plan_schema_migrations,
34};
35
36#[derive(Clone)]
37pub struct TokioPostgresAdapter {
38    client: Arc<Mutex<Client>>,
39    tx_gate: Arc<Mutex<()>>,
40    schema: Arc<DbSchema>,
41}
42
43#[derive(Clone)]
44pub struct TokioPostgresRateLimitStore {
45    client: Arc<Mutex<Client>>,
46    tx_gate: Arc<Mutex<()>>,
47    names: SqlRateLimitNames,
48}
49
50impl fmt::Debug for TokioPostgresAdapter {
51    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
52        formatter
53            .debug_struct("TokioPostgresAdapter")
54            .field("schema", &self.schema)
55            .finish_non_exhaustive()
56    }
57}
58
59impl fmt::Debug for TokioPostgresRateLimitStore {
60    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
61        formatter
62            .debug_struct("TokioPostgresRateLimitStore")
63            .field("names", &self.names)
64            .finish_non_exhaustive()
65    }
66}
67
68impl TokioPostgresAdapter {
69    pub fn new(client: Client) -> Self {
70        Self::with_schema(client, auth_schema(AuthSchemaOptions::default()))
71    }
72
73    pub fn with_schema(client: Client, schema: DbSchema) -> Self {
74        Self {
75            client: Arc::new(Mutex::new(client)),
76            tx_gate: Arc::new(Mutex::new(())),
77            schema: Arc::new(schema),
78        }
79    }
80
81    pub async fn connect(database_url: &str) -> Result<Self, OpenAuthError> {
82        Self::connect_with_schema(database_url, auth_schema(AuthSchemaOptions::default())).await
83    }
84
85    pub async fn connect_with_schema(
86        database_url: &str,
87        schema: DbSchema,
88    ) -> Result<Self, OpenAuthError> {
89        let (client, connection) = tokio_postgres::connect(database_url, NoTls)
90            .await
91            .map_err(postgres_error)?;
92        tokio::spawn(async move {
93            let _connection_result = connection.await;
94        });
95        Ok(Self::with_schema(client, schema))
96    }
97
98    pub async fn plan_migrations(
99        &self,
100        schema: &DbSchema,
101    ) -> Result<SchemaMigrationPlan, OpenAuthError> {
102        let _gate = self.tx_gate.lock().await;
103        let client = self.client.lock().await;
104        plan_schema_migrations(&client, schema).await
105    }
106
107    pub async fn compile_migrations(&self, schema: &DbSchema) -> Result<String, OpenAuthError> {
108        Ok(self.plan_migrations(schema).await?.compile())
109    }
110
111    async fn run_with_state<T>(
112        &self,
113        f: impl for<'a> FnOnce(PostgresSqlState<'a>) -> AdapterFuture<'a, T> + Send,
114    ) -> Result<T, OpenAuthError>
115    where
116        T: Send + 'static,
117    {
118        let _gate = self.tx_gate.lock().await;
119        let client = self.client.lock().await;
120        f(PostgresSqlState::new(self.schema.as_ref(), &client)).await
121    }
122}
123
124impl TokioPostgresRateLimitStore {
125    pub fn new(client: Client) -> Self {
126        Self::with_table(client, "rate_limits")
127    }
128
129    pub fn with_table(client: Client, table: impl Into<String>) -> Self {
130        Self {
131            client: Arc::new(Mutex::new(client)),
132            tx_gate: Arc::new(Mutex::new(())),
133            names: SqlRateLimitNames::new(table),
134        }
135    }
136}
137
138impl From<&TokioPostgresAdapter> for TokioPostgresRateLimitStore {
139    fn from(adapter: &TokioPostgresAdapter) -> Self {
140        Self {
141            client: Arc::clone(&adapter.client),
142            tx_gate: Arc::clone(&adapter.tx_gate),
143            names: SqlRateLimitNames::from_schema(&adapter.schema),
144        }
145    }
146}
147
148impl RateLimitStore for TokioPostgresRateLimitStore {
149    fn consume<'a>(&'a self, input: RateLimitConsumeInput) -> RateLimitFuture<'a> {
150        Box::pin(async move { consume_postgres_rate_limit(self, input).await })
151    }
152}
153
154impl DbAdapter for TokioPostgresAdapter {
155    fn id(&self) -> &str {
156        "tokio-postgres"
157    }
158
159    fn capabilities(&self) -> AdapterCapabilities {
160        AdapterCapabilities::new(self.id())
161            .named("tokio-postgres")
162            .with_json()
163            .with_arrays()
164            .with_joins()
165            .with_transactions()
166    }
167
168    fn create<'a>(&'a self, query: Create) -> AdapterFuture<'a, DbRecord> {
169        Box::pin(async move {
170            self.run_with_state(|state| Box::pin(state.create(query)))
171                .await
172        })
173    }
174
175    fn find_one<'a>(&'a self, query: FindOne) -> AdapterFuture<'a, Option<DbRecord>> {
176        Box::pin(async move {
177            self.run_with_state(|state| Box::pin(state.find_one(query)))
178                .await
179        })
180    }
181
182    fn find_many<'a>(&'a self, query: FindMany) -> AdapterFuture<'a, Vec<DbRecord>> {
183        Box::pin(async move {
184            if query.joins.len() <= 1 {
185                self.run_with_state(|state| Box::pin(state.find_many(query)))
186                    .await
187            } else {
188                let adapter =
189                    JoinAdapter::new(self.schema.as_ref().clone(), Arc::new(self.clone()), false);
190                adapter.find_many(query).await
191            }
192        })
193    }
194
195    fn count<'a>(&'a self, query: Count) -> AdapterFuture<'a, u64> {
196        Box::pin(async move {
197            self.run_with_state(|state| Box::pin(state.count(query)))
198                .await
199        })
200    }
201
202    fn update<'a>(&'a self, query: Update) -> AdapterFuture<'a, Option<DbRecord>> {
203        Box::pin(async move {
204            self.run_with_state(|state| Box::pin(state.update(query)))
205                .await
206        })
207    }
208
209    fn update_many<'a>(&'a self, query: UpdateMany) -> AdapterFuture<'a, u64> {
210        Box::pin(async move {
211            self.run_with_state(|state| Box::pin(state.update_many(query)))
212                .await
213        })
214    }
215
216    fn delete<'a>(&'a self, query: Delete) -> AdapterFuture<'a, ()> {
217        Box::pin(async move {
218            self.run_with_state(|state| Box::pin(state.delete(query)))
219                .await
220        })
221    }
222
223    fn delete_many<'a>(&'a self, query: DeleteMany) -> AdapterFuture<'a, u64> {
224        Box::pin(async move {
225            self.run_with_state(|state| Box::pin(state.delete_many(query)))
226                .await
227        })
228    }
229
230    fn transaction<'a>(&'a self, callback: TransactionCallback<'a>) -> AdapterFuture<'a, ()> {
231        Box::pin(async move {
232            let _gate = self.tx_gate.lock().await;
233            let client = self.client.lock().await;
234            client
235                .batch_execute("BEGIN")
236                .await
237                .map_err(postgres_error)?;
238            drop(client);
239
240            let adapter = TokioPostgresTxAdapter {
241                client: Arc::clone(&self.client),
242                schema: Arc::clone(&self.schema),
243            };
244            let result = callback(Box::new(adapter)).await;
245
246            let client = self.client.lock().await;
247            match result {
248                Ok(()) => client.batch_execute("COMMIT").await.map_err(postgres_error),
249                Err(error) => {
250                    let _rollback_result = client.batch_execute("ROLLBACK").await;
251                    Err(error)
252                }
253            }
254        })
255    }
256
257    fn create_schema<'a>(
258        &'a self,
259        schema: &'a DbSchema,
260        _file: Option<&'a str>,
261    ) -> AdapterFuture<'a, Option<SchemaCreation>> {
262        Box::pin(async move {
263            let _gate = self.tx_gate.lock().await;
264            let client = self.client.lock().await;
265            create_schema(&client, schema).await?;
266            Ok(None)
267        })
268    }
269
270    fn run_migrations<'a>(&'a self, schema: &'a DbSchema) -> AdapterFuture<'a, ()> {
271        Box::pin(async move {
272            let _gate = self.tx_gate.lock().await;
273            let client = self.client.lock().await;
274            execute_migration_plan(&client, schema).await
275        })
276    }
277}
278
279struct TokioPostgresTxAdapter {
280    client: Arc<Mutex<Client>>,
281    schema: Arc<DbSchema>,
282}
283
284impl TokioPostgresTxAdapter {
285    async fn run_with_state<T>(
286        &self,
287        f: impl for<'a> FnOnce(PostgresSqlState<'a>) -> AdapterFuture<'a, T> + Send,
288    ) -> Result<T, OpenAuthError>
289    where
290        T: Send + 'static,
291    {
292        let client = self.client.lock().await;
293        f(PostgresSqlState::new(self.schema.as_ref(), &client)).await
294    }
295}
296
297impl DbAdapter for TokioPostgresTxAdapter {
298    fn id(&self) -> &str {
299        "tokio-postgres-tx"
300    }
301
302    fn capabilities(&self) -> AdapterCapabilities {
303        AdapterCapabilities::new(self.id())
304            .named("tokio-postgres transaction")
305            .with_json()
306            .with_arrays()
307            .with_transactions()
308    }
309
310    fn create<'a>(&'a self, query: Create) -> AdapterFuture<'a, DbRecord> {
311        Box::pin(async move {
312            self.run_with_state(|state| Box::pin(state.create(query)))
313                .await
314        })
315    }
316
317    fn find_one<'a>(&'a self, query: FindOne) -> AdapterFuture<'a, Option<DbRecord>> {
318        Box::pin(async move {
319            self.run_with_state(|state| Box::pin(state.find_one(query)))
320                .await
321        })
322    }
323
324    fn find_many<'a>(&'a self, query: FindMany) -> AdapterFuture<'a, Vec<DbRecord>> {
325        Box::pin(async move {
326            self.run_with_state(|state| Box::pin(state.find_many(query)))
327                .await
328        })
329    }
330
331    fn count<'a>(&'a self, query: Count) -> AdapterFuture<'a, u64> {
332        Box::pin(async move {
333            self.run_with_state(|state| Box::pin(state.count(query)))
334                .await
335        })
336    }
337
338    fn update<'a>(&'a self, query: Update) -> AdapterFuture<'a, Option<DbRecord>> {
339        Box::pin(async move {
340            self.run_with_state(|state| Box::pin(state.update(query)))
341                .await
342        })
343    }
344
345    fn update_many<'a>(&'a self, query: UpdateMany) -> AdapterFuture<'a, u64> {
346        Box::pin(async move {
347            self.run_with_state(|state| Box::pin(state.update_many(query)))
348                .await
349        })
350    }
351
352    fn delete<'a>(&'a self, query: Delete) -> AdapterFuture<'a, ()> {
353        Box::pin(async move {
354            self.run_with_state(|state| Box::pin(state.delete(query)))
355                .await
356        })
357    }
358
359    fn delete_many<'a>(&'a self, query: DeleteMany) -> AdapterFuture<'a, u64> {
360        Box::pin(async move {
361            self.run_with_state(|state| Box::pin(state.delete_many(query)))
362                .await
363        })
364    }
365
366    fn transaction<'a>(&'a self, _callback: TransactionCallback<'a>) -> AdapterFuture<'a, ()> {
367        Box::pin(async {
368            Err(OpenAuthError::Adapter(
369                "nested tokio-postgres transactions are not supported".to_owned(),
370            ))
371        })
372    }
373}
374
375async fn consume_postgres_rate_limit(
376    store: &TokioPostgresRateLimitStore,
377    input: RateLimitConsumeInput,
378) -> Result<RateLimitDecision, OpenAuthError> {
379    let plan = postgres_rate_limit_plan(
380        &store.names.table,
381        &store.names.key,
382        &store.names.count,
383        &store.names.last_request,
384    )?;
385    let _gate = store.tx_gate.lock().await;
386    let client = store.client.lock().await;
387    client
388        .batch_execute("BEGIN")
389        .await
390        .map_err(postgres_error)?;
391    let result = consume_postgres_rate_limit_in_tx(&client, &plan, input).await;
392    match result {
393        Ok(decision) => {
394            client
395                .batch_execute("COMMIT")
396                .await
397                .map_err(postgres_error)?;
398            Ok(decision)
399        }
400        Err(error) => {
401            let _rollback_result = client.batch_execute("ROLLBACK").await;
402            Err(error)
403        }
404    }
405}
406
407pub use self::migration::{
408    ColumnToAdd, IndexToCreate, MigrationStatement, MigrationStatementKind, SchemaMigrationPlan,
409    SchemaMigrationWarning, TableToCreate,
410};