1use std::sync::{Arc, Mutex};
17
18use pylon_auth::{MagicCode, MagicCodeBackend};
19use rusqlite::Connection;
20
21const SQLITE_TABLE: &str = "_pylon_magic_codes";
22const PG_TABLE: &str = "_pylon_magic_codes";
23
24pub struct SqliteMagicCodeBackend {
29 conn: Arc<Mutex<Connection>>,
30}
31
32impl SqliteMagicCodeBackend {
33 pub fn open(path: &str) -> Result<Self, String> {
34 let conn = Connection::open(path).map_err(|e| format!("open: {e}"))?;
35 Self::from_connection(conn)
36 }
37
38 pub fn in_memory() -> Result<Self, String> {
39 let conn = Connection::open_in_memory().map_err(|e| format!("open: {e}"))?;
40 Self::from_connection(conn)
41 }
42
43 fn from_connection(conn: Connection) -> Result<Self, String> {
44 conn.execute_batch(&format!(
45 "CREATE TABLE IF NOT EXISTS {SQLITE_TABLE} (
46 email TEXT PRIMARY KEY,
47 code TEXT NOT NULL,
48 expires_at INTEGER NOT NULL,
49 attempts INTEGER NOT NULL DEFAULT 0
50 );
51 CREATE INDEX IF NOT EXISTS {SQLITE_TABLE}_exp_idx ON {SQLITE_TABLE}(expires_at);"
52 ))
53 .map_err(|e| format!("init schema: {e}"))?;
54 Ok(Self {
55 conn: Arc::new(Mutex::new(conn)),
56 })
57 }
58}
59
60impl MagicCodeBackend for SqliteMagicCodeBackend {
61 fn put(&self, email: &str, code: &MagicCode) {
62 if let Ok(guard) = self.conn.lock() {
63 let _ = guard.execute(
64 &format!(
65 "INSERT INTO {SQLITE_TABLE} (email, code, expires_at, attempts)
66 VALUES (?1, ?2, ?3, ?4)
67 ON CONFLICT(email) DO UPDATE SET
68 code=excluded.code,
69 expires_at=excluded.expires_at,
70 attempts=excluded.attempts"
71 ),
72 rusqlite::params![
73 email,
74 code.code,
75 code.expires_at as i64,
76 code.attempts as i64
77 ],
78 );
79 }
80 }
81 fn get(&self, email: &str) -> Option<MagicCode> {
82 let guard = self.conn.lock().ok()?;
83 guard
84 .query_row(
85 &format!(
86 "SELECT email, code, expires_at, attempts FROM {SQLITE_TABLE} WHERE email = ?1"
87 ),
88 rusqlite::params![email],
89 |row| {
90 Ok(MagicCode {
91 email: row.get(0)?,
92 code: row.get(1)?,
93 expires_at: row.get::<_, i64>(2)? as u64,
94 attempts: row.get::<_, i64>(3)? as u32,
95 })
96 },
97 )
98 .ok()
99 }
100 fn remove(&self, email: &str) {
101 if let Ok(guard) = self.conn.lock() {
102 let _ = guard.execute(
103 &format!("DELETE FROM {SQLITE_TABLE} WHERE email = ?1"),
104 rusqlite::params![email],
105 );
106 }
107 }
108 fn bump_attempts(&self, email: &str) {
109 if let Ok(guard) = self.conn.lock() {
110 let _ = guard.execute(
111 &format!("UPDATE {SQLITE_TABLE} SET attempts = attempts + 1 WHERE email = ?1"),
112 rusqlite::params![email],
113 );
114 }
115 }
116 fn load_all(&self) -> Vec<MagicCode> {
117 let Ok(guard) = self.conn.lock() else {
118 return Vec::new();
119 };
120 let mut stmt = match guard.prepare(&format!(
121 "SELECT email, code, expires_at, attempts FROM {SQLITE_TABLE}"
122 )) {
123 Ok(s) => s,
124 Err(_) => return Vec::new(),
125 };
126 let iter = match stmt.query_map([], |row| {
127 Ok(MagicCode {
128 email: row.get(0)?,
129 code: row.get(1)?,
130 expires_at: row.get::<_, i64>(2)? as u64,
131 attempts: row.get::<_, i64>(3)? as u32,
132 })
133 }) {
134 Ok(i) => i,
135 Err(_) => return Vec::new(),
136 };
137 iter.flatten().collect()
138 }
139}
140
141pub use pg::PostgresMagicCodeBackend;
146
147mod pg {
148 use super::*;
149 use postgres::Client;
150
151 pub struct PostgresMagicCodeBackend {
152 client: Mutex<Client>,
153 }
154
155 impl PostgresMagicCodeBackend {
156 pub fn connect(url: &str) -> Result<Self, String> {
157 let mut client = pylon_storage::postgres::live::connect_pg(url)?;
158 client
159 .batch_execute(&format!(
160 "CREATE TABLE IF NOT EXISTS {PG_TABLE} (
161 email TEXT PRIMARY KEY,
162 code TEXT NOT NULL,
163 expires_at BIGINT NOT NULL,
164 attempts INTEGER NOT NULL DEFAULT 0
165 );
166 CREATE INDEX IF NOT EXISTS {PG_TABLE}_exp_idx ON {PG_TABLE}(expires_at);"
167 ))
168 .map_err(|e| format!("PG init schema: {e}"))?;
169 Ok(Self {
170 client: Mutex::new(client),
171 })
172 }
173 }
174
175 impl MagicCodeBackend for PostgresMagicCodeBackend {
176 fn put(&self, email: &str, code: &MagicCode) {
177 let Ok(mut c) = self.client.lock() else {
178 return;
179 };
180 let _ = c.execute(
181 &format!(
182 "INSERT INTO {PG_TABLE} (email, code, expires_at, attempts)
183 VALUES ($1, $2, $3, $4)
184 ON CONFLICT (email) DO UPDATE SET
185 code = EXCLUDED.code,
186 expires_at = EXCLUDED.expires_at,
187 attempts = EXCLUDED.attempts"
188 ),
189 &[
190 &email,
191 &code.code,
192 &(code.expires_at as i64),
193 &(code.attempts as i32),
194 ],
195 );
196 }
197 fn get(&self, email: &str) -> Option<MagicCode> {
198 let mut c = self.client.lock().ok()?;
199 let row = c
200 .query_opt(
201 &format!(
202 "SELECT email, code, expires_at, attempts FROM {PG_TABLE} WHERE email = $1"
203 ),
204 &[&email],
205 )
206 .ok()??;
207 Some(MagicCode {
208 email: row.get(0),
209 code: row.get(1),
210 expires_at: row.get::<_, i64>(2) as u64,
211 attempts: row.get::<_, i32>(3) as u32,
212 })
213 }
214 fn remove(&self, email: &str) {
215 if let Ok(mut c) = self.client.lock() {
216 let _ = c.execute(
217 &format!("DELETE FROM {PG_TABLE} WHERE email = $1"),
218 &[&email],
219 );
220 }
221 }
222 fn bump_attempts(&self, email: &str) {
223 if let Ok(mut c) = self.client.lock() {
224 let _ = c.execute(
225 &format!("UPDATE {PG_TABLE} SET attempts = attempts + 1 WHERE email = $1"),
226 &[&email],
227 );
228 }
229 }
230 fn load_all(&self) -> Vec<MagicCode> {
231 let Ok(mut c) = self.client.lock() else {
232 return Vec::new();
233 };
234 let rows = c
235 .query(
236 &format!("SELECT email, code, expires_at, attempts FROM {PG_TABLE}"),
237 &[],
238 )
239 .unwrap_or_default();
240 rows.iter()
241 .map(|row| MagicCode {
242 email: row.get(0),
243 code: row.get(1),
244 expires_at: row.get::<_, i64>(2) as u64,
245 attempts: row.get::<_, i32>(3) as u32,
246 })
247 .collect()
248 }
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255
256 #[test]
257 fn sqlite_roundtrip_put_get_remove() {
258 let b = SqliteMagicCodeBackend::in_memory().unwrap();
259 let mc = MagicCode {
260 email: "a@b.com".into(),
261 code: "123456".into(),
262 expires_at: 9999999999,
263 attempts: 0,
264 };
265 b.put(&mc.email, &mc);
266 let got = b.get(&mc.email).unwrap();
267 assert_eq!(got.code, "123456");
268 b.bump_attempts(&mc.email);
269 assert_eq!(b.get(&mc.email).unwrap().attempts, 1);
270 b.remove(&mc.email);
271 assert!(b.get(&mc.email).is_none());
272 }
273}