Skip to main content

assay_auth/store/
sqlite.rs

1//! SQLite implementations of [`UserStore`] and [`SessionStore`].
2//!
3//! Tables live in the attached `auth` database (engine boot ATTACHes
4//! `data/auth.db` AS `auth` before [`crate::schema::migrate_sqlite`]
5//! creates them). Queries are schema-qualified (`auth.users`, …) so the
6//! syntax matches the PG store exactly — the `auth.` prefix resolves
7//! against the ATTACH alias on SQLite and against the schema on PG.
8
9use std::sync::Arc;
10
11use anyhow::{Context, Result};
12use sqlx::{Row, SqlitePool};
13
14use super::types::{PasskeyCred, Session, User};
15use super::{SessionStore, UserStore};
16
17/// User store backed by a shared `SqlitePool`. Mirrors
18/// [`super::postgres::PostgresUserStore`] in shape so callers swap one
19/// for the other based on the engine's selected backend.
20#[derive(Clone)]
21pub struct SqliteUserStore {
22    pool: SqlitePool,
23}
24
25impl SqliteUserStore {
26    pub fn new(pool: SqlitePool) -> Self {
27        Self { pool }
28    }
29
30    /// Wrap into an `Arc<dyn UserStore>` for [`crate::ctx::AuthCtx`].
31    pub fn into_dyn(self) -> Arc<dyn UserStore> {
32        Arc::new(self)
33    }
34}
35
36#[async_trait::async_trait]
37impl UserStore for SqliteUserStore {
38    async fn create_user(&self, user: &User) -> Result<()> {
39        sqlx::query(
40            "INSERT INTO auth.users
41                 (id, email, email_verified, display_name, password_hash, created_at)
42             VALUES (?, ?, ?, ?, NULL, ?)",
43        )
44        .bind(&user.id)
45        .bind(&user.email)
46        .bind(if user.email_verified { 1i64 } else { 0i64 })
47        .bind(&user.display_name)
48        .bind(user.created_at)
49        .execute(&self.pool)
50        .await
51        .context("auth.users insert")?;
52        Ok(())
53    }
54
55    async fn get_user_by_id(&self, id: &str) -> Result<Option<User>> {
56        let row = sqlx::query(
57            "SELECT id, email, email_verified, display_name, created_at
58             FROM auth.users WHERE id = ?",
59        )
60        .bind(id)
61        .fetch_optional(&self.pool)
62        .await
63        .context("auth.users select by id")?;
64        Ok(row.map(map_user_row_sqlite))
65    }
66
67    async fn get_user_by_email(&self, email: &str) -> Result<Option<User>> {
68        let row = sqlx::query(
69            "SELECT id, email, email_verified, display_name, created_at
70             FROM auth.users WHERE email = ?",
71        )
72        .bind(email)
73        .fetch_optional(&self.pool)
74        .await
75        .context("auth.users select by email")?;
76        Ok(row.map(map_user_row_sqlite))
77    }
78
79    async fn update_user(&self, user: &User) -> Result<()> {
80        sqlx::query(
81            "UPDATE auth.users
82             SET email = ?,
83                 email_verified = ?,
84                 display_name = ?
85             WHERE id = ?",
86        )
87        .bind(&user.email)
88        .bind(if user.email_verified { 1i64 } else { 0i64 })
89        .bind(&user.display_name)
90        .bind(&user.id)
91        .execute(&self.pool)
92        .await
93        .context("auth.users update")?;
94        Ok(())
95    }
96
97    async fn set_password_hash(&self, user_id: &str, hash: &str) -> Result<()> {
98        sqlx::query("UPDATE auth.users SET password_hash = ? WHERE id = ?")
99            .bind(hash)
100            .bind(user_id)
101            .execute(&self.pool)
102            .await
103            .context("auth.users set password_hash")?;
104        Ok(())
105    }
106
107    async fn get_password_hash(&self, user_id: &str) -> Result<Option<String>> {
108        let row: Option<(Option<String>,)> =
109            sqlx::query_as("SELECT password_hash FROM auth.users WHERE id = ?")
110                .bind(user_id)
111                .fetch_optional(&self.pool)
112                .await
113                .context("auth.users select password_hash")?;
114        Ok(row.and_then(|r| r.0))
115    }
116
117    async fn list_passkeys(&self, user_id: &str) -> Result<Vec<PasskeyCred>> {
118        let rows = sqlx::query(
119            "SELECT credential_id, public_key, sign_count, transports, created_at
120             FROM auth.passkeys WHERE user_id = ?
121             ORDER BY created_at",
122        )
123        .bind(user_id)
124        .fetch_all(&self.pool)
125        .await
126        .context("auth.passkeys list")?;
127        Ok(rows.into_iter().map(map_passkey_row_sqlite).collect())
128    }
129
130    async fn add_passkey(&self, user_id: &str, cred: &PasskeyCred) -> Result<()> {
131        sqlx::query(
132            "INSERT INTO auth.passkeys
133                 (credential_id, user_id, public_key, sign_count, transports, created_at)
134             VALUES (?, ?, ?, ?, ?, ?)",
135        )
136        .bind(&cred.credential_id)
137        .bind(user_id)
138        .bind(&cred.public_key)
139        .bind(cred.sign_count as i64)
140        .bind(cred.transports.join(","))
141        .bind(cred.created_at)
142        .execute(&self.pool)
143        .await
144        .context("auth.passkeys insert")?;
145        Ok(())
146    }
147
148    async fn remove_passkey(&self, credential_id: &[u8]) -> Result<bool> {
149        let res = sqlx::query("DELETE FROM auth.passkeys WHERE credential_id = ?")
150            .bind(credential_id)
151            .execute(&self.pool)
152            .await
153            .context("auth.passkeys delete")?;
154        Ok(res.rows_affected() > 0)
155    }
156
157    async fn link_upstream(&self, user_id: &str, provider: &str, subject: &str) -> Result<()> {
158        sqlx::query(
159            "INSERT INTO auth.user_upstream (provider, subject, user_id)
160             VALUES (?, ?, ?)
161             ON CONFLICT (provider, subject) DO UPDATE SET user_id = excluded.user_id",
162        )
163        .bind(provider)
164        .bind(subject)
165        .bind(user_id)
166        .execute(&self.pool)
167        .await
168        .context("auth.user_upstream upsert")?;
169        Ok(())
170    }
171
172    async fn get_user_by_upstream(
173        &self,
174        provider: &str,
175        subject: &str,
176    ) -> Result<Option<User>> {
177        let row = sqlx::query(
178            "SELECT u.id, u.email, u.email_verified, u.display_name, u.created_at
179             FROM auth.users u
180             JOIN auth.user_upstream l ON l.user_id = u.id
181             WHERE l.provider = ? AND l.subject = ?",
182        )
183        .bind(provider)
184        .bind(subject)
185        .fetch_optional(&self.pool)
186        .await
187        .context("auth.user_upstream lookup")?;
188        Ok(row.map(map_user_row_sqlite))
189    }
190
191    async fn list_users(
192        &self,
193        limit: i64,
194        offset: i64,
195        search: Option<&str>,
196    ) -> Result<Vec<User>> {
197        let lim = limit.clamp(1, 500);
198        let off = offset.max(0);
199        let rows = if let Some(needle) = search {
200            let pat = format!("%{}%", needle.to_lowercase());
201            sqlx::query(
202                "SELECT id, email, email_verified, display_name, created_at
203                 FROM auth.users
204                 WHERE LOWER(COALESCE(email, '')) LIKE ?
205                    OR LOWER(COALESCE(display_name, '')) LIKE ?
206                 ORDER BY created_at DESC
207                 LIMIT ? OFFSET ?",
208            )
209            .bind(&pat)
210            .bind(&pat)
211            .bind(lim)
212            .bind(off)
213            .fetch_all(&self.pool)
214            .await
215            .context("auth.users list (search)")?
216        } else {
217            sqlx::query(
218                "SELECT id, email, email_verified, display_name, created_at
219                 FROM auth.users
220                 ORDER BY created_at DESC
221                 LIMIT ? OFFSET ?",
222            )
223            .bind(lim)
224            .bind(off)
225            .fetch_all(&self.pool)
226            .await
227            .context("auth.users list")?
228        };
229        Ok(rows.into_iter().map(map_user_row_sqlite).collect())
230    }
231
232    async fn count_users(&self, search: Option<&str>) -> Result<i64> {
233        let row: (i64,) = if let Some(needle) = search {
234            let pat = format!("%{}%", needle.to_lowercase());
235            sqlx::query_as(
236                "SELECT COUNT(*) FROM auth.users
237                 WHERE LOWER(COALESCE(email, '')) LIKE ?
238                    OR LOWER(COALESCE(display_name, '')) LIKE ?",
239            )
240            .bind(&pat)
241            .bind(&pat)
242            .fetch_one(&self.pool)
243            .await
244            .context("auth.users count (search)")?
245        } else {
246            sqlx::query_as("SELECT COUNT(*) FROM auth.users")
247                .fetch_one(&self.pool)
248                .await
249                .context("auth.users count")?
250        };
251        Ok(row.0)
252    }
253
254    async fn delete_user(&self, id: &str) -> Result<bool> {
255        let res = sqlx::query("DELETE FROM auth.users WHERE id = ?")
256            .bind(id)
257            .execute(&self.pool)
258            .await
259            .context("auth.users delete")?;
260        Ok(res.rows_affected() > 0)
261    }
262
263    async fn list_upstream_for_user(&self, user_id: &str) -> Result<Vec<(String, String)>> {
264        let rows = sqlx::query(
265            "SELECT provider, subject FROM auth.user_upstream
266             WHERE user_id = ? ORDER BY provider, subject",
267        )
268        .bind(user_id)
269        .fetch_all(&self.pool)
270        .await
271        .context("auth.user_upstream list")?;
272        Ok(rows
273            .into_iter()
274            .map(|r| (r.get::<String, _>("provider"), r.get::<String, _>("subject")))
275            .collect())
276    }
277}
278
279/// Session store backed by `auth.sessions`. Independent struct from
280/// [`SqliteUserStore`] because they're independently mockable in
281/// tests and the engine may swap one without the other.
282#[derive(Clone)]
283pub struct SqliteSessionStore {
284    pool: SqlitePool,
285}
286
287impl SqliteSessionStore {
288    pub fn new(pool: SqlitePool) -> Self {
289        Self { pool }
290    }
291
292    pub fn into_dyn(self) -> Arc<dyn SessionStore> {
293        Arc::new(self)
294    }
295}
296
297#[async_trait::async_trait]
298impl SessionStore for SqliteSessionStore {
299    async fn create(&self, session: &Session) -> Result<()> {
300        sqlx::query(
301            "INSERT INTO auth.sessions
302                 (id, user_id, csrf_token, created_at, expires_at, ip_hash, user_agent_hash)
303             VALUES (?, ?, ?, ?, ?, ?, ?)",
304        )
305        .bind(&session.id)
306        .bind(&session.user_id)
307        .bind(&session.csrf_token)
308        .bind(session.created_at)
309        .bind(session.expires_at)
310        .bind(&session.ip_hash)
311        .bind(&session.user_agent_hash)
312        .execute(&self.pool)
313        .await
314        .context("auth.sessions insert")?;
315        Ok(())
316    }
317
318    async fn get(&self, id: &str) -> Result<Option<Session>> {
319        let row = sqlx::query(
320            "SELECT id, user_id, csrf_token, created_at, expires_at, ip_hash, user_agent_hash
321             FROM auth.sessions WHERE id = ?",
322        )
323        .bind(id)
324        .fetch_optional(&self.pool)
325        .await
326        .context("auth.sessions select")?;
327        Ok(row.map(map_session_row_sqlite))
328    }
329
330    async fn delete(&self, id: &str) -> Result<bool> {
331        let res = sqlx::query("DELETE FROM auth.sessions WHERE id = ?")
332            .bind(id)
333            .execute(&self.pool)
334            .await
335            .context("auth.sessions delete")?;
336        Ok(res.rows_affected() > 0)
337    }
338
339    async fn list_for_user(&self, user_id: &str) -> Result<Vec<Session>> {
340        let rows = sqlx::query(
341            "SELECT id, user_id, csrf_token, created_at, expires_at, ip_hash, user_agent_hash
342             FROM auth.sessions WHERE user_id = ? ORDER BY created_at DESC",
343        )
344        .bind(user_id)
345        .fetch_all(&self.pool)
346        .await
347        .context("auth.sessions list_for_user")?;
348        Ok(rows.into_iter().map(map_session_row_sqlite).collect())
349    }
350
351    async fn delete_for_user(&self, user_id: &str) -> Result<u64> {
352        let res = sqlx::query("DELETE FROM auth.sessions WHERE user_id = ?")
353            .bind(user_id)
354            .execute(&self.pool)
355            .await
356            .context("auth.sessions delete_for_user")?;
357        Ok(res.rows_affected())
358    }
359
360    async fn purge_expired(&self, now: f64) -> Result<u64> {
361        let res = sqlx::query("DELETE FROM auth.sessions WHERE expires_at <= ?")
362            .bind(now)
363            .execute(&self.pool)
364            .await
365            .context("auth.sessions purge_expired")?;
366        Ok(res.rows_affected())
367    }
368
369    async fn list_all(
370        &self,
371        limit: i64,
372        offset: i64,
373        user_filter: Option<&str>,
374    ) -> Result<Vec<Session>> {
375        let lim = limit.clamp(1, 500);
376        let off = offset.max(0);
377        let rows = if let Some(uid) = user_filter {
378            sqlx::query(
379                "SELECT id, user_id, csrf_token, created_at, expires_at, ip_hash, user_agent_hash
380                 FROM auth.sessions WHERE user_id = ?
381                 ORDER BY created_at DESC
382                 LIMIT ? OFFSET ?",
383            )
384            .bind(uid)
385            .bind(lim)
386            .bind(off)
387            .fetch_all(&self.pool)
388            .await
389            .context("auth.sessions list_all (user filter)")?
390        } else {
391            sqlx::query(
392                "SELECT id, user_id, csrf_token, created_at, expires_at, ip_hash, user_agent_hash
393                 FROM auth.sessions
394                 ORDER BY created_at DESC
395                 LIMIT ? OFFSET ?",
396            )
397            .bind(lim)
398            .bind(off)
399            .fetch_all(&self.pool)
400            .await
401            .context("auth.sessions list_all")?
402        };
403        Ok(rows.into_iter().map(map_session_row_sqlite).collect())
404    }
405
406    async fn count_all(&self, user_filter: Option<&str>) -> Result<i64> {
407        let row: (i64,) = if let Some(uid) = user_filter {
408            sqlx::query_as("SELECT COUNT(*) FROM auth.sessions WHERE user_id = ?")
409                .bind(uid)
410                .fetch_one(&self.pool)
411                .await
412                .context("auth.sessions count_all (user filter)")?
413        } else {
414            sqlx::query_as("SELECT COUNT(*) FROM auth.sessions")
415                .fetch_one(&self.pool)
416                .await
417                .context("auth.sessions count_all")?
418        };
419        Ok(row.0)
420    }
421}
422
423fn map_user_row_sqlite(row: sqlx::sqlite::SqliteRow) -> User {
424    let email_verified: i64 = row.get("email_verified");
425    User {
426        id: row.get("id"),
427        email: row.get("email"),
428        email_verified: email_verified != 0,
429        display_name: row.get("display_name"),
430        created_at: row.get("created_at"),
431    }
432}
433
434fn map_session_row_sqlite(row: sqlx::sqlite::SqliteRow) -> Session {
435    Session {
436        id: row.get("id"),
437        user_id: row.get("user_id"),
438        csrf_token: row.get("csrf_token"),
439        created_at: row.get("created_at"),
440        expires_at: row.get("expires_at"),
441        ip_hash: row.get("ip_hash"),
442        user_agent_hash: row.get("user_agent_hash"),
443    }
444}
445
446fn map_passkey_row_sqlite(row: sqlx::sqlite::SqliteRow) -> PasskeyCred {
447    let transports: String = row.get("transports");
448    let sign_count: i64 = row.get("sign_count");
449    PasskeyCred {
450        credential_id: row.get("credential_id"),
451        public_key: row.get("public_key"),
452        sign_count: sign_count.max(0) as u32,
453        transports: if transports.is_empty() {
454            Vec::new()
455        } else {
456            transports.split(',').map(|s| s.to_string()).collect()
457        },
458        created_at: row.get("created_at"),
459    }
460}