openauth_deadpool_postgres/
lib.rs1pub mod migration;
8
9use std::fmt;
10use std::sync::Arc;
11
12use deadpool_postgres::{Config, Pool, PoolConfig, Runtime};
13use openauth_core::db::{
14 auth_schema, AdapterCapabilities, AdapterFuture, AuthSchemaOptions, Count, Create, DbAdapter,
15 DbRecord, DbSchema, Delete, DeleteMany, FindMany, FindOne, JoinAdapter, SchemaCreation,
16 SqlRateLimitNames, TransactionCallback, Update, UpdateMany,
17};
18use openauth_core::error::OpenAuthError;
19use openauth_core::options::{
20 RateLimitConsumeInput, RateLimitDecision, RateLimitFuture, RateLimitStore,
21};
22use openauth_tokio_postgres::driver::{
23 consume_postgres_rate_limit_in_tx, postgres_error, postgres_rate_limit_plan, PostgresSqlState,
24};
25use tokio::sync::Mutex;
26use tokio_postgres::{Client, NoTls};
27
28const DEFAULT_POOL_MAX_SIZE: usize = 16;
29
30#[derive(Clone)]
32pub struct DeadpoolPostgresAdapter {
33 pool: Pool,
34 schema: Arc<DbSchema>,
35}
36
37#[derive(Clone)]
39pub struct DeadpoolPostgresRateLimitStore {
40 pool: Pool,
41 names: SqlRateLimitNames,
42}
43
44impl fmt::Debug for DeadpoolPostgresAdapter {
45 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
46 formatter
47 .debug_struct("DeadpoolPostgresAdapter")
48 .field("schema", &self.schema)
49 .finish_non_exhaustive()
50 }
51}
52
53impl fmt::Debug for DeadpoolPostgresRateLimitStore {
54 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
55 formatter
56 .debug_struct("DeadpoolPostgresRateLimitStore")
57 .field("names", &self.names)
58 .finish_non_exhaustive()
59 }
60}
61
62impl DeadpoolPostgresAdapter {
63 pub fn new(pool: Pool) -> Self {
64 Self::with_schema(pool, auth_schema(AuthSchemaOptions::default()))
65 }
66
67 pub fn with_schema(pool: Pool, schema: DbSchema) -> Self {
68 Self {
69 pool,
70 schema: Arc::new(schema),
71 }
72 }
73
74 pub async fn connect(database_url: &str) -> Result<Self, OpenAuthError> {
75 Self::connect_with_schema(database_url, auth_schema(AuthSchemaOptions::default())).await
76 }
77
78 pub async fn connect_with_schema(
79 database_url: &str,
80 schema: DbSchema,
81 ) -> Result<Self, OpenAuthError> {
82 let mut config = Config::new();
83 config.url = Some(database_url.to_owned());
84 Self::from_config_with_schema(config, schema, DEFAULT_POOL_MAX_SIZE)
85 }
86
87 pub fn from_config(config: Config, max_size: usize) -> Result<Self, OpenAuthError> {
88 Self::from_config_with_schema(config, auth_schema(AuthSchemaOptions::default()), max_size)
89 }
90
91 pub fn from_config_with_schema(
92 mut config: Config,
93 schema: DbSchema,
94 max_size: usize,
95 ) -> Result<Self, OpenAuthError> {
96 config.pool = Some(PoolConfig::new(max_size));
97 let pool = config
98 .create_pool(Some(Runtime::Tokio1), NoTls)
99 .map_err(deadpool_error)?;
100 Ok(Self::with_schema(pool, schema))
101 }
102
103 pub async fn plan_migrations(
104 &self,
105 schema: &DbSchema,
106 ) -> Result<SchemaMigrationPlan, OpenAuthError> {
107 let client = self.pool.get().await.map_err(deadpool_error)?;
108 openauth_tokio_postgres::driver::plan_migrations(pg_client(&client), schema).await
109 }
110
111 pub async fn compile_migrations(&self, schema: &DbSchema) -> Result<String, OpenAuthError> {
112 Ok(self.plan_migrations(schema).await?.compile())
113 }
114
115 async fn run_with_state<T>(
116 &self,
117 f: impl for<'a> FnOnce(PostgresSqlState<'a>) -> AdapterFuture<'a, T> + Send,
118 ) -> Result<T, OpenAuthError>
119 where
120 T: Send + 'static,
121 {
122 let client = self.pool.get().await.map_err(deadpool_error)?;
123 f(PostgresSqlState::new(
124 self.schema.as_ref(),
125 pg_client(&client),
126 ))
127 .await
128 }
129}
130
131impl DeadpoolPostgresRateLimitStore {
132 pub fn new(pool: Pool) -> Self {
133 Self::with_table(pool, "rate_limits")
134 }
135
136 pub fn with_table(pool: Pool, table: impl Into<String>) -> Self {
137 Self {
138 pool,
139 names: SqlRateLimitNames::new(table),
140 }
141 }
142}
143
144impl From<&DeadpoolPostgresAdapter> for DeadpoolPostgresRateLimitStore {
145 fn from(adapter: &DeadpoolPostgresAdapter) -> Self {
146 Self {
147 pool: adapter.pool.clone(),
148 names: SqlRateLimitNames::from_schema(&adapter.schema),
149 }
150 }
151}
152
153impl RateLimitStore for DeadpoolPostgresRateLimitStore {
154 fn consume<'a>(&'a self, input: RateLimitConsumeInput) -> RateLimitFuture<'a> {
155 Box::pin(async move { consume_deadpool_rate_limit(self, input).await })
156 }
157}
158
159impl DbAdapter for DeadpoolPostgresAdapter {
160 fn id(&self) -> &str {
161 "deadpool-postgres"
162 }
163
164 fn capabilities(&self) -> AdapterCapabilities {
165 AdapterCapabilities::new(self.id())
166 .named("deadpool-postgres")
167 .with_json()
168 .with_arrays()
169 .with_joins()
170 .with_transactions()
171 }
172
173 fn create<'a>(&'a self, query: Create) -> AdapterFuture<'a, DbRecord> {
174 Box::pin(async move {
175 self.run_with_state(|state| Box::pin(state.create(query)))
176 .await
177 })
178 }
179
180 fn find_one<'a>(&'a self, query: FindOne) -> AdapterFuture<'a, Option<DbRecord>> {
181 Box::pin(async move {
182 self.run_with_state(|state| Box::pin(state.find_one(query)))
183 .await
184 })
185 }
186
187 fn find_many<'a>(&'a self, query: FindMany) -> AdapterFuture<'a, Vec<DbRecord>> {
188 Box::pin(async move {
189 if query.joins.len() <= 1 {
190 self.run_with_state(|state| Box::pin(state.find_many(query)))
191 .await
192 } else {
193 let adapter =
194 JoinAdapter::new(self.schema.as_ref().clone(), Arc::new(self.clone()), false);
195 adapter.find_many(query).await
196 }
197 })
198 }
199
200 fn count<'a>(&'a self, query: Count) -> AdapterFuture<'a, u64> {
201 Box::pin(async move {
202 self.run_with_state(|state| Box::pin(state.count(query)))
203 .await
204 })
205 }
206
207 fn update<'a>(&'a self, query: Update) -> AdapterFuture<'a, Option<DbRecord>> {
208 Box::pin(async move {
209 self.run_with_state(|state| Box::pin(state.update(query)))
210 .await
211 })
212 }
213
214 fn update_many<'a>(&'a self, query: UpdateMany) -> AdapterFuture<'a, u64> {
215 Box::pin(async move {
216 self.run_with_state(|state| Box::pin(state.update_many(query)))
217 .await
218 })
219 }
220
221 fn delete<'a>(&'a self, query: Delete) -> AdapterFuture<'a, ()> {
222 Box::pin(async move {
223 self.run_with_state(|state| Box::pin(state.delete(query)))
224 .await
225 })
226 }
227
228 fn delete_many<'a>(&'a self, query: DeleteMany) -> AdapterFuture<'a, u64> {
229 Box::pin(async move {
230 self.run_with_state(|state| Box::pin(state.delete_many(query)))
231 .await
232 })
233 }
234
235 fn transaction<'a>(&'a self, callback: TransactionCallback<'a>) -> AdapterFuture<'a, ()> {
236 Box::pin(async move {
237 let client = self.pool.get().await.map_err(deadpool_error)?;
238 client
239 .batch_execute("BEGIN")
240 .await
241 .map_err(postgres_error)?;
242 let client = Arc::new(Mutex::new(client));
243 let adapter = DeadpoolPostgresTxAdapter {
244 client: Arc::clone(&client),
245 schema: Arc::clone(&self.schema),
246 };
247 let result = callback(Box::new(adapter)).await;
248
249 let client = client.lock().await;
250 match result {
251 Ok(()) => client.batch_execute("COMMIT").await.map_err(postgres_error),
252 Err(error) => {
253 let _rollback_result = client.batch_execute("ROLLBACK").await;
254 Err(error)
255 }
256 }
257 })
258 }
259
260 fn create_schema<'a>(
261 &'a self,
262 schema: &'a DbSchema,
263 _file: Option<&'a str>,
264 ) -> AdapterFuture<'a, Option<SchemaCreation>> {
265 Box::pin(async move {
266 let client = self.pool.get().await.map_err(deadpool_error)?;
267 openauth_tokio_postgres::driver::create_schema(pg_client(&client), schema).await?;
268 Ok(None)
269 })
270 }
271
272 fn run_migrations<'a>(&'a self, schema: &'a DbSchema) -> AdapterFuture<'a, ()> {
273 Box::pin(async move {
274 let client = self.pool.get().await.map_err(deadpool_error)?;
275 openauth_tokio_postgres::driver::execute_migration_plan(pg_client(&client), schema)
276 .await
277 })
278 }
279}
280
281struct DeadpoolPostgresTxAdapter {
282 client: Arc<Mutex<deadpool_postgres::Client>>,
283 schema: Arc<DbSchema>,
284}
285
286impl DeadpoolPostgresTxAdapter {
287 async fn run_with_state<T>(
288 &self,
289 f: impl for<'a> FnOnce(PostgresSqlState<'a>) -> AdapterFuture<'a, T> + Send,
290 ) -> Result<T, OpenAuthError>
291 where
292 T: Send + 'static,
293 {
294 let client = self.client.lock().await;
295 f(PostgresSqlState::new(
296 self.schema.as_ref(),
297 pg_client(&client),
298 ))
299 .await
300 }
301}
302
303impl DbAdapter for DeadpoolPostgresTxAdapter {
304 fn id(&self) -> &str {
305 "deadpool-postgres-tx"
306 }
307
308 fn capabilities(&self) -> AdapterCapabilities {
309 AdapterCapabilities::new(self.id())
310 .named("deadpool-postgres transaction")
311 .with_json()
312 .with_arrays()
313 .with_transactions()
314 }
315
316 fn create<'a>(&'a self, query: Create) -> AdapterFuture<'a, DbRecord> {
317 Box::pin(async move {
318 self.run_with_state(|state| Box::pin(state.create(query)))
319 .await
320 })
321 }
322
323 fn find_one<'a>(&'a self, query: FindOne) -> AdapterFuture<'a, Option<DbRecord>> {
324 Box::pin(async move {
325 self.run_with_state(|state| Box::pin(state.find_one(query)))
326 .await
327 })
328 }
329
330 fn find_many<'a>(&'a self, query: FindMany) -> AdapterFuture<'a, Vec<DbRecord>> {
331 Box::pin(async move {
332 self.run_with_state(|state| Box::pin(state.find_many(query)))
333 .await
334 })
335 }
336
337 fn count<'a>(&'a self, query: Count) -> AdapterFuture<'a, u64> {
338 Box::pin(async move {
339 self.run_with_state(|state| Box::pin(state.count(query)))
340 .await
341 })
342 }
343
344 fn update<'a>(&'a self, query: Update) -> AdapterFuture<'a, Option<DbRecord>> {
345 Box::pin(async move {
346 self.run_with_state(|state| Box::pin(state.update(query)))
347 .await
348 })
349 }
350
351 fn update_many<'a>(&'a self, query: UpdateMany) -> AdapterFuture<'a, u64> {
352 Box::pin(async move {
353 self.run_with_state(|state| Box::pin(state.update_many(query)))
354 .await
355 })
356 }
357
358 fn delete<'a>(&'a self, query: Delete) -> AdapterFuture<'a, ()> {
359 Box::pin(async move {
360 self.run_with_state(|state| Box::pin(state.delete(query)))
361 .await
362 })
363 }
364
365 fn delete_many<'a>(&'a self, query: DeleteMany) -> AdapterFuture<'a, u64> {
366 Box::pin(async move {
367 self.run_with_state(|state| Box::pin(state.delete_many(query)))
368 .await
369 })
370 }
371
372 fn transaction<'a>(&'a self, _callback: TransactionCallback<'a>) -> AdapterFuture<'a, ()> {
373 Box::pin(async {
374 Err(OpenAuthError::Adapter(
375 "nested deadpool-postgres transactions are not supported".to_owned(),
376 ))
377 })
378 }
379}
380
381async fn consume_deadpool_rate_limit(
382 store: &DeadpoolPostgresRateLimitStore,
383 input: RateLimitConsumeInput,
384) -> Result<RateLimitDecision, OpenAuthError> {
385 let plan = postgres_rate_limit_plan(
386 &store.names.table,
387 &store.names.key,
388 &store.names.count,
389 &store.names.last_request,
390 )?;
391 let client = store.pool.get().await.map_err(deadpool_error)?;
392 client
393 .batch_execute("BEGIN")
394 .await
395 .map_err(postgres_error)?;
396 let result = consume_postgres_rate_limit_in_tx(pg_client(&client), &plan, input).await;
397 match result {
398 Ok(decision) => {
399 client
400 .batch_execute("COMMIT")
401 .await
402 .map_err(postgres_error)?;
403 Ok(decision)
404 }
405 Err(error) => {
406 let _rollback_result = client.batch_execute("ROLLBACK").await;
407 Err(error)
408 }
409 }
410}
411
412fn pg_client(client: &deadpool_postgres::Client) -> &Client {
413 client
414}
415
416fn deadpool_error(error: impl fmt::Display) -> OpenAuthError {
417 OpenAuthError::Adapter(format!("deadpool-postgres error: {error}"))
418}
419
420pub use self::migration::{
421 ColumnToAdd, IndexToCreate, MigrationStatement, MigrationStatementKind, SchemaMigrationPlan,
422 SchemaMigrationWarning, TableToCreate,
423};