Skip to main content

pylon_runtime/
verification_backend.rs

1//! Persistent verification-token stores for password reset / email
2//! change / magic links. Schema is identical to the in-memory shape;
3//! `token_prefix` is indexed so consume-by-plaintext is fast.
4
5use std::sync::{Arc, Mutex};
6
7use pylon_auth::verification::{TokenKind, VerificationBackend, VerificationToken};
8use rusqlite::Connection;
9
10const SQLITE_TABLE: &str = "_pylon_verification_tokens";
11const PG_TABLE: &str = "_pylon_verification_tokens";
12
13fn kind_to_str(k: TokenKind) -> &'static str {
14    k.as_str()
15}
16
17/// Parse a kind value from the DB. Returns `Err` for unknown values
18/// rather than silently defaulting — Wave-6 codex P3: a corrupted
19/// row shouldn't be silently re-categorized as a magic-link token,
20/// because that would let a stale password-reset row bypass its
21/// kind check.
22fn kind_from_str(s: &str) -> Result<TokenKind, String> {
23    match s {
24        "password_reset" => Ok(TokenKind::PasswordReset),
25        "email_change" => Ok(TokenKind::EmailChange),
26        "magic_link" => Ok(TokenKind::MagicLink),
27        other => Err(format!("verification: unknown kind '{other}'")),
28    }
29}
30
31// ---------------------------------------------------------------------------
32// SQLite
33// ---------------------------------------------------------------------------
34
35pub struct SqliteVerificationBackend {
36    conn: Arc<Mutex<Connection>>,
37}
38
39impl SqliteVerificationBackend {
40    pub fn open(path: &str) -> Result<Self, String> {
41        let conn = Connection::open(path).map_err(|e| format!("open: {e}"))?;
42        Self::from_connection(conn)
43    }
44    pub fn in_memory() -> Result<Self, String> {
45        let conn = Connection::open_in_memory().map_err(|e| format!("open: {e}"))?;
46        Self::from_connection(conn)
47    }
48    fn from_connection(conn: Connection) -> Result<Self, String> {
49        conn.execute_batch(&format!(
50            "CREATE TABLE IF NOT EXISTS {SQLITE_TABLE} (
51                id TEXT PRIMARY KEY,
52                kind TEXT NOT NULL,
53                email TEXT NOT NULL,
54                user_id TEXT,
55                payload TEXT,
56                token_hash TEXT NOT NULL,
57                token_prefix TEXT NOT NULL,
58                created_at INTEGER NOT NULL,
59                expires_at INTEGER NOT NULL,
60                consumed_at INTEGER
61            );
62            CREATE INDEX IF NOT EXISTS {SQLITE_TABLE}_prefix_idx ON {SQLITE_TABLE}(token_prefix);
63            CREATE INDEX IF NOT EXISTS {SQLITE_TABLE}_exp_idx ON {SQLITE_TABLE}(expires_at);"
64        ))
65        .map_err(|e| format!("init schema: {e}"))?;
66        Ok(Self {
67            conn: Arc::new(Mutex::new(conn)),
68        })
69    }
70}
71
72impl VerificationBackend for SqliteVerificationBackend {
73    fn put(&self, t: &VerificationToken) {
74        if let Ok(c) = self.conn.lock() {
75            let _ = c.execute(
76                &format!(
77                    "INSERT INTO {SQLITE_TABLE}
78                       (id, kind, email, user_id, payload, token_hash, token_prefix,
79                        created_at, expires_at, consumed_at)
80                     VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
81                     ON CONFLICT(id) DO UPDATE SET
82                       consumed_at = excluded.consumed_at"
83                ),
84                rusqlite::params![
85                    t.id,
86                    kind_to_str(t.kind),
87                    t.email,
88                    t.user_id,
89                    t.payload,
90                    t.token_hash,
91                    t.token_prefix,
92                    t.created_at as i64,
93                    t.expires_at as i64,
94                    t.consumed_at.map(|v| v as i64),
95                ],
96            );
97        }
98    }
99
100    fn get(&self, id: &str) -> Option<VerificationToken> {
101        let c = self.conn.lock().ok()?;
102        c.query_row(
103            &format!(
104                "SELECT id, kind, email, user_id, payload, token_hash, token_prefix,
105                        created_at, expires_at, consumed_at
106                 FROM {SQLITE_TABLE} WHERE id = ?1"
107            ),
108            rusqlite::params![id],
109            row_to_token,
110        )
111        .ok()
112    }
113
114    fn by_prefix(&self, prefix: &str) -> Vec<VerificationToken> {
115        let Ok(c) = self.conn.lock() else {
116            return vec![];
117        };
118        let mut stmt = match c.prepare(&format!(
119            "SELECT id, kind, email, user_id, payload, token_hash, token_prefix,
120                    created_at, expires_at, consumed_at
121             FROM {SQLITE_TABLE} WHERE token_prefix = ?1"
122        )) {
123            Ok(s) => s,
124            Err(_) => return vec![],
125        };
126        let iter = match stmt.query_map(rusqlite::params![prefix], row_to_token) {
127            Ok(it) => it,
128            Err(_) => return vec![],
129        };
130        iter.filter_map(|r| r.ok()).collect()
131    }
132
133    fn mark_consumed(&self, id: &str, now: u64) -> bool {
134        let Ok(c) = self.conn.lock() else {
135            return false;
136        };
137        // CAS via SQL: only update when consumed_at IS NULL. The
138        // affected-row count tells us whether we won the race.
139        c.execute(
140            &format!(
141                "UPDATE {SQLITE_TABLE} SET consumed_at = ?2
142                 WHERE id = ?1 AND consumed_at IS NULL"
143            ),
144            rusqlite::params![id, now as i64],
145        )
146        .map(|n| n > 0)
147        .unwrap_or(false)
148    }
149
150    fn purge_expired(&self, now: u64) {
151        if let Ok(c) = self.conn.lock() {
152            let _ = c.execute(
153                &format!(
154                    "DELETE FROM {SQLITE_TABLE}
155                     WHERE expires_at <= ?1 AND consumed_at IS NOT NULL"
156                ),
157                rusqlite::params![now as i64],
158            );
159        }
160    }
161}
162
163fn row_to_token(row: &rusqlite::Row<'_>) -> rusqlite::Result<VerificationToken> {
164    let kind_raw: String = row.get(1)?;
165    let kind = kind_from_str(&kind_raw)
166        .map_err(|e| rusqlite::Error::InvalidColumnType(1, e, rusqlite::types::Type::Text))?;
167    Ok(VerificationToken {
168        id: row.get(0)?,
169        kind,
170        email: row.get(2)?,
171        user_id: row.get(3)?,
172        payload: row.get(4)?,
173        token_hash: row.get(5)?,
174        token_prefix: row.get(6)?,
175        created_at: row.get::<_, i64>(7)? as u64,
176        expires_at: row.get::<_, i64>(8)? as u64,
177        consumed_at: row.get::<_, Option<i64>>(9)?.map(|v| v as u64),
178    })
179}
180
181// ---------------------------------------------------------------------------
182// Postgres
183// ---------------------------------------------------------------------------
184
185pub use pg::PostgresVerificationBackend;
186
187mod pg {
188    use super::*;
189    use postgres::Client;
190
191    pub struct PostgresVerificationBackend {
192        client: Mutex<Client>,
193    }
194
195    impl PostgresVerificationBackend {
196        pub fn connect(url: &str) -> Result<Self, String> {
197            let mut client = pylon_storage::postgres::live::connect_pg(url)?;
198            client
199                .batch_execute(&format!(
200                    "CREATE TABLE IF NOT EXISTS {PG_TABLE} (
201                        id TEXT PRIMARY KEY,
202                        kind TEXT NOT NULL,
203                        email TEXT NOT NULL,
204                        user_id TEXT,
205                        payload TEXT,
206                        token_hash TEXT NOT NULL,
207                        token_prefix TEXT NOT NULL,
208                        created_at BIGINT NOT NULL,
209                        expires_at BIGINT NOT NULL,
210                        consumed_at BIGINT
211                    );
212                    CREATE INDEX IF NOT EXISTS {PG_TABLE}_prefix_idx ON {PG_TABLE}(token_prefix);
213                    CREATE INDEX IF NOT EXISTS {PG_TABLE}_exp_idx ON {PG_TABLE}(expires_at);"
214                ))
215                .map_err(|e| format!("PG init schema: {e}"))?;
216            Ok(Self {
217                client: Mutex::new(client),
218            })
219        }
220    }
221
222    impl VerificationBackend for PostgresVerificationBackend {
223        fn put(&self, t: &VerificationToken) {
224            if let Ok(mut c) = self.client.lock() {
225                let _ = c.execute(
226                    &format!(
227                        "INSERT INTO {PG_TABLE}
228                           (id, kind, email, user_id, payload, token_hash, token_prefix,
229                            created_at, expires_at, consumed_at)
230                         VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
231                         ON CONFLICT (id) DO UPDATE SET consumed_at = EXCLUDED.consumed_at"
232                    ),
233                    &[
234                        &t.id,
235                        &kind_to_str(t.kind),
236                        &t.email,
237                        &t.user_id,
238                        &t.payload,
239                        &t.token_hash,
240                        &t.token_prefix,
241                        &(t.created_at as i64),
242                        &(t.expires_at as i64),
243                        &t.consumed_at.map(|v| v as i64),
244                    ],
245                );
246            }
247        }
248
249        fn get(&self, id: &str) -> Option<VerificationToken> {
250            let mut c = self.client.lock().ok()?;
251            let row = c
252                .query_opt(
253                    &format!(
254                        "SELECT id, kind, email, user_id, payload, token_hash, token_prefix,
255                                created_at, expires_at, consumed_at
256                         FROM {PG_TABLE} WHERE id = $1"
257                    ),
258                    &[&id],
259                )
260                .ok()??;
261            pg_row_to_token(&row)
262        }
263
264        fn by_prefix(&self, prefix: &str) -> Vec<VerificationToken> {
265            let Ok(mut c) = self.client.lock() else {
266                return vec![];
267            };
268            let rows = match c.query(
269                &format!(
270                    "SELECT id, kind, email, user_id, payload, token_hash, token_prefix,
271                            created_at, expires_at, consumed_at
272                     FROM {PG_TABLE} WHERE token_prefix = $1"
273                ),
274                &[&prefix],
275            ) {
276                Ok(r) => r,
277                Err(_) => return vec![],
278            };
279            rows.iter().filter_map(pg_row_to_token).collect()
280        }
281
282        fn mark_consumed(&self, id: &str, now: u64) -> bool {
283            let Ok(mut c) = self.client.lock() else {
284                return false;
285            };
286            c.execute(
287                &format!(
288                    "UPDATE {PG_TABLE} SET consumed_at = $2
289                     WHERE id = $1 AND consumed_at IS NULL"
290                ),
291                &[&id, &(now as i64)],
292            )
293            .map(|n| n > 0)
294            .unwrap_or(false)
295        }
296
297        fn purge_expired(&self, now: u64) {
298            if let Ok(mut c) = self.client.lock() {
299                let _ = c.execute(
300                    &format!(
301                        "DELETE FROM {PG_TABLE}
302                         WHERE expires_at <= $1 AND consumed_at IS NOT NULL"
303                    ),
304                    &[&(now as i64)],
305                );
306            }
307        }
308    }
309
310    fn pg_row_to_token(row: &postgres::Row) -> Option<VerificationToken> {
311        let kind_raw: String = row.get(1);
312        let kind = kind_from_str(&kind_raw).ok()?;
313        Some(VerificationToken {
314            id: row.get(0),
315            kind,
316            email: row.get(2),
317            user_id: row.get(3),
318            payload: row.get(4),
319            token_hash: row.get(5),
320            token_prefix: row.get(6),
321            created_at: row.get::<_, i64>(7) as u64,
322            expires_at: row.get::<_, i64>(8) as u64,
323            consumed_at: row.get::<_, Option<i64>>(9).map(|v| v as u64),
324        })
325    }
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331    use pylon_auth::verification::{TokenKind, VerificationToken};
332
333    #[test]
334    fn sqlite_round_trip() {
335        let b = SqliteVerificationBackend::in_memory().unwrap();
336        let t = VerificationToken {
337            id: "vt_x".into(),
338            kind: TokenKind::PasswordReset,
339            email: "a@b.com".into(),
340            user_id: None,
341            payload: None,
342            token_hash: "h".into(),
343            token_prefix: "abcd1234".into(),
344            created_at: 100,
345            expires_at: 9_999_999_999,
346            consumed_at: None,
347        };
348        b.put(&t);
349        assert_eq!(b.get("vt_x").unwrap().email, "a@b.com");
350        assert_eq!(b.by_prefix("abcd1234").len(), 1);
351        assert_eq!(b.by_prefix("nope0000").len(), 0);
352        // mark_consumed CAS
353        assert!(b.mark_consumed("vt_x", 200));
354        assert!(!b.mark_consumed("vt_x", 300)); // second attempt loses
355        assert_eq!(b.get("vt_x").unwrap().consumed_at, Some(200));
356    }
357
358    #[test]
359    fn purge_drops_expired_consumed() {
360        let b = SqliteVerificationBackend::in_memory().unwrap();
361        b.put(&VerificationToken {
362            id: "vt_done".into(),
363            kind: TokenKind::MagicLink,
364            email: "a@b.com".into(),
365            user_id: None,
366            payload: None,
367            token_hash: "h".into(),
368            token_prefix: "p".into(),
369            created_at: 1,
370            expires_at: 2,
371            consumed_at: Some(2),
372        });
373        b.purge_expired(100);
374        assert!(b.get("vt_done").is_none());
375    }
376}