Skip to main content

pylon_runtime/
account_backend.rs

1//! Persistent account-link stores. Schema mirrors better-auth's
2//! [`account` table](https://www.better-auth.com/docs/concepts/database)
3//! so apps migrating between the two see the same field names + meanings.
4//!
5//! - [`SqliteAccountBackend`] — single-file durability for self-hosted
6//!   single-replica deploys. Lives alongside sessions/oauth-state in
7//!   `PYLON_SESSION_DB`.
8//! - [`PostgresAccountBackend`] — for `DATABASE_URL=postgres://...`
9//!   deploys. Stores account rows in the same Postgres cluster as the
10//!   user's data so a join across `_pylon_accounts` and the manifest's
11//!   `User` entity works without cross-database queries.
12//!
13//! Both implement [`pylon_auth::AccountBackend`]. Why a dedicated table
14//! instead of folding the link into the manifest's `User` entity:
15//!
16//! 1. **Multi-provider**: a single user can link Google + GitHub +
17//!    custom IdPs + a password (`provider_id="credential"`). The `User`
18//!    row needs one identity-of-record (email), not N nullable
19//!    provider columns.
20//! 2. **Refresh token + password storage**: provider secrets and
21//!    password hashes shouldn't be in the user-visible entity surface
22//!    — hiding them in a framework table keeps them out of
23//!    `/api/entities/User` responses by default.
24//! 3. **Schema agility**: the framework owns the account schema, so
25//!    adding a new provider column doesn't require a manifest
26//!    migration in every consumer app.
27
28use std::sync::{Arc, Mutex};
29
30use pylon_auth::{Account, AccountBackend};
31use rusqlite::Connection;
32
33const SQLITE_TABLE: &str = "_pylon_accounts";
34const PG_TABLE: &str = "_pylon_accounts";
35
36// ---------------------------------------------------------------------------
37// SQLite backend
38// ---------------------------------------------------------------------------
39
40pub struct SqliteAccountBackend {
41    conn: Arc<Mutex<Connection>>,
42}
43
44impl SqliteAccountBackend {
45    pub fn open(path: &str) -> Result<Self, String> {
46        let conn = Connection::open(path).map_err(|e| format!("open: {e}"))?;
47        Self::from_connection(conn)
48    }
49
50    pub fn in_memory() -> Result<Self, String> {
51        let conn = Connection::open_in_memory().map_err(|e| format!("open: {e}"))?;
52        Self::from_connection(conn)
53    }
54
55    fn from_connection(conn: Connection) -> Result<Self, String> {
56        // `id` is the row PK; `(provider_id, account_id)` is a UNIQUE
57        // composite for the OAuth lookup path. Lookups by user_id hit
58        // the secondary index — typically called on /api/auth/me to
59        // render "connected providers" UI.
60        conn.execute_batch(&format!(
61            "CREATE TABLE IF NOT EXISTS {SQLITE_TABLE} (
62                id TEXT PRIMARY KEY,
63                user_id TEXT NOT NULL,
64                provider_id TEXT NOT NULL,
65                account_id TEXT NOT NULL,
66                access_token TEXT,
67                refresh_token TEXT,
68                id_token TEXT,
69                access_token_expires_at INTEGER,
70                refresh_token_expires_at INTEGER,
71                scope TEXT,
72                password TEXT,
73                created_at INTEGER NOT NULL,
74                updated_at INTEGER NOT NULL,
75                UNIQUE (provider_id, account_id)
76            );
77            CREATE INDEX IF NOT EXISTS {SQLITE_TABLE}_user_idx ON {SQLITE_TABLE}(user_id);"
78        ))
79        .map_err(|e| format!("init schema: {e}"))?;
80        Ok(Self {
81            conn: Arc::new(Mutex::new(conn)),
82        })
83    }
84}
85
86#[allow(clippy::too_many_arguments)]
87fn row_to_account(
88    id: String,
89    user_id: String,
90    provider_id: String,
91    account_id: String,
92    access_token: Option<String>,
93    refresh_token: Option<String>,
94    id_token: Option<String>,
95    access_token_expires_at: Option<i64>,
96    refresh_token_expires_at: Option<i64>,
97    scope: Option<String>,
98    password: Option<String>,
99    created_at: i64,
100    updated_at: i64,
101) -> Account {
102    Account {
103        id,
104        user_id,
105        provider_id,
106        account_id,
107        access_token,
108        refresh_token,
109        id_token,
110        access_token_expires_at: access_token_expires_at.map(|n| n as u64),
111        refresh_token_expires_at: refresh_token_expires_at.map(|n| n as u64),
112        scope,
113        password,
114        created_at: created_at as u64,
115        updated_at: updated_at as u64,
116    }
117}
118
119const SELECT_COLS: &str = "id, user_id, provider_id, account_id, access_token, \
120    refresh_token, id_token, access_token_expires_at, refresh_token_expires_at, \
121    scope, password, created_at, updated_at";
122
123impl AccountBackend for SqliteAccountBackend {
124    fn upsert(&self, a: &Account) {
125        if let Ok(guard) = self.conn.lock() {
126            // ON CONFLICT on the composite key — preserves the original
127            // row id so external references stay valid across
128            // re-authentications.
129            let _ = guard.execute(
130                &format!(
131                    "INSERT INTO {SQLITE_TABLE}
132                       (id, user_id, provider_id, account_id, access_token, refresh_token,
133                        id_token, access_token_expires_at, refresh_token_expires_at,
134                        scope, password, created_at, updated_at)
135                     VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13)
136                     ON CONFLICT(provider_id, account_id) DO UPDATE SET
137                       user_id = excluded.user_id,
138                       access_token = excluded.access_token,
139                       refresh_token = excluded.refresh_token,
140                       id_token = excluded.id_token,
141                       access_token_expires_at = excluded.access_token_expires_at,
142                       refresh_token_expires_at = excluded.refresh_token_expires_at,
143                       scope = excluded.scope,
144                       password = excluded.password,
145                       updated_at = excluded.updated_at"
146                ),
147                rusqlite::params![
148                    a.id,
149                    a.user_id,
150                    a.provider_id,
151                    a.account_id,
152                    a.access_token,
153                    a.refresh_token,
154                    a.id_token,
155                    a.access_token_expires_at.map(|n| n as i64),
156                    a.refresh_token_expires_at.map(|n| n as i64),
157                    a.scope,
158                    a.password,
159                    a.created_at as i64,
160                    a.updated_at as i64,
161                ],
162            );
163        }
164    }
165
166    fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
167        let guard = self.conn.lock().ok()?;
168        guard
169            .query_row(
170                &format!(
171                    "SELECT {SELECT_COLS}
172                     FROM {SQLITE_TABLE}
173                     WHERE provider_id = ?1 AND account_id = ?2"
174                ),
175                rusqlite::params![provider_id, account_id],
176                |row| {
177                    Ok(row_to_account(
178                        row.get(0)?,
179                        row.get(1)?,
180                        row.get(2)?,
181                        row.get(3)?,
182                        row.get::<_, Option<String>>(4)?,
183                        row.get::<_, Option<String>>(5)?,
184                        row.get::<_, Option<String>>(6)?,
185                        row.get::<_, Option<i64>>(7)?,
186                        row.get::<_, Option<i64>>(8)?,
187                        row.get::<_, Option<String>>(9)?,
188                        row.get::<_, Option<String>>(10)?,
189                        row.get(11)?,
190                        row.get(12)?,
191                    ))
192                },
193            )
194            .ok()
195    }
196
197    fn find_for_user(&self, user_id: &str) -> Vec<Account> {
198        let Ok(guard) = self.conn.lock() else {
199            return Vec::new();
200        };
201        let mut stmt = match guard.prepare(&format!(
202            "SELECT {SELECT_COLS} FROM {SQLITE_TABLE} WHERE user_id = ?1"
203        )) {
204            Ok(s) => s,
205            Err(_) => return Vec::new(),
206        };
207        let iter = match stmt.query_map(rusqlite::params![user_id], |row| {
208            Ok(row_to_account(
209                row.get(0)?,
210                row.get(1)?,
211                row.get(2)?,
212                row.get(3)?,
213                row.get::<_, Option<String>>(4)?,
214                row.get::<_, Option<String>>(5)?,
215                row.get::<_, Option<String>>(6)?,
216                row.get::<_, Option<i64>>(7)?,
217                row.get::<_, Option<i64>>(8)?,
218                row.get::<_, Option<String>>(9)?,
219                row.get::<_, Option<String>>(10)?,
220                row.get(11)?,
221                row.get(12)?,
222            ))
223        }) {
224            Ok(i) => i,
225            Err(_) => return Vec::new(),
226        };
227        iter.flatten().collect()
228    }
229
230    fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
231        let Ok(guard) = self.conn.lock() else {
232            return false;
233        };
234        guard
235            .execute(
236                &format!("DELETE FROM {SQLITE_TABLE} WHERE provider_id = ?1 AND account_id = ?2"),
237                rusqlite::params![provider_id, account_id],
238            )
239            .map(|n| n > 0)
240            .unwrap_or(false)
241    }
242
243    fn list_all(&self) -> Vec<Account> {
244        let Ok(guard) = self.conn.lock() else {
245            return Vec::new();
246        };
247        let mut stmt = match guard.prepare(&format!("SELECT {SELECT_COLS} FROM {SQLITE_TABLE}")) {
248            Ok(s) => s,
249            Err(_) => return Vec::new(),
250        };
251        let iter = match stmt.query_map([], |row| {
252            Ok(row_to_account(
253                row.get(0)?,
254                row.get(1)?,
255                row.get(2)?,
256                row.get(3)?,
257                row.get::<_, Option<String>>(4)?,
258                row.get::<_, Option<String>>(5)?,
259                row.get::<_, Option<String>>(6)?,
260                row.get::<_, Option<i64>>(7)?,
261                row.get::<_, Option<i64>>(8)?,
262                row.get::<_, Option<String>>(9)?,
263                row.get::<_, Option<String>>(10)?,
264                row.get(11)?,
265                row.get(12)?,
266            ))
267        }) {
268            Ok(i) => i,
269            Err(_) => return Vec::new(),
270        };
271        iter.flatten().collect()
272    }
273}
274
275// ---------------------------------------------------------------------------
276// Postgres backend
277// ---------------------------------------------------------------------------
278
279pub use pg::PostgresAccountBackend;
280
281mod pg {
282    use super::*;
283    use postgres::Client;
284
285    pub struct PostgresAccountBackend {
286        client: Mutex<Client>,
287    }
288
289    impl PostgresAccountBackend {
290        pub fn connect(url: &str) -> Result<Self, String> {
291            let mut client = pylon_storage::postgres::live::connect_pg(url)?;
292            client
293                .batch_execute(&format!(
294                    "CREATE TABLE IF NOT EXISTS {PG_TABLE} (
295                        id TEXT PRIMARY KEY,
296                        user_id TEXT NOT NULL,
297                        provider_id TEXT NOT NULL,
298                        account_id TEXT NOT NULL,
299                        access_token TEXT,
300                        refresh_token TEXT,
301                        id_token TEXT,
302                        access_token_expires_at BIGINT,
303                        refresh_token_expires_at BIGINT,
304                        scope TEXT,
305                        password TEXT,
306                        created_at BIGINT NOT NULL,
307                        updated_at BIGINT NOT NULL,
308                        UNIQUE (provider_id, account_id)
309                    );
310                    CREATE INDEX IF NOT EXISTS {PG_TABLE}_user_idx ON {PG_TABLE}(user_id);"
311                ))
312                .map_err(|e| format!("PG init schema: {e}"))?;
313            Ok(Self {
314                client: Mutex::new(client),
315            })
316        }
317    }
318
319    impl AccountBackend for PostgresAccountBackend {
320        fn upsert(&self, a: &Account) {
321            let Ok(mut c) = self.client.lock() else {
322                return;
323            };
324            let _ = c.execute(
325                &format!(
326                    "INSERT INTO {PG_TABLE}
327                       (id, user_id, provider_id, account_id, access_token, refresh_token,
328                        id_token, access_token_expires_at, refresh_token_expires_at,
329                        scope, password, created_at, updated_at)
330                     VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
331                     ON CONFLICT (provider_id, account_id) DO UPDATE SET
332                       user_id = EXCLUDED.user_id,
333                       access_token = EXCLUDED.access_token,
334                       refresh_token = EXCLUDED.refresh_token,
335                       id_token = EXCLUDED.id_token,
336                       access_token_expires_at = EXCLUDED.access_token_expires_at,
337                       refresh_token_expires_at = EXCLUDED.refresh_token_expires_at,
338                       scope = EXCLUDED.scope,
339                       password = EXCLUDED.password,
340                       updated_at = EXCLUDED.updated_at"
341                ),
342                &[
343                    &a.id,
344                    &a.user_id,
345                    &a.provider_id,
346                    &a.account_id,
347                    &a.access_token,
348                    &a.refresh_token,
349                    &a.id_token,
350                    &a.access_token_expires_at.map(|n| n as i64),
351                    &a.refresh_token_expires_at.map(|n| n as i64),
352                    &a.scope,
353                    &a.password,
354                    &(a.created_at as i64),
355                    &(a.updated_at as i64),
356                ],
357            );
358        }
359
360        fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
361            let mut c = self.client.lock().ok()?;
362            let row = c
363                .query_opt(
364                    &format!(
365                        "SELECT {SELECT_COLS}
366                         FROM {PG_TABLE}
367                         WHERE provider_id = $1 AND account_id = $2"
368                    ),
369                    &[&provider_id, &account_id],
370                )
371                .ok()??;
372            Some(row_to_account(
373                row.get(0),
374                row.get(1),
375                row.get(2),
376                row.get(3),
377                row.get::<_, Option<String>>(4),
378                row.get::<_, Option<String>>(5),
379                row.get::<_, Option<String>>(6),
380                row.get::<_, Option<i64>>(7),
381                row.get::<_, Option<i64>>(8),
382                row.get::<_, Option<String>>(9),
383                row.get::<_, Option<String>>(10),
384                row.get(11),
385                row.get(12),
386            ))
387        }
388
389        fn find_for_user(&self, user_id: &str) -> Vec<Account> {
390            let Ok(mut c) = self.client.lock() else {
391                return Vec::new();
392            };
393            let rows = c
394                .query(
395                    &format!("SELECT {SELECT_COLS} FROM {PG_TABLE} WHERE user_id = $1"),
396                    &[&user_id],
397                )
398                .unwrap_or_default();
399            rows.iter()
400                .map(|row| {
401                    row_to_account(
402                        row.get(0),
403                        row.get(1),
404                        row.get(2),
405                        row.get(3),
406                        row.get::<_, Option<String>>(4),
407                        row.get::<_, Option<String>>(5),
408                        row.get::<_, Option<String>>(6),
409                        row.get::<_, Option<i64>>(7),
410                        row.get::<_, Option<i64>>(8),
411                        row.get::<_, Option<String>>(9),
412                        row.get::<_, Option<String>>(10),
413                        row.get(11),
414                        row.get(12),
415                    )
416                })
417                .collect()
418        }
419
420        fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
421            let Ok(mut c) = self.client.lock() else {
422                return false;
423            };
424            c.execute(
425                &format!("DELETE FROM {PG_TABLE} WHERE provider_id = $1 AND account_id = $2"),
426                &[&provider_id, &account_id],
427            )
428            .map(|n| n > 0)
429            .unwrap_or(false)
430        }
431
432        fn list_all(&self) -> Vec<Account> {
433            let Ok(mut c) = self.client.lock() else {
434                return Vec::new();
435            };
436            let rows = c
437                .query(&format!("SELECT {SELECT_COLS} FROM {PG_TABLE}"), &[])
438                .unwrap_or_default();
439            rows.iter()
440                .map(|row| {
441                    row_to_account(
442                        row.get(0),
443                        row.get(1),
444                        row.get(2),
445                        row.get(3),
446                        row.get::<_, Option<String>>(4),
447                        row.get::<_, Option<String>>(5),
448                        row.get::<_, Option<String>>(6),
449                        row.get::<_, Option<i64>>(7),
450                        row.get::<_, Option<i64>>(8),
451                        row.get::<_, Option<String>>(9),
452                        row.get::<_, Option<String>>(10),
453                        row.get(11),
454                        row.get(12),
455                    )
456                })
457                .collect()
458        }
459    }
460}
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465    use pylon_auth::{Account, AccountBackend};
466
467    fn fixture(provider_id: &str, user: &str, account_id: &str) -> Account {
468        Account {
469            id: format!("acct_{provider_id}_{account_id}"),
470            user_id: user.into(),
471            provider_id: provider_id.into(),
472            account_id: account_id.into(),
473            access_token: Some("at".into()),
474            refresh_token: Some("rt".into()),
475            id_token: None,
476            access_token_expires_at: Some(9999999999),
477            refresh_token_expires_at: None,
478            scope: Some("email profile".into()),
479            password: None,
480            created_at: 1,
481            updated_at: 1,
482        }
483    }
484
485    #[test]
486    fn sqlite_upsert_then_find_by_provider() {
487        let b = SqliteAccountBackend::in_memory().unwrap();
488        b.upsert(&fixture("google", "u1", "sub_x"));
489        let got = b.find_by_provider("google", "sub_x").unwrap();
490        assert_eq!(got.user_id, "u1");
491        assert_eq!(got.refresh_token.as_deref(), Some("rt"));
492    }
493
494    #[test]
495    fn sqlite_find_for_user_lists_multiple_providers() {
496        let b = SqliteAccountBackend::in_memory().unwrap();
497        b.upsert(&fixture("google", "u1", "g_sub"));
498        b.upsert(&fixture("github", "u1", "gh_sub"));
499        b.upsert(&fixture("google", "u2", "other"));
500        let mine = b.find_for_user("u1");
501        assert_eq!(mine.len(), 2);
502        assert!(mine.iter().any(|a| a.provider_id == "google"));
503        assert!(mine.iter().any(|a| a.provider_id == "github"));
504    }
505
506    #[test]
507    fn sqlite_upsert_is_idempotent_and_refreshes_tokens() {
508        let b = SqliteAccountBackend::in_memory().unwrap();
509        let mut a = fixture("google", "u1", "sub");
510        b.upsert(&a);
511        a.access_token = Some("new_at".into());
512        a.updated_at = 99;
513        b.upsert(&a);
514        let got = b.find_by_provider("google", "sub").unwrap();
515        assert_eq!(got.access_token.as_deref(), Some("new_at"));
516        assert_eq!(got.updated_at, 99);
517        assert_eq!(b.find_for_user("u1").len(), 1);
518    }
519
520    #[test]
521    fn sqlite_unlink_removes_row() {
522        let b = SqliteAccountBackend::in_memory().unwrap();
523        b.upsert(&fixture("google", "u1", "sub"));
524        assert!(b.unlink("google", "sub"));
525        assert!(b.find_by_provider("google", "sub").is_none());
526        assert!(!b.unlink("google", "sub"), "second unlink is a no-op");
527    }
528
529    #[test]
530    fn sqlite_password_column_is_present_for_future_credential_provider() {
531        // Confirms the column is wired end-to-end so adding email/password
532        // auth later doesn't need a schema migration. Forms the basis of
533        // the future provider_id="credential" rows.
534        let b = SqliteAccountBackend::in_memory().unwrap();
535        let mut a = fixture("credential", "u1", "u1");
536        a.access_token = None;
537        a.refresh_token = None;
538        a.password = Some("argon2id$v=19$m=65536,t=3,p=4$...".into());
539        b.upsert(&a);
540        let got = b.find_by_provider("credential", "u1").unwrap();
541        assert_eq!(
542            got.password.as_deref(),
543            Some("argon2id$v=19$m=65536,t=3,p=4$...")
544        );
545    }
546}