1use std::sync::Arc;
9
10use anyhow::{Context, Result};
11use sqlx::{PgPool, Row};
12
13use super::types::{PasskeyCred, Session, User};
14use super::{SessionStore, UserStore};
15
16#[derive(Clone)]
19pub struct PostgresUserStore {
20 pool: PgPool,
21}
22
23impl PostgresUserStore {
24 pub fn new(pool: PgPool) -> Self {
25 Self { pool }
26 }
27
28 pub fn into_dyn(self) -> Arc<dyn UserStore> {
30 Arc::new(self)
31 }
32}
33
34#[async_trait::async_trait]
35impl UserStore for PostgresUserStore {
36 async fn create_user(&self, user: &User) -> Result<()> {
37 sqlx::query(
38 "INSERT INTO auth.users
39 (id, email, email_verified, display_name, password_hash, created_at)
40 VALUES ($1, $2, $3, $4, NULL, $5)",
41 )
42 .bind(&user.id)
43 .bind(&user.email)
44 .bind(user.email_verified)
45 .bind(&user.display_name)
46 .bind(user.created_at)
47 .execute(&self.pool)
48 .await
49 .context("auth.users insert")?;
50 Ok(())
51 }
52
53 async fn get_user_by_id(&self, id: &str) -> Result<Option<User>> {
54 let row = sqlx::query(
55 "SELECT id, email, email_verified, display_name, created_at
56 FROM auth.users WHERE id = $1",
57 )
58 .bind(id)
59 .fetch_optional(&self.pool)
60 .await
61 .context("auth.users select by id")?;
62 Ok(row.map(map_user_row_pg))
63 }
64
65 async fn get_user_by_email(&self, email: &str) -> Result<Option<User>> {
66 let row = sqlx::query(
67 "SELECT id, email, email_verified, display_name, created_at
68 FROM auth.users WHERE email = $1",
69 )
70 .bind(email)
71 .fetch_optional(&self.pool)
72 .await
73 .context("auth.users select by email")?;
74 Ok(row.map(map_user_row_pg))
75 }
76
77 async fn update_user(&self, user: &User) -> Result<()> {
78 sqlx::query(
79 "UPDATE auth.users
80 SET email = $2,
81 email_verified = $3,
82 display_name = $4
83 WHERE id = $1",
84 )
85 .bind(&user.id)
86 .bind(&user.email)
87 .bind(user.email_verified)
88 .bind(&user.display_name)
89 .execute(&self.pool)
90 .await
91 .context("auth.users update")?;
92 Ok(())
93 }
94
95 async fn set_password_hash(&self, user_id: &str, hash: &str) -> Result<()> {
96 sqlx::query("UPDATE auth.users SET password_hash = $2 WHERE id = $1")
97 .bind(user_id)
98 .bind(hash)
99 .execute(&self.pool)
100 .await
101 .context("auth.users set password_hash")?;
102 Ok(())
103 }
104
105 async fn get_password_hash(&self, user_id: &str) -> Result<Option<String>> {
106 let row: Option<(Option<String>,)> =
107 sqlx::query_as("SELECT password_hash FROM auth.users WHERE id = $1")
108 .bind(user_id)
109 .fetch_optional(&self.pool)
110 .await
111 .context("auth.users select password_hash")?;
112 Ok(row.and_then(|r| r.0))
113 }
114
115 async fn list_passkeys(&self, user_id: &str) -> Result<Vec<PasskeyCred>> {
116 let rows = sqlx::query(
117 "SELECT credential_id, public_key, sign_count, transports, created_at
118 FROM auth.passkeys WHERE user_id = $1
119 ORDER BY created_at",
120 )
121 .bind(user_id)
122 .fetch_all(&self.pool)
123 .await
124 .context("auth.passkeys list")?;
125 Ok(rows.into_iter().map(map_passkey_row_pg).collect())
126 }
127
128 async fn add_passkey(&self, user_id: &str, cred: &PasskeyCred) -> Result<()> {
129 sqlx::query(
130 "INSERT INTO auth.passkeys
131 (credential_id, user_id, public_key, sign_count, transports, created_at)
132 VALUES ($1, $2, $3, $4, $5, $6)",
133 )
134 .bind(&cred.credential_id)
135 .bind(user_id)
136 .bind(&cred.public_key)
137 .bind(cred.sign_count as i32)
138 .bind(cred.transports.join(","))
139 .bind(cred.created_at)
140 .execute(&self.pool)
141 .await
142 .context("auth.passkeys insert")?;
143 Ok(())
144 }
145
146 async fn remove_passkey(&self, credential_id: &[u8]) -> Result<bool> {
147 let res = sqlx::query("DELETE FROM auth.passkeys WHERE credential_id = $1")
148 .bind(credential_id)
149 .execute(&self.pool)
150 .await
151 .context("auth.passkeys delete")?;
152 Ok(res.rows_affected() > 0)
153 }
154
155 async fn link_upstream(&self, user_id: &str, provider: &str, subject: &str) -> Result<()> {
156 sqlx::query(
157 "INSERT INTO auth.user_upstream (provider, subject, user_id)
158 VALUES ($1, $2, $3)
159 ON CONFLICT (provider, subject) DO UPDATE SET user_id = EXCLUDED.user_id",
160 )
161 .bind(provider)
162 .bind(subject)
163 .bind(user_id)
164 .execute(&self.pool)
165 .await
166 .context("auth.user_upstream upsert")?;
167 Ok(())
168 }
169
170 async fn get_user_by_upstream(
171 &self,
172 provider: &str,
173 subject: &str,
174 ) -> Result<Option<User>> {
175 let row = sqlx::query(
176 "SELECT u.id, u.email, u.email_verified, u.display_name, u.created_at
177 FROM auth.users u
178 JOIN auth.user_upstream l ON l.user_id = u.id
179 WHERE l.provider = $1 AND l.subject = $2",
180 )
181 .bind(provider)
182 .bind(subject)
183 .fetch_optional(&self.pool)
184 .await
185 .context("auth.user_upstream lookup")?;
186 Ok(row.map(map_user_row_pg))
187 }
188
189 async fn list_users(
190 &self,
191 limit: i64,
192 offset: i64,
193 search: Option<&str>,
194 ) -> Result<Vec<User>> {
195 let lim = limit.clamp(1, 500);
198 let off = offset.max(0);
199 let rows = if let Some(needle) = search {
200 let pat = format!("%{}%", needle.to_lowercase());
201 sqlx::query(
202 "SELECT id, email, email_verified, display_name, created_at
203 FROM auth.users
204 WHERE LOWER(COALESCE(email, '')) LIKE $1
205 OR LOWER(COALESCE(display_name, '')) LIKE $1
206 ORDER BY created_at DESC
207 LIMIT $2 OFFSET $3",
208 )
209 .bind(pat)
210 .bind(lim)
211 .bind(off)
212 .fetch_all(&self.pool)
213 .await
214 .context("auth.users list (search)")?
215 } else {
216 sqlx::query(
217 "SELECT id, email, email_verified, display_name, created_at
218 FROM auth.users
219 ORDER BY created_at DESC
220 LIMIT $1 OFFSET $2",
221 )
222 .bind(lim)
223 .bind(off)
224 .fetch_all(&self.pool)
225 .await
226 .context("auth.users list")?
227 };
228 Ok(rows.into_iter().map(map_user_row_pg).collect())
229 }
230
231 async fn count_users(&self, search: Option<&str>) -> Result<i64> {
232 let row: (i64,) = if let Some(needle) = search {
233 let pat = format!("%{}%", needle.to_lowercase());
234 sqlx::query_as(
235 "SELECT COUNT(*) FROM auth.users
236 WHERE LOWER(COALESCE(email, '')) LIKE $1
237 OR LOWER(COALESCE(display_name, '')) LIKE $1",
238 )
239 .bind(pat)
240 .fetch_one(&self.pool)
241 .await
242 .context("auth.users count (search)")?
243 } else {
244 sqlx::query_as("SELECT COUNT(*) FROM auth.users")
245 .fetch_one(&self.pool)
246 .await
247 .context("auth.users count")?
248 };
249 Ok(row.0)
250 }
251
252 async fn delete_user(&self, id: &str) -> Result<bool> {
253 let res = sqlx::query("DELETE FROM auth.users WHERE id = $1")
254 .bind(id)
255 .execute(&self.pool)
256 .await
257 .context("auth.users delete")?;
258 Ok(res.rows_affected() > 0)
259 }
260
261 async fn list_upstream_for_user(&self, user_id: &str) -> Result<Vec<(String, String)>> {
262 let rows = sqlx::query(
263 "SELECT provider, subject FROM auth.user_upstream
264 WHERE user_id = $1 ORDER BY provider, subject",
265 )
266 .bind(user_id)
267 .fetch_all(&self.pool)
268 .await
269 .context("auth.user_upstream list")?;
270 Ok(rows
271 .into_iter()
272 .map(|r| (r.get::<String, _>("provider"), r.get::<String, _>("subject")))
273 .collect())
274 }
275}
276
277#[derive(Clone)]
281pub struct PostgresSessionStore {
282 pool: PgPool,
283}
284
285impl PostgresSessionStore {
286 pub fn new(pool: PgPool) -> Self {
287 Self { pool }
288 }
289
290 pub fn into_dyn(self) -> Arc<dyn SessionStore> {
291 Arc::new(self)
292 }
293}
294
295#[async_trait::async_trait]
296impl SessionStore for PostgresSessionStore {
297 async fn create(&self, session: &Session) -> Result<()> {
298 sqlx::query(
299 "INSERT INTO auth.sessions
300 (id, user_id, csrf_token, created_at, expires_at, ip_hash, user_agent_hash)
301 VALUES ($1, $2, $3, $4, $5, $6, $7)",
302 )
303 .bind(&session.id)
304 .bind(&session.user_id)
305 .bind(&session.csrf_token)
306 .bind(session.created_at)
307 .bind(session.expires_at)
308 .bind(&session.ip_hash)
309 .bind(&session.user_agent_hash)
310 .execute(&self.pool)
311 .await
312 .context("auth.sessions insert")?;
313 Ok(())
314 }
315
316 async fn get(&self, id: &str) -> Result<Option<Session>> {
317 let row = sqlx::query(
318 "SELECT id, user_id, csrf_token, created_at, expires_at, ip_hash, user_agent_hash
319 FROM auth.sessions WHERE id = $1",
320 )
321 .bind(id)
322 .fetch_optional(&self.pool)
323 .await
324 .context("auth.sessions select")?;
325 Ok(row.map(map_session_row_pg))
326 }
327
328 async fn delete(&self, id: &str) -> Result<bool> {
329 let res = sqlx::query("DELETE FROM auth.sessions WHERE id = $1")
330 .bind(id)
331 .execute(&self.pool)
332 .await
333 .context("auth.sessions delete")?;
334 Ok(res.rows_affected() > 0)
335 }
336
337 async fn list_for_user(&self, user_id: &str) -> Result<Vec<Session>> {
338 let rows = sqlx::query(
339 "SELECT id, user_id, csrf_token, created_at, expires_at, ip_hash, user_agent_hash
340 FROM auth.sessions WHERE user_id = $1 ORDER BY created_at DESC",
341 )
342 .bind(user_id)
343 .fetch_all(&self.pool)
344 .await
345 .context("auth.sessions list_for_user")?;
346 Ok(rows.into_iter().map(map_session_row_pg).collect())
347 }
348
349 async fn delete_for_user(&self, user_id: &str) -> Result<u64> {
350 let res = sqlx::query("DELETE FROM auth.sessions WHERE user_id = $1")
351 .bind(user_id)
352 .execute(&self.pool)
353 .await
354 .context("auth.sessions delete_for_user")?;
355 Ok(res.rows_affected())
356 }
357
358 async fn purge_expired(&self, now: f64) -> Result<u64> {
359 let res = sqlx::query("DELETE FROM auth.sessions WHERE expires_at <= $1")
360 .bind(now)
361 .execute(&self.pool)
362 .await
363 .context("auth.sessions purge_expired")?;
364 Ok(res.rows_affected())
365 }
366
367 async fn list_all(
368 &self,
369 limit: i64,
370 offset: i64,
371 user_filter: Option<&str>,
372 ) -> Result<Vec<Session>> {
373 let lim = limit.clamp(1, 500);
374 let off = offset.max(0);
375 let rows = if let Some(uid) = user_filter {
376 sqlx::query(
377 "SELECT id, user_id, csrf_token, created_at, expires_at, ip_hash, user_agent_hash
378 FROM auth.sessions WHERE user_id = $1
379 ORDER BY created_at DESC
380 LIMIT $2 OFFSET $3",
381 )
382 .bind(uid)
383 .bind(lim)
384 .bind(off)
385 .fetch_all(&self.pool)
386 .await
387 .context("auth.sessions list_all (user filter)")?
388 } else {
389 sqlx::query(
390 "SELECT id, user_id, csrf_token, created_at, expires_at, ip_hash, user_agent_hash
391 FROM auth.sessions
392 ORDER BY created_at DESC
393 LIMIT $1 OFFSET $2",
394 )
395 .bind(lim)
396 .bind(off)
397 .fetch_all(&self.pool)
398 .await
399 .context("auth.sessions list_all")?
400 };
401 Ok(rows.into_iter().map(map_session_row_pg).collect())
402 }
403
404 async fn count_all(&self, user_filter: Option<&str>) -> Result<i64> {
405 let row: (i64,) = if let Some(uid) = user_filter {
406 sqlx::query_as("SELECT COUNT(*) FROM auth.sessions WHERE user_id = $1")
407 .bind(uid)
408 .fetch_one(&self.pool)
409 .await
410 .context("auth.sessions count_all (user filter)")?
411 } else {
412 sqlx::query_as("SELECT COUNT(*) FROM auth.sessions")
413 .fetch_one(&self.pool)
414 .await
415 .context("auth.sessions count_all")?
416 };
417 Ok(row.0)
418 }
419}
420
421fn map_user_row_pg(row: sqlx::postgres::PgRow) -> User {
422 User {
423 id: row.get("id"),
424 email: row.get("email"),
425 email_verified: row.get("email_verified"),
426 display_name: row.get("display_name"),
427 created_at: row.get("created_at"),
428 }
429}
430
431fn map_session_row_pg(row: sqlx::postgres::PgRow) -> Session {
432 Session {
433 id: row.get("id"),
434 user_id: row.get("user_id"),
435 csrf_token: row.get("csrf_token"),
436 created_at: row.get("created_at"),
437 expires_at: row.get("expires_at"),
438 ip_hash: row.get("ip_hash"),
439 user_agent_hash: row.get("user_agent_hash"),
440 }
441}
442
443fn map_passkey_row_pg(row: sqlx::postgres::PgRow) -> PasskeyCred {
444 let transports: String = row.get("transports");
445 let sign_count: i32 = row.get("sign_count");
446 PasskeyCred {
447 credential_id: row.get("credential_id"),
448 public_key: row.get("public_key"),
449 sign_count: sign_count.max(0) as u32,
450 transports: if transports.is_empty() {
451 Vec::new()
452 } else {
453 transports.split(',').map(|s| s.to_string()).collect()
454 },
455 created_at: row.get("created_at"),
456 }
457}