Skip to main content

pylon_runtime/
oauth_backend.rs

1//! SQLite-backed OAuth state persistence.
2//!
3//! State tokens are short-lived (10 min) and single-use. Persisting them to
4//! SQLite lets the OAuth flow survive a server restart that happens between
5//! the user clicking "Sign in with Google" and the provider redirecting back.
6//!
7//! Schema is one row per token. Cleanup happens lazily — when `take()` finds
8//! an expired token it returns None; a periodic VACUUM is unnecessary at the
9//! volumes OAuth flows actually generate.
10
11use std::sync::{Arc, Mutex};
12
13use pylon_auth::OAuthStateBackend;
14use rusqlite::Connection;
15
16const TABLE: &str = "_pylon_oauth_state";
17
18pub struct SqliteOAuthBackend {
19    conn: Arc<Mutex<Connection>>,
20}
21
22impl SqliteOAuthBackend {
23    pub fn open(path: &str) -> Result<Self, String> {
24        let conn = Connection::open(path).map_err(|e| format!("open: {e}"))?;
25        Self::from_connection(conn)
26    }
27
28    pub fn in_memory() -> Result<Self, String> {
29        let conn = Connection::open_in_memory().map_err(|e| format!("open: {e}"))?;
30        Self::from_connection(conn)
31    }
32
33    fn from_connection(conn: Connection) -> Result<Self, String> {
34        conn.execute_batch(&format!(
35            "CREATE TABLE IF NOT EXISTS {TABLE} (
36                token TEXT PRIMARY KEY,
37                provider TEXT NOT NULL,
38                expires_at INTEGER NOT NULL
39            );
40            CREATE INDEX IF NOT EXISTS {TABLE}_exp_idx ON {TABLE}(expires_at);"
41        ))
42        .map_err(|e| format!("init schema: {e}"))?;
43        Ok(Self {
44            conn: Arc::new(Mutex::new(conn)),
45        })
46    }
47}
48
49impl OAuthStateBackend for SqliteOAuthBackend {
50    fn put(&self, token: &str, provider: &str, expires_at: u64) {
51        if let Ok(guard) = self.conn.lock() {
52            let _ = guard.execute(
53                &format!(
54                    "INSERT INTO {TABLE} (token, provider, expires_at) VALUES (?1, ?2, ?3)
55                     ON CONFLICT(token) DO UPDATE SET
56                       provider = excluded.provider,
57                       expires_at = excluded.expires_at"
58                ),
59                rusqlite::params![token, provider, expires_at as i64],
60            );
61        }
62    }
63
64    fn take(&self, token: &str, now_unix_secs: u64) -> Option<String> {
65        let guard = self.conn.lock().ok()?;
66        // Read first, then delete — must be a transaction so concurrent
67        // callbacks can't both succeed with the same token.
68        let tx = guard.unchecked_transaction().ok()?;
69        let row: Option<(String, i64)> = tx
70            .query_row(
71                &format!("SELECT provider, expires_at FROM {TABLE} WHERE token = ?1"),
72                rusqlite::params![token],
73                |r| Ok((r.get(0)?, r.get(1)?)),
74            )
75            .ok();
76        // Always delete what we read — single-use even if expired.
77        if row.is_some() {
78            let _ = tx.execute(
79                &format!("DELETE FROM {TABLE} WHERE token = ?1"),
80                rusqlite::params![token],
81            );
82        }
83        let _ = tx.commit();
84
85        let (provider, expires_at) = row?;
86        if (expires_at as u64) <= now_unix_secs {
87            return None;
88        }
89        Some(provider)
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96
97    #[test]
98    fn put_then_take_returns_provider() {
99        let b = SqliteOAuthBackend::in_memory().unwrap();
100        b.put("tok1", "google", 9999999999);
101        assert_eq!(b.take("tok1", 100).as_deref(), Some("google"));
102    }
103
104    #[test]
105    fn take_is_single_use() {
106        let b = SqliteOAuthBackend::in_memory().unwrap();
107        b.put("tok2", "github", 9999999999);
108        assert!(b.take("tok2", 100).is_some());
109        assert!(b.take("tok2", 100).is_none());
110    }
111
112    #[test]
113    fn expired_token_returns_none() {
114        let b = SqliteOAuthBackend::in_memory().unwrap();
115        b.put("tok3", "google", 100);
116        assert!(b.take("tok3", 200).is_none());
117    }
118
119    #[test]
120    fn missing_token_returns_none() {
121        let b = SqliteOAuthBackend::in_memory().unwrap();
122        assert!(b.take("never_existed", 0).is_none());
123    }
124
125    #[test]
126    fn put_overwrites_previous_token() {
127        let b = SqliteOAuthBackend::in_memory().unwrap();
128        b.put("dup", "google", 9999999999);
129        b.put("dup", "github", 9999999999);
130        assert_eq!(b.take("dup", 100).as_deref(), Some("github"));
131    }
132}