Skip to main content

pylon_runtime/
verification_backend.rs

1//! Persistent verification-token stores for password reset / email
2//! change / magic links. Schema is identical to the in-memory shape;
3//! `token_prefix` is indexed so consume-by-plaintext is fast.
4
5use std::sync::{Arc, Mutex};
6
7use pylon_auth::verification::{TokenKind, VerificationBackend, VerificationToken};
8use rusqlite::Connection;
9
10const SQLITE_TABLE: &str = "_pylon_verification_tokens";
11const PG_TABLE: &str = "_pylon_verification_tokens";
12
13fn kind_to_str(k: TokenKind) -> &'static str {
14    k.as_str()
15}
16
17fn kind_from_str(s: &str) -> TokenKind {
18    match s {
19        "password_reset" => TokenKind::PasswordReset,
20        "email_change" => TokenKind::EmailChange,
21        _ => TokenKind::MagicLink,
22    }
23}
24
25// ---------------------------------------------------------------------------
26// SQLite
27// ---------------------------------------------------------------------------
28
29pub struct SqliteVerificationBackend {
30    conn: Arc<Mutex<Connection>>,
31}
32
33impl SqliteVerificationBackend {
34    pub fn open(path: &str) -> Result<Self, String> {
35        let conn = Connection::open(path).map_err(|e| format!("open: {e}"))?;
36        Self::from_connection(conn)
37    }
38    pub fn in_memory() -> Result<Self, String> {
39        let conn = Connection::open_in_memory().map_err(|e| format!("open: {e}"))?;
40        Self::from_connection(conn)
41    }
42    fn from_connection(conn: Connection) -> Result<Self, String> {
43        conn.execute_batch(&format!(
44            "CREATE TABLE IF NOT EXISTS {SQLITE_TABLE} (
45                id TEXT PRIMARY KEY,
46                kind TEXT NOT NULL,
47                email TEXT NOT NULL,
48                user_id TEXT,
49                payload TEXT,
50                token_hash TEXT NOT NULL,
51                token_prefix TEXT NOT NULL,
52                created_at INTEGER NOT NULL,
53                expires_at INTEGER NOT NULL,
54                consumed_at INTEGER
55            );
56            CREATE INDEX IF NOT EXISTS {SQLITE_TABLE}_prefix_idx ON {SQLITE_TABLE}(token_prefix);
57            CREATE INDEX IF NOT EXISTS {SQLITE_TABLE}_exp_idx ON {SQLITE_TABLE}(expires_at);"
58        ))
59        .map_err(|e| format!("init schema: {e}"))?;
60        Ok(Self {
61            conn: Arc::new(Mutex::new(conn)),
62        })
63    }
64}
65
66impl VerificationBackend for SqliteVerificationBackend {
67    fn put(&self, t: &VerificationToken) {
68        if let Ok(c) = self.conn.lock() {
69            let _ = c.execute(
70                &format!(
71                    "INSERT INTO {SQLITE_TABLE}
72                       (id, kind, email, user_id, payload, token_hash, token_prefix,
73                        created_at, expires_at, consumed_at)
74                     VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
75                     ON CONFLICT(id) DO UPDATE SET
76                       consumed_at = excluded.consumed_at"
77                ),
78                rusqlite::params![
79                    t.id,
80                    kind_to_str(t.kind),
81                    t.email,
82                    t.user_id,
83                    t.payload,
84                    t.token_hash,
85                    t.token_prefix,
86                    t.created_at as i64,
87                    t.expires_at as i64,
88                    t.consumed_at.map(|v| v as i64),
89                ],
90            );
91        }
92    }
93
94    fn get(&self, id: &str) -> Option<VerificationToken> {
95        let c = self.conn.lock().ok()?;
96        c.query_row(
97            &format!(
98                "SELECT id, kind, email, user_id, payload, token_hash, token_prefix,
99                        created_at, expires_at, consumed_at
100                 FROM {SQLITE_TABLE} WHERE id = ?1"
101            ),
102            rusqlite::params![id],
103            row_to_token,
104        )
105        .ok()
106    }
107
108    fn by_prefix(&self, prefix: &str) -> Vec<VerificationToken> {
109        let Ok(c) = self.conn.lock() else {
110            return vec![];
111        };
112        let mut stmt = match c.prepare(&format!(
113            "SELECT id, kind, email, user_id, payload, token_hash, token_prefix,
114                    created_at, expires_at, consumed_at
115             FROM {SQLITE_TABLE} WHERE token_prefix = ?1"
116        )) {
117            Ok(s) => s,
118            Err(_) => return vec![],
119        };
120        let iter = match stmt.query_map(rusqlite::params![prefix], row_to_token) {
121            Ok(it) => it,
122            Err(_) => return vec![],
123        };
124        iter.filter_map(|r| r.ok()).collect()
125    }
126
127    fn mark_consumed(&self, id: &str, now: u64) -> bool {
128        let Ok(c) = self.conn.lock() else {
129            return false;
130        };
131        // CAS via SQL: only update when consumed_at IS NULL. The
132        // affected-row count tells us whether we won the race.
133        c.execute(
134            &format!(
135                "UPDATE {SQLITE_TABLE} SET consumed_at = ?2
136                 WHERE id = ?1 AND consumed_at IS NULL"
137            ),
138            rusqlite::params![id, now as i64],
139        )
140        .map(|n| n > 0)
141        .unwrap_or(false)
142    }
143
144    fn purge_expired(&self, now: u64) {
145        if let Ok(c) = self.conn.lock() {
146            let _ = c.execute(
147                &format!(
148                    "DELETE FROM {SQLITE_TABLE}
149                     WHERE expires_at <= ?1 AND consumed_at IS NOT NULL"
150                ),
151                rusqlite::params![now as i64],
152            );
153        }
154    }
155}
156
157fn row_to_token(row: &rusqlite::Row<'_>) -> rusqlite::Result<VerificationToken> {
158    let kind: String = row.get(1)?;
159    Ok(VerificationToken {
160        id: row.get(0)?,
161        kind: kind_from_str(&kind),
162        email: row.get(2)?,
163        user_id: row.get(3)?,
164        payload: row.get(4)?,
165        token_hash: row.get(5)?,
166        token_prefix: row.get(6)?,
167        created_at: row.get::<_, i64>(7)? as u64,
168        expires_at: row.get::<_, i64>(8)? as u64,
169        consumed_at: row.get::<_, Option<i64>>(9)?.map(|v| v as u64),
170    })
171}
172
173// ---------------------------------------------------------------------------
174// Postgres
175// ---------------------------------------------------------------------------
176
177pub use pg::PostgresVerificationBackend;
178
179mod pg {
180    use super::*;
181    use postgres::Client;
182
183    pub struct PostgresVerificationBackend {
184        client: Mutex<Client>,
185    }
186
187    impl PostgresVerificationBackend {
188        pub fn connect(url: &str) -> Result<Self, String> {
189            let mut client = pylon_storage::postgres::live::connect_pg(url)?;
190            client
191                .batch_execute(&format!(
192                    "CREATE TABLE IF NOT EXISTS {PG_TABLE} (
193                        id TEXT PRIMARY KEY,
194                        kind TEXT NOT NULL,
195                        email TEXT NOT NULL,
196                        user_id TEXT,
197                        payload TEXT,
198                        token_hash TEXT NOT NULL,
199                        token_prefix TEXT NOT NULL,
200                        created_at BIGINT NOT NULL,
201                        expires_at BIGINT NOT NULL,
202                        consumed_at BIGINT
203                    );
204                    CREATE INDEX IF NOT EXISTS {PG_TABLE}_prefix_idx ON {PG_TABLE}(token_prefix);
205                    CREATE INDEX IF NOT EXISTS {PG_TABLE}_exp_idx ON {PG_TABLE}(expires_at);"
206                ))
207                .map_err(|e| format!("PG init schema: {e}"))?;
208            Ok(Self {
209                client: Mutex::new(client),
210            })
211        }
212    }
213
214    impl VerificationBackend for PostgresVerificationBackend {
215        fn put(&self, t: &VerificationToken) {
216            if let Ok(mut c) = self.client.lock() {
217                let _ = c.execute(
218                    &format!(
219                        "INSERT INTO {PG_TABLE}
220                           (id, kind, email, user_id, payload, token_hash, token_prefix,
221                            created_at, expires_at, consumed_at)
222                         VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
223                         ON CONFLICT (id) DO UPDATE SET consumed_at = EXCLUDED.consumed_at"
224                    ),
225                    &[
226                        &t.id,
227                        &kind_to_str(t.kind),
228                        &t.email,
229                        &t.user_id,
230                        &t.payload,
231                        &t.token_hash,
232                        &t.token_prefix,
233                        &(t.created_at as i64),
234                        &(t.expires_at as i64),
235                        &t.consumed_at.map(|v| v as i64),
236                    ],
237                );
238            }
239        }
240
241        fn get(&self, id: &str) -> Option<VerificationToken> {
242            let mut c = self.client.lock().ok()?;
243            let row = c
244                .query_opt(
245                    &format!(
246                        "SELECT id, kind, email, user_id, payload, token_hash, token_prefix,
247                                created_at, expires_at, consumed_at
248                         FROM {PG_TABLE} WHERE id = $1"
249                    ),
250                    &[&id],
251                )
252                .ok()??;
253            Some(pg_row_to_token(&row))
254        }
255
256        fn by_prefix(&self, prefix: &str) -> Vec<VerificationToken> {
257            let Ok(mut c) = self.client.lock() else {
258                return vec![];
259            };
260            let rows = match c.query(
261                &format!(
262                    "SELECT id, kind, email, user_id, payload, token_hash, token_prefix,
263                            created_at, expires_at, consumed_at
264                     FROM {PG_TABLE} WHERE token_prefix = $1"
265                ),
266                &[&prefix],
267            ) {
268                Ok(r) => r,
269                Err(_) => return vec![],
270            };
271            rows.iter().map(pg_row_to_token).collect()
272        }
273
274        fn mark_consumed(&self, id: &str, now: u64) -> bool {
275            let Ok(mut c) = self.client.lock() else {
276                return false;
277            };
278            c.execute(
279                &format!(
280                    "UPDATE {PG_TABLE} SET consumed_at = $2
281                     WHERE id = $1 AND consumed_at IS NULL"
282                ),
283                &[&id, &(now as i64)],
284            )
285            .map(|n| n > 0)
286            .unwrap_or(false)
287        }
288
289        fn purge_expired(&self, now: u64) {
290            if let Ok(mut c) = self.client.lock() {
291                let _ = c.execute(
292                    &format!(
293                        "DELETE FROM {PG_TABLE}
294                         WHERE expires_at <= $1 AND consumed_at IS NOT NULL"
295                    ),
296                    &[&(now as i64)],
297                );
298            }
299        }
300    }
301
302    fn pg_row_to_token(row: &postgres::Row) -> VerificationToken {
303        let kind: String = row.get(1);
304        VerificationToken {
305            id: row.get(0),
306            kind: kind_from_str(&kind),
307            email: row.get(2),
308            user_id: row.get(3),
309            payload: row.get(4),
310            token_hash: row.get(5),
311            token_prefix: row.get(6),
312            created_at: row.get::<_, i64>(7) as u64,
313            expires_at: row.get::<_, i64>(8) as u64,
314            consumed_at: row.get::<_, Option<i64>>(9).map(|v| v as u64),
315        }
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322    use pylon_auth::verification::{TokenKind, VerificationToken};
323
324    #[test]
325    fn sqlite_round_trip() {
326        let b = SqliteVerificationBackend::in_memory().unwrap();
327        let t = VerificationToken {
328            id: "vt_x".into(),
329            kind: TokenKind::PasswordReset,
330            email: "a@b.com".into(),
331            user_id: None,
332            payload: None,
333            token_hash: "h".into(),
334            token_prefix: "abcd1234".into(),
335            created_at: 100,
336            expires_at: 9_999_999_999,
337            consumed_at: None,
338        };
339        b.put(&t);
340        assert_eq!(b.get("vt_x").unwrap().email, "a@b.com");
341        assert_eq!(b.by_prefix("abcd1234").len(), 1);
342        assert_eq!(b.by_prefix("nope0000").len(), 0);
343        // mark_consumed CAS
344        assert!(b.mark_consumed("vt_x", 200));
345        assert!(!b.mark_consumed("vt_x", 300)); // second attempt loses
346        assert_eq!(b.get("vt_x").unwrap().consumed_at, Some(200));
347    }
348
349    #[test]
350    fn purge_drops_expired_consumed() {
351        let b = SqliteVerificationBackend::in_memory().unwrap();
352        b.put(&VerificationToken {
353            id: "vt_done".into(),
354            kind: TokenKind::MagicLink,
355            email: "a@b.com".into(),
356            user_id: None,
357            payload: None,
358            token_hash: "h".into(),
359            token_prefix: "p".into(),
360            created_at: 1,
361            expires_at: 2,
362            consumed_at: Some(2),
363        });
364        b.purge_expired(100);
365        assert!(b.get("vt_done").is_none());
366    }
367}