1use std::sync::{Arc, Mutex};
17
18use pylon_auth::{MagicCode, MagicCodeBackend};
19use rusqlite::Connection;
20
21const SQLITE_TABLE: &str = "_pylon_magic_codes";
22const PG_TABLE: &str = "_pylon_magic_codes";
23
24pub struct SqliteMagicCodeBackend {
29 conn: Arc<Mutex<Connection>>,
30}
31
32impl SqliteMagicCodeBackend {
33 pub fn open(path: &str) -> Result<Self, String> {
34 let conn = Connection::open(path).map_err(|e| format!("open: {e}"))?;
35 Self::from_connection(conn)
36 }
37
38 pub fn in_memory() -> Result<Self, String> {
39 let conn = Connection::open_in_memory().map_err(|e| format!("open: {e}"))?;
40 Self::from_connection(conn)
41 }
42
43 fn from_connection(conn: Connection) -> Result<Self, String> {
44 conn.execute_batch(&format!(
45 "CREATE TABLE IF NOT EXISTS {SQLITE_TABLE} (
46 email TEXT PRIMARY KEY,
47 code TEXT NOT NULL,
48 expires_at INTEGER NOT NULL,
49 attempts INTEGER NOT NULL DEFAULT 0
50 );
51 CREATE INDEX IF NOT EXISTS {SQLITE_TABLE}_exp_idx ON {SQLITE_TABLE}(expires_at);"
52 ))
53 .map_err(|e| format!("init schema: {e}"))?;
54 Ok(Self {
55 conn: Arc::new(Mutex::new(conn)),
56 })
57 }
58}
59
60impl MagicCodeBackend for SqliteMagicCodeBackend {
61 fn put(&self, email: &str, code: &MagicCode) {
62 if let Ok(guard) = self.conn.lock() {
63 let _ = guard.execute(
64 &format!(
65 "INSERT INTO {SQLITE_TABLE} (email, code, expires_at, attempts)
66 VALUES (?1, ?2, ?3, ?4)
67 ON CONFLICT(email) DO UPDATE SET
68 code=excluded.code,
69 expires_at=excluded.expires_at,
70 attempts=excluded.attempts"
71 ),
72 rusqlite::params![
73 email,
74 code.code,
75 code.expires_at as i64,
76 code.attempts as i64
77 ],
78 );
79 }
80 }
81 fn get(&self, email: &str) -> Option<MagicCode> {
82 let guard = self.conn.lock().ok()?;
83 guard
84 .query_row(
85 &format!(
86 "SELECT email, code, expires_at, attempts FROM {SQLITE_TABLE} WHERE email = ?1"
87 ),
88 rusqlite::params![email],
89 |row| {
90 Ok(MagicCode {
91 email: row.get(0)?,
92 code: row.get(1)?,
93 expires_at: row.get::<_, i64>(2)? as u64,
94 attempts: row.get::<_, i64>(3)? as u32,
95 })
96 },
97 )
98 .ok()
99 }
100 fn remove(&self, email: &str) {
101 if let Ok(guard) = self.conn.lock() {
102 let _ = guard.execute(
103 &format!("DELETE FROM {SQLITE_TABLE} WHERE email = ?1"),
104 rusqlite::params![email],
105 );
106 }
107 }
108 fn bump_attempts(&self, email: &str) {
109 if let Ok(guard) = self.conn.lock() {
110 let _ = guard.execute(
111 &format!("UPDATE {SQLITE_TABLE} SET attempts = attempts + 1 WHERE email = ?1"),
112 rusqlite::params![email],
113 );
114 }
115 }
116 fn load_all(&self) -> Vec<MagicCode> {
117 let Ok(guard) = self.conn.lock() else {
118 return Vec::new();
119 };
120 let mut stmt = match guard.prepare(&format!(
121 "SELECT email, code, expires_at, attempts FROM {SQLITE_TABLE}"
122 )) {
123 Ok(s) => s,
124 Err(_) => return Vec::new(),
125 };
126 let iter = match stmt.query_map([], |row| {
127 Ok(MagicCode {
128 email: row.get(0)?,
129 code: row.get(1)?,
130 expires_at: row.get::<_, i64>(2)? as u64,
131 attempts: row.get::<_, i64>(3)? as u32,
132 })
133 }) {
134 Ok(i) => i,
135 Err(_) => return Vec::new(),
136 };
137 iter.flatten().collect()
138 }
139}
140
141pub use pg::PostgresMagicCodeBackend;
146
147mod pg {
148 use super::*;
149 use postgres::Client;
150
151 pub struct PostgresMagicCodeBackend {
152 client: Mutex<Client>,
153 }
154
155 impl PostgresMagicCodeBackend {
156 pub fn connect(url: &str) -> Result<Self, String> {
157 let mut client =
158 Client::connect(url, postgres::NoTls).map_err(|e| format!("PG connect: {e}"))?;
159 client
160 .batch_execute(&format!(
161 "CREATE TABLE IF NOT EXISTS {PG_TABLE} (
162 email TEXT PRIMARY KEY,
163 code TEXT NOT NULL,
164 expires_at BIGINT NOT NULL,
165 attempts INTEGER NOT NULL DEFAULT 0
166 );
167 CREATE INDEX IF NOT EXISTS {PG_TABLE}_exp_idx ON {PG_TABLE}(expires_at);"
168 ))
169 .map_err(|e| format!("PG init schema: {e}"))?;
170 Ok(Self {
171 client: Mutex::new(client),
172 })
173 }
174 }
175
176 impl MagicCodeBackend for PostgresMagicCodeBackend {
177 fn put(&self, email: &str, code: &MagicCode) {
178 let Ok(mut c) = self.client.lock() else {
179 return;
180 };
181 let _ = c.execute(
182 &format!(
183 "INSERT INTO {PG_TABLE} (email, code, expires_at, attempts)
184 VALUES ($1, $2, $3, $4)
185 ON CONFLICT (email) DO UPDATE SET
186 code = EXCLUDED.code,
187 expires_at = EXCLUDED.expires_at,
188 attempts = EXCLUDED.attempts"
189 ),
190 &[
191 &email,
192 &code.code,
193 &(code.expires_at as i64),
194 &(code.attempts as i32),
195 ],
196 );
197 }
198 fn get(&self, email: &str) -> Option<MagicCode> {
199 let mut c = self.client.lock().ok()?;
200 let row = c
201 .query_opt(
202 &format!(
203 "SELECT email, code, expires_at, attempts FROM {PG_TABLE} WHERE email = $1"
204 ),
205 &[&email],
206 )
207 .ok()??;
208 Some(MagicCode {
209 email: row.get(0),
210 code: row.get(1),
211 expires_at: row.get::<_, i64>(2) as u64,
212 attempts: row.get::<_, i32>(3) as u32,
213 })
214 }
215 fn remove(&self, email: &str) {
216 if let Ok(mut c) = self.client.lock() {
217 let _ = c.execute(
218 &format!("DELETE FROM {PG_TABLE} WHERE email = $1"),
219 &[&email],
220 );
221 }
222 }
223 fn bump_attempts(&self, email: &str) {
224 if let Ok(mut c) = self.client.lock() {
225 let _ = c.execute(
226 &format!("UPDATE {PG_TABLE} SET attempts = attempts + 1 WHERE email = $1"),
227 &[&email],
228 );
229 }
230 }
231 fn load_all(&self) -> Vec<MagicCode> {
232 let Ok(mut c) = self.client.lock() else {
233 return Vec::new();
234 };
235 let rows = c
236 .query(
237 &format!("SELECT email, code, expires_at, attempts FROM {PG_TABLE}"),
238 &[],
239 )
240 .unwrap_or_default();
241 rows.iter()
242 .map(|row| MagicCode {
243 email: row.get(0),
244 code: row.get(1),
245 expires_at: row.get::<_, i64>(2) as u64,
246 attempts: row.get::<_, i32>(3) as u32,
247 })
248 .collect()
249 }
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256
257 #[test]
258 fn sqlite_roundtrip_put_get_remove() {
259 let b = SqliteMagicCodeBackend::in_memory().unwrap();
260 let mc = MagicCode {
261 email: "a@b.com".into(),
262 code: "123456".into(),
263 expires_at: 9999999999,
264 attempts: 0,
265 };
266 b.put(&mc.email, &mc);
267 let got = b.get(&mc.email).unwrap();
268 assert_eq!(got.code, "123456");
269 b.bump_attempts(&mc.email);
270 assert_eq!(b.get(&mc.email).unwrap().attempts, 1);
271 b.remove(&mc.email);
272 assert!(b.get(&mc.email).is_none());
273 }
274}