1pub mod driver;
9mod errors;
10pub mod migration;
11mod query;
12mod row;
13mod schema;
14
15use std::fmt;
16use std::sync::Arc;
17
18use openauth_core::db::{
19 auth_schema, AdapterCapabilities, AdapterFuture, AuthSchemaOptions, Count, Create, DbAdapter,
20 DbRecord, DbSchema, Delete, DeleteMany, FindMany, FindOne, JoinAdapter, SchemaCreation,
21 SqlRateLimitNames, TransactionCallback, Update, UpdateMany,
22};
23use openauth_core::error::OpenAuthError;
24use openauth_core::options::{
25 RateLimitConsumeInput, RateLimitDecision, RateLimitFuture, RateLimitStore,
26};
27use tokio::sync::Mutex;
28use tokio_postgres::{Client, NoTls};
29
30use self::driver::{consume_postgres_rate_limit_in_tx, postgres_rate_limit_plan, PostgresSqlState};
31use self::errors::postgres_error;
32use self::schema::{
33 create_schema, execute_migration_plan, plan_migrations as plan_schema_migrations,
34};
35
36#[derive(Clone)]
37pub struct TokioPostgresAdapter {
38 client: Arc<Mutex<Client>>,
39 tx_gate: Arc<Mutex<()>>,
40 schema: Arc<DbSchema>,
41}
42
43#[derive(Clone)]
44pub struct TokioPostgresRateLimitStore {
45 client: Arc<Mutex<Client>>,
46 tx_gate: Arc<Mutex<()>>,
47 names: SqlRateLimitNames,
48}
49
50impl fmt::Debug for TokioPostgresAdapter {
51 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
52 formatter
53 .debug_struct("TokioPostgresAdapter")
54 .field("schema", &self.schema)
55 .finish_non_exhaustive()
56 }
57}
58
59impl fmt::Debug for TokioPostgresRateLimitStore {
60 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
61 formatter
62 .debug_struct("TokioPostgresRateLimitStore")
63 .field("names", &self.names)
64 .finish_non_exhaustive()
65 }
66}
67
68impl TokioPostgresAdapter {
69 pub fn new(client: Client) -> Self {
70 Self::with_schema(client, auth_schema(AuthSchemaOptions::default()))
71 }
72
73 pub fn with_schema(client: Client, schema: DbSchema) -> Self {
74 Self {
75 client: Arc::new(Mutex::new(client)),
76 tx_gate: Arc::new(Mutex::new(())),
77 schema: Arc::new(schema),
78 }
79 }
80
81 pub async fn connect(database_url: &str) -> Result<Self, OpenAuthError> {
82 Self::connect_with_schema(database_url, auth_schema(AuthSchemaOptions::default())).await
83 }
84
85 pub async fn connect_with_schema(
86 database_url: &str,
87 schema: DbSchema,
88 ) -> Result<Self, OpenAuthError> {
89 let (client, connection) = tokio_postgres::connect(database_url, NoTls)
90 .await
91 .map_err(postgres_error)?;
92 tokio::spawn(async move {
93 let _connection_result = connection.await;
94 });
95 Ok(Self::with_schema(client, schema))
96 }
97
98 pub async fn plan_migrations(
99 &self,
100 schema: &DbSchema,
101 ) -> Result<SchemaMigrationPlan, OpenAuthError> {
102 let _gate = self.tx_gate.lock().await;
103 let client = self.client.lock().await;
104 plan_schema_migrations(&client, schema).await
105 }
106
107 pub async fn compile_migrations(&self, schema: &DbSchema) -> Result<String, OpenAuthError> {
108 Ok(self.plan_migrations(schema).await?.compile())
109 }
110
111 async fn run_with_state<T>(
112 &self,
113 f: impl for<'a> FnOnce(PostgresSqlState<'a>) -> AdapterFuture<'a, T> + Send,
114 ) -> Result<T, OpenAuthError>
115 where
116 T: Send + 'static,
117 {
118 let _gate = self.tx_gate.lock().await;
119 let client = self.client.lock().await;
120 f(PostgresSqlState::new(self.schema.as_ref(), &client)).await
121 }
122}
123
124impl TokioPostgresRateLimitStore {
125 pub fn new(client: Client) -> Self {
126 Self::with_table(client, "rate_limits")
127 }
128
129 pub fn with_table(client: Client, table: impl Into<String>) -> Self {
130 Self {
131 client: Arc::new(Mutex::new(client)),
132 tx_gate: Arc::new(Mutex::new(())),
133 names: SqlRateLimitNames::new(table),
134 }
135 }
136}
137
138impl From<&TokioPostgresAdapter> for TokioPostgresRateLimitStore {
139 fn from(adapter: &TokioPostgresAdapter) -> Self {
140 Self {
141 client: Arc::clone(&adapter.client),
142 tx_gate: Arc::clone(&adapter.tx_gate),
143 names: SqlRateLimitNames::from_schema(&adapter.schema),
144 }
145 }
146}
147
148impl RateLimitStore for TokioPostgresRateLimitStore {
149 fn consume<'a>(&'a self, input: RateLimitConsumeInput) -> RateLimitFuture<'a> {
150 Box::pin(async move { consume_postgres_rate_limit(self, input).await })
151 }
152}
153
154impl DbAdapter for TokioPostgresAdapter {
155 fn id(&self) -> &str {
156 "tokio-postgres"
157 }
158
159 fn capabilities(&self) -> AdapterCapabilities {
160 AdapterCapabilities::new(self.id())
161 .named("tokio-postgres")
162 .with_json()
163 .with_arrays()
164 .with_joins()
165 .with_transactions()
166 }
167
168 fn create<'a>(&'a self, query: Create) -> AdapterFuture<'a, DbRecord> {
169 Box::pin(async move {
170 self.run_with_state(|state| Box::pin(state.create(query)))
171 .await
172 })
173 }
174
175 fn find_one<'a>(&'a self, query: FindOne) -> AdapterFuture<'a, Option<DbRecord>> {
176 Box::pin(async move {
177 self.run_with_state(|state| Box::pin(state.find_one(query)))
178 .await
179 })
180 }
181
182 fn find_many<'a>(&'a self, query: FindMany) -> AdapterFuture<'a, Vec<DbRecord>> {
183 Box::pin(async move {
184 if query.joins.len() <= 1 {
185 self.run_with_state(|state| Box::pin(state.find_many(query)))
186 .await
187 } else {
188 let adapter =
189 JoinAdapter::new(self.schema.as_ref().clone(), Arc::new(self.clone()), false);
190 adapter.find_many(query).await
191 }
192 })
193 }
194
195 fn count<'a>(&'a self, query: Count) -> AdapterFuture<'a, u64> {
196 Box::pin(async move {
197 self.run_with_state(|state| Box::pin(state.count(query)))
198 .await
199 })
200 }
201
202 fn update<'a>(&'a self, query: Update) -> AdapterFuture<'a, Option<DbRecord>> {
203 Box::pin(async move {
204 self.run_with_state(|state| Box::pin(state.update(query)))
205 .await
206 })
207 }
208
209 fn update_many<'a>(&'a self, query: UpdateMany) -> AdapterFuture<'a, u64> {
210 Box::pin(async move {
211 self.run_with_state(|state| Box::pin(state.update_many(query)))
212 .await
213 })
214 }
215
216 fn delete<'a>(&'a self, query: Delete) -> AdapterFuture<'a, ()> {
217 Box::pin(async move {
218 self.run_with_state(|state| Box::pin(state.delete(query)))
219 .await
220 })
221 }
222
223 fn delete_many<'a>(&'a self, query: DeleteMany) -> AdapterFuture<'a, u64> {
224 Box::pin(async move {
225 self.run_with_state(|state| Box::pin(state.delete_many(query)))
226 .await
227 })
228 }
229
230 fn transaction<'a>(&'a self, callback: TransactionCallback<'a>) -> AdapterFuture<'a, ()> {
231 Box::pin(async move {
232 let _gate = self.tx_gate.lock().await;
233 let client = self.client.lock().await;
234 client
235 .batch_execute("BEGIN")
236 .await
237 .map_err(postgres_error)?;
238 drop(client);
239
240 let adapter = TokioPostgresTxAdapter {
241 client: Arc::clone(&self.client),
242 schema: Arc::clone(&self.schema),
243 };
244 let result = callback(Box::new(adapter)).await;
245
246 let client = self.client.lock().await;
247 match result {
248 Ok(()) => client.batch_execute("COMMIT").await.map_err(postgres_error),
249 Err(error) => {
250 let _rollback_result = client.batch_execute("ROLLBACK").await;
251 Err(error)
252 }
253 }
254 })
255 }
256
257 fn create_schema<'a>(
258 &'a self,
259 schema: &'a DbSchema,
260 _file: Option<&'a str>,
261 ) -> AdapterFuture<'a, Option<SchemaCreation>> {
262 Box::pin(async move {
263 let _gate = self.tx_gate.lock().await;
264 let client = self.client.lock().await;
265 create_schema(&client, schema).await?;
266 Ok(None)
267 })
268 }
269
270 fn run_migrations<'a>(&'a self, schema: &'a DbSchema) -> AdapterFuture<'a, ()> {
271 Box::pin(async move {
272 let _gate = self.tx_gate.lock().await;
273 let client = self.client.lock().await;
274 execute_migration_plan(&client, schema).await
275 })
276 }
277}
278
279struct TokioPostgresTxAdapter {
280 client: Arc<Mutex<Client>>,
281 schema: Arc<DbSchema>,
282}
283
284impl TokioPostgresTxAdapter {
285 async fn run_with_state<T>(
286 &self,
287 f: impl for<'a> FnOnce(PostgresSqlState<'a>) -> AdapterFuture<'a, T> + Send,
288 ) -> Result<T, OpenAuthError>
289 where
290 T: Send + 'static,
291 {
292 let client = self.client.lock().await;
293 f(PostgresSqlState::new(self.schema.as_ref(), &client)).await
294 }
295}
296
297impl DbAdapter for TokioPostgresTxAdapter {
298 fn id(&self) -> &str {
299 "tokio-postgres-tx"
300 }
301
302 fn capabilities(&self) -> AdapterCapabilities {
303 AdapterCapabilities::new(self.id())
304 .named("tokio-postgres transaction")
305 .with_json()
306 .with_arrays()
307 .with_transactions()
308 }
309
310 fn create<'a>(&'a self, query: Create) -> AdapterFuture<'a, DbRecord> {
311 Box::pin(async move {
312 self.run_with_state(|state| Box::pin(state.create(query)))
313 .await
314 })
315 }
316
317 fn find_one<'a>(&'a self, query: FindOne) -> AdapterFuture<'a, Option<DbRecord>> {
318 Box::pin(async move {
319 self.run_with_state(|state| Box::pin(state.find_one(query)))
320 .await
321 })
322 }
323
324 fn find_many<'a>(&'a self, query: FindMany) -> AdapterFuture<'a, Vec<DbRecord>> {
325 Box::pin(async move {
326 self.run_with_state(|state| Box::pin(state.find_many(query)))
327 .await
328 })
329 }
330
331 fn count<'a>(&'a self, query: Count) -> AdapterFuture<'a, u64> {
332 Box::pin(async move {
333 self.run_with_state(|state| Box::pin(state.count(query)))
334 .await
335 })
336 }
337
338 fn update<'a>(&'a self, query: Update) -> AdapterFuture<'a, Option<DbRecord>> {
339 Box::pin(async move {
340 self.run_with_state(|state| Box::pin(state.update(query)))
341 .await
342 })
343 }
344
345 fn update_many<'a>(&'a self, query: UpdateMany) -> AdapterFuture<'a, u64> {
346 Box::pin(async move {
347 self.run_with_state(|state| Box::pin(state.update_many(query)))
348 .await
349 })
350 }
351
352 fn delete<'a>(&'a self, query: Delete) -> AdapterFuture<'a, ()> {
353 Box::pin(async move {
354 self.run_with_state(|state| Box::pin(state.delete(query)))
355 .await
356 })
357 }
358
359 fn delete_many<'a>(&'a self, query: DeleteMany) -> AdapterFuture<'a, u64> {
360 Box::pin(async move {
361 self.run_with_state(|state| Box::pin(state.delete_many(query)))
362 .await
363 })
364 }
365
366 fn transaction<'a>(&'a self, _callback: TransactionCallback<'a>) -> AdapterFuture<'a, ()> {
367 Box::pin(async {
368 Err(OpenAuthError::Adapter(
369 "nested tokio-postgres transactions are not supported".to_owned(),
370 ))
371 })
372 }
373}
374
375async fn consume_postgres_rate_limit(
376 store: &TokioPostgresRateLimitStore,
377 input: RateLimitConsumeInput,
378) -> Result<RateLimitDecision, OpenAuthError> {
379 let plan = postgres_rate_limit_plan(
380 &store.names.table,
381 &store.names.key,
382 &store.names.count,
383 &store.names.last_request,
384 )?;
385 let _gate = store.tx_gate.lock().await;
386 let client = store.client.lock().await;
387 client
388 .batch_execute("BEGIN")
389 .await
390 .map_err(postgres_error)?;
391 let result = consume_postgres_rate_limit_in_tx(&client, &plan, input).await;
392 match result {
393 Ok(decision) => {
394 client
395 .batch_execute("COMMIT")
396 .await
397 .map_err(postgres_error)?;
398 Ok(decision)
399 }
400 Err(error) => {
401 let _rollback_result = client.batch_execute("ROLLBACK").await;
402 Err(error)
403 }
404 }
405}
406
407pub use self::migration::{
408 ColumnToAdd, IndexToCreate, MigrationStatement, MigrationStatementKind, SchemaMigrationPlan,
409 SchemaMigrationWarning, TableToCreate,
410};