1use std::sync::{Arc, Mutex};
16
17use pylon_auth::{OAuthState, OAuthStateBackend};
18use rusqlite::Connection;
19
20const TABLE: &str = "_pylon_oauth_state";
21
22pub struct SqliteOAuthBackend {
27 conn: Arc<Mutex<Connection>>,
28}
29
30impl SqliteOAuthBackend {
31 pub fn open(path: &str) -> Result<Self, String> {
32 let conn = Connection::open(path).map_err(|e| format!("open: {e}"))?;
33 Self::from_connection(conn)
34 }
35
36 pub fn in_memory() -> Result<Self, String> {
37 let conn = Connection::open_in_memory().map_err(|e| format!("open: {e}"))?;
38 Self::from_connection(conn)
39 }
40
41 fn from_connection(conn: Connection) -> Result<Self, String> {
42 conn.execute_batch(&format!(
49 "CREATE TABLE IF NOT EXISTS {TABLE} (
50 token TEXT PRIMARY KEY,
51 provider TEXT NOT NULL,
52 callback_url TEXT NOT NULL DEFAULT '',
53 error_callback_url TEXT NOT NULL DEFAULT '',
54 pkce_verifier TEXT,
55 expires_at INTEGER NOT NULL
56 );
57 CREATE INDEX IF NOT EXISTS {TABLE}_exp_idx ON {TABLE}(expires_at);"
58 ))
59 .map_err(|e| format!("init schema: {e}"))?;
60 let _ = conn.execute(
61 &format!("ALTER TABLE {TABLE} ADD COLUMN callback_url TEXT NOT NULL DEFAULT ''"),
62 [],
63 );
64 let _ = conn.execute(
65 &format!("ALTER TABLE {TABLE} ADD COLUMN error_callback_url TEXT NOT NULL DEFAULT ''"),
66 [],
67 );
68 let _ = conn.execute(
71 &format!("ALTER TABLE {TABLE} ADD COLUMN pkce_verifier TEXT"),
72 [],
73 );
74 Ok(Self {
75 conn: Arc::new(Mutex::new(conn)),
76 })
77 }
78}
79
80impl OAuthStateBackend for SqliteOAuthBackend {
81 fn put(&self, token: &str, state: &OAuthState) {
82 if let Ok(guard) = self.conn.lock() {
83 let _ = guard.execute(
84 &format!(
85 "INSERT INTO {TABLE} (token, provider, callback_url, error_callback_url, pkce_verifier, expires_at)
86 VALUES (?1, ?2, ?3, ?4, ?5, ?6)
87 ON CONFLICT(token) DO UPDATE SET
88 provider = excluded.provider,
89 callback_url = excluded.callback_url,
90 error_callback_url = excluded.error_callback_url,
91 pkce_verifier = excluded.pkce_verifier,
92 expires_at = excluded.expires_at"
93 ),
94 rusqlite::params![
95 token,
96 state.provider,
97 state.callback_url,
98 state.error_callback_url,
99 state.pkce_verifier,
100 state.expires_at as i64,
101 ],
102 );
103 }
104 }
105
106 fn take(&self, token: &str, now_unix_secs: u64) -> Option<OAuthState> {
107 let guard = self.conn.lock().ok()?;
108 let tx = guard.unchecked_transaction().ok()?;
111 let row: Option<(String, String, String, Option<String>, i64)> = tx
112 .query_row(
113 &format!(
114 "SELECT provider, callback_url, error_callback_url, pkce_verifier, expires_at
115 FROM {TABLE} WHERE token = ?1"
116 ),
117 rusqlite::params![token],
118 |r| Ok((r.get(0)?, r.get(1)?, r.get(2)?, r.get(3)?, r.get(4)?)),
119 )
120 .ok();
121 if row.is_some() {
123 let _ = tx.execute(
124 &format!("DELETE FROM {TABLE} WHERE token = ?1"),
125 rusqlite::params![token],
126 );
127 }
128 let _ = tx.commit();
129
130 let (provider, callback_url, error_callback_url, pkce_verifier, expires_at) = row?;
131 if (expires_at as u64) <= now_unix_secs {
132 return None;
133 }
134 Some(OAuthState {
135 provider,
136 callback_url,
137 error_callback_url,
138 pkce_verifier,
139 expires_at: expires_at as u64,
140 })
141 }
142}
143
144pub use pg::PostgresOAuthBackend;
149
150mod pg {
151 use super::*;
152 use postgres::Client;
153 use std::sync::Mutex;
154
155 const PG_TABLE: &str = "_pylon_oauth_state";
156
157 pub struct PostgresOAuthBackend {
158 client: Mutex<Client>,
159 }
160
161 impl PostgresOAuthBackend {
162 pub fn connect(url: &str) -> Result<Self, String> {
163 let mut client =
164 Client::connect(url, postgres::NoTls).map_err(|e| format!("PG connect: {e}"))?;
165 client
170 .batch_execute(&format!(
171 "CREATE TABLE IF NOT EXISTS {PG_TABLE} (
172 token TEXT PRIMARY KEY,
173 provider TEXT NOT NULL,
174 callback_url TEXT NOT NULL DEFAULT '',
175 error_callback_url TEXT NOT NULL DEFAULT '',
176 pkce_verifier TEXT,
177 expires_at BIGINT NOT NULL
178 );
179 ALTER TABLE {PG_TABLE} ADD COLUMN IF NOT EXISTS callback_url TEXT NOT NULL DEFAULT '';
180 ALTER TABLE {PG_TABLE} ADD COLUMN IF NOT EXISTS error_callback_url TEXT NOT NULL DEFAULT '';
181 ALTER TABLE {PG_TABLE} ADD COLUMN IF NOT EXISTS pkce_verifier TEXT;
182 CREATE INDEX IF NOT EXISTS {PG_TABLE}_exp_idx ON {PG_TABLE}(expires_at);"
183 ))
184 .map_err(|e| format!("PG init schema: {e}"))?;
185 Ok(Self {
186 client: Mutex::new(client),
187 })
188 }
189 }
190
191 impl OAuthStateBackend for PostgresOAuthBackend {
192 fn put(&self, token: &str, state: &OAuthState) {
193 if let Ok(mut c) = self.client.lock() {
194 let _ = c.execute(
195 &format!(
196 "INSERT INTO {PG_TABLE} (token, provider, callback_url, error_callback_url, pkce_verifier, expires_at)
197 VALUES ($1, $2, $3, $4, $5, $6)
198 ON CONFLICT (token) DO UPDATE SET
199 provider = EXCLUDED.provider,
200 callback_url = EXCLUDED.callback_url,
201 error_callback_url = EXCLUDED.error_callback_url,
202 pkce_verifier = EXCLUDED.pkce_verifier,
203 expires_at = EXCLUDED.expires_at"
204 ),
205 &[
206 &token,
207 &state.provider,
208 &state.callback_url,
209 &state.error_callback_url,
210 &state.pkce_verifier,
211 &(state.expires_at as i64),
212 ],
213 );
214 }
215 }
216
217 fn take(&self, token: &str, now_unix_secs: u64) -> Option<OAuthState> {
218 let mut c = self.client.lock().ok()?;
224 let row = c
225 .query_opt(
226 &format!(
227 "DELETE FROM {PG_TABLE} WHERE token = $1
228 RETURNING provider, callback_url, error_callback_url, pkce_verifier, expires_at"
229 ),
230 &[&token],
231 )
232 .ok()??;
233 let provider: String = row.get(0);
234 let callback_url: String = row.get(1);
235 let error_callback_url: String = row.get(2);
236 let pkce_verifier: Option<String> = row.get(3);
237 let expires_at: i64 = row.get(4);
238 if (expires_at as u64) <= now_unix_secs {
239 return None;
240 }
241 Some(OAuthState {
242 provider,
243 callback_url,
244 error_callback_url,
245 pkce_verifier,
246 expires_at: expires_at as u64,
247 })
248 }
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255
256 fn fixture(provider: &str, callback: &str) -> OAuthState {
257 OAuthState {
258 provider: provider.to_string(),
259 callback_url: callback.to_string(),
260 error_callback_url: callback.to_string(),
261 pkce_verifier: None,
262 expires_at: 9_999_999_999,
263 }
264 }
265
266 #[test]
267 fn put_then_take_returns_full_state() {
268 let b = SqliteOAuthBackend::in_memory().unwrap();
269 let s = fixture("google", "http://localhost:3000/dashboard");
270 b.put("tok1", &s);
271 let got = b.take("tok1", 100).expect("present");
272 assert_eq!(got.provider, "google");
273 assert_eq!(got.callback_url, "http://localhost:3000/dashboard");
274 assert_eq!(got.error_callback_url, "http://localhost:3000/dashboard");
275 }
276
277 #[test]
278 fn take_is_single_use() {
279 let b = SqliteOAuthBackend::in_memory().unwrap();
280 b.put("tok2", &fixture("github", "http://localhost:3000/dash"));
281 assert!(b.take("tok2", 100).is_some());
282 assert!(b.take("tok2", 100).is_none());
283 }
284
285 #[test]
286 fn expired_token_returns_none() {
287 let b = SqliteOAuthBackend::in_memory().unwrap();
288 let mut s = fixture("google", "http://localhost:3000/dash");
289 s.expires_at = 100;
290 b.put("tok3", &s);
291 assert!(b.take("tok3", 200).is_none());
292 }
293
294 #[test]
295 fn missing_token_returns_none() {
296 let b = SqliteOAuthBackend::in_memory().unwrap();
297 assert!(b.take("never_existed", 0).is_none());
298 }
299}