1use std::sync::{Arc, Mutex};
29
30use pylon_auth::{Account, AccountBackend};
31use rusqlite::Connection;
32
33const SQLITE_TABLE: &str = "_pylon_accounts";
34const PG_TABLE: &str = "_pylon_accounts";
35
36pub struct SqliteAccountBackend {
41 conn: Arc<Mutex<Connection>>,
42}
43
44impl SqliteAccountBackend {
45 pub fn open(path: &str) -> Result<Self, String> {
46 let conn = Connection::open(path).map_err(|e| format!("open: {e}"))?;
47 Self::from_connection(conn)
48 }
49
50 pub fn in_memory() -> Result<Self, String> {
51 let conn = Connection::open_in_memory().map_err(|e| format!("open: {e}"))?;
52 Self::from_connection(conn)
53 }
54
55 fn from_connection(conn: Connection) -> Result<Self, String> {
56 conn.execute_batch(&format!(
61 "CREATE TABLE IF NOT EXISTS {SQLITE_TABLE} (
62 id TEXT PRIMARY KEY,
63 user_id TEXT NOT NULL,
64 provider_id TEXT NOT NULL,
65 account_id TEXT NOT NULL,
66 access_token TEXT,
67 refresh_token TEXT,
68 id_token TEXT,
69 access_token_expires_at INTEGER,
70 refresh_token_expires_at INTEGER,
71 scope TEXT,
72 password TEXT,
73 created_at INTEGER NOT NULL,
74 updated_at INTEGER NOT NULL,
75 UNIQUE (provider_id, account_id)
76 );
77 CREATE INDEX IF NOT EXISTS {SQLITE_TABLE}_user_idx ON {SQLITE_TABLE}(user_id);"
78 ))
79 .map_err(|e| format!("init schema: {e}"))?;
80 Ok(Self {
81 conn: Arc::new(Mutex::new(conn)),
82 })
83 }
84}
85
86#[allow(clippy::too_many_arguments)]
87fn row_to_account(
88 id: String,
89 user_id: String,
90 provider_id: String,
91 account_id: String,
92 access_token: Option<String>,
93 refresh_token: Option<String>,
94 id_token: Option<String>,
95 access_token_expires_at: Option<i64>,
96 refresh_token_expires_at: Option<i64>,
97 scope: Option<String>,
98 password: Option<String>,
99 created_at: i64,
100 updated_at: i64,
101) -> Account {
102 Account {
103 id,
104 user_id,
105 provider_id,
106 account_id,
107 access_token,
108 refresh_token,
109 id_token,
110 access_token_expires_at: access_token_expires_at.map(|n| n as u64),
111 refresh_token_expires_at: refresh_token_expires_at.map(|n| n as u64),
112 scope,
113 password,
114 created_at: created_at as u64,
115 updated_at: updated_at as u64,
116 }
117}
118
119const SELECT_COLS: &str = "id, user_id, provider_id, account_id, access_token, \
120 refresh_token, id_token, access_token_expires_at, refresh_token_expires_at, \
121 scope, password, created_at, updated_at";
122
123impl AccountBackend for SqliteAccountBackend {
124 fn upsert(&self, a: &Account) {
125 if let Ok(guard) = self.conn.lock() {
126 let _ = guard.execute(
130 &format!(
131 "INSERT INTO {SQLITE_TABLE}
132 (id, user_id, provider_id, account_id, access_token, refresh_token,
133 id_token, access_token_expires_at, refresh_token_expires_at,
134 scope, password, created_at, updated_at)
135 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13)
136 ON CONFLICT(provider_id, account_id) DO UPDATE SET
137 user_id = excluded.user_id,
138 access_token = excluded.access_token,
139 refresh_token = excluded.refresh_token,
140 id_token = excluded.id_token,
141 access_token_expires_at = excluded.access_token_expires_at,
142 refresh_token_expires_at = excluded.refresh_token_expires_at,
143 scope = excluded.scope,
144 password = excluded.password,
145 updated_at = excluded.updated_at"
146 ),
147 rusqlite::params![
148 a.id,
149 a.user_id,
150 a.provider_id,
151 a.account_id,
152 a.access_token,
153 a.refresh_token,
154 a.id_token,
155 a.access_token_expires_at.map(|n| n as i64),
156 a.refresh_token_expires_at.map(|n| n as i64),
157 a.scope,
158 a.password,
159 a.created_at as i64,
160 a.updated_at as i64,
161 ],
162 );
163 }
164 }
165
166 fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
167 let guard = self.conn.lock().ok()?;
168 guard
169 .query_row(
170 &format!(
171 "SELECT {SELECT_COLS}
172 FROM {SQLITE_TABLE}
173 WHERE provider_id = ?1 AND account_id = ?2"
174 ),
175 rusqlite::params![provider_id, account_id],
176 |row| {
177 Ok(row_to_account(
178 row.get(0)?,
179 row.get(1)?,
180 row.get(2)?,
181 row.get(3)?,
182 row.get::<_, Option<String>>(4)?,
183 row.get::<_, Option<String>>(5)?,
184 row.get::<_, Option<String>>(6)?,
185 row.get::<_, Option<i64>>(7)?,
186 row.get::<_, Option<i64>>(8)?,
187 row.get::<_, Option<String>>(9)?,
188 row.get::<_, Option<String>>(10)?,
189 row.get(11)?,
190 row.get(12)?,
191 ))
192 },
193 )
194 .ok()
195 }
196
197 fn find_for_user(&self, user_id: &str) -> Vec<Account> {
198 let Ok(guard) = self.conn.lock() else {
199 return Vec::new();
200 };
201 let mut stmt = match guard.prepare(&format!(
202 "SELECT {SELECT_COLS} FROM {SQLITE_TABLE} WHERE user_id = ?1"
203 )) {
204 Ok(s) => s,
205 Err(_) => return Vec::new(),
206 };
207 let iter = match stmt.query_map(rusqlite::params![user_id], |row| {
208 Ok(row_to_account(
209 row.get(0)?,
210 row.get(1)?,
211 row.get(2)?,
212 row.get(3)?,
213 row.get::<_, Option<String>>(4)?,
214 row.get::<_, Option<String>>(5)?,
215 row.get::<_, Option<String>>(6)?,
216 row.get::<_, Option<i64>>(7)?,
217 row.get::<_, Option<i64>>(8)?,
218 row.get::<_, Option<String>>(9)?,
219 row.get::<_, Option<String>>(10)?,
220 row.get(11)?,
221 row.get(12)?,
222 ))
223 }) {
224 Ok(i) => i,
225 Err(_) => return Vec::new(),
226 };
227 iter.flatten().collect()
228 }
229
230 fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
231 let Ok(guard) = self.conn.lock() else {
232 return false;
233 };
234 guard
235 .execute(
236 &format!("DELETE FROM {SQLITE_TABLE} WHERE provider_id = ?1 AND account_id = ?2"),
237 rusqlite::params![provider_id, account_id],
238 )
239 .map(|n| n > 0)
240 .unwrap_or(false)
241 }
242
243 fn list_all(&self) -> Vec<Account> {
244 let Ok(guard) = self.conn.lock() else {
245 return Vec::new();
246 };
247 let mut stmt = match guard.prepare(&format!("SELECT {SELECT_COLS} FROM {SQLITE_TABLE}")) {
248 Ok(s) => s,
249 Err(_) => return Vec::new(),
250 };
251 let iter = match stmt.query_map([], |row| {
252 Ok(row_to_account(
253 row.get(0)?,
254 row.get(1)?,
255 row.get(2)?,
256 row.get(3)?,
257 row.get::<_, Option<String>>(4)?,
258 row.get::<_, Option<String>>(5)?,
259 row.get::<_, Option<String>>(6)?,
260 row.get::<_, Option<i64>>(7)?,
261 row.get::<_, Option<i64>>(8)?,
262 row.get::<_, Option<String>>(9)?,
263 row.get::<_, Option<String>>(10)?,
264 row.get(11)?,
265 row.get(12)?,
266 ))
267 }) {
268 Ok(i) => i,
269 Err(_) => return Vec::new(),
270 };
271 iter.flatten().collect()
272 }
273}
274
275pub use pg::PostgresAccountBackend;
280
281mod pg {
282 use super::*;
283 use postgres::Client;
284
285 pub struct PostgresAccountBackend {
286 client: Mutex<Client>,
287 }
288
289 impl PostgresAccountBackend {
290 pub fn connect(url: &str) -> Result<Self, String> {
291 let mut client = pylon_storage::postgres::live::connect_pg(url)?;
292 client
293 .batch_execute(&format!(
294 "CREATE TABLE IF NOT EXISTS {PG_TABLE} (
295 id TEXT PRIMARY KEY,
296 user_id TEXT NOT NULL,
297 provider_id TEXT NOT NULL,
298 account_id TEXT NOT NULL,
299 access_token TEXT,
300 refresh_token TEXT,
301 id_token TEXT,
302 access_token_expires_at BIGINT,
303 refresh_token_expires_at BIGINT,
304 scope TEXT,
305 password TEXT,
306 created_at BIGINT NOT NULL,
307 updated_at BIGINT NOT NULL,
308 UNIQUE (provider_id, account_id)
309 );
310 CREATE INDEX IF NOT EXISTS {PG_TABLE}_user_idx ON {PG_TABLE}(user_id);"
311 ))
312 .map_err(|e| format!("PG init schema: {e}"))?;
313 Ok(Self {
314 client: Mutex::new(client),
315 })
316 }
317 }
318
319 impl AccountBackend for PostgresAccountBackend {
320 fn upsert(&self, a: &Account) {
321 let Ok(mut c) = self.client.lock() else {
322 return;
323 };
324 let _ = c.execute(
325 &format!(
326 "INSERT INTO {PG_TABLE}
327 (id, user_id, provider_id, account_id, access_token, refresh_token,
328 id_token, access_token_expires_at, refresh_token_expires_at,
329 scope, password, created_at, updated_at)
330 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
331 ON CONFLICT (provider_id, account_id) DO UPDATE SET
332 user_id = EXCLUDED.user_id,
333 access_token = EXCLUDED.access_token,
334 refresh_token = EXCLUDED.refresh_token,
335 id_token = EXCLUDED.id_token,
336 access_token_expires_at = EXCLUDED.access_token_expires_at,
337 refresh_token_expires_at = EXCLUDED.refresh_token_expires_at,
338 scope = EXCLUDED.scope,
339 password = EXCLUDED.password,
340 updated_at = EXCLUDED.updated_at"
341 ),
342 &[
343 &a.id,
344 &a.user_id,
345 &a.provider_id,
346 &a.account_id,
347 &a.access_token,
348 &a.refresh_token,
349 &a.id_token,
350 &a.access_token_expires_at.map(|n| n as i64),
351 &a.refresh_token_expires_at.map(|n| n as i64),
352 &a.scope,
353 &a.password,
354 &(a.created_at as i64),
355 &(a.updated_at as i64),
356 ],
357 );
358 }
359
360 fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
361 let mut c = self.client.lock().ok()?;
362 let row = c
363 .query_opt(
364 &format!(
365 "SELECT {SELECT_COLS}
366 FROM {PG_TABLE}
367 WHERE provider_id = $1 AND account_id = $2"
368 ),
369 &[&provider_id, &account_id],
370 )
371 .ok()??;
372 Some(row_to_account(
373 row.get(0),
374 row.get(1),
375 row.get(2),
376 row.get(3),
377 row.get::<_, Option<String>>(4),
378 row.get::<_, Option<String>>(5),
379 row.get::<_, Option<String>>(6),
380 row.get::<_, Option<i64>>(7),
381 row.get::<_, Option<i64>>(8),
382 row.get::<_, Option<String>>(9),
383 row.get::<_, Option<String>>(10),
384 row.get(11),
385 row.get(12),
386 ))
387 }
388
389 fn find_for_user(&self, user_id: &str) -> Vec<Account> {
390 let Ok(mut c) = self.client.lock() else {
391 return Vec::new();
392 };
393 let rows = c
394 .query(
395 &format!("SELECT {SELECT_COLS} FROM {PG_TABLE} WHERE user_id = $1"),
396 &[&user_id],
397 )
398 .unwrap_or_default();
399 rows.iter()
400 .map(|row| {
401 row_to_account(
402 row.get(0),
403 row.get(1),
404 row.get(2),
405 row.get(3),
406 row.get::<_, Option<String>>(4),
407 row.get::<_, Option<String>>(5),
408 row.get::<_, Option<String>>(6),
409 row.get::<_, Option<i64>>(7),
410 row.get::<_, Option<i64>>(8),
411 row.get::<_, Option<String>>(9),
412 row.get::<_, Option<String>>(10),
413 row.get(11),
414 row.get(12),
415 )
416 })
417 .collect()
418 }
419
420 fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
421 let Ok(mut c) = self.client.lock() else {
422 return false;
423 };
424 c.execute(
425 &format!("DELETE FROM {PG_TABLE} WHERE provider_id = $1 AND account_id = $2"),
426 &[&provider_id, &account_id],
427 )
428 .map(|n| n > 0)
429 .unwrap_or(false)
430 }
431
432 fn list_all(&self) -> Vec<Account> {
433 let Ok(mut c) = self.client.lock() else {
434 return Vec::new();
435 };
436 let rows = c
437 .query(&format!("SELECT {SELECT_COLS} FROM {PG_TABLE}"), &[])
438 .unwrap_or_default();
439 rows.iter()
440 .map(|row| {
441 row_to_account(
442 row.get(0),
443 row.get(1),
444 row.get(2),
445 row.get(3),
446 row.get::<_, Option<String>>(4),
447 row.get::<_, Option<String>>(5),
448 row.get::<_, Option<String>>(6),
449 row.get::<_, Option<i64>>(7),
450 row.get::<_, Option<i64>>(8),
451 row.get::<_, Option<String>>(9),
452 row.get::<_, Option<String>>(10),
453 row.get(11),
454 row.get(12),
455 )
456 })
457 .collect()
458 }
459 }
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465 use pylon_auth::{Account, AccountBackend};
466
467 fn fixture(provider_id: &str, user: &str, account_id: &str) -> Account {
468 Account {
469 id: format!("acct_{provider_id}_{account_id}"),
470 user_id: user.into(),
471 provider_id: provider_id.into(),
472 account_id: account_id.into(),
473 access_token: Some("at".into()),
474 refresh_token: Some("rt".into()),
475 id_token: None,
476 access_token_expires_at: Some(9999999999),
477 refresh_token_expires_at: None,
478 scope: Some("email profile".into()),
479 password: None,
480 created_at: 1,
481 updated_at: 1,
482 }
483 }
484
485 #[test]
486 fn sqlite_upsert_then_find_by_provider() {
487 let b = SqliteAccountBackend::in_memory().unwrap();
488 b.upsert(&fixture("google", "u1", "sub_x"));
489 let got = b.find_by_provider("google", "sub_x").unwrap();
490 assert_eq!(got.user_id, "u1");
491 assert_eq!(got.refresh_token.as_deref(), Some("rt"));
492 }
493
494 #[test]
495 fn sqlite_find_for_user_lists_multiple_providers() {
496 let b = SqliteAccountBackend::in_memory().unwrap();
497 b.upsert(&fixture("google", "u1", "g_sub"));
498 b.upsert(&fixture("github", "u1", "gh_sub"));
499 b.upsert(&fixture("google", "u2", "other"));
500 let mine = b.find_for_user("u1");
501 assert_eq!(mine.len(), 2);
502 assert!(mine.iter().any(|a| a.provider_id == "google"));
503 assert!(mine.iter().any(|a| a.provider_id == "github"));
504 }
505
506 #[test]
507 fn sqlite_upsert_is_idempotent_and_refreshes_tokens() {
508 let b = SqliteAccountBackend::in_memory().unwrap();
509 let mut a = fixture("google", "u1", "sub");
510 b.upsert(&a);
511 a.access_token = Some("new_at".into());
512 a.updated_at = 99;
513 b.upsert(&a);
514 let got = b.find_by_provider("google", "sub").unwrap();
515 assert_eq!(got.access_token.as_deref(), Some("new_at"));
516 assert_eq!(got.updated_at, 99);
517 assert_eq!(b.find_for_user("u1").len(), 1);
518 }
519
520 #[test]
521 fn sqlite_unlink_removes_row() {
522 let b = SqliteAccountBackend::in_memory().unwrap();
523 b.upsert(&fixture("google", "u1", "sub"));
524 assert!(b.unlink("google", "sub"));
525 assert!(b.find_by_provider("google", "sub").is_none());
526 assert!(!b.unlink("google", "sub"), "second unlink is a no-op");
527 }
528
529 #[test]
530 fn sqlite_password_column_is_present_for_future_credential_provider() {
531 let b = SqliteAccountBackend::in_memory().unwrap();
535 let mut a = fixture("credential", "u1", "u1");
536 a.access_token = None;
537 a.refresh_token = None;
538 a.password = Some("argon2id$v=19$m=65536,t=3,p=4$...".into());
539 b.upsert(&a);
540 let got = b.find_by_provider("credential", "u1").unwrap();
541 assert_eq!(
542 got.password.as_deref(),
543 Some("argon2id$v=19$m=65536,t=3,p=4$...")
544 );
545 }
546}