1use std::sync::{Arc, Mutex};
29
30use pylon_auth::{Account, AccountBackend};
31use rusqlite::Connection;
32
33const SQLITE_TABLE: &str = "_pylon_accounts";
34const PG_TABLE: &str = "_pylon_accounts";
35
36pub struct SqliteAccountBackend {
41 conn: Arc<Mutex<Connection>>,
42}
43
44impl SqliteAccountBackend {
45 pub fn open(path: &str) -> Result<Self, String> {
46 let conn = Connection::open(path).map_err(|e| format!("open: {e}"))?;
47 Self::from_connection(conn)
48 }
49
50 pub fn in_memory() -> Result<Self, String> {
51 let conn = Connection::open_in_memory().map_err(|e| format!("open: {e}"))?;
52 Self::from_connection(conn)
53 }
54
55 fn from_connection(conn: Connection) -> Result<Self, String> {
56 conn.execute_batch(&format!(
61 "CREATE TABLE IF NOT EXISTS {SQLITE_TABLE} (
62 id TEXT PRIMARY KEY,
63 user_id TEXT NOT NULL,
64 provider_id TEXT NOT NULL,
65 account_id TEXT NOT NULL,
66 access_token TEXT,
67 refresh_token TEXT,
68 id_token TEXT,
69 access_token_expires_at INTEGER,
70 refresh_token_expires_at INTEGER,
71 scope TEXT,
72 password TEXT,
73 created_at INTEGER NOT NULL,
74 updated_at INTEGER NOT NULL,
75 UNIQUE (provider_id, account_id)
76 );
77 CREATE INDEX IF NOT EXISTS {SQLITE_TABLE}_user_idx ON {SQLITE_TABLE}(user_id);"
78 ))
79 .map_err(|e| format!("init schema: {e}"))?;
80 Ok(Self {
81 conn: Arc::new(Mutex::new(conn)),
82 })
83 }
84}
85
86#[allow(clippy::too_many_arguments)]
87fn row_to_account(
88 id: String,
89 user_id: String,
90 provider_id: String,
91 account_id: String,
92 access_token: Option<String>,
93 refresh_token: Option<String>,
94 id_token: Option<String>,
95 access_token_expires_at: Option<i64>,
96 refresh_token_expires_at: Option<i64>,
97 scope: Option<String>,
98 password: Option<String>,
99 created_at: i64,
100 updated_at: i64,
101) -> Account {
102 Account {
103 id,
104 user_id,
105 provider_id,
106 account_id,
107 access_token,
108 refresh_token,
109 id_token,
110 access_token_expires_at: access_token_expires_at.map(|n| n as u64),
111 refresh_token_expires_at: refresh_token_expires_at.map(|n| n as u64),
112 scope,
113 password,
114 created_at: created_at as u64,
115 updated_at: updated_at as u64,
116 }
117}
118
119const SELECT_COLS: &str = "id, user_id, provider_id, account_id, access_token, \
120 refresh_token, id_token, access_token_expires_at, refresh_token_expires_at, \
121 scope, password, created_at, updated_at";
122
123impl AccountBackend for SqliteAccountBackend {
124 fn upsert(&self, a: &Account) {
125 if let Ok(guard) = self.conn.lock() {
126 let _ = guard.execute(
130 &format!(
131 "INSERT INTO {SQLITE_TABLE}
132 (id, user_id, provider_id, account_id, access_token, refresh_token,
133 id_token, access_token_expires_at, refresh_token_expires_at,
134 scope, password, created_at, updated_at)
135 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13)
136 ON CONFLICT(provider_id, account_id) DO UPDATE SET
137 user_id = excluded.user_id,
138 access_token = excluded.access_token,
139 refresh_token = excluded.refresh_token,
140 id_token = excluded.id_token,
141 access_token_expires_at = excluded.access_token_expires_at,
142 refresh_token_expires_at = excluded.refresh_token_expires_at,
143 scope = excluded.scope,
144 password = excluded.password,
145 updated_at = excluded.updated_at"
146 ),
147 rusqlite::params![
148 a.id,
149 a.user_id,
150 a.provider_id,
151 a.account_id,
152 a.access_token,
153 a.refresh_token,
154 a.id_token,
155 a.access_token_expires_at.map(|n| n as i64),
156 a.refresh_token_expires_at.map(|n| n as i64),
157 a.scope,
158 a.password,
159 a.created_at as i64,
160 a.updated_at as i64,
161 ],
162 );
163 }
164 }
165
166 fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
167 let guard = self.conn.lock().ok()?;
168 guard
169 .query_row(
170 &format!(
171 "SELECT {SELECT_COLS}
172 FROM {SQLITE_TABLE}
173 WHERE provider_id = ?1 AND account_id = ?2"
174 ),
175 rusqlite::params![provider_id, account_id],
176 |row| {
177 Ok(row_to_account(
178 row.get(0)?,
179 row.get(1)?,
180 row.get(2)?,
181 row.get(3)?,
182 row.get::<_, Option<String>>(4)?,
183 row.get::<_, Option<String>>(5)?,
184 row.get::<_, Option<String>>(6)?,
185 row.get::<_, Option<i64>>(7)?,
186 row.get::<_, Option<i64>>(8)?,
187 row.get::<_, Option<String>>(9)?,
188 row.get::<_, Option<String>>(10)?,
189 row.get(11)?,
190 row.get(12)?,
191 ))
192 },
193 )
194 .ok()
195 }
196
197 fn find_for_user(&self, user_id: &str) -> Vec<Account> {
198 let Ok(guard) = self.conn.lock() else {
199 return Vec::new();
200 };
201 let mut stmt = match guard.prepare(&format!(
202 "SELECT {SELECT_COLS} FROM {SQLITE_TABLE} WHERE user_id = ?1"
203 )) {
204 Ok(s) => s,
205 Err(_) => return Vec::new(),
206 };
207 let iter = match stmt.query_map(rusqlite::params![user_id], |row| {
208 Ok(row_to_account(
209 row.get(0)?,
210 row.get(1)?,
211 row.get(2)?,
212 row.get(3)?,
213 row.get::<_, Option<String>>(4)?,
214 row.get::<_, Option<String>>(5)?,
215 row.get::<_, Option<String>>(6)?,
216 row.get::<_, Option<i64>>(7)?,
217 row.get::<_, Option<i64>>(8)?,
218 row.get::<_, Option<String>>(9)?,
219 row.get::<_, Option<String>>(10)?,
220 row.get(11)?,
221 row.get(12)?,
222 ))
223 }) {
224 Ok(i) => i,
225 Err(_) => return Vec::new(),
226 };
227 iter.flatten().collect()
228 }
229
230 fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
231 let Ok(guard) = self.conn.lock() else {
232 return false;
233 };
234 guard
235 .execute(
236 &format!("DELETE FROM {SQLITE_TABLE} WHERE provider_id = ?1 AND account_id = ?2"),
237 rusqlite::params![provider_id, account_id],
238 )
239 .map(|n| n > 0)
240 .unwrap_or(false)
241 }
242
243 fn list_all(&self) -> Vec<Account> {
244 let Ok(guard) = self.conn.lock() else {
245 return Vec::new();
246 };
247 let mut stmt = match guard.prepare(&format!("SELECT {SELECT_COLS} FROM {SQLITE_TABLE}")) {
248 Ok(s) => s,
249 Err(_) => return Vec::new(),
250 };
251 let iter = match stmt.query_map([], |row| {
252 Ok(row_to_account(
253 row.get(0)?,
254 row.get(1)?,
255 row.get(2)?,
256 row.get(3)?,
257 row.get::<_, Option<String>>(4)?,
258 row.get::<_, Option<String>>(5)?,
259 row.get::<_, Option<String>>(6)?,
260 row.get::<_, Option<i64>>(7)?,
261 row.get::<_, Option<i64>>(8)?,
262 row.get::<_, Option<String>>(9)?,
263 row.get::<_, Option<String>>(10)?,
264 row.get(11)?,
265 row.get(12)?,
266 ))
267 }) {
268 Ok(i) => i,
269 Err(_) => return Vec::new(),
270 };
271 iter.flatten().collect()
272 }
273}
274
275pub use pg::PostgresAccountBackend;
280
281mod pg {
282 use super::*;
283 use postgres::Client;
284
285 pub struct PostgresAccountBackend {
286 client: Mutex<Client>,
287 }
288
289 impl PostgresAccountBackend {
290 pub fn connect(url: &str) -> Result<Self, String> {
291 let mut client =
292 Client::connect(url, postgres::NoTls).map_err(|e| format!("PG connect: {e}"))?;
293 client
294 .batch_execute(&format!(
295 "CREATE TABLE IF NOT EXISTS {PG_TABLE} (
296 id TEXT PRIMARY KEY,
297 user_id TEXT NOT NULL,
298 provider_id TEXT NOT NULL,
299 account_id TEXT NOT NULL,
300 access_token TEXT,
301 refresh_token TEXT,
302 id_token TEXT,
303 access_token_expires_at BIGINT,
304 refresh_token_expires_at BIGINT,
305 scope TEXT,
306 password TEXT,
307 created_at BIGINT NOT NULL,
308 updated_at BIGINT NOT NULL,
309 UNIQUE (provider_id, account_id)
310 );
311 CREATE INDEX IF NOT EXISTS {PG_TABLE}_user_idx ON {PG_TABLE}(user_id);"
312 ))
313 .map_err(|e| format!("PG init schema: {e}"))?;
314 Ok(Self {
315 client: Mutex::new(client),
316 })
317 }
318 }
319
320 impl AccountBackend for PostgresAccountBackend {
321 fn upsert(&self, a: &Account) {
322 let Ok(mut c) = self.client.lock() else {
323 return;
324 };
325 let _ = c.execute(
326 &format!(
327 "INSERT INTO {PG_TABLE}
328 (id, user_id, provider_id, account_id, access_token, refresh_token,
329 id_token, access_token_expires_at, refresh_token_expires_at,
330 scope, password, created_at, updated_at)
331 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
332 ON CONFLICT (provider_id, account_id) DO UPDATE SET
333 user_id = EXCLUDED.user_id,
334 access_token = EXCLUDED.access_token,
335 refresh_token = EXCLUDED.refresh_token,
336 id_token = EXCLUDED.id_token,
337 access_token_expires_at = EXCLUDED.access_token_expires_at,
338 refresh_token_expires_at = EXCLUDED.refresh_token_expires_at,
339 scope = EXCLUDED.scope,
340 password = EXCLUDED.password,
341 updated_at = EXCLUDED.updated_at"
342 ),
343 &[
344 &a.id,
345 &a.user_id,
346 &a.provider_id,
347 &a.account_id,
348 &a.access_token,
349 &a.refresh_token,
350 &a.id_token,
351 &a.access_token_expires_at.map(|n| n as i64),
352 &a.refresh_token_expires_at.map(|n| n as i64),
353 &a.scope,
354 &a.password,
355 &(a.created_at as i64),
356 &(a.updated_at as i64),
357 ],
358 );
359 }
360
361 fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
362 let mut c = self.client.lock().ok()?;
363 let row = c
364 .query_opt(
365 &format!(
366 "SELECT {SELECT_COLS}
367 FROM {PG_TABLE}
368 WHERE provider_id = $1 AND account_id = $2"
369 ),
370 &[&provider_id, &account_id],
371 )
372 .ok()??;
373 Some(row_to_account(
374 row.get(0),
375 row.get(1),
376 row.get(2),
377 row.get(3),
378 row.get::<_, Option<String>>(4),
379 row.get::<_, Option<String>>(5),
380 row.get::<_, Option<String>>(6),
381 row.get::<_, Option<i64>>(7),
382 row.get::<_, Option<i64>>(8),
383 row.get::<_, Option<String>>(9),
384 row.get::<_, Option<String>>(10),
385 row.get(11),
386 row.get(12),
387 ))
388 }
389
390 fn find_for_user(&self, user_id: &str) -> Vec<Account> {
391 let Ok(mut c) = self.client.lock() else {
392 return Vec::new();
393 };
394 let rows = c
395 .query(
396 &format!("SELECT {SELECT_COLS} FROM {PG_TABLE} WHERE user_id = $1"),
397 &[&user_id],
398 )
399 .unwrap_or_default();
400 rows.iter()
401 .map(|row| {
402 row_to_account(
403 row.get(0),
404 row.get(1),
405 row.get(2),
406 row.get(3),
407 row.get::<_, Option<String>>(4),
408 row.get::<_, Option<String>>(5),
409 row.get::<_, Option<String>>(6),
410 row.get::<_, Option<i64>>(7),
411 row.get::<_, Option<i64>>(8),
412 row.get::<_, Option<String>>(9),
413 row.get::<_, Option<String>>(10),
414 row.get(11),
415 row.get(12),
416 )
417 })
418 .collect()
419 }
420
421 fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
422 let Ok(mut c) = self.client.lock() else {
423 return false;
424 };
425 c.execute(
426 &format!("DELETE FROM {PG_TABLE} WHERE provider_id = $1 AND account_id = $2"),
427 &[&provider_id, &account_id],
428 )
429 .map(|n| n > 0)
430 .unwrap_or(false)
431 }
432
433 fn list_all(&self) -> Vec<Account> {
434 let Ok(mut c) = self.client.lock() else {
435 return Vec::new();
436 };
437 let rows = c
438 .query(&format!("SELECT {SELECT_COLS} FROM {PG_TABLE}"), &[])
439 .unwrap_or_default();
440 rows.iter()
441 .map(|row| {
442 row_to_account(
443 row.get(0),
444 row.get(1),
445 row.get(2),
446 row.get(3),
447 row.get::<_, Option<String>>(4),
448 row.get::<_, Option<String>>(5),
449 row.get::<_, Option<String>>(6),
450 row.get::<_, Option<i64>>(7),
451 row.get::<_, Option<i64>>(8),
452 row.get::<_, Option<String>>(9),
453 row.get::<_, Option<String>>(10),
454 row.get(11),
455 row.get(12),
456 )
457 })
458 .collect()
459 }
460 }
461}
462
463#[cfg(test)]
464mod tests {
465 use super::*;
466 use pylon_auth::{Account, AccountBackend};
467
468 fn fixture(provider_id: &str, user: &str, account_id: &str) -> Account {
469 Account {
470 id: format!("acct_{provider_id}_{account_id}"),
471 user_id: user.into(),
472 provider_id: provider_id.into(),
473 account_id: account_id.into(),
474 access_token: Some("at".into()),
475 refresh_token: Some("rt".into()),
476 id_token: None,
477 access_token_expires_at: Some(9999999999),
478 refresh_token_expires_at: None,
479 scope: Some("email profile".into()),
480 password: None,
481 created_at: 1,
482 updated_at: 1,
483 }
484 }
485
486 #[test]
487 fn sqlite_upsert_then_find_by_provider() {
488 let b = SqliteAccountBackend::in_memory().unwrap();
489 b.upsert(&fixture("google", "u1", "sub_x"));
490 let got = b.find_by_provider("google", "sub_x").unwrap();
491 assert_eq!(got.user_id, "u1");
492 assert_eq!(got.refresh_token.as_deref(), Some("rt"));
493 }
494
495 #[test]
496 fn sqlite_find_for_user_lists_multiple_providers() {
497 let b = SqliteAccountBackend::in_memory().unwrap();
498 b.upsert(&fixture("google", "u1", "g_sub"));
499 b.upsert(&fixture("github", "u1", "gh_sub"));
500 b.upsert(&fixture("google", "u2", "other"));
501 let mine = b.find_for_user("u1");
502 assert_eq!(mine.len(), 2);
503 assert!(mine.iter().any(|a| a.provider_id == "google"));
504 assert!(mine.iter().any(|a| a.provider_id == "github"));
505 }
506
507 #[test]
508 fn sqlite_upsert_is_idempotent_and_refreshes_tokens() {
509 let b = SqliteAccountBackend::in_memory().unwrap();
510 let mut a = fixture("google", "u1", "sub");
511 b.upsert(&a);
512 a.access_token = Some("new_at".into());
513 a.updated_at = 99;
514 b.upsert(&a);
515 let got = b.find_by_provider("google", "sub").unwrap();
516 assert_eq!(got.access_token.as_deref(), Some("new_at"));
517 assert_eq!(got.updated_at, 99);
518 assert_eq!(b.find_for_user("u1").len(), 1);
519 }
520
521 #[test]
522 fn sqlite_unlink_removes_row() {
523 let b = SqliteAccountBackend::in_memory().unwrap();
524 b.upsert(&fixture("google", "u1", "sub"));
525 assert!(b.unlink("google", "sub"));
526 assert!(b.find_by_provider("google", "sub").is_none());
527 assert!(!b.unlink("google", "sub"), "second unlink is a no-op");
528 }
529
530 #[test]
531 fn sqlite_password_column_is_present_for_future_credential_provider() {
532 let b = SqliteAccountBackend::in_memory().unwrap();
536 let mut a = fixture("credential", "u1", "u1");
537 a.access_token = None;
538 a.refresh_token = None;
539 a.password = Some("argon2id$v=19$m=65536,t=3,p=4$...".into());
540 b.upsert(&a);
541 let got = b.find_by_provider("credential", "u1").unwrap();
542 assert_eq!(
543 got.password.as_deref(),
544 Some("argon2id$v=19$m=65536,t=3,p=4$...")
545 );
546 }
547}