Skip to main content

pylon_runtime/
session_backend.rs

1//! SQLite-backed session persistence.
2//!
3//! Stores sessions in a dedicated `_pylon_sessions` table so users don't
4//! get logged out when the server restarts.
5//!
6//! The schema is intentionally minimal and under-engineered: every session
7//! mutation is a single UPSERT/DELETE. Reads happen only at startup via
8//! `load_all`. If session-churn ever outgrows this, sharding/indexing can
9//! come later without changing the trait contract.
10
11use std::sync::{Arc, Mutex};
12
13use pylon_auth::{Session, SessionBackend};
14use rusqlite::Connection;
15
16const TABLE: &str = "_pylon_sessions";
17
18/// Persistent session backend backed by a SQLite connection.
19///
20/// Holds the connection behind a `Mutex` because SQLite's `Connection`
21/// isn't `Sync`. Sessions are low-frequency compared to CRUD — this lock
22/// is not a hot path.
23pub struct SqliteSessionBackend {
24    conn: Arc<Mutex<Connection>>,
25}
26
27impl SqliteSessionBackend {
28    /// Open or create a SQLite file and ensure the session table exists.
29    pub fn open(path: &str) -> Result<Self, String> {
30        let conn = Connection::open(path).map_err(|e| format!("open: {e}"))?;
31        Self::from_connection(conn)
32    }
33
34    /// Use an in-memory database (for tests).
35    pub fn in_memory() -> Result<Self, String> {
36        let conn = Connection::open_in_memory().map_err(|e| format!("open: {e}"))?;
37        Self::from_connection(conn)
38    }
39
40    fn from_connection(conn: Connection) -> Result<Self, String> {
41        // Base table for new installs. Existing installs miss `tenant_id`
42        // and get an ALTER below — ADD COLUMN is a no-op on a table that
43        // already has the column, so we swallow its error for idempotency.
44        conn.execute_batch(&format!(
45            "CREATE TABLE IF NOT EXISTS {TABLE} (
46                token TEXT PRIMARY KEY,
47                user_id TEXT NOT NULL,
48                expires_at INTEGER NOT NULL,
49                created_at INTEGER NOT NULL,
50                device TEXT,
51                tenant_id TEXT
52            );
53            CREATE INDEX IF NOT EXISTS {TABLE}_user_idx ON {TABLE}(user_id);
54            CREATE INDEX IF NOT EXISTS {TABLE}_exp_idx ON {TABLE}(expires_at);"
55        ))
56        .map_err(|e| format!("init schema: {e}"))?;
57        // Idempotent migration for pre-existing session DBs.
58        let _ = conn.execute(
59            &format!("ALTER TABLE {TABLE} ADD COLUMN tenant_id TEXT"),
60            [],
61        );
62        Ok(Self {
63            conn: Arc::new(Mutex::new(conn)),
64        })
65    }
66}
67
68impl SessionBackend for SqliteSessionBackend {
69    fn load_all(&self) -> Vec<Session> {
70        let guard = match self.conn.lock() {
71            Ok(g) => g,
72            Err(_) => return Vec::new(),
73        };
74        let mut stmt = match guard.prepare(&format!(
75            "SELECT token, user_id, expires_at, created_at, device, tenant_id FROM {TABLE}"
76        )) {
77            Ok(s) => s,
78            Err(_) => return Vec::new(),
79        };
80        let iter = match stmt.query_map([], |row| {
81            Ok(Session {
82                token: row.get(0)?,
83                user_id: row.get(1)?,
84                expires_at: row.get::<_, i64>(2)? as u64,
85                created_at: row.get::<_, i64>(3)? as u64,
86                device: row.get::<_, Option<String>>(4)?,
87                tenant_id: row.get::<_, Option<String>>(5)?,
88            })
89        }) {
90            Ok(i) => i,
91            Err(_) => return Vec::new(),
92        };
93        iter.flatten().collect()
94    }
95
96    fn save(&self, session: &Session) {
97        if let Ok(guard) = self.conn.lock() {
98            let _ = guard.execute(
99                &format!(
100                    "INSERT INTO {TABLE} (token, user_id, expires_at, created_at, device, tenant_id)
101                     VALUES (?1, ?2, ?3, ?4, ?5, ?6)
102                     ON CONFLICT(token) DO UPDATE SET
103                       user_id=excluded.user_id,
104                       expires_at=excluded.expires_at,
105                       device=excluded.device,
106                       tenant_id=excluded.tenant_id"
107                ),
108                rusqlite::params![
109                    session.token,
110                    session.user_id,
111                    session.expires_at as i64,
112                    session.created_at as i64,
113                    session.device,
114                    session.tenant_id,
115                ],
116            );
117        }
118    }
119
120    fn remove(&self, token: &str) {
121        if let Ok(guard) = self.conn.lock() {
122            let _ = guard.execute(
123                &format!("DELETE FROM {TABLE} WHERE token = ?1"),
124                rusqlite::params![token],
125            );
126        }
127    }
128}
129
130// ---------------------------------------------------------------------------
131// Postgres backend
132// ---------------------------------------------------------------------------
133
134pub use pg::PostgresSessionBackend;
135
136mod pg {
137    use super::*;
138    use postgres::Client;
139    use std::sync::Mutex;
140
141    const PG_TABLE: &str = "_pylon_sessions";
142
143    /// Postgres-backed session store. Schema mirrors the SQLite version
144    /// — same column set + same indexes — so a deploy that flips
145    /// `DATABASE_URL` from a local SQLite file to a managed PG cluster
146    /// only changes WHERE the rows live, not what the rows mean.
147    pub struct PostgresSessionBackend {
148        client: Mutex<Client>,
149    }
150
151    impl PostgresSessionBackend {
152        pub fn connect(url: &str) -> Result<Self, String> {
153            let mut client = pylon_storage::postgres::live::connect_pg(url)?;
154            client
155                .batch_execute(&format!(
156                    "CREATE TABLE IF NOT EXISTS {PG_TABLE} (
157                        token TEXT PRIMARY KEY,
158                        user_id TEXT NOT NULL,
159                        expires_at BIGINT NOT NULL,
160                        created_at BIGINT NOT NULL,
161                        device TEXT,
162                        tenant_id TEXT
163                    );
164                    CREATE INDEX IF NOT EXISTS {PG_TABLE}_user_idx ON {PG_TABLE}(user_id);
165                    CREATE INDEX IF NOT EXISTS {PG_TABLE}_exp_idx ON {PG_TABLE}(expires_at);"
166                ))
167                .map_err(|e| format!("PG init schema: {e}"))?;
168            Ok(Self {
169                client: Mutex::new(client),
170            })
171        }
172    }
173
174    impl SessionBackend for PostgresSessionBackend {
175        fn load_all(&self) -> Vec<Session> {
176            let Ok(mut c) = self.client.lock() else {
177                return Vec::new();
178            };
179            let rows = c
180                .query(
181                    &format!(
182                        "SELECT token, user_id, expires_at, created_at, device, tenant_id
183                         FROM {PG_TABLE}"
184                    ),
185                    &[],
186                )
187                .unwrap_or_default();
188            rows.iter()
189                .map(|row| Session {
190                    token: row.get(0),
191                    user_id: row.get(1),
192                    expires_at: row.get::<_, i64>(2) as u64,
193                    created_at: row.get::<_, i64>(3) as u64,
194                    device: row.get::<_, Option<String>>(4),
195                    tenant_id: row.get::<_, Option<String>>(5),
196                })
197                .collect()
198        }
199
200        fn save(&self, session: &Session) {
201            if let Ok(mut c) = self.client.lock() {
202                let _ = c.execute(
203                    &format!(
204                        "INSERT INTO {PG_TABLE} (token, user_id, expires_at, created_at, device, tenant_id)
205                         VALUES ($1, $2, $3, $4, $5, $6)
206                         ON CONFLICT (token) DO UPDATE SET
207                           user_id = EXCLUDED.user_id,
208                           expires_at = EXCLUDED.expires_at,
209                           device = EXCLUDED.device,
210                           tenant_id = EXCLUDED.tenant_id"
211                    ),
212                    &[
213                        &session.token,
214                        &session.user_id,
215                        &(session.expires_at as i64),
216                        &(session.created_at as i64),
217                        &session.device,
218                        &session.tenant_id,
219                    ],
220                );
221            }
222        }
223
224        fn remove(&self, token: &str) {
225            if let Ok(mut c) = self.client.lock() {
226                let _ = c.execute(
227                    &format!("DELETE FROM {PG_TABLE} WHERE token = $1"),
228                    &[&token],
229                );
230            }
231        }
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238    use pylon_auth::Session;
239
240    #[test]
241    fn roundtrip_save_load() {
242        let backend = SqliteSessionBackend::in_memory().unwrap();
243        let session = Session::new("user_1".to_string());
244        backend.save(&session);
245        let loaded = backend.load_all();
246        assert_eq!(loaded.len(), 1);
247        assert_eq!(loaded[0].user_id, "user_1");
248        assert_eq!(loaded[0].token, session.token);
249    }
250
251    #[test]
252    fn remove_takes_effect() {
253        let backend = SqliteSessionBackend::in_memory().unwrap();
254        let session = Session::new("u".to_string());
255        backend.save(&session);
256        backend.remove(&session.token);
257        assert!(backend.load_all().is_empty());
258    }
259
260    #[test]
261    fn upsert_on_save_twice() {
262        let backend = SqliteSessionBackend::in_memory().unwrap();
263        let mut session = Session::new("u".to_string());
264        backend.save(&session);
265        session.device = Some("Safari on Mac".into());
266        backend.save(&session);
267        let loaded = backend.load_all();
268        assert_eq!(loaded.len(), 1);
269        assert_eq!(loaded[0].device.as_deref(), Some("Safari on Mac"));
270    }
271}