Skip to main content

pylon_runtime/
session_backend.rs

1//! SQLite-backed session persistence.
2//!
3//! Stores sessions in a dedicated `_pylon_sessions` table so users don't
4//! get logged out when the server restarts.
5//!
6//! The schema is intentionally minimal and under-engineered: every session
7//! mutation is a single UPSERT/DELETE. Reads happen only at startup via
8//! `load_all`. If session-churn ever outgrows this, sharding/indexing can
9//! come later without changing the trait contract.
10
11use std::sync::{Arc, Mutex};
12
13use pylon_auth::{Session, SessionBackend};
14use rusqlite::Connection;
15
16const TABLE: &str = "_pylon_sessions";
17
18/// Persistent session backend backed by a SQLite connection.
19///
20/// Holds the connection behind a `Mutex` because SQLite's `Connection`
21/// isn't `Sync`. Sessions are low-frequency compared to CRUD — this lock
22/// is not a hot path.
23pub struct SqliteSessionBackend {
24    conn: Arc<Mutex<Connection>>,
25}
26
27impl SqliteSessionBackend {
28    /// Open or create a SQLite file and ensure the session table exists.
29    pub fn open(path: &str) -> Result<Self, String> {
30        let conn = Connection::open(path).map_err(|e| format!("open: {e}"))?;
31        Self::from_connection(conn)
32    }
33
34    /// Use an in-memory database (for tests).
35    pub fn in_memory() -> Result<Self, String> {
36        let conn = Connection::open_in_memory().map_err(|e| format!("open: {e}"))?;
37        Self::from_connection(conn)
38    }
39
40    fn from_connection(conn: Connection) -> Result<Self, String> {
41        // Base table for new installs. Existing installs miss `tenant_id`
42        // and get an ALTER below — ADD COLUMN is a no-op on a table that
43        // already has the column, so we swallow its error for idempotency.
44        conn.execute_batch(&format!(
45            "CREATE TABLE IF NOT EXISTS {TABLE} (
46                token TEXT PRIMARY KEY,
47                user_id TEXT NOT NULL,
48                expires_at INTEGER NOT NULL,
49                created_at INTEGER NOT NULL,
50                device TEXT,
51                tenant_id TEXT
52            );
53            CREATE INDEX IF NOT EXISTS {TABLE}_user_idx ON {TABLE}(user_id);
54            CREATE INDEX IF NOT EXISTS {TABLE}_exp_idx ON {TABLE}(expires_at);"
55        ))
56        .map_err(|e| format!("init schema: {e}"))?;
57        // Idempotent migration for pre-existing session DBs.
58        let _ = conn.execute(
59            &format!("ALTER TABLE {TABLE} ADD COLUMN tenant_id TEXT"),
60            [],
61        );
62        Ok(Self {
63            conn: Arc::new(Mutex::new(conn)),
64        })
65    }
66}
67
68impl SessionBackend for SqliteSessionBackend {
69    fn load_all(&self) -> Vec<Session> {
70        let guard = match self.conn.lock() {
71            Ok(g) => g,
72            Err(_) => return Vec::new(),
73        };
74        let mut stmt = match guard.prepare(&format!(
75            "SELECT token, user_id, expires_at, created_at, device, tenant_id FROM {TABLE}"
76        )) {
77            Ok(s) => s,
78            Err(_) => return Vec::new(),
79        };
80        let iter = match stmt.query_map([], |row| {
81            Ok(Session {
82                token: row.get(0)?,
83                user_id: row.get(1)?,
84                expires_at: row.get::<_, i64>(2)? as u64,
85                created_at: row.get::<_, i64>(3)? as u64,
86                device: row.get::<_, Option<String>>(4)?,
87                tenant_id: row.get::<_, Option<String>>(5)?,
88            })
89        }) {
90            Ok(i) => i,
91            Err(_) => return Vec::new(),
92        };
93        iter.flatten().collect()
94    }
95
96    fn save(&self, session: &Session) {
97        if let Ok(guard) = self.conn.lock() {
98            let _ = guard.execute(
99                &format!(
100                    "INSERT INTO {TABLE} (token, user_id, expires_at, created_at, device, tenant_id)
101                     VALUES (?1, ?2, ?3, ?4, ?5, ?6)
102                     ON CONFLICT(token) DO UPDATE SET
103                       user_id=excluded.user_id,
104                       expires_at=excluded.expires_at,
105                       device=excluded.device,
106                       tenant_id=excluded.tenant_id"
107                ),
108                rusqlite::params![
109                    session.token,
110                    session.user_id,
111                    session.expires_at as i64,
112                    session.created_at as i64,
113                    session.device,
114                    session.tenant_id,
115                ],
116            );
117        }
118    }
119
120    fn remove(&self, token: &str) {
121        if let Ok(guard) = self.conn.lock() {
122            let _ = guard.execute(
123                &format!("DELETE FROM {TABLE} WHERE token = ?1"),
124                rusqlite::params![token],
125            );
126        }
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133    use pylon_auth::Session;
134
135    #[test]
136    fn roundtrip_save_load() {
137        let backend = SqliteSessionBackend::in_memory().unwrap();
138        let session = Session::new("user_1".to_string());
139        backend.save(&session);
140        let loaded = backend.load_all();
141        assert_eq!(loaded.len(), 1);
142        assert_eq!(loaded[0].user_id, "user_1");
143        assert_eq!(loaded[0].token, session.token);
144    }
145
146    #[test]
147    fn remove_takes_effect() {
148        let backend = SqliteSessionBackend::in_memory().unwrap();
149        let session = Session::new("u".to_string());
150        backend.save(&session);
151        backend.remove(&session.token);
152        assert!(backend.load_all().is_empty());
153    }
154
155    #[test]
156    fn upsert_on_save_twice() {
157        let backend = SqliteSessionBackend::in_memory().unwrap();
158        let mut session = Session::new("u".to_string());
159        backend.save(&session);
160        session.device = Some("Safari on Mac".into());
161        backend.save(&session);
162        let loaded = backend.load_all();
163        assert_eq!(loaded.len(), 1);
164        assert_eq!(loaded[0].device.as_deref(), Some("Safari on Mac"));
165    }
166}