pylon_runtime/
oauth_backend.rs1use std::sync::{Arc, Mutex};
12
13use pylon_auth::OAuthStateBackend;
14use rusqlite::Connection;
15
16const TABLE: &str = "_pylon_oauth_state";
17
18pub struct SqliteOAuthBackend {
19 conn: Arc<Mutex<Connection>>,
20}
21
22impl SqliteOAuthBackend {
23 pub fn open(path: &str) -> Result<Self, String> {
24 let conn = Connection::open(path).map_err(|e| format!("open: {e}"))?;
25 Self::from_connection(conn)
26 }
27
28 pub fn in_memory() -> Result<Self, String> {
29 let conn = Connection::open_in_memory().map_err(|e| format!("open: {e}"))?;
30 Self::from_connection(conn)
31 }
32
33 fn from_connection(conn: Connection) -> Result<Self, String> {
34 conn.execute_batch(&format!(
35 "CREATE TABLE IF NOT EXISTS {TABLE} (
36 token TEXT PRIMARY KEY,
37 provider TEXT NOT NULL,
38 expires_at INTEGER NOT NULL
39 );
40 CREATE INDEX IF NOT EXISTS {TABLE}_exp_idx ON {TABLE}(expires_at);"
41 ))
42 .map_err(|e| format!("init schema: {e}"))?;
43 Ok(Self {
44 conn: Arc::new(Mutex::new(conn)),
45 })
46 }
47}
48
49impl OAuthStateBackend for SqliteOAuthBackend {
50 fn put(&self, token: &str, provider: &str, expires_at: u64) {
51 if let Ok(guard) = self.conn.lock() {
52 let _ = guard.execute(
53 &format!(
54 "INSERT INTO {TABLE} (token, provider, expires_at) VALUES (?1, ?2, ?3)
55 ON CONFLICT(token) DO UPDATE SET
56 provider = excluded.provider,
57 expires_at = excluded.expires_at"
58 ),
59 rusqlite::params![token, provider, expires_at as i64],
60 );
61 }
62 }
63
64 fn take(&self, token: &str, now_unix_secs: u64) -> Option<String> {
65 let guard = self.conn.lock().ok()?;
66 let tx = guard.unchecked_transaction().ok()?;
69 let row: Option<(String, i64)> = tx
70 .query_row(
71 &format!("SELECT provider, expires_at FROM {TABLE} WHERE token = ?1"),
72 rusqlite::params![token],
73 |r| Ok((r.get(0)?, r.get(1)?)),
74 )
75 .ok();
76 if row.is_some() {
78 let _ = tx.execute(
79 &format!("DELETE FROM {TABLE} WHERE token = ?1"),
80 rusqlite::params![token],
81 );
82 }
83 let _ = tx.commit();
84
85 let (provider, expires_at) = row?;
86 if (expires_at as u64) <= now_unix_secs {
87 return None;
88 }
89 Some(provider)
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96
97 #[test]
98 fn put_then_take_returns_provider() {
99 let b = SqliteOAuthBackend::in_memory().unwrap();
100 b.put("tok1", "google", 9999999999);
101 assert_eq!(b.take("tok1", 100).as_deref(), Some("google"));
102 }
103
104 #[test]
105 fn take_is_single_use() {
106 let b = SqliteOAuthBackend::in_memory().unwrap();
107 b.put("tok2", "github", 9999999999);
108 assert!(b.take("tok2", 100).is_some());
109 assert!(b.take("tok2", 100).is_none());
110 }
111
112 #[test]
113 fn expired_token_returns_none() {
114 let b = SqliteOAuthBackend::in_memory().unwrap();
115 b.put("tok3", "google", 100);
116 assert!(b.take("tok3", 200).is_none());
117 }
118
119 #[test]
120 fn missing_token_returns_none() {
121 let b = SqliteOAuthBackend::in_memory().unwrap();
122 assert!(b.take("never_existed", 0).is_none());
123 }
124
125 #[test]
126 fn put_overwrites_previous_token() {
127 let b = SqliteOAuthBackend::in_memory().unwrap();
128 b.put("dup", "google", 9999999999);
129 b.put("dup", "github", 9999999999);
130 assert_eq!(b.take("dup", 100).as_deref(), Some("github"));
131 }
132}