Skip to main content

pylon_runtime/
account_backend.rs

1//! Persistent account-link stores. Schema mirrors better-auth's
2//! [`account` table](https://www.better-auth.com/docs/concepts/database)
3//! so apps migrating between the two see the same field names + meanings.
4//!
5//! - [`SqliteAccountBackend`] — single-file durability for self-hosted
6//!   single-replica deploys. Lives alongside sessions/oauth-state in
7//!   `PYLON_SESSION_DB`.
8//! - [`PostgresAccountBackend`] — for `DATABASE_URL=postgres://...`
9//!   deploys. Stores account rows in the same Postgres cluster as the
10//!   user's data so a join across `_pylon_accounts` and the manifest's
11//!   `User` entity works without cross-database queries.
12//!
13//! Both implement [`pylon_auth::AccountBackend`]. Why a dedicated table
14//! instead of folding the link into the manifest's `User` entity:
15//!
16//! 1. **Multi-provider**: a single user can link Google + GitHub +
17//!    custom IdPs + a password (`provider_id="credential"`). The `User`
18//!    row needs one identity-of-record (email), not N nullable
19//!    provider columns.
20//! 2. **Refresh token + password storage**: provider secrets and
21//!    password hashes shouldn't be in the user-visible entity surface
22//!    — hiding them in a framework table keeps them out of
23//!    `/api/entities/User` responses by default.
24//! 3. **Schema agility**: the framework owns the account schema, so
25//!    adding a new provider column doesn't require a manifest
26//!    migration in every consumer app.
27
28use std::sync::{Arc, Mutex};
29
30use pylon_auth::{Account, AccountBackend};
31use rusqlite::Connection;
32
33const SQLITE_TABLE: &str = "_pylon_accounts";
34const PG_TABLE: &str = "_pylon_accounts";
35
36// ---------------------------------------------------------------------------
37// SQLite backend
38// ---------------------------------------------------------------------------
39
40pub struct SqliteAccountBackend {
41    conn: Arc<Mutex<Connection>>,
42}
43
44impl SqliteAccountBackend {
45    pub fn open(path: &str) -> Result<Self, String> {
46        let conn = Connection::open(path).map_err(|e| format!("open: {e}"))?;
47        Self::from_connection(conn)
48    }
49
50    pub fn in_memory() -> Result<Self, String> {
51        let conn = Connection::open_in_memory().map_err(|e| format!("open: {e}"))?;
52        Self::from_connection(conn)
53    }
54
55    fn from_connection(conn: Connection) -> Result<Self, String> {
56        // `id` is the row PK; `(provider_id, account_id)` is a UNIQUE
57        // composite for the OAuth lookup path. Lookups by user_id hit
58        // the secondary index — typically called on /api/auth/me to
59        // render "connected providers" UI.
60        conn.execute_batch(&format!(
61            "CREATE TABLE IF NOT EXISTS {SQLITE_TABLE} (
62                id TEXT PRIMARY KEY,
63                user_id TEXT NOT NULL,
64                provider_id TEXT NOT NULL,
65                account_id TEXT NOT NULL,
66                access_token TEXT,
67                refresh_token TEXT,
68                id_token TEXT,
69                access_token_expires_at INTEGER,
70                refresh_token_expires_at INTEGER,
71                scope TEXT,
72                password TEXT,
73                created_at INTEGER NOT NULL,
74                updated_at INTEGER NOT NULL,
75                UNIQUE (provider_id, account_id)
76            );
77            CREATE INDEX IF NOT EXISTS {SQLITE_TABLE}_user_idx ON {SQLITE_TABLE}(user_id);"
78        ))
79        .map_err(|e| format!("init schema: {e}"))?;
80        Ok(Self {
81            conn: Arc::new(Mutex::new(conn)),
82        })
83    }
84}
85
86#[allow(clippy::too_many_arguments)]
87fn row_to_account(
88    id: String,
89    user_id: String,
90    provider_id: String,
91    account_id: String,
92    access_token: Option<String>,
93    refresh_token: Option<String>,
94    id_token: Option<String>,
95    access_token_expires_at: Option<i64>,
96    refresh_token_expires_at: Option<i64>,
97    scope: Option<String>,
98    password: Option<String>,
99    created_at: i64,
100    updated_at: i64,
101) -> Account {
102    Account {
103        id,
104        user_id,
105        provider_id,
106        account_id,
107        access_token,
108        refresh_token,
109        id_token,
110        access_token_expires_at: access_token_expires_at.map(|n| n as u64),
111        refresh_token_expires_at: refresh_token_expires_at.map(|n| n as u64),
112        scope,
113        password,
114        created_at: created_at as u64,
115        updated_at: updated_at as u64,
116    }
117}
118
119const SELECT_COLS: &str = "id, user_id, provider_id, account_id, access_token, \
120    refresh_token, id_token, access_token_expires_at, refresh_token_expires_at, \
121    scope, password, created_at, updated_at";
122
123impl AccountBackend for SqliteAccountBackend {
124    fn upsert(&self, a: &Account) {
125        if let Ok(guard) = self.conn.lock() {
126            // ON CONFLICT on the composite key — preserves the original
127            // row id so external references stay valid across
128            // re-authentications.
129            let _ = guard.execute(
130                &format!(
131                    "INSERT INTO {SQLITE_TABLE}
132                       (id, user_id, provider_id, account_id, access_token, refresh_token,
133                        id_token, access_token_expires_at, refresh_token_expires_at,
134                        scope, password, created_at, updated_at)
135                     VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13)
136                     ON CONFLICT(provider_id, account_id) DO UPDATE SET
137                       user_id = excluded.user_id,
138                       access_token = excluded.access_token,
139                       refresh_token = excluded.refresh_token,
140                       id_token = excluded.id_token,
141                       access_token_expires_at = excluded.access_token_expires_at,
142                       refresh_token_expires_at = excluded.refresh_token_expires_at,
143                       scope = excluded.scope,
144                       password = excluded.password,
145                       updated_at = excluded.updated_at"
146                ),
147                rusqlite::params![
148                    a.id,
149                    a.user_id,
150                    a.provider_id,
151                    a.account_id,
152                    a.access_token,
153                    a.refresh_token,
154                    a.id_token,
155                    a.access_token_expires_at.map(|n| n as i64),
156                    a.refresh_token_expires_at.map(|n| n as i64),
157                    a.scope,
158                    a.password,
159                    a.created_at as i64,
160                    a.updated_at as i64,
161                ],
162            );
163        }
164    }
165
166    fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
167        let guard = self.conn.lock().ok()?;
168        guard
169            .query_row(
170                &format!(
171                    "SELECT {SELECT_COLS}
172                     FROM {SQLITE_TABLE}
173                     WHERE provider_id = ?1 AND account_id = ?2"
174                ),
175                rusqlite::params![provider_id, account_id],
176                |row| {
177                    Ok(row_to_account(
178                        row.get(0)?,
179                        row.get(1)?,
180                        row.get(2)?,
181                        row.get(3)?,
182                        row.get::<_, Option<String>>(4)?,
183                        row.get::<_, Option<String>>(5)?,
184                        row.get::<_, Option<String>>(6)?,
185                        row.get::<_, Option<i64>>(7)?,
186                        row.get::<_, Option<i64>>(8)?,
187                        row.get::<_, Option<String>>(9)?,
188                        row.get::<_, Option<String>>(10)?,
189                        row.get(11)?,
190                        row.get(12)?,
191                    ))
192                },
193            )
194            .ok()
195    }
196
197    fn find_for_user(&self, user_id: &str) -> Vec<Account> {
198        let Ok(guard) = self.conn.lock() else {
199            return Vec::new();
200        };
201        let mut stmt = match guard.prepare(&format!(
202            "SELECT {SELECT_COLS} FROM {SQLITE_TABLE} WHERE user_id = ?1"
203        )) {
204            Ok(s) => s,
205            Err(_) => return Vec::new(),
206        };
207        let iter = match stmt.query_map(rusqlite::params![user_id], |row| {
208            Ok(row_to_account(
209                row.get(0)?,
210                row.get(1)?,
211                row.get(2)?,
212                row.get(3)?,
213                row.get::<_, Option<String>>(4)?,
214                row.get::<_, Option<String>>(5)?,
215                row.get::<_, Option<String>>(6)?,
216                row.get::<_, Option<i64>>(7)?,
217                row.get::<_, Option<i64>>(8)?,
218                row.get::<_, Option<String>>(9)?,
219                row.get::<_, Option<String>>(10)?,
220                row.get(11)?,
221                row.get(12)?,
222            ))
223        }) {
224            Ok(i) => i,
225            Err(_) => return Vec::new(),
226        };
227        iter.flatten().collect()
228    }
229
230    fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
231        let Ok(guard) = self.conn.lock() else {
232            return false;
233        };
234        guard
235            .execute(
236                &format!("DELETE FROM {SQLITE_TABLE} WHERE provider_id = ?1 AND account_id = ?2"),
237                rusqlite::params![provider_id, account_id],
238            )
239            .map(|n| n > 0)
240            .unwrap_or(false)
241    }
242
243    fn list_all(&self) -> Vec<Account> {
244        let Ok(guard) = self.conn.lock() else {
245            return Vec::new();
246        };
247        let mut stmt = match guard.prepare(&format!("SELECT {SELECT_COLS} FROM {SQLITE_TABLE}")) {
248            Ok(s) => s,
249            Err(_) => return Vec::new(),
250        };
251        let iter = match stmt.query_map([], |row| {
252            Ok(row_to_account(
253                row.get(0)?,
254                row.get(1)?,
255                row.get(2)?,
256                row.get(3)?,
257                row.get::<_, Option<String>>(4)?,
258                row.get::<_, Option<String>>(5)?,
259                row.get::<_, Option<String>>(6)?,
260                row.get::<_, Option<i64>>(7)?,
261                row.get::<_, Option<i64>>(8)?,
262                row.get::<_, Option<String>>(9)?,
263                row.get::<_, Option<String>>(10)?,
264                row.get(11)?,
265                row.get(12)?,
266            ))
267        }) {
268            Ok(i) => i,
269            Err(_) => return Vec::new(),
270        };
271        iter.flatten().collect()
272    }
273}
274
275// ---------------------------------------------------------------------------
276// Postgres backend
277// ---------------------------------------------------------------------------
278
279pub use pg::PostgresAccountBackend;
280
281mod pg {
282    use super::*;
283    use postgres::Client;
284
285    pub struct PostgresAccountBackend {
286        client: Mutex<Client>,
287    }
288
289    impl PostgresAccountBackend {
290        pub fn connect(url: &str) -> Result<Self, String> {
291            let mut client =
292                Client::connect(url, postgres::NoTls).map_err(|e| format!("PG connect: {e}"))?;
293            client
294                .batch_execute(&format!(
295                    "CREATE TABLE IF NOT EXISTS {PG_TABLE} (
296                        id TEXT PRIMARY KEY,
297                        user_id TEXT NOT NULL,
298                        provider_id TEXT NOT NULL,
299                        account_id TEXT NOT NULL,
300                        access_token TEXT,
301                        refresh_token TEXT,
302                        id_token TEXT,
303                        access_token_expires_at BIGINT,
304                        refresh_token_expires_at BIGINT,
305                        scope TEXT,
306                        password TEXT,
307                        created_at BIGINT NOT NULL,
308                        updated_at BIGINT NOT NULL,
309                        UNIQUE (provider_id, account_id)
310                    );
311                    CREATE INDEX IF NOT EXISTS {PG_TABLE}_user_idx ON {PG_TABLE}(user_id);"
312                ))
313                .map_err(|e| format!("PG init schema: {e}"))?;
314            Ok(Self {
315                client: Mutex::new(client),
316            })
317        }
318    }
319
320    impl AccountBackend for PostgresAccountBackend {
321        fn upsert(&self, a: &Account) {
322            let Ok(mut c) = self.client.lock() else {
323                return;
324            };
325            let _ = c.execute(
326                &format!(
327                    "INSERT INTO {PG_TABLE}
328                       (id, user_id, provider_id, account_id, access_token, refresh_token,
329                        id_token, access_token_expires_at, refresh_token_expires_at,
330                        scope, password, created_at, updated_at)
331                     VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
332                     ON CONFLICT (provider_id, account_id) DO UPDATE SET
333                       user_id = EXCLUDED.user_id,
334                       access_token = EXCLUDED.access_token,
335                       refresh_token = EXCLUDED.refresh_token,
336                       id_token = EXCLUDED.id_token,
337                       access_token_expires_at = EXCLUDED.access_token_expires_at,
338                       refresh_token_expires_at = EXCLUDED.refresh_token_expires_at,
339                       scope = EXCLUDED.scope,
340                       password = EXCLUDED.password,
341                       updated_at = EXCLUDED.updated_at"
342                ),
343                &[
344                    &a.id,
345                    &a.user_id,
346                    &a.provider_id,
347                    &a.account_id,
348                    &a.access_token,
349                    &a.refresh_token,
350                    &a.id_token,
351                    &a.access_token_expires_at.map(|n| n as i64),
352                    &a.refresh_token_expires_at.map(|n| n as i64),
353                    &a.scope,
354                    &a.password,
355                    &(a.created_at as i64),
356                    &(a.updated_at as i64),
357                ],
358            );
359        }
360
361        fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
362            let mut c = self.client.lock().ok()?;
363            let row = c
364                .query_opt(
365                    &format!(
366                        "SELECT {SELECT_COLS}
367                         FROM {PG_TABLE}
368                         WHERE provider_id = $1 AND account_id = $2"
369                    ),
370                    &[&provider_id, &account_id],
371                )
372                .ok()??;
373            Some(row_to_account(
374                row.get(0),
375                row.get(1),
376                row.get(2),
377                row.get(3),
378                row.get::<_, Option<String>>(4),
379                row.get::<_, Option<String>>(5),
380                row.get::<_, Option<String>>(6),
381                row.get::<_, Option<i64>>(7),
382                row.get::<_, Option<i64>>(8),
383                row.get::<_, Option<String>>(9),
384                row.get::<_, Option<String>>(10),
385                row.get(11),
386                row.get(12),
387            ))
388        }
389
390        fn find_for_user(&self, user_id: &str) -> Vec<Account> {
391            let Ok(mut c) = self.client.lock() else {
392                return Vec::new();
393            };
394            let rows = c
395                .query(
396                    &format!("SELECT {SELECT_COLS} FROM {PG_TABLE} WHERE user_id = $1"),
397                    &[&user_id],
398                )
399                .unwrap_or_default();
400            rows.iter()
401                .map(|row| {
402                    row_to_account(
403                        row.get(0),
404                        row.get(1),
405                        row.get(2),
406                        row.get(3),
407                        row.get::<_, Option<String>>(4),
408                        row.get::<_, Option<String>>(5),
409                        row.get::<_, Option<String>>(6),
410                        row.get::<_, Option<i64>>(7),
411                        row.get::<_, Option<i64>>(8),
412                        row.get::<_, Option<String>>(9),
413                        row.get::<_, Option<String>>(10),
414                        row.get(11),
415                        row.get(12),
416                    )
417                })
418                .collect()
419        }
420
421        fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
422            let Ok(mut c) = self.client.lock() else {
423                return false;
424            };
425            c.execute(
426                &format!("DELETE FROM {PG_TABLE} WHERE provider_id = $1 AND account_id = $2"),
427                &[&provider_id, &account_id],
428            )
429            .map(|n| n > 0)
430            .unwrap_or(false)
431        }
432
433        fn list_all(&self) -> Vec<Account> {
434            let Ok(mut c) = self.client.lock() else {
435                return Vec::new();
436            };
437            let rows = c
438                .query(&format!("SELECT {SELECT_COLS} FROM {PG_TABLE}"), &[])
439                .unwrap_or_default();
440            rows.iter()
441                .map(|row| {
442                    row_to_account(
443                        row.get(0),
444                        row.get(1),
445                        row.get(2),
446                        row.get(3),
447                        row.get::<_, Option<String>>(4),
448                        row.get::<_, Option<String>>(5),
449                        row.get::<_, Option<String>>(6),
450                        row.get::<_, Option<i64>>(7),
451                        row.get::<_, Option<i64>>(8),
452                        row.get::<_, Option<String>>(9),
453                        row.get::<_, Option<String>>(10),
454                        row.get(11),
455                        row.get(12),
456                    )
457                })
458                .collect()
459        }
460    }
461}
462
463#[cfg(test)]
464mod tests {
465    use super::*;
466    use pylon_auth::{Account, AccountBackend};
467
468    fn fixture(provider_id: &str, user: &str, account_id: &str) -> Account {
469        Account {
470            id: format!("acct_{provider_id}_{account_id}"),
471            user_id: user.into(),
472            provider_id: provider_id.into(),
473            account_id: account_id.into(),
474            access_token: Some("at".into()),
475            refresh_token: Some("rt".into()),
476            id_token: None,
477            access_token_expires_at: Some(9999999999),
478            refresh_token_expires_at: None,
479            scope: Some("email profile".into()),
480            password: None,
481            created_at: 1,
482            updated_at: 1,
483        }
484    }
485
486    #[test]
487    fn sqlite_upsert_then_find_by_provider() {
488        let b = SqliteAccountBackend::in_memory().unwrap();
489        b.upsert(&fixture("google", "u1", "sub_x"));
490        let got = b.find_by_provider("google", "sub_x").unwrap();
491        assert_eq!(got.user_id, "u1");
492        assert_eq!(got.refresh_token.as_deref(), Some("rt"));
493    }
494
495    #[test]
496    fn sqlite_find_for_user_lists_multiple_providers() {
497        let b = SqliteAccountBackend::in_memory().unwrap();
498        b.upsert(&fixture("google", "u1", "g_sub"));
499        b.upsert(&fixture("github", "u1", "gh_sub"));
500        b.upsert(&fixture("google", "u2", "other"));
501        let mine = b.find_for_user("u1");
502        assert_eq!(mine.len(), 2);
503        assert!(mine.iter().any(|a| a.provider_id == "google"));
504        assert!(mine.iter().any(|a| a.provider_id == "github"));
505    }
506
507    #[test]
508    fn sqlite_upsert_is_idempotent_and_refreshes_tokens() {
509        let b = SqliteAccountBackend::in_memory().unwrap();
510        let mut a = fixture("google", "u1", "sub");
511        b.upsert(&a);
512        a.access_token = Some("new_at".into());
513        a.updated_at = 99;
514        b.upsert(&a);
515        let got = b.find_by_provider("google", "sub").unwrap();
516        assert_eq!(got.access_token.as_deref(), Some("new_at"));
517        assert_eq!(got.updated_at, 99);
518        assert_eq!(b.find_for_user("u1").len(), 1);
519    }
520
521    #[test]
522    fn sqlite_unlink_removes_row() {
523        let b = SqliteAccountBackend::in_memory().unwrap();
524        b.upsert(&fixture("google", "u1", "sub"));
525        assert!(b.unlink("google", "sub"));
526        assert!(b.find_by_provider("google", "sub").is_none());
527        assert!(!b.unlink("google", "sub"), "second unlink is a no-op");
528    }
529
530    #[test]
531    fn sqlite_password_column_is_present_for_future_credential_provider() {
532        // Confirms the column is wired end-to-end so adding email/password
533        // auth later doesn't need a schema migration. Forms the basis of
534        // the future provider_id="credential" rows.
535        let b = SqliteAccountBackend::in_memory().unwrap();
536        let mut a = fixture("credential", "u1", "u1");
537        a.access_token = None;
538        a.refresh_token = None;
539        a.password = Some("argon2id$v=19$m=65536,t=3,p=4$...".into());
540        b.upsert(&a);
541        let got = b.find_by_provider("credential", "u1").unwrap();
542        assert_eq!(
543            got.password.as_deref(),
544            Some("argon2id$v=19$m=65536,t=3,p=4$...")
545        );
546    }
547}