Skip to main content

pylon_runtime/
oauth_backend.rs

1//! Persistent OAuth state stores (SQLite + Postgres).
2//!
3//! State tokens are short-lived (10 min) and single-use. Persisting
4//! them to durable storage lets the OAuth flow survive a server
5//! restart that happens between the user clicking "Sign in with
6//! Google" and the provider redirecting back. Schema carries the
7//! callback / error_callback URLs (validated against PYLON_TRUSTED_ORIGINS
8//! at create time) so the callback handler doesn't need any env var
9//! to know where to redirect after success or failure.
10//!
11//! Cleanup happens lazily — when `take()` finds an expired token it
12//! returns None and the row sticks around until VACUUM. At the
13//! volumes OAuth flows actually generate this is never a problem.
14
15use std::sync::{Arc, Mutex};
16
17use pylon_auth::{OAuthState, OAuthStateBackend};
18use rusqlite::Connection;
19
20const TABLE: &str = "_pylon_oauth_state";
21
22// ---------------------------------------------------------------------------
23// SQLite backend
24// ---------------------------------------------------------------------------
25
26pub struct SqliteOAuthBackend {
27    conn: Arc<Mutex<Connection>>,
28}
29
30impl SqliteOAuthBackend {
31    pub fn open(path: &str) -> Result<Self, String> {
32        let conn = Connection::open(path).map_err(|e| format!("open: {e}"))?;
33        Self::from_connection(conn)
34    }
35
36    pub fn in_memory() -> Result<Self, String> {
37        let conn = Connection::open_in_memory().map_err(|e| format!("open: {e}"))?;
38        Self::from_connection(conn)
39    }
40
41    fn from_connection(conn: Connection) -> Result<Self, String> {
42        // Base table for new installs. Existing installs predate the
43        // callback_url / error_callback_url columns and get an
44        // ALTER TABLE ADD COLUMN below — ADD COLUMN is a no-op when
45        // the column already exists, so we swallow its error for
46        // idempotency. Same pattern as session_backend's tenant_id
47        // migration.
48        conn.execute_batch(&format!(
49            "CREATE TABLE IF NOT EXISTS {TABLE} (
50                token TEXT PRIMARY KEY,
51                provider TEXT NOT NULL,
52                callback_url TEXT NOT NULL DEFAULT '',
53                error_callback_url TEXT NOT NULL DEFAULT '',
54                pkce_verifier TEXT,
55                expires_at INTEGER NOT NULL
56            );
57            CREATE INDEX IF NOT EXISTS {TABLE}_exp_idx ON {TABLE}(expires_at);"
58        ))
59        .map_err(|e| format!("init schema: {e}"))?;
60        let _ = conn.execute(
61            &format!("ALTER TABLE {TABLE} ADD COLUMN callback_url TEXT NOT NULL DEFAULT ''"),
62            [],
63        );
64        let _ = conn.execute(
65            &format!("ALTER TABLE {TABLE} ADD COLUMN error_callback_url TEXT NOT NULL DEFAULT ''"),
66            [],
67        );
68        // PKCE column was added when Twitter/X support landed. Existing
69        // installs need an idempotent ADD COLUMN.
70        let _ = conn.execute(
71            &format!("ALTER TABLE {TABLE} ADD COLUMN pkce_verifier TEXT"),
72            [],
73        );
74        Ok(Self {
75            conn: Arc::new(Mutex::new(conn)),
76        })
77    }
78}
79
80impl OAuthStateBackend for SqliteOAuthBackend {
81    fn put(&self, token: &str, state: &OAuthState) {
82        if let Ok(guard) = self.conn.lock() {
83            let _ = guard.execute(
84                &format!(
85                    "INSERT INTO {TABLE} (token, provider, callback_url, error_callback_url, pkce_verifier, expires_at)
86                     VALUES (?1, ?2, ?3, ?4, ?5, ?6)
87                     ON CONFLICT(token) DO UPDATE SET
88                       provider = excluded.provider,
89                       callback_url = excluded.callback_url,
90                       error_callback_url = excluded.error_callback_url,
91                       pkce_verifier = excluded.pkce_verifier,
92                       expires_at = excluded.expires_at"
93                ),
94                rusqlite::params![
95                    token,
96                    state.provider,
97                    state.callback_url,
98                    state.error_callback_url,
99                    state.pkce_verifier,
100                    state.expires_at as i64,
101                ],
102            );
103        }
104    }
105
106    fn take(&self, token: &str, now_unix_secs: u64) -> Option<OAuthState> {
107        let guard = self.conn.lock().ok()?;
108        // Read first, then delete — must be a transaction so concurrent
109        // callbacks can't both succeed with the same token.
110        let tx = guard.unchecked_transaction().ok()?;
111        let row: Option<(String, String, String, Option<String>, i64)> = tx
112            .query_row(
113                &format!(
114                    "SELECT provider, callback_url, error_callback_url, pkce_verifier, expires_at
115                     FROM {TABLE} WHERE token = ?1"
116                ),
117                rusqlite::params![token],
118                |r| Ok((r.get(0)?, r.get(1)?, r.get(2)?, r.get(3)?, r.get(4)?)),
119            )
120            .ok();
121        // Always delete what we read — single-use even if expired.
122        if row.is_some() {
123            let _ = tx.execute(
124                &format!("DELETE FROM {TABLE} WHERE token = ?1"),
125                rusqlite::params![token],
126            );
127        }
128        let _ = tx.commit();
129
130        let (provider, callback_url, error_callback_url, pkce_verifier, expires_at) = row?;
131        if (expires_at as u64) <= now_unix_secs {
132            return None;
133        }
134        Some(OAuthState {
135            provider,
136            callback_url,
137            error_callback_url,
138            pkce_verifier,
139            expires_at: expires_at as u64,
140        })
141    }
142}
143
144// ---------------------------------------------------------------------------
145// Postgres backend
146// ---------------------------------------------------------------------------
147
148pub use pg::PostgresOAuthBackend;
149
150mod pg {
151    use super::*;
152    use postgres::Client;
153    use std::sync::Mutex;
154
155    const PG_TABLE: &str = "_pylon_oauth_state";
156
157    pub struct PostgresOAuthBackend {
158        client: Mutex<Client>,
159    }
160
161    impl PostgresOAuthBackend {
162        pub fn connect(url: &str) -> Result<Self, String> {
163            let mut client = pylon_storage::postgres::live::connect_pg(url)?;
164            // Same shape as the SQLite version — declare the columns
165            // up front for new installs, and idempotent ALTER TABLEs
166            // for ones that predate the callback URL fields. Postgres'
167            // IF NOT EXISTS on ADD COLUMN is fine here.
168            client
169                .batch_execute(&format!(
170                    "CREATE TABLE IF NOT EXISTS {PG_TABLE} (
171                        token TEXT PRIMARY KEY,
172                        provider TEXT NOT NULL,
173                        callback_url TEXT NOT NULL DEFAULT '',
174                        error_callback_url TEXT NOT NULL DEFAULT '',
175                        pkce_verifier TEXT,
176                        expires_at BIGINT NOT NULL
177                    );
178                    ALTER TABLE {PG_TABLE} ADD COLUMN IF NOT EXISTS callback_url TEXT NOT NULL DEFAULT '';
179                    ALTER TABLE {PG_TABLE} ADD COLUMN IF NOT EXISTS error_callback_url TEXT NOT NULL DEFAULT '';
180                    ALTER TABLE {PG_TABLE} ADD COLUMN IF NOT EXISTS pkce_verifier TEXT;
181                    CREATE INDEX IF NOT EXISTS {PG_TABLE}_exp_idx ON {PG_TABLE}(expires_at);"
182                ))
183                .map_err(|e| format!("PG init schema: {e}"))?;
184            Ok(Self {
185                client: Mutex::new(client),
186            })
187        }
188    }
189
190    impl OAuthStateBackend for PostgresOAuthBackend {
191        fn put(&self, token: &str, state: &OAuthState) {
192            if let Ok(mut c) = self.client.lock() {
193                let _ = c.execute(
194                    &format!(
195                        "INSERT INTO {PG_TABLE} (token, provider, callback_url, error_callback_url, pkce_verifier, expires_at)
196                         VALUES ($1, $2, $3, $4, $5, $6)
197                         ON CONFLICT (token) DO UPDATE SET
198                           provider = EXCLUDED.provider,
199                           callback_url = EXCLUDED.callback_url,
200                           error_callback_url = EXCLUDED.error_callback_url,
201                           pkce_verifier = EXCLUDED.pkce_verifier,
202                           expires_at = EXCLUDED.expires_at"
203                    ),
204                    &[
205                        &token,
206                        &state.provider,
207                        &state.callback_url,
208                        &state.error_callback_url,
209                        &state.pkce_verifier,
210                        &(state.expires_at as i64),
211                    ],
212                );
213            }
214        }
215
216        fn take(&self, token: &str, now_unix_secs: u64) -> Option<OAuthState> {
217            // Single round-trip with `RETURNING` is atomic enough — the
218            // DELETE removes the row whether it's expired or not (single-use),
219            // and we filter the returned state by expires_at after.
220            // Concurrent callbacks for the same token can't both succeed
221            // because only one DELETE will return a row.
222            let mut c = self.client.lock().ok()?;
223            let row = c
224                .query_opt(
225                    &format!(
226                        "DELETE FROM {PG_TABLE} WHERE token = $1
227                         RETURNING provider, callback_url, error_callback_url, pkce_verifier, expires_at"
228                    ),
229                    &[&token],
230                )
231                .ok()??;
232            let provider: String = row.get(0);
233            let callback_url: String = row.get(1);
234            let error_callback_url: String = row.get(2);
235            let pkce_verifier: Option<String> = row.get(3);
236            let expires_at: i64 = row.get(4);
237            if (expires_at as u64) <= now_unix_secs {
238                return None;
239            }
240            Some(OAuthState {
241                provider,
242                callback_url,
243                error_callback_url,
244                pkce_verifier,
245                expires_at: expires_at as u64,
246            })
247        }
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    fn fixture(provider: &str, callback: &str) -> OAuthState {
256        OAuthState {
257            provider: provider.to_string(),
258            callback_url: callback.to_string(),
259            error_callback_url: callback.to_string(),
260            pkce_verifier: None,
261            expires_at: 9_999_999_999,
262        }
263    }
264
265    #[test]
266    fn put_then_take_returns_full_state() {
267        let b = SqliteOAuthBackend::in_memory().unwrap();
268        let s = fixture("google", "http://localhost:3000/dashboard");
269        b.put("tok1", &s);
270        let got = b.take("tok1", 100).expect("present");
271        assert_eq!(got.provider, "google");
272        assert_eq!(got.callback_url, "http://localhost:3000/dashboard");
273        assert_eq!(got.error_callback_url, "http://localhost:3000/dashboard");
274    }
275
276    #[test]
277    fn take_is_single_use() {
278        let b = SqliteOAuthBackend::in_memory().unwrap();
279        b.put("tok2", &fixture("github", "http://localhost:3000/dash"));
280        assert!(b.take("tok2", 100).is_some());
281        assert!(b.take("tok2", 100).is_none());
282    }
283
284    #[test]
285    fn expired_token_returns_none() {
286        let b = SqliteOAuthBackend::in_memory().unwrap();
287        let mut s = fixture("google", "http://localhost:3000/dash");
288        s.expires_at = 100;
289        b.put("tok3", &s);
290        assert!(b.take("tok3", 200).is_none());
291    }
292
293    #[test]
294    fn missing_token_returns_none() {
295        let b = SqliteOAuthBackend::in_memory().unwrap();
296        assert!(b.take("never_existed", 0).is_none());
297    }
298}