Skip to main content

openauth_sqlx/sqlite/
mod.rs

1mod errors;
2mod query;
3mod row;
4mod schema;
5mod state;
6mod support;
7
8use std::sync::Arc;
9
10use openauth_core::db::{
11    auth_schema, rate_limit_consume_statements, AdapterCapabilities, AdapterFuture,
12    AuthSchemaOptions, Count, Create, DbAdapter, DbRecord, DbSchema, Delete, DeleteMany, FindMany,
13    FindOne, JoinAdapter, SchemaCreation, SqlDialect, TransactionCallback, Update, UpdateMany,
14};
15use openauth_core::error::OpenAuthError;
16use openauth_core::options::{
17    RateLimitConsumeInput, RateLimitDecision, RateLimitFuture, RateLimitRecord, RateLimitStore,
18};
19use sqlx::sqlite::SqlitePoolOptions;
20use sqlx::{Executor, Row, Sqlite, SqlitePool, Transaction};
21use tokio::sync::Mutex;
22
23use self::errors::sql_error;
24use self::schema::{
25    create_schema, execute_migration_plan, plan_migrations as plan_schema_migrations,
26};
27use self::state::{SqliteExecutor, SqliteState};
28use crate::migration::SchemaMigrationPlan;
29use crate::{consume_record, RateLimitSqlNames};
30
31#[derive(Debug, Clone)]
32pub struct SqliteAdapter {
33    pool: SqlitePool,
34    schema: Arc<DbSchema>,
35}
36
37#[derive(Debug, Clone)]
38pub struct SqliteRateLimitStore {
39    pool: SqlitePool,
40    names: RateLimitSqlNames,
41}
42
43impl SqliteRateLimitStore {
44    pub fn new(pool: SqlitePool) -> Self {
45        Self::with_table(pool, "rate_limits")
46    }
47
48    pub fn with_table(pool: SqlitePool, table: impl Into<String>) -> Self {
49        Self {
50            pool,
51            names: RateLimitSqlNames::new(table),
52        }
53    }
54}
55
56impl From<&SqliteAdapter> for SqliteRateLimitStore {
57    fn from(adapter: &SqliteAdapter) -> Self {
58        Self {
59            pool: adapter.pool.clone(),
60            names: RateLimitSqlNames::from_schema(&adapter.schema),
61        }
62    }
63}
64
65impl RateLimitStore for SqliteRateLimitStore {
66    fn consume<'a>(&'a self, input: RateLimitConsumeInput) -> RateLimitFuture<'a> {
67        Box::pin(async move { consume_sqlite_rate_limit(&self.pool, &self.names, input).await })
68    }
69}
70
71async fn consume_sqlite_rate_limit(
72    pool: &SqlitePool,
73    names: &RateLimitSqlNames,
74    input: RateLimitConsumeInput,
75) -> Result<RateLimitDecision, OpenAuthError> {
76    let plan = rate_limit_consume_statements(
77        SqlDialect::Sqlite,
78        &names.table,
79        &names.key,
80        &names.count,
81        &names.last_request,
82    )?;
83    let mut tx = pool.begin().await.map_err(sql_error)?;
84    sqlx::query(&plan.insert_ignore.sql)
85        .bind(&input.key)
86        .bind(input.now_ms)
87        .execute(&mut *tx)
88        .await
89        .map_err(sql_error)?;
90    let row = sqlx::query(&plan.select.sql)
91        .bind(&input.key)
92        .fetch_optional(&mut *tx)
93        .await
94        .map_err(sql_error)?
95        .ok_or_else(|| OpenAuthError::Adapter("missing rate limit row".to_owned()))?;
96    let (decision, record, update) = consume_record(input, Some(sqlite_record(row)));
97    if decision.permitted && update {
98        sqlx::query(&plan.update.sql)
99            .bind(record.count as i64)
100            .bind(record.last_request)
101            .bind(&record.key)
102            .execute(&mut *tx)
103            .await
104            .map_err(sql_error)?;
105    }
106    tx.commit().await.map_err(sql_error)?;
107    Ok(decision)
108}
109
110fn sqlite_record(row: sqlx::sqlite::SqliteRow) -> RateLimitRecord {
111    RateLimitRecord {
112        key: String::new(),
113        count: row.get::<i64, _>("count") as u64,
114        last_request: row.get("last_request"),
115    }
116}
117
118impl SqliteAdapter {
119    pub fn new(pool: SqlitePool) -> Self {
120        Self::with_schema(pool, auth_schema(AuthSchemaOptions::default()))
121    }
122
123    pub fn with_schema(pool: SqlitePool, schema: DbSchema) -> Self {
124        Self {
125            pool,
126            schema: Arc::new(schema),
127        }
128    }
129
130    pub async fn connect(database_url: &str) -> Result<Self, OpenAuthError> {
131        Self::connect_with_schema(database_url, auth_schema(AuthSchemaOptions::default())).await
132    }
133
134    pub async fn connect_with_schema(
135        database_url: &str,
136        schema: DbSchema,
137    ) -> Result<Self, OpenAuthError> {
138        let pool = SqlitePoolOptions::new()
139            .connect(database_url)
140            .await
141            .map_err(sql_error)?;
142        Ok(Self::with_schema(pool, schema))
143    }
144
145    pub async fn plan_migrations(
146        &self,
147        schema: &DbSchema,
148    ) -> Result<SchemaMigrationPlan, OpenAuthError> {
149        plan_schema_migrations(SqliteExecutor::Pool(&self.pool), schema).await
150    }
151
152    pub async fn compile_migrations(&self, schema: &DbSchema) -> Result<String, OpenAuthError> {
153        Ok(self.plan_migrations(schema).await?.compile())
154    }
155
156    fn state(&self) -> SqliteState<'_, '_> {
157        SqliteState {
158            schema: &self.schema,
159            executor: SqliteExecutor::Pool(&self.pool),
160        }
161    }
162}
163
164impl DbAdapter for SqliteAdapter {
165    fn id(&self) -> &str {
166        "sqlx-sqlite"
167    }
168
169    fn capabilities(&self) -> AdapterCapabilities {
170        AdapterCapabilities::new(self.id())
171            .named("SQLx SQLite")
172            .with_json()
173            .with_arrays()
174            .with_joins()
175            .with_transactions()
176    }
177
178    fn create<'a>(&'a self, query: Create) -> AdapterFuture<'a, DbRecord> {
179        Box::pin(async move { self.state().create(query).await })
180    }
181
182    fn find_one<'a>(&'a self, query: FindOne) -> AdapterFuture<'a, Option<DbRecord>> {
183        Box::pin(async move { self.state().find_one(query).await })
184    }
185
186    fn find_many<'a>(&'a self, query: FindMany) -> AdapterFuture<'a, Vec<DbRecord>> {
187        Box::pin(async move {
188            if query.joins.len() <= 1 {
189                self.state().find_many(query).await
190            } else {
191                let adapter =
192                    JoinAdapter::new(self.schema.as_ref().clone(), Arc::new(self.clone()), false);
193                adapter.find_many(query).await
194            }
195        })
196    }
197
198    fn count<'a>(&'a self, query: Count) -> AdapterFuture<'a, u64> {
199        Box::pin(async move { self.state().count(query).await })
200    }
201
202    fn update<'a>(&'a self, query: Update) -> AdapterFuture<'a, Option<DbRecord>> {
203        Box::pin(async move { self.state().update(query).await })
204    }
205
206    fn update_many<'a>(&'a self, query: UpdateMany) -> AdapterFuture<'a, u64> {
207        Box::pin(async move { self.state().update_many(query).await })
208    }
209
210    fn delete<'a>(&'a self, query: Delete) -> AdapterFuture<'a, ()> {
211        Box::pin(async move { self.state().delete(query).await })
212    }
213
214    fn delete_many<'a>(&'a self, query: DeleteMany) -> AdapterFuture<'a, u64> {
215        Box::pin(async move { self.state().delete_many(query).await })
216    }
217
218    fn transaction<'a>(&'a self, callback: TransactionCallback<'a>) -> AdapterFuture<'a, ()> {
219        Box::pin(async move {
220            let tx = self.pool.begin().await.map_err(sql_error)?;
221            let adapter = Arc::new(SqliteTxAdapter {
222                schema: Arc::clone(&self.schema),
223                tx: Mutex::new(Some(tx)),
224            });
225            let result = callback(Box::new(Arc::clone(&adapter))).await;
226            let mut guard = adapter.tx.lock().await;
227            let Some(tx) = guard.take() else {
228                return Err(OpenAuthError::Adapter(
229                    "sqlite transaction was already completed".to_owned(),
230                ));
231            };
232            drop(guard);
233            match result {
234                Ok(()) => tx.commit().await.map_err(sql_error),
235                Err(error) => {
236                    let _rollback_result = tx.rollback().await;
237                    Err(error)
238                }
239            }
240        })
241    }
242
243    fn create_schema<'a>(
244        &'a self,
245        schema: &'a DbSchema,
246        _file: Option<&'a str>,
247    ) -> AdapterFuture<'a, Option<SchemaCreation>> {
248        Box::pin(async move {
249            self.pool
250                .execute("PRAGMA foreign_keys = ON")
251                .await
252                .map_err(sql_error)?;
253            create_schema(SqliteExecutor::Pool(&self.pool), schema).await?;
254            Ok(None)
255        })
256    }
257
258    fn run_migrations<'a>(&'a self, schema: &'a DbSchema) -> AdapterFuture<'a, ()> {
259        Box::pin(async move {
260            self.pool
261                .execute("PRAGMA foreign_keys = ON")
262                .await
263                .map_err(sql_error)?;
264            let plan = plan_schema_migrations(SqliteExecutor::Pool(&self.pool), schema).await?;
265            let mut executor = SqliteExecutor::Pool(&self.pool);
266            execute_migration_plan(&mut executor, &plan).await?;
267            Ok(())
268        })
269    }
270}
271
272struct SqliteTxAdapter<'tx> {
273    schema: Arc<DbSchema>,
274    tx: Mutex<Option<Transaction<'tx, Sqlite>>>,
275}
276
277impl DbAdapter for SqliteTxAdapter<'_> {
278    fn id(&self) -> &str {
279        "sqlx-sqlite"
280    }
281
282    fn capabilities(&self) -> AdapterCapabilities {
283        AdapterCapabilities::new(self.id())
284            .named("SQLx SQLite")
285            .with_json()
286            .with_arrays()
287            .with_transactions()
288    }
289
290    fn create<'a>(&'a self, query: Create) -> AdapterFuture<'a, DbRecord> {
291        Box::pin(async move { self.state().await?.create(query).await })
292    }
293
294    fn find_one<'a>(&'a self, query: FindOne) -> AdapterFuture<'a, Option<DbRecord>> {
295        Box::pin(async move { self.state().await?.find_one(query).await })
296    }
297
298    fn find_many<'a>(&'a self, query: FindMany) -> AdapterFuture<'a, Vec<DbRecord>> {
299        Box::pin(async move { self.state().await?.find_many(query).await })
300    }
301
302    fn count<'a>(&'a self, query: Count) -> AdapterFuture<'a, u64> {
303        Box::pin(async move { self.state().await?.count(query).await })
304    }
305
306    fn update<'a>(&'a self, query: Update) -> AdapterFuture<'a, Option<DbRecord>> {
307        Box::pin(async move { self.state().await?.update(query).await })
308    }
309
310    fn update_many<'a>(&'a self, query: UpdateMany) -> AdapterFuture<'a, u64> {
311        Box::pin(async move { self.state().await?.update_many(query).await })
312    }
313
314    fn delete<'a>(&'a self, query: Delete) -> AdapterFuture<'a, ()> {
315        Box::pin(async move { self.state().await?.delete(query).await })
316    }
317
318    fn delete_many<'a>(&'a self, query: DeleteMany) -> AdapterFuture<'a, u64> {
319        Box::pin(async move { self.state().await?.delete_many(query).await })
320    }
321
322    fn transaction<'a>(&'a self, callback: TransactionCallback<'a>) -> AdapterFuture<'a, ()> {
323        callback(Box::new(self))
324    }
325}
326
327impl<'tx> SqliteTxAdapter<'tx> {
328    async fn state<'a>(&'a self) -> Result<SqliteState<'a, 'tx>, OpenAuthError> {
329        let guard = self.tx.lock().await;
330        if guard.is_none() {
331            return Err(OpenAuthError::Adapter(
332                "sqlite transaction is no longer active".to_owned(),
333            ));
334        }
335        Ok(SqliteState {
336            schema: &self.schema,
337            executor: SqliteExecutor::Transaction(guard),
338        })
339    }
340}