1use std::sync::{Arc, Mutex};
16
17use pylon_auth::{OAuthState, OAuthStateBackend};
18use rusqlite::Connection;
19
20const TABLE: &str = "_pylon_oauth_state";
21
22pub struct SqliteOAuthBackend {
27 conn: Arc<Mutex<Connection>>,
28}
29
30impl SqliteOAuthBackend {
31 pub fn open(path: &str) -> Result<Self, String> {
32 let conn = Connection::open(path).map_err(|e| format!("open: {e}"))?;
33 Self::from_connection(conn)
34 }
35
36 pub fn in_memory() -> Result<Self, String> {
37 let conn = Connection::open_in_memory().map_err(|e| format!("open: {e}"))?;
38 Self::from_connection(conn)
39 }
40
41 fn from_connection(conn: Connection) -> Result<Self, String> {
42 conn.execute_batch(&format!(
49 "CREATE TABLE IF NOT EXISTS {TABLE} (
50 token TEXT PRIMARY KEY,
51 provider TEXT NOT NULL,
52 callback_url TEXT NOT NULL DEFAULT '',
53 error_callback_url TEXT NOT NULL DEFAULT '',
54 pkce_verifier TEXT,
55 expires_at INTEGER NOT NULL
56 );
57 CREATE INDEX IF NOT EXISTS {TABLE}_exp_idx ON {TABLE}(expires_at);"
58 ))
59 .map_err(|e| format!("init schema: {e}"))?;
60 let _ = conn.execute(
61 &format!("ALTER TABLE {TABLE} ADD COLUMN callback_url TEXT NOT NULL DEFAULT ''"),
62 [],
63 );
64 let _ = conn.execute(
65 &format!("ALTER TABLE {TABLE} ADD COLUMN error_callback_url TEXT NOT NULL DEFAULT ''"),
66 [],
67 );
68 let _ = conn.execute(
71 &format!("ALTER TABLE {TABLE} ADD COLUMN pkce_verifier TEXT"),
72 [],
73 );
74 Ok(Self {
75 conn: Arc::new(Mutex::new(conn)),
76 })
77 }
78}
79
80impl OAuthStateBackend for SqliteOAuthBackend {
81 fn put(&self, token: &str, state: &OAuthState) {
82 if let Ok(guard) = self.conn.lock() {
83 let _ = guard.execute(
84 &format!(
85 "INSERT INTO {TABLE} (token, provider, callback_url, error_callback_url, pkce_verifier, expires_at)
86 VALUES (?1, ?2, ?3, ?4, ?5, ?6)
87 ON CONFLICT(token) DO UPDATE SET
88 provider = excluded.provider,
89 callback_url = excluded.callback_url,
90 error_callback_url = excluded.error_callback_url,
91 pkce_verifier = excluded.pkce_verifier,
92 expires_at = excluded.expires_at"
93 ),
94 rusqlite::params![
95 token,
96 state.provider,
97 state.callback_url,
98 state.error_callback_url,
99 state.pkce_verifier,
100 state.expires_at as i64,
101 ],
102 );
103 }
104 }
105
106 fn take(&self, token: &str, now_unix_secs: u64) -> Option<OAuthState> {
107 let guard = self.conn.lock().ok()?;
108 let tx = guard.unchecked_transaction().ok()?;
111 let row: Option<(String, String, String, Option<String>, i64)> = tx
112 .query_row(
113 &format!(
114 "SELECT provider, callback_url, error_callback_url, pkce_verifier, expires_at
115 FROM {TABLE} WHERE token = ?1"
116 ),
117 rusqlite::params![token],
118 |r| Ok((r.get(0)?, r.get(1)?, r.get(2)?, r.get(3)?, r.get(4)?)),
119 )
120 .ok();
121 if row.is_some() {
123 let _ = tx.execute(
124 &format!("DELETE FROM {TABLE} WHERE token = ?1"),
125 rusqlite::params![token],
126 );
127 }
128 let _ = tx.commit();
129
130 let (provider, callback_url, error_callback_url, pkce_verifier, expires_at) = row?;
131 if (expires_at as u64) <= now_unix_secs {
132 return None;
133 }
134 Some(OAuthState {
135 provider,
136 callback_url,
137 error_callback_url,
138 pkce_verifier,
139 expires_at: expires_at as u64,
140 })
141 }
142}
143
144pub use pg::PostgresOAuthBackend;
149
150mod pg {
151 use super::*;
152 use postgres::Client;
153 use std::sync::Mutex;
154
155 const PG_TABLE: &str = "_pylon_oauth_state";
156
157 pub struct PostgresOAuthBackend {
158 client: Mutex<Client>,
159 }
160
161 impl PostgresOAuthBackend {
162 pub fn connect(url: &str) -> Result<Self, String> {
163 let mut client = pylon_storage::postgres::live::connect_pg(url)?;
164 client
169 .batch_execute(&format!(
170 "CREATE TABLE IF NOT EXISTS {PG_TABLE} (
171 token TEXT PRIMARY KEY,
172 provider TEXT NOT NULL,
173 callback_url TEXT NOT NULL DEFAULT '',
174 error_callback_url TEXT NOT NULL DEFAULT '',
175 pkce_verifier TEXT,
176 expires_at BIGINT NOT NULL
177 );
178 ALTER TABLE {PG_TABLE} ADD COLUMN IF NOT EXISTS callback_url TEXT NOT NULL DEFAULT '';
179 ALTER TABLE {PG_TABLE} ADD COLUMN IF NOT EXISTS error_callback_url TEXT NOT NULL DEFAULT '';
180 ALTER TABLE {PG_TABLE} ADD COLUMN IF NOT EXISTS pkce_verifier TEXT;
181 CREATE INDEX IF NOT EXISTS {PG_TABLE}_exp_idx ON {PG_TABLE}(expires_at);"
182 ))
183 .map_err(|e| format!("PG init schema: {e}"))?;
184 Ok(Self {
185 client: Mutex::new(client),
186 })
187 }
188 }
189
190 impl OAuthStateBackend for PostgresOAuthBackend {
191 fn put(&self, token: &str, state: &OAuthState) {
192 if let Ok(mut c) = self.client.lock() {
193 let _ = c.execute(
194 &format!(
195 "INSERT INTO {PG_TABLE} (token, provider, callback_url, error_callback_url, pkce_verifier, expires_at)
196 VALUES ($1, $2, $3, $4, $5, $6)
197 ON CONFLICT (token) DO UPDATE SET
198 provider = EXCLUDED.provider,
199 callback_url = EXCLUDED.callback_url,
200 error_callback_url = EXCLUDED.error_callback_url,
201 pkce_verifier = EXCLUDED.pkce_verifier,
202 expires_at = EXCLUDED.expires_at"
203 ),
204 &[
205 &token,
206 &state.provider,
207 &state.callback_url,
208 &state.error_callback_url,
209 &state.pkce_verifier,
210 &(state.expires_at as i64),
211 ],
212 );
213 }
214 }
215
216 fn take(&self, token: &str, now_unix_secs: u64) -> Option<OAuthState> {
217 let mut c = self.client.lock().ok()?;
223 let row = c
224 .query_opt(
225 &format!(
226 "DELETE FROM {PG_TABLE} WHERE token = $1
227 RETURNING provider, callback_url, error_callback_url, pkce_verifier, expires_at"
228 ),
229 &[&token],
230 )
231 .ok()??;
232 let provider: String = row.get(0);
233 let callback_url: String = row.get(1);
234 let error_callback_url: String = row.get(2);
235 let pkce_verifier: Option<String> = row.get(3);
236 let expires_at: i64 = row.get(4);
237 if (expires_at as u64) <= now_unix_secs {
238 return None;
239 }
240 Some(OAuthState {
241 provider,
242 callback_url,
243 error_callback_url,
244 pkce_verifier,
245 expires_at: expires_at as u64,
246 })
247 }
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 fn fixture(provider: &str, callback: &str) -> OAuthState {
256 OAuthState {
257 provider: provider.to_string(),
258 callback_url: callback.to_string(),
259 error_callback_url: callback.to_string(),
260 pkce_verifier: None,
261 expires_at: 9_999_999_999,
262 }
263 }
264
265 #[test]
266 fn put_then_take_returns_full_state() {
267 let b = SqliteOAuthBackend::in_memory().unwrap();
268 let s = fixture("google", "http://localhost:3000/dashboard");
269 b.put("tok1", &s);
270 let got = b.take("tok1", 100).expect("present");
271 assert_eq!(got.provider, "google");
272 assert_eq!(got.callback_url, "http://localhost:3000/dashboard");
273 assert_eq!(got.error_callback_url, "http://localhost:3000/dashboard");
274 }
275
276 #[test]
277 fn take_is_single_use() {
278 let b = SqliteOAuthBackend::in_memory().unwrap();
279 b.put("tok2", &fixture("github", "http://localhost:3000/dash"));
280 assert!(b.take("tok2", 100).is_some());
281 assert!(b.take("tok2", 100).is_none());
282 }
283
284 #[test]
285 fn expired_token_returns_none() {
286 let b = SqliteOAuthBackend::in_memory().unwrap();
287 let mut s = fixture("google", "http://localhost:3000/dash");
288 s.expires_at = 100;
289 b.put("tok3", &s);
290 assert!(b.take("tok3", 200).is_none());
291 }
292
293 #[test]
294 fn missing_token_returns_none() {
295 let b = SqliteOAuthBackend::in_memory().unwrap();
296 assert!(b.take("never_existed", 0).is_none());
297 }
298}