1use std::sync::{Arc, Mutex};
6
7use pylon_auth::verification::{TokenKind, VerificationBackend, VerificationToken};
8use rusqlite::Connection;
9
10const SQLITE_TABLE: &str = "_pylon_verification_tokens";
11const PG_TABLE: &str = "_pylon_verification_tokens";
12
13fn kind_to_str(k: TokenKind) -> &'static str {
14 k.as_str()
15}
16
17fn kind_from_str(s: &str) -> Result<TokenKind, String> {
23 match s {
24 "password_reset" => Ok(TokenKind::PasswordReset),
25 "email_change" => Ok(TokenKind::EmailChange),
26 "magic_link" => Ok(TokenKind::MagicLink),
27 other => Err(format!("verification: unknown kind '{other}'")),
28 }
29}
30
31pub struct SqliteVerificationBackend {
36 conn: Arc<Mutex<Connection>>,
37}
38
39impl SqliteVerificationBackend {
40 pub fn open(path: &str) -> Result<Self, String> {
41 let conn = Connection::open(path).map_err(|e| format!("open: {e}"))?;
42 Self::from_connection(conn)
43 }
44 pub fn in_memory() -> Result<Self, String> {
45 let conn = Connection::open_in_memory().map_err(|e| format!("open: {e}"))?;
46 Self::from_connection(conn)
47 }
48 fn from_connection(conn: Connection) -> Result<Self, String> {
49 conn.execute_batch(&format!(
50 "CREATE TABLE IF NOT EXISTS {SQLITE_TABLE} (
51 id TEXT PRIMARY KEY,
52 kind TEXT NOT NULL,
53 email TEXT NOT NULL,
54 user_id TEXT,
55 payload TEXT,
56 token_hash TEXT NOT NULL,
57 token_prefix TEXT NOT NULL,
58 created_at INTEGER NOT NULL,
59 expires_at INTEGER NOT NULL,
60 consumed_at INTEGER
61 );
62 CREATE INDEX IF NOT EXISTS {SQLITE_TABLE}_prefix_idx ON {SQLITE_TABLE}(token_prefix);
63 CREATE INDEX IF NOT EXISTS {SQLITE_TABLE}_exp_idx ON {SQLITE_TABLE}(expires_at);"
64 ))
65 .map_err(|e| format!("init schema: {e}"))?;
66 Ok(Self {
67 conn: Arc::new(Mutex::new(conn)),
68 })
69 }
70}
71
72impl VerificationBackend for SqliteVerificationBackend {
73 fn put(&self, t: &VerificationToken) {
74 if let Ok(c) = self.conn.lock() {
75 let _ = c.execute(
76 &format!(
77 "INSERT INTO {SQLITE_TABLE}
78 (id, kind, email, user_id, payload, token_hash, token_prefix,
79 created_at, expires_at, consumed_at)
80 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
81 ON CONFLICT(id) DO UPDATE SET
82 consumed_at = excluded.consumed_at"
83 ),
84 rusqlite::params![
85 t.id,
86 kind_to_str(t.kind),
87 t.email,
88 t.user_id,
89 t.payload,
90 t.token_hash,
91 t.token_prefix,
92 t.created_at as i64,
93 t.expires_at as i64,
94 t.consumed_at.map(|v| v as i64),
95 ],
96 );
97 }
98 }
99
100 fn get(&self, id: &str) -> Option<VerificationToken> {
101 let c = self.conn.lock().ok()?;
102 c.query_row(
103 &format!(
104 "SELECT id, kind, email, user_id, payload, token_hash, token_prefix,
105 created_at, expires_at, consumed_at
106 FROM {SQLITE_TABLE} WHERE id = ?1"
107 ),
108 rusqlite::params![id],
109 row_to_token,
110 )
111 .ok()
112 }
113
114 fn by_prefix(&self, prefix: &str) -> Vec<VerificationToken> {
115 let Ok(c) = self.conn.lock() else {
116 return vec![];
117 };
118 let mut stmt = match c.prepare(&format!(
119 "SELECT id, kind, email, user_id, payload, token_hash, token_prefix,
120 created_at, expires_at, consumed_at
121 FROM {SQLITE_TABLE} WHERE token_prefix = ?1"
122 )) {
123 Ok(s) => s,
124 Err(_) => return vec![],
125 };
126 let iter = match stmt.query_map(rusqlite::params![prefix], row_to_token) {
127 Ok(it) => it,
128 Err(_) => return vec![],
129 };
130 iter.filter_map(|r| r.ok()).collect()
131 }
132
133 fn mark_consumed(&self, id: &str, now: u64) -> bool {
134 let Ok(c) = self.conn.lock() else {
135 return false;
136 };
137 c.execute(
140 &format!(
141 "UPDATE {SQLITE_TABLE} SET consumed_at = ?2
142 WHERE id = ?1 AND consumed_at IS NULL"
143 ),
144 rusqlite::params![id, now as i64],
145 )
146 .map(|n| n > 0)
147 .unwrap_or(false)
148 }
149
150 fn purge_expired(&self, now: u64) {
151 if let Ok(c) = self.conn.lock() {
152 let _ = c.execute(
153 &format!(
154 "DELETE FROM {SQLITE_TABLE}
155 WHERE expires_at <= ?1 AND consumed_at IS NOT NULL"
156 ),
157 rusqlite::params![now as i64],
158 );
159 }
160 }
161}
162
163fn row_to_token(row: &rusqlite::Row<'_>) -> rusqlite::Result<VerificationToken> {
164 let kind_raw: String = row.get(1)?;
165 let kind = kind_from_str(&kind_raw)
166 .map_err(|e| rusqlite::Error::InvalidColumnType(1, e, rusqlite::types::Type::Text))?;
167 Ok(VerificationToken {
168 id: row.get(0)?,
169 kind,
170 email: row.get(2)?,
171 user_id: row.get(3)?,
172 payload: row.get(4)?,
173 token_hash: row.get(5)?,
174 token_prefix: row.get(6)?,
175 created_at: row.get::<_, i64>(7)? as u64,
176 expires_at: row.get::<_, i64>(8)? as u64,
177 consumed_at: row.get::<_, Option<i64>>(9)?.map(|v| v as u64),
178 })
179}
180
181pub use pg::PostgresVerificationBackend;
186
187mod pg {
188 use super::*;
189 use postgres::Client;
190
191 pub struct PostgresVerificationBackend {
192 client: Mutex<Client>,
193 }
194
195 impl PostgresVerificationBackend {
196 pub fn connect(url: &str) -> Result<Self, String> {
197 let mut client = pylon_storage::postgres::live::connect_pg(url)?;
198 client
199 .batch_execute(&format!(
200 "CREATE TABLE IF NOT EXISTS {PG_TABLE} (
201 id TEXT PRIMARY KEY,
202 kind TEXT NOT NULL,
203 email TEXT NOT NULL,
204 user_id TEXT,
205 payload TEXT,
206 token_hash TEXT NOT NULL,
207 token_prefix TEXT NOT NULL,
208 created_at BIGINT NOT NULL,
209 expires_at BIGINT NOT NULL,
210 consumed_at BIGINT
211 );
212 CREATE INDEX IF NOT EXISTS {PG_TABLE}_prefix_idx ON {PG_TABLE}(token_prefix);
213 CREATE INDEX IF NOT EXISTS {PG_TABLE}_exp_idx ON {PG_TABLE}(expires_at);"
214 ))
215 .map_err(|e| format!("PG init schema: {e}"))?;
216 Ok(Self {
217 client: Mutex::new(client),
218 })
219 }
220 }
221
222 impl VerificationBackend for PostgresVerificationBackend {
223 fn put(&self, t: &VerificationToken) {
224 if let Ok(mut c) = self.client.lock() {
225 let _ = c.execute(
226 &format!(
227 "INSERT INTO {PG_TABLE}
228 (id, kind, email, user_id, payload, token_hash, token_prefix,
229 created_at, expires_at, consumed_at)
230 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
231 ON CONFLICT (id) DO UPDATE SET consumed_at = EXCLUDED.consumed_at"
232 ),
233 &[
234 &t.id,
235 &kind_to_str(t.kind),
236 &t.email,
237 &t.user_id,
238 &t.payload,
239 &t.token_hash,
240 &t.token_prefix,
241 &(t.created_at as i64),
242 &(t.expires_at as i64),
243 &t.consumed_at.map(|v| v as i64),
244 ],
245 );
246 }
247 }
248
249 fn get(&self, id: &str) -> Option<VerificationToken> {
250 let mut c = self.client.lock().ok()?;
251 let row = c
252 .query_opt(
253 &format!(
254 "SELECT id, kind, email, user_id, payload, token_hash, token_prefix,
255 created_at, expires_at, consumed_at
256 FROM {PG_TABLE} WHERE id = $1"
257 ),
258 &[&id],
259 )
260 .ok()??;
261 pg_row_to_token(&row)
262 }
263
264 fn by_prefix(&self, prefix: &str) -> Vec<VerificationToken> {
265 let Ok(mut c) = self.client.lock() else {
266 return vec![];
267 };
268 let rows = match c.query(
269 &format!(
270 "SELECT id, kind, email, user_id, payload, token_hash, token_prefix,
271 created_at, expires_at, consumed_at
272 FROM {PG_TABLE} WHERE token_prefix = $1"
273 ),
274 &[&prefix],
275 ) {
276 Ok(r) => r,
277 Err(_) => return vec![],
278 };
279 rows.iter().filter_map(pg_row_to_token).collect()
280 }
281
282 fn mark_consumed(&self, id: &str, now: u64) -> bool {
283 let Ok(mut c) = self.client.lock() else {
284 return false;
285 };
286 c.execute(
287 &format!(
288 "UPDATE {PG_TABLE} SET consumed_at = $2
289 WHERE id = $1 AND consumed_at IS NULL"
290 ),
291 &[&id, &(now as i64)],
292 )
293 .map(|n| n > 0)
294 .unwrap_or(false)
295 }
296
297 fn purge_expired(&self, now: u64) {
298 if let Ok(mut c) = self.client.lock() {
299 let _ = c.execute(
300 &format!(
301 "DELETE FROM {PG_TABLE}
302 WHERE expires_at <= $1 AND consumed_at IS NOT NULL"
303 ),
304 &[&(now as i64)],
305 );
306 }
307 }
308 }
309
310 fn pg_row_to_token(row: &postgres::Row) -> Option<VerificationToken> {
311 let kind_raw: String = row.get(1);
312 let kind = kind_from_str(&kind_raw).ok()?;
313 Some(VerificationToken {
314 id: row.get(0),
315 kind,
316 email: row.get(2),
317 user_id: row.get(3),
318 payload: row.get(4),
319 token_hash: row.get(5),
320 token_prefix: row.get(6),
321 created_at: row.get::<_, i64>(7) as u64,
322 expires_at: row.get::<_, i64>(8) as u64,
323 consumed_at: row.get::<_, Option<i64>>(9).map(|v| v as u64),
324 })
325 }
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331 use pylon_auth::verification::{TokenKind, VerificationToken};
332
333 #[test]
334 fn sqlite_round_trip() {
335 let b = SqliteVerificationBackend::in_memory().unwrap();
336 let t = VerificationToken {
337 id: "vt_x".into(),
338 kind: TokenKind::PasswordReset,
339 email: "a@b.com".into(),
340 user_id: None,
341 payload: None,
342 token_hash: "h".into(),
343 token_prefix: "abcd1234".into(),
344 created_at: 100,
345 expires_at: 9_999_999_999,
346 consumed_at: None,
347 };
348 b.put(&t);
349 assert_eq!(b.get("vt_x").unwrap().email, "a@b.com");
350 assert_eq!(b.by_prefix("abcd1234").len(), 1);
351 assert_eq!(b.by_prefix("nope0000").len(), 0);
352 assert!(b.mark_consumed("vt_x", 200));
354 assert!(!b.mark_consumed("vt_x", 300)); assert_eq!(b.get("vt_x").unwrap().consumed_at, Some(200));
356 }
357
358 #[test]
359 fn purge_drops_expired_consumed() {
360 let b = SqliteVerificationBackend::in_memory().unwrap();
361 b.put(&VerificationToken {
362 id: "vt_done".into(),
363 kind: TokenKind::MagicLink,
364 email: "a@b.com".into(),
365 user_id: None,
366 payload: None,
367 token_hash: "h".into(),
368 token_prefix: "p".into(),
369 created_at: 1,
370 expires_at: 2,
371 consumed_at: Some(2),
372 });
373 b.purge_expired(100);
374 assert!(b.get("vt_done").is_none());
375 }
376}