Skip to main content

pylon_runtime/
session_backend.rs

1//! SQLite-backed session persistence.
2//!
3//! Stores sessions in a dedicated `_pylon_sessions` table so users don't
4//! get logged out when the server restarts.
5//!
6//! The schema is intentionally minimal and under-engineered: every session
7//! mutation is a single UPSERT/DELETE. Reads happen only at startup via
8//! `load_all`. If session-churn ever outgrows this, sharding/indexing can
9//! come later without changing the trait contract.
10
11use std::sync::{Arc, Mutex};
12
13use pylon_auth::{Session, SessionBackend};
14use rusqlite::Connection;
15
16const TABLE: &str = "_pylon_sessions";
17
18/// Persistent session backend backed by a SQLite connection.
19///
20/// Holds the connection behind a `Mutex` because SQLite's `Connection`
21/// isn't `Sync`. Sessions are low-frequency compared to CRUD — this lock
22/// is not a hot path.
23pub struct SqliteSessionBackend {
24    conn: Arc<Mutex<Connection>>,
25}
26
27impl SqliteSessionBackend {
28    /// Open or create a SQLite file and ensure the session table exists.
29    pub fn open(path: &str) -> Result<Self, String> {
30        let conn = Connection::open(path).map_err(|e| format!("open: {e}"))?;
31        Self::from_connection(conn)
32    }
33
34    /// Use an in-memory database (for tests).
35    pub fn in_memory() -> Result<Self, String> {
36        let conn = Connection::open_in_memory().map_err(|e| format!("open: {e}"))?;
37        Self::from_connection(conn)
38    }
39
40    fn from_connection(conn: Connection) -> Result<Self, String> {
41        // Base table for new installs. Existing installs miss `tenant_id`
42        // and get an ALTER below — ADD COLUMN is a no-op on a table that
43        // already has the column, so we swallow its error for idempotency.
44        conn.execute_batch(&format!(
45            "CREATE TABLE IF NOT EXISTS {TABLE} (
46                token TEXT PRIMARY KEY,
47                user_id TEXT NOT NULL,
48                expires_at INTEGER NOT NULL,
49                created_at INTEGER NOT NULL,
50                device TEXT,
51                tenant_id TEXT
52            );
53            CREATE INDEX IF NOT EXISTS {TABLE}_user_idx ON {TABLE}(user_id);
54            CREATE INDEX IF NOT EXISTS {TABLE}_exp_idx ON {TABLE}(expires_at);"
55        ))
56        .map_err(|e| format!("init schema: {e}"))?;
57        // Idempotent migration for pre-existing session DBs.
58        let _ = conn.execute(
59            &format!("ALTER TABLE {TABLE} ADD COLUMN tenant_id TEXT"),
60            [],
61        );
62        Ok(Self {
63            conn: Arc::new(Mutex::new(conn)),
64        })
65    }
66}
67
68impl SessionBackend for SqliteSessionBackend {
69    fn load_all(&self) -> Vec<Session> {
70        let guard = match self.conn.lock() {
71            Ok(g) => g,
72            Err(_) => return Vec::new(),
73        };
74        let mut stmt = match guard.prepare(&format!(
75            "SELECT token, user_id, expires_at, created_at, device, tenant_id FROM {TABLE}"
76        )) {
77            Ok(s) => s,
78            Err(_) => return Vec::new(),
79        };
80        let iter = match stmt.query_map([], |row| {
81            Ok(Session {
82                token: row.get(0)?,
83                user_id: row.get(1)?,
84                expires_at: row.get::<_, i64>(2)? as u64,
85                created_at: row.get::<_, i64>(3)? as u64,
86                device: row.get::<_, Option<String>>(4)?,
87                tenant_id: row.get::<_, Option<String>>(5)?,
88            })
89        }) {
90            Ok(i) => i,
91            Err(_) => return Vec::new(),
92        };
93        iter.flatten().collect()
94    }
95
96    fn save(&self, session: &Session) {
97        if let Ok(guard) = self.conn.lock() {
98            let _ = guard.execute(
99                &format!(
100                    "INSERT INTO {TABLE} (token, user_id, expires_at, created_at, device, tenant_id)
101                     VALUES (?1, ?2, ?3, ?4, ?5, ?6)
102                     ON CONFLICT(token) DO UPDATE SET
103                       user_id=excluded.user_id,
104                       expires_at=excluded.expires_at,
105                       device=excluded.device,
106                       tenant_id=excluded.tenant_id"
107                ),
108                rusqlite::params![
109                    session.token,
110                    session.user_id,
111                    session.expires_at as i64,
112                    session.created_at as i64,
113                    session.device,
114                    session.tenant_id,
115                ],
116            );
117        }
118    }
119
120    fn remove(&self, token: &str) {
121        if let Ok(guard) = self.conn.lock() {
122            let _ = guard.execute(
123                &format!("DELETE FROM {TABLE} WHERE token = ?1"),
124                rusqlite::params![token],
125            );
126        }
127    }
128}
129
130// ---------------------------------------------------------------------------
131// Postgres backend
132// ---------------------------------------------------------------------------
133
134pub use pg::PostgresSessionBackend;
135
136mod pg {
137    use super::*;
138    use postgres::Client;
139    use std::sync::Mutex;
140
141    const PG_TABLE: &str = "_pylon_sessions";
142
143    /// Postgres-backed session store. Schema mirrors the SQLite version
144    /// — same column set + same indexes — so a deploy that flips
145    /// `DATABASE_URL` from a local SQLite file to a managed PG cluster
146    /// only changes WHERE the rows live, not what the rows mean.
147    pub struct PostgresSessionBackend {
148        client: Mutex<Client>,
149    }
150
151    impl PostgresSessionBackend {
152        pub fn connect(url: &str) -> Result<Self, String> {
153            let mut client =
154                Client::connect(url, postgres::NoTls).map_err(|e| format!("PG connect: {e}"))?;
155            client
156                .batch_execute(&format!(
157                    "CREATE TABLE IF NOT EXISTS {PG_TABLE} (
158                        token TEXT PRIMARY KEY,
159                        user_id TEXT NOT NULL,
160                        expires_at BIGINT NOT NULL,
161                        created_at BIGINT NOT NULL,
162                        device TEXT,
163                        tenant_id TEXT
164                    );
165                    CREATE INDEX IF NOT EXISTS {PG_TABLE}_user_idx ON {PG_TABLE}(user_id);
166                    CREATE INDEX IF NOT EXISTS {PG_TABLE}_exp_idx ON {PG_TABLE}(expires_at);"
167                ))
168                .map_err(|e| format!("PG init schema: {e}"))?;
169            Ok(Self {
170                client: Mutex::new(client),
171            })
172        }
173    }
174
175    impl SessionBackend for PostgresSessionBackend {
176        fn load_all(&self) -> Vec<Session> {
177            let Ok(mut c) = self.client.lock() else {
178                return Vec::new();
179            };
180            let rows = c
181                .query(
182                    &format!(
183                        "SELECT token, user_id, expires_at, created_at, device, tenant_id
184                         FROM {PG_TABLE}"
185                    ),
186                    &[],
187                )
188                .unwrap_or_default();
189            rows.iter()
190                .map(|row| Session {
191                    token: row.get(0),
192                    user_id: row.get(1),
193                    expires_at: row.get::<_, i64>(2) as u64,
194                    created_at: row.get::<_, i64>(3) as u64,
195                    device: row.get::<_, Option<String>>(4),
196                    tenant_id: row.get::<_, Option<String>>(5),
197                })
198                .collect()
199        }
200
201        fn save(&self, session: &Session) {
202            if let Ok(mut c) = self.client.lock() {
203                let _ = c.execute(
204                    &format!(
205                        "INSERT INTO {PG_TABLE} (token, user_id, expires_at, created_at, device, tenant_id)
206                         VALUES ($1, $2, $3, $4, $5, $6)
207                         ON CONFLICT (token) DO UPDATE SET
208                           user_id = EXCLUDED.user_id,
209                           expires_at = EXCLUDED.expires_at,
210                           device = EXCLUDED.device,
211                           tenant_id = EXCLUDED.tenant_id"
212                    ),
213                    &[
214                        &session.token,
215                        &session.user_id,
216                        &(session.expires_at as i64),
217                        &(session.created_at as i64),
218                        &session.device,
219                        &session.tenant_id,
220                    ],
221                );
222            }
223        }
224
225        fn remove(&self, token: &str) {
226            if let Ok(mut c) = self.client.lock() {
227                let _ = c.execute(
228                    &format!("DELETE FROM {PG_TABLE} WHERE token = $1"),
229                    &[&token],
230                );
231            }
232        }
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239    use pylon_auth::Session;
240
241    #[test]
242    fn roundtrip_save_load() {
243        let backend = SqliteSessionBackend::in_memory().unwrap();
244        let session = Session::new("user_1".to_string());
245        backend.save(&session);
246        let loaded = backend.load_all();
247        assert_eq!(loaded.len(), 1);
248        assert_eq!(loaded[0].user_id, "user_1");
249        assert_eq!(loaded[0].token, session.token);
250    }
251
252    #[test]
253    fn remove_takes_effect() {
254        let backend = SqliteSessionBackend::in_memory().unwrap();
255        let session = Session::new("u".to_string());
256        backend.save(&session);
257        backend.remove(&session.token);
258        assert!(backend.load_all().is_empty());
259    }
260
261    #[test]
262    fn upsert_on_save_twice() {
263        let backend = SqliteSessionBackend::in_memory().unwrap();
264        let mut session = Session::new("u".to_string());
265        backend.save(&session);
266        session.device = Some("Safari on Mac".into());
267        backend.save(&session);
268        let loaded = backend.load_all();
269        assert_eq!(loaded.len(), 1);
270        assert_eq!(loaded[0].device.as_deref(), Some("Safari on Mac"));
271    }
272}