Skip to main content

pylon_runtime/
magic_code_backend.rs

1//! Persistent magic-code stores. Two backends ship today:
2//!
3//! - [`SqliteMagicCodeBackend`] — single-file durability. Default for
4//!   self-hosted single-replica deploys. Lives in the same SQLite file
5//!   as sessions/oauth-state (`PYLON_SESSION_DB`).
6//! - [`PostgresMagicCodeBackend`] — multi-replica deploys against
7//!   `DATABASE_URL=postgres://...`. Magic codes are short-lived (10 min)
8//!   so a row-level pkey is sufficient — no GIN/index gymnastics needed.
9//!
10//! Both back the [`pylon_auth::MagicCodeBackend`] trait. The store layer
11//! in pylon-auth keeps an in-memory cache as the authoritative read path;
12//! these backends are write-through + load-on-startup so a server
13//! restart between "send magic code" and "verify" doesn't kill the
14//! user's pending login.
15
16use std::sync::{Arc, Mutex};
17
18use pylon_auth::{MagicCode, MagicCodeBackend};
19use rusqlite::Connection;
20
21const SQLITE_TABLE: &str = "_pylon_magic_codes";
22const PG_TABLE: &str = "_pylon_magic_codes";
23
24// ---------------------------------------------------------------------------
25// SQLite backend
26// ---------------------------------------------------------------------------
27
28pub struct SqliteMagicCodeBackend {
29    conn: Arc<Mutex<Connection>>,
30}
31
32impl SqliteMagicCodeBackend {
33    pub fn open(path: &str) -> Result<Self, String> {
34        let conn = Connection::open(path).map_err(|e| format!("open: {e}"))?;
35        Self::from_connection(conn)
36    }
37
38    pub fn in_memory() -> Result<Self, String> {
39        let conn = Connection::open_in_memory().map_err(|e| format!("open: {e}"))?;
40        Self::from_connection(conn)
41    }
42
43    fn from_connection(conn: Connection) -> Result<Self, String> {
44        conn.execute_batch(&format!(
45            "CREATE TABLE IF NOT EXISTS {SQLITE_TABLE} (
46                email TEXT PRIMARY KEY,
47                code TEXT NOT NULL,
48                expires_at INTEGER NOT NULL,
49                attempts INTEGER NOT NULL DEFAULT 0
50            );
51            CREATE INDEX IF NOT EXISTS {SQLITE_TABLE}_exp_idx ON {SQLITE_TABLE}(expires_at);"
52        ))
53        .map_err(|e| format!("init schema: {e}"))?;
54        Ok(Self {
55            conn: Arc::new(Mutex::new(conn)),
56        })
57    }
58}
59
60impl MagicCodeBackend for SqliteMagicCodeBackend {
61    fn put(&self, email: &str, code: &MagicCode) {
62        if let Ok(guard) = self.conn.lock() {
63            let _ = guard.execute(
64                &format!(
65                    "INSERT INTO {SQLITE_TABLE} (email, code, expires_at, attempts)
66                     VALUES (?1, ?2, ?3, ?4)
67                     ON CONFLICT(email) DO UPDATE SET
68                       code=excluded.code,
69                       expires_at=excluded.expires_at,
70                       attempts=excluded.attempts"
71                ),
72                rusqlite::params![
73                    email,
74                    code.code,
75                    code.expires_at as i64,
76                    code.attempts as i64
77                ],
78            );
79        }
80    }
81    fn get(&self, email: &str) -> Option<MagicCode> {
82        let guard = self.conn.lock().ok()?;
83        guard
84            .query_row(
85                &format!(
86                    "SELECT email, code, expires_at, attempts FROM {SQLITE_TABLE} WHERE email = ?1"
87                ),
88                rusqlite::params![email],
89                |row| {
90                    Ok(MagicCode {
91                        email: row.get(0)?,
92                        code: row.get(1)?,
93                        expires_at: row.get::<_, i64>(2)? as u64,
94                        attempts: row.get::<_, i64>(3)? as u32,
95                    })
96                },
97            )
98            .ok()
99    }
100    fn remove(&self, email: &str) {
101        if let Ok(guard) = self.conn.lock() {
102            let _ = guard.execute(
103                &format!("DELETE FROM {SQLITE_TABLE} WHERE email = ?1"),
104                rusqlite::params![email],
105            );
106        }
107    }
108    fn bump_attempts(&self, email: &str) {
109        if let Ok(guard) = self.conn.lock() {
110            let _ = guard.execute(
111                &format!("UPDATE {SQLITE_TABLE} SET attempts = attempts + 1 WHERE email = ?1"),
112                rusqlite::params![email],
113            );
114        }
115    }
116    fn load_all(&self) -> Vec<MagicCode> {
117        let Ok(guard) = self.conn.lock() else {
118            return Vec::new();
119        };
120        let mut stmt = match guard.prepare(&format!(
121            "SELECT email, code, expires_at, attempts FROM {SQLITE_TABLE}"
122        )) {
123            Ok(s) => s,
124            Err(_) => return Vec::new(),
125        };
126        let iter = match stmt.query_map([], |row| {
127            Ok(MagicCode {
128                email: row.get(0)?,
129                code: row.get(1)?,
130                expires_at: row.get::<_, i64>(2)? as u64,
131                attempts: row.get::<_, i64>(3)? as u32,
132            })
133        }) {
134            Ok(i) => i,
135            Err(_) => return Vec::new(),
136        };
137        iter.flatten().collect()
138    }
139}
140
141// ---------------------------------------------------------------------------
142// Postgres backend
143// ---------------------------------------------------------------------------
144
145pub use pg::PostgresMagicCodeBackend;
146
147mod pg {
148    use super::*;
149    use postgres::Client;
150
151    pub struct PostgresMagicCodeBackend {
152        client: Mutex<Client>,
153    }
154
155    impl PostgresMagicCodeBackend {
156        pub fn connect(url: &str) -> Result<Self, String> {
157            let mut client = pylon_storage::postgres::live::connect_pg(url)?;
158            client
159                .batch_execute(&format!(
160                    "CREATE TABLE IF NOT EXISTS {PG_TABLE} (
161                        email TEXT PRIMARY KEY,
162                        code TEXT NOT NULL,
163                        expires_at BIGINT NOT NULL,
164                        attempts INTEGER NOT NULL DEFAULT 0
165                    );
166                    CREATE INDEX IF NOT EXISTS {PG_TABLE}_exp_idx ON {PG_TABLE}(expires_at);"
167                ))
168                .map_err(|e| format!("PG init schema: {e}"))?;
169            Ok(Self {
170                client: Mutex::new(client),
171            })
172        }
173    }
174
175    impl MagicCodeBackend for PostgresMagicCodeBackend {
176        fn put(&self, email: &str, code: &MagicCode) {
177            let Ok(mut c) = self.client.lock() else {
178                return;
179            };
180            let _ = c.execute(
181                &format!(
182                    "INSERT INTO {PG_TABLE} (email, code, expires_at, attempts)
183                     VALUES ($1, $2, $3, $4)
184                     ON CONFLICT (email) DO UPDATE SET
185                       code = EXCLUDED.code,
186                       expires_at = EXCLUDED.expires_at,
187                       attempts = EXCLUDED.attempts"
188                ),
189                &[
190                    &email,
191                    &code.code,
192                    &(code.expires_at as i64),
193                    &(code.attempts as i32),
194                ],
195            );
196        }
197        fn get(&self, email: &str) -> Option<MagicCode> {
198            let mut c = self.client.lock().ok()?;
199            let row = c
200                .query_opt(
201                    &format!(
202                        "SELECT email, code, expires_at, attempts FROM {PG_TABLE} WHERE email = $1"
203                    ),
204                    &[&email],
205                )
206                .ok()??;
207            Some(MagicCode {
208                email: row.get(0),
209                code: row.get(1),
210                expires_at: row.get::<_, i64>(2) as u64,
211                attempts: row.get::<_, i32>(3) as u32,
212            })
213        }
214        fn remove(&self, email: &str) {
215            if let Ok(mut c) = self.client.lock() {
216                let _ = c.execute(
217                    &format!("DELETE FROM {PG_TABLE} WHERE email = $1"),
218                    &[&email],
219                );
220            }
221        }
222        fn bump_attempts(&self, email: &str) {
223            if let Ok(mut c) = self.client.lock() {
224                let _ = c.execute(
225                    &format!("UPDATE {PG_TABLE} SET attempts = attempts + 1 WHERE email = $1"),
226                    &[&email],
227                );
228            }
229        }
230        fn load_all(&self) -> Vec<MagicCode> {
231            let Ok(mut c) = self.client.lock() else {
232                return Vec::new();
233            };
234            let rows = c
235                .query(
236                    &format!("SELECT email, code, expires_at, attempts FROM {PG_TABLE}"),
237                    &[],
238                )
239                .unwrap_or_default();
240            rows.iter()
241                .map(|row| MagicCode {
242                    email: row.get(0),
243                    code: row.get(1),
244                    expires_at: row.get::<_, i64>(2) as u64,
245                    attempts: row.get::<_, i32>(3) as u32,
246                })
247                .collect()
248        }
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    #[test]
257    fn sqlite_roundtrip_put_get_remove() {
258        let b = SqliteMagicCodeBackend::in_memory().unwrap();
259        let mc = MagicCode {
260            email: "a@b.com".into(),
261            code: "123456".into(),
262            expires_at: 9999999999,
263            attempts: 0,
264        };
265        b.put(&mc.email, &mc);
266        let got = b.get(&mc.email).unwrap();
267        assert_eq!(got.code, "123456");
268        b.bump_attempts(&mc.email);
269        assert_eq!(b.get(&mc.email).unwrap().attempts, 1);
270        b.remove(&mc.email);
271        assert!(b.get(&mc.email).is_none());
272    }
273}