Skip to main content

pylon_runtime/
oauth_backend.rs

1//! Persistent OAuth state stores (SQLite + Postgres).
2//!
3//! State tokens are short-lived (10 min) and single-use. Persisting
4//! them to durable storage lets the OAuth flow survive a server
5//! restart that happens between the user clicking "Sign in with
6//! Google" and the provider redirecting back. Schema carries the
7//! callback / error_callback URLs (validated against PYLON_TRUSTED_ORIGINS
8//! at create time) so the callback handler doesn't need any env var
9//! to know where to redirect after success or failure.
10//!
11//! Cleanup happens lazily — when `take()` finds an expired token it
12//! returns None and the row sticks around until VACUUM. At the
13//! volumes OAuth flows actually generate this is never a problem.
14
15use std::sync::{Arc, Mutex};
16
17use pylon_auth::{OAuthState, OAuthStateBackend};
18use rusqlite::Connection;
19
20const TABLE: &str = "_pylon_oauth_state";
21
22// ---------------------------------------------------------------------------
23// SQLite backend
24// ---------------------------------------------------------------------------
25
26pub struct SqliteOAuthBackend {
27    conn: Arc<Mutex<Connection>>,
28}
29
30impl SqliteOAuthBackend {
31    pub fn open(path: &str) -> Result<Self, String> {
32        let conn = Connection::open(path).map_err(|e| format!("open: {e}"))?;
33        Self::from_connection(conn)
34    }
35
36    pub fn in_memory() -> Result<Self, String> {
37        let conn = Connection::open_in_memory().map_err(|e| format!("open: {e}"))?;
38        Self::from_connection(conn)
39    }
40
41    fn from_connection(conn: Connection) -> Result<Self, String> {
42        // Base table for new installs. Existing installs predate the
43        // callback_url / error_callback_url columns and get an
44        // ALTER TABLE ADD COLUMN below — ADD COLUMN is a no-op when
45        // the column already exists, so we swallow its error for
46        // idempotency. Same pattern as session_backend's tenant_id
47        // migration.
48        conn.execute_batch(&format!(
49            "CREATE TABLE IF NOT EXISTS {TABLE} (
50                token TEXT PRIMARY KEY,
51                provider TEXT NOT NULL,
52                callback_url TEXT NOT NULL DEFAULT '',
53                error_callback_url TEXT NOT NULL DEFAULT '',
54                pkce_verifier TEXT,
55                expires_at INTEGER NOT NULL
56            );
57            CREATE INDEX IF NOT EXISTS {TABLE}_exp_idx ON {TABLE}(expires_at);"
58        ))
59        .map_err(|e| format!("init schema: {e}"))?;
60        let _ = conn.execute(
61            &format!("ALTER TABLE {TABLE} ADD COLUMN callback_url TEXT NOT NULL DEFAULT ''"),
62            [],
63        );
64        let _ = conn.execute(
65            &format!("ALTER TABLE {TABLE} ADD COLUMN error_callback_url TEXT NOT NULL DEFAULT ''"),
66            [],
67        );
68        // PKCE column was added when Twitter/X support landed. Existing
69        // installs need an idempotent ADD COLUMN.
70        let _ = conn.execute(
71            &format!("ALTER TABLE {TABLE} ADD COLUMN pkce_verifier TEXT"),
72            [],
73        );
74        Ok(Self {
75            conn: Arc::new(Mutex::new(conn)),
76        })
77    }
78}
79
80impl OAuthStateBackend for SqliteOAuthBackend {
81    fn put(&self, token: &str, state: &OAuthState) {
82        if let Ok(guard) = self.conn.lock() {
83            let _ = guard.execute(
84                &format!(
85                    "INSERT INTO {TABLE} (token, provider, callback_url, error_callback_url, pkce_verifier, expires_at)
86                     VALUES (?1, ?2, ?3, ?4, ?5, ?6)
87                     ON CONFLICT(token) DO UPDATE SET
88                       provider = excluded.provider,
89                       callback_url = excluded.callback_url,
90                       error_callback_url = excluded.error_callback_url,
91                       pkce_verifier = excluded.pkce_verifier,
92                       expires_at = excluded.expires_at"
93                ),
94                rusqlite::params![
95                    token,
96                    state.provider,
97                    state.callback_url,
98                    state.error_callback_url,
99                    state.pkce_verifier,
100                    state.expires_at as i64,
101                ],
102            );
103        }
104    }
105
106    fn take(&self, token: &str, now_unix_secs: u64) -> Option<OAuthState> {
107        let guard = self.conn.lock().ok()?;
108        // Read first, then delete — must be a transaction so concurrent
109        // callbacks can't both succeed with the same token.
110        let tx = guard.unchecked_transaction().ok()?;
111        let row: Option<(String, String, String, Option<String>, i64)> = tx
112            .query_row(
113                &format!(
114                    "SELECT provider, callback_url, error_callback_url, pkce_verifier, expires_at
115                     FROM {TABLE} WHERE token = ?1"
116                ),
117                rusqlite::params![token],
118                |r| Ok((r.get(0)?, r.get(1)?, r.get(2)?, r.get(3)?, r.get(4)?)),
119            )
120            .ok();
121        // Always delete what we read — single-use even if expired.
122        if row.is_some() {
123            let _ = tx.execute(
124                &format!("DELETE FROM {TABLE} WHERE token = ?1"),
125                rusqlite::params![token],
126            );
127        }
128        let _ = tx.commit();
129
130        let (provider, callback_url, error_callback_url, pkce_verifier, expires_at) = row?;
131        if (expires_at as u64) <= now_unix_secs {
132            return None;
133        }
134        Some(OAuthState {
135            provider,
136            callback_url,
137            error_callback_url,
138            pkce_verifier,
139            expires_at: expires_at as u64,
140        })
141    }
142}
143
144// ---------------------------------------------------------------------------
145// Postgres backend
146// ---------------------------------------------------------------------------
147
148pub use pg::PostgresOAuthBackend;
149
150mod pg {
151    use super::*;
152    use postgres::Client;
153    use std::sync::Mutex;
154
155    const PG_TABLE: &str = "_pylon_oauth_state";
156
157    pub struct PostgresOAuthBackend {
158        client: Mutex<Client>,
159    }
160
161    impl PostgresOAuthBackend {
162        pub fn connect(url: &str) -> Result<Self, String> {
163            let mut client =
164                Client::connect(url, postgres::NoTls).map_err(|e| format!("PG connect: {e}"))?;
165            // Same shape as the SQLite version — declare the columns
166            // up front for new installs, and idempotent ALTER TABLEs
167            // for ones that predate the callback URL fields. Postgres'
168            // IF NOT EXISTS on ADD COLUMN is fine here.
169            client
170                .batch_execute(&format!(
171                    "CREATE TABLE IF NOT EXISTS {PG_TABLE} (
172                        token TEXT PRIMARY KEY,
173                        provider TEXT NOT NULL,
174                        callback_url TEXT NOT NULL DEFAULT '',
175                        error_callback_url TEXT NOT NULL DEFAULT '',
176                        pkce_verifier TEXT,
177                        expires_at BIGINT NOT NULL
178                    );
179                    ALTER TABLE {PG_TABLE} ADD COLUMN IF NOT EXISTS callback_url TEXT NOT NULL DEFAULT '';
180                    ALTER TABLE {PG_TABLE} ADD COLUMN IF NOT EXISTS error_callback_url TEXT NOT NULL DEFAULT '';
181                    ALTER TABLE {PG_TABLE} ADD COLUMN IF NOT EXISTS pkce_verifier TEXT;
182                    CREATE INDEX IF NOT EXISTS {PG_TABLE}_exp_idx ON {PG_TABLE}(expires_at);"
183                ))
184                .map_err(|e| format!("PG init schema: {e}"))?;
185            Ok(Self {
186                client: Mutex::new(client),
187            })
188        }
189    }
190
191    impl OAuthStateBackend for PostgresOAuthBackend {
192        fn put(&self, token: &str, state: &OAuthState) {
193            if let Ok(mut c) = self.client.lock() {
194                let _ = c.execute(
195                    &format!(
196                        "INSERT INTO {PG_TABLE} (token, provider, callback_url, error_callback_url, pkce_verifier, expires_at)
197                         VALUES ($1, $2, $3, $4, $5, $6)
198                         ON CONFLICT (token) DO UPDATE SET
199                           provider = EXCLUDED.provider,
200                           callback_url = EXCLUDED.callback_url,
201                           error_callback_url = EXCLUDED.error_callback_url,
202                           pkce_verifier = EXCLUDED.pkce_verifier,
203                           expires_at = EXCLUDED.expires_at"
204                    ),
205                    &[
206                        &token,
207                        &state.provider,
208                        &state.callback_url,
209                        &state.error_callback_url,
210                        &state.pkce_verifier,
211                        &(state.expires_at as i64),
212                    ],
213                );
214            }
215        }
216
217        fn take(&self, token: &str, now_unix_secs: u64) -> Option<OAuthState> {
218            // Single round-trip with `RETURNING` is atomic enough — the
219            // DELETE removes the row whether it's expired or not (single-use),
220            // and we filter the returned state by expires_at after.
221            // Concurrent callbacks for the same token can't both succeed
222            // because only one DELETE will return a row.
223            let mut c = self.client.lock().ok()?;
224            let row = c
225                .query_opt(
226                    &format!(
227                        "DELETE FROM {PG_TABLE} WHERE token = $1
228                         RETURNING provider, callback_url, error_callback_url, pkce_verifier, expires_at"
229                    ),
230                    &[&token],
231                )
232                .ok()??;
233            let provider: String = row.get(0);
234            let callback_url: String = row.get(1);
235            let error_callback_url: String = row.get(2);
236            let pkce_verifier: Option<String> = row.get(3);
237            let expires_at: i64 = row.get(4);
238            if (expires_at as u64) <= now_unix_secs {
239                return None;
240            }
241            Some(OAuthState {
242                provider,
243                callback_url,
244                error_callback_url,
245                pkce_verifier,
246                expires_at: expires_at as u64,
247            })
248        }
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    fn fixture(provider: &str, callback: &str) -> OAuthState {
257        OAuthState {
258            provider: provider.to_string(),
259            callback_url: callback.to_string(),
260            error_callback_url: callback.to_string(),
261            pkce_verifier: None,
262            expires_at: 9_999_999_999,
263        }
264    }
265
266    #[test]
267    fn put_then_take_returns_full_state() {
268        let b = SqliteOAuthBackend::in_memory().unwrap();
269        let s = fixture("google", "http://localhost:3000/dashboard");
270        b.put("tok1", &s);
271        let got = b.take("tok1", 100).expect("present");
272        assert_eq!(got.provider, "google");
273        assert_eq!(got.callback_url, "http://localhost:3000/dashboard");
274        assert_eq!(got.error_callback_url, "http://localhost:3000/dashboard");
275    }
276
277    #[test]
278    fn take_is_single_use() {
279        let b = SqliteOAuthBackend::in_memory().unwrap();
280        b.put("tok2", &fixture("github", "http://localhost:3000/dash"));
281        assert!(b.take("tok2", 100).is_some());
282        assert!(b.take("tok2", 100).is_none());
283    }
284
285    #[test]
286    fn expired_token_returns_none() {
287        let b = SqliteOAuthBackend::in_memory().unwrap();
288        let mut s = fixture("google", "http://localhost:3000/dash");
289        s.expires_at = 100;
290        b.put("tok3", &s);
291        assert!(b.take("tok3", 200).is_none());
292    }
293
294    #[test]
295    fn missing_token_returns_none() {
296        let b = SqliteOAuthBackend::in_memory().unwrap();
297        assert!(b.take("never_existed", 0).is_none());
298    }
299}