1use std::sync::Arc;
10
11use anyhow::{Context, Result};
12use sqlx::{Row, SqlitePool};
13
14use super::types::{PasskeyCred, Session, User};
15use super::{SessionStore, UserStore};
16
17#[derive(Clone)]
21pub struct SqliteUserStore {
22 pool: SqlitePool,
23}
24
25impl SqliteUserStore {
26 pub fn new(pool: SqlitePool) -> Self {
27 Self { pool }
28 }
29
30 pub fn into_dyn(self) -> Arc<dyn UserStore> {
32 Arc::new(self)
33 }
34}
35
36#[async_trait::async_trait]
37impl UserStore for SqliteUserStore {
38 async fn create_user(&self, user: &User) -> Result<()> {
39 sqlx::query(
40 "INSERT INTO auth.users
41 (id, email, email_verified, display_name, password_hash, created_at)
42 VALUES (?, ?, ?, ?, NULL, ?)",
43 )
44 .bind(&user.id)
45 .bind(&user.email)
46 .bind(if user.email_verified { 1i64 } else { 0i64 })
47 .bind(&user.display_name)
48 .bind(user.created_at)
49 .execute(&self.pool)
50 .await
51 .context("auth.users insert")?;
52 Ok(())
53 }
54
55 async fn get_user_by_id(&self, id: &str) -> Result<Option<User>> {
56 let row = sqlx::query(
57 "SELECT id, email, email_verified, display_name, created_at
58 FROM auth.users WHERE id = ?",
59 )
60 .bind(id)
61 .fetch_optional(&self.pool)
62 .await
63 .context("auth.users select by id")?;
64 Ok(row.map(map_user_row_sqlite))
65 }
66
67 async fn get_user_by_email(&self, email: &str) -> Result<Option<User>> {
68 let row = sqlx::query(
69 "SELECT id, email, email_verified, display_name, created_at
70 FROM auth.users WHERE email = ?",
71 )
72 .bind(email)
73 .fetch_optional(&self.pool)
74 .await
75 .context("auth.users select by email")?;
76 Ok(row.map(map_user_row_sqlite))
77 }
78
79 async fn update_user(&self, user: &User) -> Result<()> {
80 sqlx::query(
81 "UPDATE auth.users
82 SET email = ?,
83 email_verified = ?,
84 display_name = ?
85 WHERE id = ?",
86 )
87 .bind(&user.email)
88 .bind(if user.email_verified { 1i64 } else { 0i64 })
89 .bind(&user.display_name)
90 .bind(&user.id)
91 .execute(&self.pool)
92 .await
93 .context("auth.users update")?;
94 Ok(())
95 }
96
97 async fn set_password_hash(&self, user_id: &str, hash: &str) -> Result<()> {
98 sqlx::query("UPDATE auth.users SET password_hash = ? WHERE id = ?")
99 .bind(hash)
100 .bind(user_id)
101 .execute(&self.pool)
102 .await
103 .context("auth.users set password_hash")?;
104 Ok(())
105 }
106
107 async fn get_password_hash(&self, user_id: &str) -> Result<Option<String>> {
108 let row: Option<(Option<String>,)> =
109 sqlx::query_as("SELECT password_hash FROM auth.users WHERE id = ?")
110 .bind(user_id)
111 .fetch_optional(&self.pool)
112 .await
113 .context("auth.users select password_hash")?;
114 Ok(row.and_then(|r| r.0))
115 }
116
117 async fn list_passkeys(&self, user_id: &str) -> Result<Vec<PasskeyCred>> {
118 let rows = sqlx::query(
119 "SELECT credential_id, public_key, sign_count, transports, created_at
120 FROM auth.passkeys WHERE user_id = ?
121 ORDER BY created_at",
122 )
123 .bind(user_id)
124 .fetch_all(&self.pool)
125 .await
126 .context("auth.passkeys list")?;
127 Ok(rows.into_iter().map(map_passkey_row_sqlite).collect())
128 }
129
130 async fn add_passkey(&self, user_id: &str, cred: &PasskeyCred) -> Result<()> {
131 sqlx::query(
132 "INSERT INTO auth.passkeys
133 (credential_id, user_id, public_key, sign_count, transports, created_at)
134 VALUES (?, ?, ?, ?, ?, ?)",
135 )
136 .bind(&cred.credential_id)
137 .bind(user_id)
138 .bind(&cred.public_key)
139 .bind(cred.sign_count as i64)
140 .bind(cred.transports.join(","))
141 .bind(cred.created_at)
142 .execute(&self.pool)
143 .await
144 .context("auth.passkeys insert")?;
145 Ok(())
146 }
147
148 async fn remove_passkey(&self, credential_id: &[u8]) -> Result<bool> {
149 let res = sqlx::query("DELETE FROM auth.passkeys WHERE credential_id = ?")
150 .bind(credential_id)
151 .execute(&self.pool)
152 .await
153 .context("auth.passkeys delete")?;
154 Ok(res.rows_affected() > 0)
155 }
156
157 async fn link_upstream(&self, user_id: &str, provider: &str, subject: &str) -> Result<()> {
158 sqlx::query(
159 "INSERT INTO auth.user_upstream (provider, subject, user_id)
160 VALUES (?, ?, ?)
161 ON CONFLICT (provider, subject) DO UPDATE SET user_id = excluded.user_id",
162 )
163 .bind(provider)
164 .bind(subject)
165 .bind(user_id)
166 .execute(&self.pool)
167 .await
168 .context("auth.user_upstream upsert")?;
169 Ok(())
170 }
171
172 async fn get_user_by_upstream(
173 &self,
174 provider: &str,
175 subject: &str,
176 ) -> Result<Option<User>> {
177 let row = sqlx::query(
178 "SELECT u.id, u.email, u.email_verified, u.display_name, u.created_at
179 FROM auth.users u
180 JOIN auth.user_upstream l ON l.user_id = u.id
181 WHERE l.provider = ? AND l.subject = ?",
182 )
183 .bind(provider)
184 .bind(subject)
185 .fetch_optional(&self.pool)
186 .await
187 .context("auth.user_upstream lookup")?;
188 Ok(row.map(map_user_row_sqlite))
189 }
190
191 async fn list_users(
192 &self,
193 limit: i64,
194 offset: i64,
195 search: Option<&str>,
196 ) -> Result<Vec<User>> {
197 let lim = limit.clamp(1, 500);
198 let off = offset.max(0);
199 let rows = if let Some(needle) = search {
200 let pat = format!("%{}%", needle.to_lowercase());
201 sqlx::query(
202 "SELECT id, email, email_verified, display_name, created_at
203 FROM auth.users
204 WHERE LOWER(COALESCE(email, '')) LIKE ?
205 OR LOWER(COALESCE(display_name, '')) LIKE ?
206 ORDER BY created_at DESC
207 LIMIT ? OFFSET ?",
208 )
209 .bind(&pat)
210 .bind(&pat)
211 .bind(lim)
212 .bind(off)
213 .fetch_all(&self.pool)
214 .await
215 .context("auth.users list (search)")?
216 } else {
217 sqlx::query(
218 "SELECT id, email, email_verified, display_name, created_at
219 FROM auth.users
220 ORDER BY created_at DESC
221 LIMIT ? OFFSET ?",
222 )
223 .bind(lim)
224 .bind(off)
225 .fetch_all(&self.pool)
226 .await
227 .context("auth.users list")?
228 };
229 Ok(rows.into_iter().map(map_user_row_sqlite).collect())
230 }
231
232 async fn count_users(&self, search: Option<&str>) -> Result<i64> {
233 let row: (i64,) = if let Some(needle) = search {
234 let pat = format!("%{}%", needle.to_lowercase());
235 sqlx::query_as(
236 "SELECT COUNT(*) FROM auth.users
237 WHERE LOWER(COALESCE(email, '')) LIKE ?
238 OR LOWER(COALESCE(display_name, '')) LIKE ?",
239 )
240 .bind(&pat)
241 .bind(&pat)
242 .fetch_one(&self.pool)
243 .await
244 .context("auth.users count (search)")?
245 } else {
246 sqlx::query_as("SELECT COUNT(*) FROM auth.users")
247 .fetch_one(&self.pool)
248 .await
249 .context("auth.users count")?
250 };
251 Ok(row.0)
252 }
253
254 async fn delete_user(&self, id: &str) -> Result<bool> {
255 let res = sqlx::query("DELETE FROM auth.users WHERE id = ?")
256 .bind(id)
257 .execute(&self.pool)
258 .await
259 .context("auth.users delete")?;
260 Ok(res.rows_affected() > 0)
261 }
262
263 async fn list_upstream_for_user(&self, user_id: &str) -> Result<Vec<(String, String)>> {
264 let rows = sqlx::query(
265 "SELECT provider, subject FROM auth.user_upstream
266 WHERE user_id = ? ORDER BY provider, subject",
267 )
268 .bind(user_id)
269 .fetch_all(&self.pool)
270 .await
271 .context("auth.user_upstream list")?;
272 Ok(rows
273 .into_iter()
274 .map(|r| (r.get::<String, _>("provider"), r.get::<String, _>("subject")))
275 .collect())
276 }
277}
278
279#[derive(Clone)]
283pub struct SqliteSessionStore {
284 pool: SqlitePool,
285}
286
287impl SqliteSessionStore {
288 pub fn new(pool: SqlitePool) -> Self {
289 Self { pool }
290 }
291
292 pub fn into_dyn(self) -> Arc<dyn SessionStore> {
293 Arc::new(self)
294 }
295}
296
297#[async_trait::async_trait]
298impl SessionStore for SqliteSessionStore {
299 async fn create(&self, session: &Session) -> Result<()> {
300 sqlx::query(
301 "INSERT INTO auth.sessions
302 (id, user_id, csrf_token, created_at, expires_at, ip_hash, user_agent_hash)
303 VALUES (?, ?, ?, ?, ?, ?, ?)",
304 )
305 .bind(&session.id)
306 .bind(&session.user_id)
307 .bind(&session.csrf_token)
308 .bind(session.created_at)
309 .bind(session.expires_at)
310 .bind(&session.ip_hash)
311 .bind(&session.user_agent_hash)
312 .execute(&self.pool)
313 .await
314 .context("auth.sessions insert")?;
315 Ok(())
316 }
317
318 async fn get(&self, id: &str) -> Result<Option<Session>> {
319 let row = sqlx::query(
320 "SELECT id, user_id, csrf_token, created_at, expires_at, ip_hash, user_agent_hash
321 FROM auth.sessions WHERE id = ?",
322 )
323 .bind(id)
324 .fetch_optional(&self.pool)
325 .await
326 .context("auth.sessions select")?;
327 Ok(row.map(map_session_row_sqlite))
328 }
329
330 async fn delete(&self, id: &str) -> Result<bool> {
331 let res = sqlx::query("DELETE FROM auth.sessions WHERE id = ?")
332 .bind(id)
333 .execute(&self.pool)
334 .await
335 .context("auth.sessions delete")?;
336 Ok(res.rows_affected() > 0)
337 }
338
339 async fn list_for_user(&self, user_id: &str) -> Result<Vec<Session>> {
340 let rows = sqlx::query(
341 "SELECT id, user_id, csrf_token, created_at, expires_at, ip_hash, user_agent_hash
342 FROM auth.sessions WHERE user_id = ? ORDER BY created_at DESC",
343 )
344 .bind(user_id)
345 .fetch_all(&self.pool)
346 .await
347 .context("auth.sessions list_for_user")?;
348 Ok(rows.into_iter().map(map_session_row_sqlite).collect())
349 }
350
351 async fn delete_for_user(&self, user_id: &str) -> Result<u64> {
352 let res = sqlx::query("DELETE FROM auth.sessions WHERE user_id = ?")
353 .bind(user_id)
354 .execute(&self.pool)
355 .await
356 .context("auth.sessions delete_for_user")?;
357 Ok(res.rows_affected())
358 }
359
360 async fn purge_expired(&self, now: f64) -> Result<u64> {
361 let res = sqlx::query("DELETE FROM auth.sessions WHERE expires_at <= ?")
362 .bind(now)
363 .execute(&self.pool)
364 .await
365 .context("auth.sessions purge_expired")?;
366 Ok(res.rows_affected())
367 }
368
369 async fn list_all(
370 &self,
371 limit: i64,
372 offset: i64,
373 user_filter: Option<&str>,
374 ) -> Result<Vec<Session>> {
375 let lim = limit.clamp(1, 500);
376 let off = offset.max(0);
377 let rows = if let Some(uid) = user_filter {
378 sqlx::query(
379 "SELECT id, user_id, csrf_token, created_at, expires_at, ip_hash, user_agent_hash
380 FROM auth.sessions WHERE user_id = ?
381 ORDER BY created_at DESC
382 LIMIT ? OFFSET ?",
383 )
384 .bind(uid)
385 .bind(lim)
386 .bind(off)
387 .fetch_all(&self.pool)
388 .await
389 .context("auth.sessions list_all (user filter)")?
390 } else {
391 sqlx::query(
392 "SELECT id, user_id, csrf_token, created_at, expires_at, ip_hash, user_agent_hash
393 FROM auth.sessions
394 ORDER BY created_at DESC
395 LIMIT ? OFFSET ?",
396 )
397 .bind(lim)
398 .bind(off)
399 .fetch_all(&self.pool)
400 .await
401 .context("auth.sessions list_all")?
402 };
403 Ok(rows.into_iter().map(map_session_row_sqlite).collect())
404 }
405
406 async fn count_all(&self, user_filter: Option<&str>) -> Result<i64> {
407 let row: (i64,) = if let Some(uid) = user_filter {
408 sqlx::query_as("SELECT COUNT(*) FROM auth.sessions WHERE user_id = ?")
409 .bind(uid)
410 .fetch_one(&self.pool)
411 .await
412 .context("auth.sessions count_all (user filter)")?
413 } else {
414 sqlx::query_as("SELECT COUNT(*) FROM auth.sessions")
415 .fetch_one(&self.pool)
416 .await
417 .context("auth.sessions count_all")?
418 };
419 Ok(row.0)
420 }
421}
422
423fn map_user_row_sqlite(row: sqlx::sqlite::SqliteRow) -> User {
424 let email_verified: i64 = row.get("email_verified");
425 User {
426 id: row.get("id"),
427 email: row.get("email"),
428 email_verified: email_verified != 0,
429 display_name: row.get("display_name"),
430 created_at: row.get("created_at"),
431 }
432}
433
434fn map_session_row_sqlite(row: sqlx::sqlite::SqliteRow) -> Session {
435 Session {
436 id: row.get("id"),
437 user_id: row.get("user_id"),
438 csrf_token: row.get("csrf_token"),
439 created_at: row.get("created_at"),
440 expires_at: row.get("expires_at"),
441 ip_hash: row.get("ip_hash"),
442 user_agent_hash: row.get("user_agent_hash"),
443 }
444}
445
446fn map_passkey_row_sqlite(row: sqlx::sqlite::SqliteRow) -> PasskeyCred {
447 let transports: String = row.get("transports");
448 let sign_count: i64 = row.get("sign_count");
449 PasskeyCred {
450 credential_id: row.get("credential_id"),
451 public_key: row.get("public_key"),
452 sign_count: sign_count.max(0) as u32,
453 transports: if transports.is_empty() {
454 Vec::new()
455 } else {
456 transports.split(',').map(|s| s.to_string()).collect()
457 },
458 created_at: row.get("created_at"),
459 }
460}