1mod errors;
2mod query;
3mod row;
4mod schema;
5mod state;
6mod support;
7
8use std::sync::Arc;
9
10use openauth_core::db::{
11 auth_schema, rate_limit_consume_statements, AdapterCapabilities, AdapterFuture,
12 AuthSchemaOptions, Count, Create, DbAdapter, DbRecord, DbSchema, Delete, DeleteMany, FindMany,
13 FindOne, JoinAdapter, SchemaCreation, SqlDialect, TransactionCallback, Update, UpdateMany,
14};
15use openauth_core::error::OpenAuthError;
16use openauth_core::options::{
17 RateLimitConsumeInput, RateLimitDecision, RateLimitFuture, RateLimitRecord, RateLimitStore,
18};
19use sqlx::sqlite::SqlitePoolOptions;
20use sqlx::{Executor, Row, Sqlite, SqlitePool, Transaction};
21use tokio::sync::Mutex;
22
23use self::errors::sql_error;
24use self::schema::{
25 create_schema, execute_migration_plan, plan_migrations as plan_schema_migrations,
26};
27use self::state::{SqliteExecutor, SqliteState};
28use crate::migration::SchemaMigrationPlan;
29use crate::{consume_record, RateLimitSqlNames};
30
31#[derive(Debug, Clone)]
32pub struct SqliteAdapter {
33 pool: SqlitePool,
34 schema: Arc<DbSchema>,
35}
36
37#[derive(Debug, Clone)]
38pub struct SqliteRateLimitStore {
39 pool: SqlitePool,
40 names: RateLimitSqlNames,
41}
42
43impl SqliteRateLimitStore {
44 pub fn new(pool: SqlitePool) -> Self {
45 Self::with_table(pool, "rate_limits")
46 }
47
48 pub fn with_table(pool: SqlitePool, table: impl Into<String>) -> Self {
49 Self {
50 pool,
51 names: RateLimitSqlNames::new(table),
52 }
53 }
54}
55
56impl From<&SqliteAdapter> for SqliteRateLimitStore {
57 fn from(adapter: &SqliteAdapter) -> Self {
58 Self {
59 pool: adapter.pool.clone(),
60 names: RateLimitSqlNames::from_schema(&adapter.schema),
61 }
62 }
63}
64
65impl RateLimitStore for SqliteRateLimitStore {
66 fn consume<'a>(&'a self, input: RateLimitConsumeInput) -> RateLimitFuture<'a> {
67 Box::pin(async move { consume_sqlite_rate_limit(&self.pool, &self.names, input).await })
68 }
69}
70
71async fn consume_sqlite_rate_limit(
72 pool: &SqlitePool,
73 names: &RateLimitSqlNames,
74 input: RateLimitConsumeInput,
75) -> Result<RateLimitDecision, OpenAuthError> {
76 let plan = rate_limit_consume_statements(
77 SqlDialect::Sqlite,
78 &names.table,
79 &names.key,
80 &names.count,
81 &names.last_request,
82 )?;
83 let mut tx = pool.begin().await.map_err(sql_error)?;
84 sqlx::query(&plan.insert_ignore.sql)
85 .bind(&input.key)
86 .bind(input.now_ms)
87 .execute(&mut *tx)
88 .await
89 .map_err(sql_error)?;
90 let row = sqlx::query(&plan.select.sql)
91 .bind(&input.key)
92 .fetch_optional(&mut *tx)
93 .await
94 .map_err(sql_error)?
95 .ok_or_else(|| OpenAuthError::Adapter("missing rate limit row".to_owned()))?;
96 let (decision, record, update) = consume_record(input, Some(sqlite_record(row)));
97 if decision.permitted && update {
98 sqlx::query(&plan.update.sql)
99 .bind(record.count as i64)
100 .bind(record.last_request)
101 .bind(&record.key)
102 .execute(&mut *tx)
103 .await
104 .map_err(sql_error)?;
105 }
106 tx.commit().await.map_err(sql_error)?;
107 Ok(decision)
108}
109
110fn sqlite_record(row: sqlx::sqlite::SqliteRow) -> RateLimitRecord {
111 RateLimitRecord {
112 key: String::new(),
113 count: row.get::<i64, _>("count") as u64,
114 last_request: row.get("last_request"),
115 }
116}
117
118impl SqliteAdapter {
119 pub fn new(pool: SqlitePool) -> Self {
120 Self::with_schema(pool, auth_schema(AuthSchemaOptions::default()))
121 }
122
123 pub fn with_schema(pool: SqlitePool, schema: DbSchema) -> Self {
124 Self {
125 pool,
126 schema: Arc::new(schema),
127 }
128 }
129
130 pub async fn connect(database_url: &str) -> Result<Self, OpenAuthError> {
131 Self::connect_with_schema(database_url, auth_schema(AuthSchemaOptions::default())).await
132 }
133
134 pub async fn connect_with_schema(
135 database_url: &str,
136 schema: DbSchema,
137 ) -> Result<Self, OpenAuthError> {
138 let pool = SqlitePoolOptions::new()
139 .connect(database_url)
140 .await
141 .map_err(sql_error)?;
142 Ok(Self::with_schema(pool, schema))
143 }
144
145 pub async fn plan_migrations(
146 &self,
147 schema: &DbSchema,
148 ) -> Result<SchemaMigrationPlan, OpenAuthError> {
149 plan_schema_migrations(SqliteExecutor::Pool(&self.pool), schema).await
150 }
151
152 pub async fn compile_migrations(&self, schema: &DbSchema) -> Result<String, OpenAuthError> {
153 Ok(self.plan_migrations(schema).await?.compile())
154 }
155
156 fn state(&self) -> SqliteState<'_, '_> {
157 SqliteState {
158 schema: &self.schema,
159 executor: SqliteExecutor::Pool(&self.pool),
160 }
161 }
162}
163
164impl DbAdapter for SqliteAdapter {
165 fn id(&self) -> &str {
166 "sqlx-sqlite"
167 }
168
169 fn capabilities(&self) -> AdapterCapabilities {
170 AdapterCapabilities::new(self.id())
171 .named("SQLx SQLite")
172 .with_json()
173 .with_arrays()
174 .with_joins()
175 .with_transactions()
176 }
177
178 fn create<'a>(&'a self, query: Create) -> AdapterFuture<'a, DbRecord> {
179 Box::pin(async move { self.state().create(query).await })
180 }
181
182 fn find_one<'a>(&'a self, query: FindOne) -> AdapterFuture<'a, Option<DbRecord>> {
183 Box::pin(async move { self.state().find_one(query).await })
184 }
185
186 fn find_many<'a>(&'a self, query: FindMany) -> AdapterFuture<'a, Vec<DbRecord>> {
187 Box::pin(async move {
188 if query.joins.len() <= 1 {
189 self.state().find_many(query).await
190 } else {
191 let adapter =
192 JoinAdapter::new(self.schema.as_ref().clone(), Arc::new(self.clone()), false);
193 adapter.find_many(query).await
194 }
195 })
196 }
197
198 fn count<'a>(&'a self, query: Count) -> AdapterFuture<'a, u64> {
199 Box::pin(async move { self.state().count(query).await })
200 }
201
202 fn update<'a>(&'a self, query: Update) -> AdapterFuture<'a, Option<DbRecord>> {
203 Box::pin(async move { self.state().update(query).await })
204 }
205
206 fn update_many<'a>(&'a self, query: UpdateMany) -> AdapterFuture<'a, u64> {
207 Box::pin(async move { self.state().update_many(query).await })
208 }
209
210 fn delete<'a>(&'a self, query: Delete) -> AdapterFuture<'a, ()> {
211 Box::pin(async move { self.state().delete(query).await })
212 }
213
214 fn delete_many<'a>(&'a self, query: DeleteMany) -> AdapterFuture<'a, u64> {
215 Box::pin(async move { self.state().delete_many(query).await })
216 }
217
218 fn transaction<'a>(&'a self, callback: TransactionCallback<'a>) -> AdapterFuture<'a, ()> {
219 Box::pin(async move {
220 let tx = self.pool.begin().await.map_err(sql_error)?;
221 let adapter = Arc::new(SqliteTxAdapter {
222 schema: Arc::clone(&self.schema),
223 tx: Mutex::new(Some(tx)),
224 });
225 let result = callback(Box::new(Arc::clone(&adapter))).await;
226 let mut guard = adapter.tx.lock().await;
227 let Some(tx) = guard.take() else {
228 return Err(OpenAuthError::Adapter(
229 "sqlite transaction was already completed".to_owned(),
230 ));
231 };
232 drop(guard);
233 match result {
234 Ok(()) => tx.commit().await.map_err(sql_error),
235 Err(error) => {
236 let _rollback_result = tx.rollback().await;
237 Err(error)
238 }
239 }
240 })
241 }
242
243 fn create_schema<'a>(
244 &'a self,
245 schema: &'a DbSchema,
246 _file: Option<&'a str>,
247 ) -> AdapterFuture<'a, Option<SchemaCreation>> {
248 Box::pin(async move {
249 self.pool
250 .execute("PRAGMA foreign_keys = ON")
251 .await
252 .map_err(sql_error)?;
253 create_schema(SqliteExecutor::Pool(&self.pool), schema).await?;
254 Ok(None)
255 })
256 }
257
258 fn run_migrations<'a>(&'a self, schema: &'a DbSchema) -> AdapterFuture<'a, ()> {
259 Box::pin(async move {
260 self.pool
261 .execute("PRAGMA foreign_keys = ON")
262 .await
263 .map_err(sql_error)?;
264 let plan = plan_schema_migrations(SqliteExecutor::Pool(&self.pool), schema).await?;
265 let mut executor = SqliteExecutor::Pool(&self.pool);
266 execute_migration_plan(&mut executor, &plan).await?;
267 Ok(())
268 })
269 }
270}
271
272struct SqliteTxAdapter<'tx> {
273 schema: Arc<DbSchema>,
274 tx: Mutex<Option<Transaction<'tx, Sqlite>>>,
275}
276
277impl DbAdapter for SqliteTxAdapter<'_> {
278 fn id(&self) -> &str {
279 "sqlx-sqlite"
280 }
281
282 fn capabilities(&self) -> AdapterCapabilities {
283 AdapterCapabilities::new(self.id())
284 .named("SQLx SQLite")
285 .with_json()
286 .with_arrays()
287 .with_transactions()
288 }
289
290 fn create<'a>(&'a self, query: Create) -> AdapterFuture<'a, DbRecord> {
291 Box::pin(async move { self.state().await?.create(query).await })
292 }
293
294 fn find_one<'a>(&'a self, query: FindOne) -> AdapterFuture<'a, Option<DbRecord>> {
295 Box::pin(async move { self.state().await?.find_one(query).await })
296 }
297
298 fn find_many<'a>(&'a self, query: FindMany) -> AdapterFuture<'a, Vec<DbRecord>> {
299 Box::pin(async move { self.state().await?.find_many(query).await })
300 }
301
302 fn count<'a>(&'a self, query: Count) -> AdapterFuture<'a, u64> {
303 Box::pin(async move { self.state().await?.count(query).await })
304 }
305
306 fn update<'a>(&'a self, query: Update) -> AdapterFuture<'a, Option<DbRecord>> {
307 Box::pin(async move { self.state().await?.update(query).await })
308 }
309
310 fn update_many<'a>(&'a self, query: UpdateMany) -> AdapterFuture<'a, u64> {
311 Box::pin(async move { self.state().await?.update_many(query).await })
312 }
313
314 fn delete<'a>(&'a self, query: Delete) -> AdapterFuture<'a, ()> {
315 Box::pin(async move { self.state().await?.delete(query).await })
316 }
317
318 fn delete_many<'a>(&'a self, query: DeleteMany) -> AdapterFuture<'a, u64> {
319 Box::pin(async move { self.state().await?.delete_many(query).await })
320 }
321
322 fn transaction<'a>(&'a self, callback: TransactionCallback<'a>) -> AdapterFuture<'a, ()> {
323 callback(Box::new(self))
324 }
325}
326
327impl<'tx> SqliteTxAdapter<'tx> {
328 async fn state<'a>(&'a self) -> Result<SqliteState<'a, 'tx>, OpenAuthError> {
329 let guard = self.tx.lock().await;
330 if guard.is_none() {
331 return Err(OpenAuthError::Adapter(
332 "sqlite transaction is no longer active".to_owned(),
333 ));
334 }
335 Ok(SqliteState {
336 schema: &self.schema,
337 executor: SqliteExecutor::Transaction(guard),
338 })
339 }
340}