1use std::sync::{Arc, Mutex};
6
7use pylon_auth::verification::{TokenKind, VerificationBackend, VerificationToken};
8use rusqlite::Connection;
9
10const SQLITE_TABLE: &str = "_pylon_verification_tokens";
11const PG_TABLE: &str = "_pylon_verification_tokens";
12
13fn kind_to_str(k: TokenKind) -> &'static str {
14 k.as_str()
15}
16
17fn kind_from_str(s: &str) -> TokenKind {
18 match s {
19 "password_reset" => TokenKind::PasswordReset,
20 "email_change" => TokenKind::EmailChange,
21 _ => TokenKind::MagicLink,
22 }
23}
24
25pub struct SqliteVerificationBackend {
30 conn: Arc<Mutex<Connection>>,
31}
32
33impl SqliteVerificationBackend {
34 pub fn open(path: &str) -> Result<Self, String> {
35 let conn = Connection::open(path).map_err(|e| format!("open: {e}"))?;
36 Self::from_connection(conn)
37 }
38 pub fn in_memory() -> Result<Self, String> {
39 let conn = Connection::open_in_memory().map_err(|e| format!("open: {e}"))?;
40 Self::from_connection(conn)
41 }
42 fn from_connection(conn: Connection) -> Result<Self, String> {
43 conn.execute_batch(&format!(
44 "CREATE TABLE IF NOT EXISTS {SQLITE_TABLE} (
45 id TEXT PRIMARY KEY,
46 kind TEXT NOT NULL,
47 email TEXT NOT NULL,
48 user_id TEXT,
49 payload TEXT,
50 token_hash TEXT NOT NULL,
51 token_prefix TEXT NOT NULL,
52 created_at INTEGER NOT NULL,
53 expires_at INTEGER NOT NULL,
54 consumed_at INTEGER
55 );
56 CREATE INDEX IF NOT EXISTS {SQLITE_TABLE}_prefix_idx ON {SQLITE_TABLE}(token_prefix);
57 CREATE INDEX IF NOT EXISTS {SQLITE_TABLE}_exp_idx ON {SQLITE_TABLE}(expires_at);"
58 ))
59 .map_err(|e| format!("init schema: {e}"))?;
60 Ok(Self {
61 conn: Arc::new(Mutex::new(conn)),
62 })
63 }
64}
65
66impl VerificationBackend for SqliteVerificationBackend {
67 fn put(&self, t: &VerificationToken) {
68 if let Ok(c) = self.conn.lock() {
69 let _ = c.execute(
70 &format!(
71 "INSERT INTO {SQLITE_TABLE}
72 (id, kind, email, user_id, payload, token_hash, token_prefix,
73 created_at, expires_at, consumed_at)
74 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
75 ON CONFLICT(id) DO UPDATE SET
76 consumed_at = excluded.consumed_at"
77 ),
78 rusqlite::params![
79 t.id,
80 kind_to_str(t.kind),
81 t.email,
82 t.user_id,
83 t.payload,
84 t.token_hash,
85 t.token_prefix,
86 t.created_at as i64,
87 t.expires_at as i64,
88 t.consumed_at.map(|v| v as i64),
89 ],
90 );
91 }
92 }
93
94 fn get(&self, id: &str) -> Option<VerificationToken> {
95 let c = self.conn.lock().ok()?;
96 c.query_row(
97 &format!(
98 "SELECT id, kind, email, user_id, payload, token_hash, token_prefix,
99 created_at, expires_at, consumed_at
100 FROM {SQLITE_TABLE} WHERE id = ?1"
101 ),
102 rusqlite::params![id],
103 row_to_token,
104 )
105 .ok()
106 }
107
108 fn by_prefix(&self, prefix: &str) -> Vec<VerificationToken> {
109 let Ok(c) = self.conn.lock() else {
110 return vec![];
111 };
112 let mut stmt = match c.prepare(&format!(
113 "SELECT id, kind, email, user_id, payload, token_hash, token_prefix,
114 created_at, expires_at, consumed_at
115 FROM {SQLITE_TABLE} WHERE token_prefix = ?1"
116 )) {
117 Ok(s) => s,
118 Err(_) => return vec![],
119 };
120 let iter = match stmt.query_map(rusqlite::params![prefix], row_to_token) {
121 Ok(it) => it,
122 Err(_) => return vec![],
123 };
124 iter.filter_map(|r| r.ok()).collect()
125 }
126
127 fn mark_consumed(&self, id: &str, now: u64) -> bool {
128 let Ok(c) = self.conn.lock() else {
129 return false;
130 };
131 c.execute(
134 &format!(
135 "UPDATE {SQLITE_TABLE} SET consumed_at = ?2
136 WHERE id = ?1 AND consumed_at IS NULL"
137 ),
138 rusqlite::params![id, now as i64],
139 )
140 .map(|n| n > 0)
141 .unwrap_or(false)
142 }
143
144 fn purge_expired(&self, now: u64) {
145 if let Ok(c) = self.conn.lock() {
146 let _ = c.execute(
147 &format!(
148 "DELETE FROM {SQLITE_TABLE}
149 WHERE expires_at <= ?1 AND consumed_at IS NOT NULL"
150 ),
151 rusqlite::params![now as i64],
152 );
153 }
154 }
155}
156
157fn row_to_token(row: &rusqlite::Row<'_>) -> rusqlite::Result<VerificationToken> {
158 let kind: String = row.get(1)?;
159 Ok(VerificationToken {
160 id: row.get(0)?,
161 kind: kind_from_str(&kind),
162 email: row.get(2)?,
163 user_id: row.get(3)?,
164 payload: row.get(4)?,
165 token_hash: row.get(5)?,
166 token_prefix: row.get(6)?,
167 created_at: row.get::<_, i64>(7)? as u64,
168 expires_at: row.get::<_, i64>(8)? as u64,
169 consumed_at: row.get::<_, Option<i64>>(9)?.map(|v| v as u64),
170 })
171}
172
173pub use pg::PostgresVerificationBackend;
178
179mod pg {
180 use super::*;
181 use postgres::Client;
182
183 pub struct PostgresVerificationBackend {
184 client: Mutex<Client>,
185 }
186
187 impl PostgresVerificationBackend {
188 pub fn connect(url: &str) -> Result<Self, String> {
189 let mut client = pylon_storage::postgres::live::connect_pg(url)?;
190 client
191 .batch_execute(&format!(
192 "CREATE TABLE IF NOT EXISTS {PG_TABLE} (
193 id TEXT PRIMARY KEY,
194 kind TEXT NOT NULL,
195 email TEXT NOT NULL,
196 user_id TEXT,
197 payload TEXT,
198 token_hash TEXT NOT NULL,
199 token_prefix TEXT NOT NULL,
200 created_at BIGINT NOT NULL,
201 expires_at BIGINT NOT NULL,
202 consumed_at BIGINT
203 );
204 CREATE INDEX IF NOT EXISTS {PG_TABLE}_prefix_idx ON {PG_TABLE}(token_prefix);
205 CREATE INDEX IF NOT EXISTS {PG_TABLE}_exp_idx ON {PG_TABLE}(expires_at);"
206 ))
207 .map_err(|e| format!("PG init schema: {e}"))?;
208 Ok(Self {
209 client: Mutex::new(client),
210 })
211 }
212 }
213
214 impl VerificationBackend for PostgresVerificationBackend {
215 fn put(&self, t: &VerificationToken) {
216 if let Ok(mut c) = self.client.lock() {
217 let _ = c.execute(
218 &format!(
219 "INSERT INTO {PG_TABLE}
220 (id, kind, email, user_id, payload, token_hash, token_prefix,
221 created_at, expires_at, consumed_at)
222 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
223 ON CONFLICT (id) DO UPDATE SET consumed_at = EXCLUDED.consumed_at"
224 ),
225 &[
226 &t.id,
227 &kind_to_str(t.kind),
228 &t.email,
229 &t.user_id,
230 &t.payload,
231 &t.token_hash,
232 &t.token_prefix,
233 &(t.created_at as i64),
234 &(t.expires_at as i64),
235 &t.consumed_at.map(|v| v as i64),
236 ],
237 );
238 }
239 }
240
241 fn get(&self, id: &str) -> Option<VerificationToken> {
242 let mut c = self.client.lock().ok()?;
243 let row = c
244 .query_opt(
245 &format!(
246 "SELECT id, kind, email, user_id, payload, token_hash, token_prefix,
247 created_at, expires_at, consumed_at
248 FROM {PG_TABLE} WHERE id = $1"
249 ),
250 &[&id],
251 )
252 .ok()??;
253 Some(pg_row_to_token(&row))
254 }
255
256 fn by_prefix(&self, prefix: &str) -> Vec<VerificationToken> {
257 let Ok(mut c) = self.client.lock() else {
258 return vec![];
259 };
260 let rows = match c.query(
261 &format!(
262 "SELECT id, kind, email, user_id, payload, token_hash, token_prefix,
263 created_at, expires_at, consumed_at
264 FROM {PG_TABLE} WHERE token_prefix = $1"
265 ),
266 &[&prefix],
267 ) {
268 Ok(r) => r,
269 Err(_) => return vec![],
270 };
271 rows.iter().map(pg_row_to_token).collect()
272 }
273
274 fn mark_consumed(&self, id: &str, now: u64) -> bool {
275 let Ok(mut c) = self.client.lock() else {
276 return false;
277 };
278 c.execute(
279 &format!(
280 "UPDATE {PG_TABLE} SET consumed_at = $2
281 WHERE id = $1 AND consumed_at IS NULL"
282 ),
283 &[&id, &(now as i64)],
284 )
285 .map(|n| n > 0)
286 .unwrap_or(false)
287 }
288
289 fn purge_expired(&self, now: u64) {
290 if let Ok(mut c) = self.client.lock() {
291 let _ = c.execute(
292 &format!(
293 "DELETE FROM {PG_TABLE}
294 WHERE expires_at <= $1 AND consumed_at IS NOT NULL"
295 ),
296 &[&(now as i64)],
297 );
298 }
299 }
300 }
301
302 fn pg_row_to_token(row: &postgres::Row) -> VerificationToken {
303 let kind: String = row.get(1);
304 VerificationToken {
305 id: row.get(0),
306 kind: kind_from_str(&kind),
307 email: row.get(2),
308 user_id: row.get(3),
309 payload: row.get(4),
310 token_hash: row.get(5),
311 token_prefix: row.get(6),
312 created_at: row.get::<_, i64>(7) as u64,
313 expires_at: row.get::<_, i64>(8) as u64,
314 consumed_at: row.get::<_, Option<i64>>(9).map(|v| v as u64),
315 }
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322 use pylon_auth::verification::{TokenKind, VerificationToken};
323
324 #[test]
325 fn sqlite_round_trip() {
326 let b = SqliteVerificationBackend::in_memory().unwrap();
327 let t = VerificationToken {
328 id: "vt_x".into(),
329 kind: TokenKind::PasswordReset,
330 email: "a@b.com".into(),
331 user_id: None,
332 payload: None,
333 token_hash: "h".into(),
334 token_prefix: "abcd1234".into(),
335 created_at: 100,
336 expires_at: 9_999_999_999,
337 consumed_at: None,
338 };
339 b.put(&t);
340 assert_eq!(b.get("vt_x").unwrap().email, "a@b.com");
341 assert_eq!(b.by_prefix("abcd1234").len(), 1);
342 assert_eq!(b.by_prefix("nope0000").len(), 0);
343 assert!(b.mark_consumed("vt_x", 200));
345 assert!(!b.mark_consumed("vt_x", 300)); assert_eq!(b.get("vt_x").unwrap().consumed_at, Some(200));
347 }
348
349 #[test]
350 fn purge_drops_expired_consumed() {
351 let b = SqliteVerificationBackend::in_memory().unwrap();
352 b.put(&VerificationToken {
353 id: "vt_done".into(),
354 kind: TokenKind::MagicLink,
355 email: "a@b.com".into(),
356 user_id: None,
357 payload: None,
358 token_hash: "h".into(),
359 token_prefix: "p".into(),
360 created_at: 1,
361 expires_at: 2,
362 consumed_at: Some(2),
363 });
364 b.purge_expired(100);
365 assert!(b.get("vt_done").is_none());
366 }
367}