1use std::sync::{Arc, Mutex};
12
13use pylon_auth::{Session, SessionBackend};
14use rusqlite::Connection;
15
16const TABLE: &str = "_pylon_sessions";
17
18pub struct SqliteSessionBackend {
24 conn: Arc<Mutex<Connection>>,
25}
26
27impl SqliteSessionBackend {
28 pub fn open(path: &str) -> Result<Self, String> {
30 let conn = Connection::open(path).map_err(|e| format!("open: {e}"))?;
31 Self::from_connection(conn)
32 }
33
34 pub fn in_memory() -> Result<Self, String> {
36 let conn = Connection::open_in_memory().map_err(|e| format!("open: {e}"))?;
37 Self::from_connection(conn)
38 }
39
40 fn from_connection(conn: Connection) -> Result<Self, String> {
41 conn.execute_batch(&format!(
45 "CREATE TABLE IF NOT EXISTS {TABLE} (
46 token TEXT PRIMARY KEY,
47 user_id TEXT NOT NULL,
48 expires_at INTEGER NOT NULL,
49 created_at INTEGER NOT NULL,
50 device TEXT,
51 tenant_id TEXT
52 );
53 CREATE INDEX IF NOT EXISTS {TABLE}_user_idx ON {TABLE}(user_id);
54 CREATE INDEX IF NOT EXISTS {TABLE}_exp_idx ON {TABLE}(expires_at);"
55 ))
56 .map_err(|e| format!("init schema: {e}"))?;
57 let _ = conn.execute(
59 &format!("ALTER TABLE {TABLE} ADD COLUMN tenant_id TEXT"),
60 [],
61 );
62 Ok(Self {
63 conn: Arc::new(Mutex::new(conn)),
64 })
65 }
66}
67
68impl SessionBackend for SqliteSessionBackend {
69 fn load_all(&self) -> Vec<Session> {
70 let guard = match self.conn.lock() {
71 Ok(g) => g,
72 Err(_) => return Vec::new(),
73 };
74 let mut stmt = match guard.prepare(&format!(
75 "SELECT token, user_id, expires_at, created_at, device, tenant_id FROM {TABLE}"
76 )) {
77 Ok(s) => s,
78 Err(_) => return Vec::new(),
79 };
80 let iter = match stmt.query_map([], |row| {
81 Ok(Session {
82 token: row.get(0)?,
83 user_id: row.get(1)?,
84 expires_at: row.get::<_, i64>(2)? as u64,
85 created_at: row.get::<_, i64>(3)? as u64,
86 device: row.get::<_, Option<String>>(4)?,
87 tenant_id: row.get::<_, Option<String>>(5)?,
88 })
89 }) {
90 Ok(i) => i,
91 Err(_) => return Vec::new(),
92 };
93 iter.flatten().collect()
94 }
95
96 fn save(&self, session: &Session) {
97 if let Ok(guard) = self.conn.lock() {
98 let _ = guard.execute(
99 &format!(
100 "INSERT INTO {TABLE} (token, user_id, expires_at, created_at, device, tenant_id)
101 VALUES (?1, ?2, ?3, ?4, ?5, ?6)
102 ON CONFLICT(token) DO UPDATE SET
103 user_id=excluded.user_id,
104 expires_at=excluded.expires_at,
105 device=excluded.device,
106 tenant_id=excluded.tenant_id"
107 ),
108 rusqlite::params![
109 session.token,
110 session.user_id,
111 session.expires_at as i64,
112 session.created_at as i64,
113 session.device,
114 session.tenant_id,
115 ],
116 );
117 }
118 }
119
120 fn remove(&self, token: &str) {
121 if let Ok(guard) = self.conn.lock() {
122 let _ = guard.execute(
123 &format!("DELETE FROM {TABLE} WHERE token = ?1"),
124 rusqlite::params![token],
125 );
126 }
127 }
128}
129
130pub use pg::PostgresSessionBackend;
135
136mod pg {
137 use super::*;
138 use postgres::Client;
139 use std::sync::Mutex;
140
141 const PG_TABLE: &str = "_pylon_sessions";
142
143 pub struct PostgresSessionBackend {
148 client: Mutex<Client>,
149 }
150
151 impl PostgresSessionBackend {
152 pub fn connect(url: &str) -> Result<Self, String> {
153 let mut client = pylon_storage::postgres::live::connect_pg(url)?;
154 client
155 .batch_execute(&format!(
156 "CREATE TABLE IF NOT EXISTS {PG_TABLE} (
157 token TEXT PRIMARY KEY,
158 user_id TEXT NOT NULL,
159 expires_at BIGINT NOT NULL,
160 created_at BIGINT NOT NULL,
161 device TEXT,
162 tenant_id TEXT
163 );
164 CREATE INDEX IF NOT EXISTS {PG_TABLE}_user_idx ON {PG_TABLE}(user_id);
165 CREATE INDEX IF NOT EXISTS {PG_TABLE}_exp_idx ON {PG_TABLE}(expires_at);"
166 ))
167 .map_err(|e| format!("PG init schema: {e}"))?;
168 Ok(Self {
169 client: Mutex::new(client),
170 })
171 }
172 }
173
174 impl SessionBackend for PostgresSessionBackend {
175 fn load_all(&self) -> Vec<Session> {
176 let Ok(mut c) = self.client.lock() else {
177 return Vec::new();
178 };
179 let rows = c
180 .query(
181 &format!(
182 "SELECT token, user_id, expires_at, created_at, device, tenant_id
183 FROM {PG_TABLE}"
184 ),
185 &[],
186 )
187 .unwrap_or_default();
188 rows.iter()
189 .map(|row| Session {
190 token: row.get(0),
191 user_id: row.get(1),
192 expires_at: row.get::<_, i64>(2) as u64,
193 created_at: row.get::<_, i64>(3) as u64,
194 device: row.get::<_, Option<String>>(4),
195 tenant_id: row.get::<_, Option<String>>(5),
196 })
197 .collect()
198 }
199
200 fn save(&self, session: &Session) {
201 if let Ok(mut c) = self.client.lock() {
202 let _ = c.execute(
203 &format!(
204 "INSERT INTO {PG_TABLE} (token, user_id, expires_at, created_at, device, tenant_id)
205 VALUES ($1, $2, $3, $4, $5, $6)
206 ON CONFLICT (token) DO UPDATE SET
207 user_id = EXCLUDED.user_id,
208 expires_at = EXCLUDED.expires_at,
209 device = EXCLUDED.device,
210 tenant_id = EXCLUDED.tenant_id"
211 ),
212 &[
213 &session.token,
214 &session.user_id,
215 &(session.expires_at as i64),
216 &(session.created_at as i64),
217 &session.device,
218 &session.tenant_id,
219 ],
220 );
221 }
222 }
223
224 fn remove(&self, token: &str) {
225 if let Ok(mut c) = self.client.lock() {
226 let _ = c.execute(
227 &format!("DELETE FROM {PG_TABLE} WHERE token = $1"),
228 &[&token],
229 );
230 }
231 }
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238 use pylon_auth::Session;
239
240 #[test]
241 fn roundtrip_save_load() {
242 let backend = SqliteSessionBackend::in_memory().unwrap();
243 let session = Session::new("user_1".to_string());
244 backend.save(&session);
245 let loaded = backend.load_all();
246 assert_eq!(loaded.len(), 1);
247 assert_eq!(loaded[0].user_id, "user_1");
248 assert_eq!(loaded[0].token, session.token);
249 }
250
251 #[test]
252 fn remove_takes_effect() {
253 let backend = SqliteSessionBackend::in_memory().unwrap();
254 let session = Session::new("u".to_string());
255 backend.save(&session);
256 backend.remove(&session.token);
257 assert!(backend.load_all().is_empty());
258 }
259
260 #[test]
261 fn upsert_on_save_twice() {
262 let backend = SqliteSessionBackend::in_memory().unwrap();
263 let mut session = Session::new("u".to_string());
264 backend.save(&session);
265 session.device = Some("Safari on Mac".into());
266 backend.save(&session);
267 let loaded = backend.load_all();
268 assert_eq!(loaded.len(), 1);
269 assert_eq!(loaded[0].device.as_deref(), Some("Safari on Mac"));
270 }
271}