Skip to main content

assay_auth/store/
postgres.rs

1//! Postgres implementations of [`UserStore`] and [`SessionStore`].
2//!
3//! Tables live in the `auth` schema (created/migrated by
4//! [`crate::schema::migrate_postgres`]). Queries are schema-qualified
5//! (`auth.users`, …) so the connection's `search_path` is irrelevant —
6//! matches the rest of the codebase's "explicit > implicit" stance.
7
8use std::sync::Arc;
9
10use anyhow::{Context, Result};
11use sqlx::{PgPool, Row};
12
13use super::types::{PasskeyCred, Session, User};
14use super::{SessionStore, UserStore};
15
16/// User store backed by a shared `PgPool`. Cheap to clone (the
17/// underlying pool is `Arc`d already).
18#[derive(Clone)]
19pub struct PostgresUserStore {
20    pool: PgPool,
21}
22
23impl PostgresUserStore {
24    pub fn new(pool: PgPool) -> Self {
25        Self { pool }
26    }
27
28    /// Wrap into an `Arc<dyn UserStore>` for [`crate::ctx::AuthCtx`].
29    pub fn into_dyn(self) -> Arc<dyn UserStore> {
30        Arc::new(self)
31    }
32}
33
34#[async_trait::async_trait]
35impl UserStore for PostgresUserStore {
36    async fn create_user(&self, user: &User) -> Result<()> {
37        sqlx::query(
38            "INSERT INTO auth.users
39                 (id, email, email_verified, display_name, password_hash, created_at)
40             VALUES ($1, $2, $3, $4, NULL, $5)",
41        )
42        .bind(&user.id)
43        .bind(&user.email)
44        .bind(user.email_verified)
45        .bind(&user.display_name)
46        .bind(user.created_at)
47        .execute(&self.pool)
48        .await
49        .context("auth.users insert")?;
50        Ok(())
51    }
52
53    async fn get_user_by_id(&self, id: &str) -> Result<Option<User>> {
54        let row = sqlx::query(
55            "SELECT id, email, email_verified, display_name, created_at
56             FROM auth.users WHERE id = $1",
57        )
58        .bind(id)
59        .fetch_optional(&self.pool)
60        .await
61        .context("auth.users select by id")?;
62        Ok(row.map(map_user_row_pg))
63    }
64
65    async fn get_user_by_email(&self, email: &str) -> Result<Option<User>> {
66        let row = sqlx::query(
67            "SELECT id, email, email_verified, display_name, created_at
68             FROM auth.users WHERE email = $1",
69        )
70        .bind(email)
71        .fetch_optional(&self.pool)
72        .await
73        .context("auth.users select by email")?;
74        Ok(row.map(map_user_row_pg))
75    }
76
77    async fn update_user(&self, user: &User) -> Result<()> {
78        sqlx::query(
79            "UPDATE auth.users
80             SET email = $2,
81                 email_verified = $3,
82                 display_name = $4
83             WHERE id = $1",
84        )
85        .bind(&user.id)
86        .bind(&user.email)
87        .bind(user.email_verified)
88        .bind(&user.display_name)
89        .execute(&self.pool)
90        .await
91        .context("auth.users update")?;
92        Ok(())
93    }
94
95    async fn set_password_hash(&self, user_id: &str, hash: &str) -> Result<()> {
96        sqlx::query("UPDATE auth.users SET password_hash = $2 WHERE id = $1")
97            .bind(user_id)
98            .bind(hash)
99            .execute(&self.pool)
100            .await
101            .context("auth.users set password_hash")?;
102        Ok(())
103    }
104
105    async fn get_password_hash(&self, user_id: &str) -> Result<Option<String>> {
106        let row: Option<(Option<String>,)> =
107            sqlx::query_as("SELECT password_hash FROM auth.users WHERE id = $1")
108                .bind(user_id)
109                .fetch_optional(&self.pool)
110                .await
111                .context("auth.users select password_hash")?;
112        Ok(row.and_then(|r| r.0))
113    }
114
115    async fn list_passkeys(&self, user_id: &str) -> Result<Vec<PasskeyCred>> {
116        let rows = sqlx::query(
117            "SELECT credential_id, public_key, sign_count, transports, created_at
118             FROM auth.passkeys WHERE user_id = $1
119             ORDER BY created_at",
120        )
121        .bind(user_id)
122        .fetch_all(&self.pool)
123        .await
124        .context("auth.passkeys list")?;
125        Ok(rows.into_iter().map(map_passkey_row_pg).collect())
126    }
127
128    async fn add_passkey(&self, user_id: &str, cred: &PasskeyCred) -> Result<()> {
129        sqlx::query(
130            "INSERT INTO auth.passkeys
131                 (credential_id, user_id, public_key, sign_count, transports, created_at)
132             VALUES ($1, $2, $3, $4, $5, $6)",
133        )
134        .bind(&cred.credential_id)
135        .bind(user_id)
136        .bind(&cred.public_key)
137        .bind(cred.sign_count as i32)
138        .bind(cred.transports.join(","))
139        .bind(cred.created_at)
140        .execute(&self.pool)
141        .await
142        .context("auth.passkeys insert")?;
143        Ok(())
144    }
145
146    async fn remove_passkey(&self, credential_id: &[u8]) -> Result<bool> {
147        let res = sqlx::query("DELETE FROM auth.passkeys WHERE credential_id = $1")
148            .bind(credential_id)
149            .execute(&self.pool)
150            .await
151            .context("auth.passkeys delete")?;
152        Ok(res.rows_affected() > 0)
153    }
154
155    async fn link_upstream(&self, user_id: &str, provider: &str, subject: &str) -> Result<()> {
156        sqlx::query(
157            "INSERT INTO auth.user_upstream (provider, subject, user_id)
158             VALUES ($1, $2, $3)
159             ON CONFLICT (provider, subject) DO UPDATE SET user_id = EXCLUDED.user_id",
160        )
161        .bind(provider)
162        .bind(subject)
163        .bind(user_id)
164        .execute(&self.pool)
165        .await
166        .context("auth.user_upstream upsert")?;
167        Ok(())
168    }
169
170    async fn get_user_by_upstream(
171        &self,
172        provider: &str,
173        subject: &str,
174    ) -> Result<Option<User>> {
175        let row = sqlx::query(
176            "SELECT u.id, u.email, u.email_verified, u.display_name, u.created_at
177             FROM auth.users u
178             JOIN auth.user_upstream l ON l.user_id = u.id
179             WHERE l.provider = $1 AND l.subject = $2",
180        )
181        .bind(provider)
182        .bind(subject)
183        .fetch_optional(&self.pool)
184        .await
185        .context("auth.user_upstream lookup")?;
186        Ok(row.map(map_user_row_pg))
187    }
188
189    async fn list_users(
190        &self,
191        limit: i64,
192        offset: i64,
193        search: Option<&str>,
194    ) -> Result<Vec<User>> {
195        // Cap limit at 500 — the admin dashboard pages 50 at a time so
196        // anything bigger is either a misconfigured client or a probe.
197        let lim = limit.clamp(1, 500);
198        let off = offset.max(0);
199        let rows = if let Some(needle) = search {
200            let pat = format!("%{}%", needle.to_lowercase());
201            sqlx::query(
202                "SELECT id, email, email_verified, display_name, created_at
203                 FROM auth.users
204                 WHERE LOWER(COALESCE(email, '')) LIKE $1
205                    OR LOWER(COALESCE(display_name, '')) LIKE $1
206                 ORDER BY created_at DESC
207                 LIMIT $2 OFFSET $3",
208            )
209            .bind(pat)
210            .bind(lim)
211            .bind(off)
212            .fetch_all(&self.pool)
213            .await
214            .context("auth.users list (search)")?
215        } else {
216            sqlx::query(
217                "SELECT id, email, email_verified, display_name, created_at
218                 FROM auth.users
219                 ORDER BY created_at DESC
220                 LIMIT $1 OFFSET $2",
221            )
222            .bind(lim)
223            .bind(off)
224            .fetch_all(&self.pool)
225            .await
226            .context("auth.users list")?
227        };
228        Ok(rows.into_iter().map(map_user_row_pg).collect())
229    }
230
231    async fn count_users(&self, search: Option<&str>) -> Result<i64> {
232        let row: (i64,) = if let Some(needle) = search {
233            let pat = format!("%{}%", needle.to_lowercase());
234            sqlx::query_as(
235                "SELECT COUNT(*) FROM auth.users
236                 WHERE LOWER(COALESCE(email, '')) LIKE $1
237                    OR LOWER(COALESCE(display_name, '')) LIKE $1",
238            )
239            .bind(pat)
240            .fetch_one(&self.pool)
241            .await
242            .context("auth.users count (search)")?
243        } else {
244            sqlx::query_as("SELECT COUNT(*) FROM auth.users")
245                .fetch_one(&self.pool)
246                .await
247                .context("auth.users count")?
248        };
249        Ok(row.0)
250    }
251
252    async fn delete_user(&self, id: &str) -> Result<bool> {
253        let res = sqlx::query("DELETE FROM auth.users WHERE id = $1")
254            .bind(id)
255            .execute(&self.pool)
256            .await
257            .context("auth.users delete")?;
258        Ok(res.rows_affected() > 0)
259    }
260
261    async fn list_upstream_for_user(&self, user_id: &str) -> Result<Vec<(String, String)>> {
262        let rows = sqlx::query(
263            "SELECT provider, subject FROM auth.user_upstream
264             WHERE user_id = $1 ORDER BY provider, subject",
265        )
266        .bind(user_id)
267        .fetch_all(&self.pool)
268        .await
269        .context("auth.user_upstream list")?;
270        Ok(rows
271            .into_iter()
272            .map(|r| (r.get::<String, _>("provider"), r.get::<String, _>("subject")))
273            .collect())
274    }
275}
276
277/// Session store backed by `auth.sessions`. Independent struct from
278/// [`PostgresUserStore`] because they're independently mockable in
279/// tests and the engine may swap one without the other.
280#[derive(Clone)]
281pub struct PostgresSessionStore {
282    pool: PgPool,
283}
284
285impl PostgresSessionStore {
286    pub fn new(pool: PgPool) -> Self {
287        Self { pool }
288    }
289
290    pub fn into_dyn(self) -> Arc<dyn SessionStore> {
291        Arc::new(self)
292    }
293}
294
295#[async_trait::async_trait]
296impl SessionStore for PostgresSessionStore {
297    async fn create(&self, session: &Session) -> Result<()> {
298        sqlx::query(
299            "INSERT INTO auth.sessions
300                 (id, user_id, csrf_token, created_at, expires_at, ip_hash, user_agent_hash)
301             VALUES ($1, $2, $3, $4, $5, $6, $7)",
302        )
303        .bind(&session.id)
304        .bind(&session.user_id)
305        .bind(&session.csrf_token)
306        .bind(session.created_at)
307        .bind(session.expires_at)
308        .bind(&session.ip_hash)
309        .bind(&session.user_agent_hash)
310        .execute(&self.pool)
311        .await
312        .context("auth.sessions insert")?;
313        Ok(())
314    }
315
316    async fn get(&self, id: &str) -> Result<Option<Session>> {
317        let row = sqlx::query(
318            "SELECT id, user_id, csrf_token, created_at, expires_at, ip_hash, user_agent_hash
319             FROM auth.sessions WHERE id = $1",
320        )
321        .bind(id)
322        .fetch_optional(&self.pool)
323        .await
324        .context("auth.sessions select")?;
325        Ok(row.map(map_session_row_pg))
326    }
327
328    async fn delete(&self, id: &str) -> Result<bool> {
329        let res = sqlx::query("DELETE FROM auth.sessions WHERE id = $1")
330            .bind(id)
331            .execute(&self.pool)
332            .await
333            .context("auth.sessions delete")?;
334        Ok(res.rows_affected() > 0)
335    }
336
337    async fn list_for_user(&self, user_id: &str) -> Result<Vec<Session>> {
338        let rows = sqlx::query(
339            "SELECT id, user_id, csrf_token, created_at, expires_at, ip_hash, user_agent_hash
340             FROM auth.sessions WHERE user_id = $1 ORDER BY created_at DESC",
341        )
342        .bind(user_id)
343        .fetch_all(&self.pool)
344        .await
345        .context("auth.sessions list_for_user")?;
346        Ok(rows.into_iter().map(map_session_row_pg).collect())
347    }
348
349    async fn delete_for_user(&self, user_id: &str) -> Result<u64> {
350        let res = sqlx::query("DELETE FROM auth.sessions WHERE user_id = $1")
351            .bind(user_id)
352            .execute(&self.pool)
353            .await
354            .context("auth.sessions delete_for_user")?;
355        Ok(res.rows_affected())
356    }
357
358    async fn purge_expired(&self, now: f64) -> Result<u64> {
359        let res = sqlx::query("DELETE FROM auth.sessions WHERE expires_at <= $1")
360            .bind(now)
361            .execute(&self.pool)
362            .await
363            .context("auth.sessions purge_expired")?;
364        Ok(res.rows_affected())
365    }
366
367    async fn list_all(
368        &self,
369        limit: i64,
370        offset: i64,
371        user_filter: Option<&str>,
372    ) -> Result<Vec<Session>> {
373        let lim = limit.clamp(1, 500);
374        let off = offset.max(0);
375        let rows = if let Some(uid) = user_filter {
376            sqlx::query(
377                "SELECT id, user_id, csrf_token, created_at, expires_at, ip_hash, user_agent_hash
378                 FROM auth.sessions WHERE user_id = $1
379                 ORDER BY created_at DESC
380                 LIMIT $2 OFFSET $3",
381            )
382            .bind(uid)
383            .bind(lim)
384            .bind(off)
385            .fetch_all(&self.pool)
386            .await
387            .context("auth.sessions list_all (user filter)")?
388        } else {
389            sqlx::query(
390                "SELECT id, user_id, csrf_token, created_at, expires_at, ip_hash, user_agent_hash
391                 FROM auth.sessions
392                 ORDER BY created_at DESC
393                 LIMIT $1 OFFSET $2",
394            )
395            .bind(lim)
396            .bind(off)
397            .fetch_all(&self.pool)
398            .await
399            .context("auth.sessions list_all")?
400        };
401        Ok(rows.into_iter().map(map_session_row_pg).collect())
402    }
403
404    async fn count_all(&self, user_filter: Option<&str>) -> Result<i64> {
405        let row: (i64,) = if let Some(uid) = user_filter {
406            sqlx::query_as("SELECT COUNT(*) FROM auth.sessions WHERE user_id = $1")
407                .bind(uid)
408                .fetch_one(&self.pool)
409                .await
410                .context("auth.sessions count_all (user filter)")?
411        } else {
412            sqlx::query_as("SELECT COUNT(*) FROM auth.sessions")
413                .fetch_one(&self.pool)
414                .await
415                .context("auth.sessions count_all")?
416        };
417        Ok(row.0)
418    }
419}
420
421fn map_user_row_pg(row: sqlx::postgres::PgRow) -> User {
422    User {
423        id: row.get("id"),
424        email: row.get("email"),
425        email_verified: row.get("email_verified"),
426        display_name: row.get("display_name"),
427        created_at: row.get("created_at"),
428    }
429}
430
431fn map_session_row_pg(row: sqlx::postgres::PgRow) -> Session {
432    Session {
433        id: row.get("id"),
434        user_id: row.get("user_id"),
435        csrf_token: row.get("csrf_token"),
436        created_at: row.get("created_at"),
437        expires_at: row.get("expires_at"),
438        ip_hash: row.get("ip_hash"),
439        user_agent_hash: row.get("user_agent_hash"),
440    }
441}
442
443fn map_passkey_row_pg(row: sqlx::postgres::PgRow) -> PasskeyCred {
444    let transports: String = row.get("transports");
445    let sign_count: i32 = row.get("sign_count");
446    PasskeyCred {
447        credential_id: row.get("credential_id"),
448        public_key: row.get("public_key"),
449        sign_count: sign_count.max(0) as u32,
450        transports: if transports.is_empty() {
451            Vec::new()
452        } else {
453            transports.split(',').map(|s| s.to_string()).collect()
454        },
455        created_at: row.get("created_at"),
456    }
457}