1use std::collections::HashSet;
10use std::time::Duration;
11
12use chrono::{DateTime, Utc};
13use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
14use sqlx::SqlitePool;
15
16use crate::code;
17use crate::types::{AllowedSender, ApprovedRequest, PairingError, PendingRequest, UpsertOutcome};
18
19const PENDING_TTL: Duration = Duration::from_secs(60 * 60);
20const MAX_PENDING_PER_ACCOUNT: usize = 3;
21
22pub struct PairingStore {
23 pool: SqlitePool,
24}
25
26impl PairingStore {
27 pub async fn open(path: &str) -> Result<Self, PairingError> {
28 let opts = SqliteConnectOptions::new()
29 .filename(path)
30 .create_if_missing(true);
31 let max_conns = if path == ":memory:" { 1 } else { 4 };
34 let pool = SqlitePoolOptions::new()
35 .max_connections(max_conns)
36 .connect_with(opts)
37 .await
38 .map_err(|e| PairingError::Storage(e.to_string()))?;
39 sqlx::query(
40 "CREATE TABLE IF NOT EXISTS pairing_pending (\
41 channel TEXT NOT NULL,\
42 account_id TEXT NOT NULL,\
43 sender_id TEXT NOT NULL,\
44 code TEXT NOT NULL,\
45 created_at INTEGER NOT NULL,\
46 meta_json TEXT NOT NULL DEFAULT '{}',\
47 PRIMARY KEY (channel, account_id, sender_id)\
48 )",
49 )
50 .execute(&pool)
51 .await
52 .map_err(|e| PairingError::Storage(e.to_string()))?;
53 sqlx::query("CREATE INDEX IF NOT EXISTS idx_pairing_pending_code ON pairing_pending(code)")
54 .execute(&pool)
55 .await
56 .map_err(|e| PairingError::Storage(e.to_string()))?;
57 sqlx::query(
58 "CREATE TABLE IF NOT EXISTS pairing_allow_from (\
59 channel TEXT NOT NULL,\
60 account_id TEXT NOT NULL,\
61 sender_id TEXT NOT NULL,\
62 approved_at INTEGER NOT NULL,\
63 approved_via TEXT NOT NULL DEFAULT 'cli',\
64 revoked_at INTEGER,\
65 PRIMARY KEY (channel, account_id, sender_id)\
66 )",
67 )
68 .execute(&pool)
69 .await
70 .map_err(|e| PairingError::Storage(e.to_string()))?;
71 Ok(Self { pool })
72 }
73
74 pub async fn open_memory() -> Result<Self, PairingError> {
75 Self::open(":memory:").await
76 }
77
78 pub async fn upsert_pending(
83 &self,
84 channel: &str,
85 account_id: &str,
86 sender_id: &str,
87 meta: serde_json::Value,
88 ) -> Result<UpsertOutcome, PairingError> {
89 self.purge_expired().await?;
91
92 let existing: Option<String> = sqlx::query_scalar(
97 "SELECT code FROM pairing_pending WHERE channel = ? AND account_id = ? AND sender_id = ?",
98 )
99 .bind(channel)
100 .bind(account_id)
101 .bind(sender_id)
102 .fetch_optional(&self.pool)
103 .await
104 .map_err(|e| PairingError::Storage(e.to_string()))?;
105 if let Some(code) = existing {
106 return Ok(UpsertOutcome {
107 code,
108 created: false,
109 });
110 }
111
112 let count: i64 = sqlx::query_scalar(
114 "SELECT COUNT(*) FROM pairing_pending WHERE channel = ? AND account_id = ?",
115 )
116 .bind(channel)
117 .bind(account_id)
118 .fetch_one(&self.pool)
119 .await
120 .map_err(|e| PairingError::Storage(e.to_string()))?;
121 if count as usize >= MAX_PENDING_PER_ACCOUNT {
122 return Err(PairingError::MaxPending {
123 channel: channel.into(),
124 account_id: account_id.into(),
125 });
126 }
127
128 let active_codes: Vec<String> = sqlx::query_scalar("SELECT code FROM pairing_pending")
131 .fetch_all(&self.pool)
132 .await
133 .map_err(|e| PairingError::Storage(e.to_string()))?;
134 let set: HashSet<String> = active_codes.into_iter().collect();
135 let code = code::generate_unique(&set).map_err(PairingError::Invalid)?;
136
137 let now = Utc::now().timestamp();
138 let meta_json =
139 serde_json::to_string(&meta).map_err(|e| PairingError::Storage(e.to_string()))?;
140 sqlx::query(
141 "INSERT INTO pairing_pending(channel, account_id, sender_id, code, created_at, meta_json) VALUES(?, ?, ?, ?, ?, ?)",
142 )
143 .bind(channel)
144 .bind(account_id)
145 .bind(sender_id)
146 .bind(&code)
147 .bind(now)
148 .bind(meta_json)
149 .execute(&self.pool)
150 .await
151 .map_err(|e| PairingError::Storage(e.to_string()))?;
152 crate::telemetry::inc_requests_pending(channel);
153 Ok(UpsertOutcome {
154 code,
155 created: true,
156 })
157 }
158
159 pub async fn list_pending(
160 &self,
161 channel: Option<&str>,
162 ) -> Result<Vec<PendingRequest>, PairingError> {
163 let rows: Vec<(String, String, String, String, i64, String)> = if let Some(c) = channel {
164 sqlx::query_as(
165 "SELECT channel, account_id, sender_id, code, created_at, meta_json FROM pairing_pending WHERE channel = ? ORDER BY created_at",
166 )
167 .bind(c)
168 .fetch_all(&self.pool)
169 .await
170 } else {
171 sqlx::query_as(
172 "SELECT channel, account_id, sender_id, code, created_at, meta_json FROM pairing_pending ORDER BY created_at",
173 )
174 .fetch_all(&self.pool)
175 .await
176 }
177 .map_err(|e| PairingError::Storage(e.to_string()))?;
178 let mut out = Vec::with_capacity(rows.len());
179 for (channel, account_id, sender_id, code, created_at, meta_json) in rows {
180 let meta: serde_json::Value =
181 serde_json::from_str(&meta_json).unwrap_or(serde_json::Value::Null);
182 let created_at =
183 DateTime::<Utc>::from_timestamp(created_at, 0).unwrap_or_else(Utc::now);
184 out.push(PendingRequest {
185 channel,
186 account_id,
187 sender_id,
188 code,
189 created_at,
190 meta,
191 });
192 }
193 Ok(out)
194 }
195
196 pub async fn list_allow(
204 &self,
205 channel: Option<&str>,
206 include_revoked: bool,
207 ) -> Result<Vec<AllowedSender>, PairingError> {
208 let mut sql = String::from(
209 "SELECT channel, account_id, sender_id, approved_at, approved_via, revoked_at \
210 FROM pairing_allow_from",
211 );
212 let mut clauses: Vec<&str> = Vec::new();
213 if !include_revoked {
214 clauses.push("revoked_at IS NULL");
215 }
216 if channel.is_some() {
217 clauses.push("channel = ?");
218 }
219 if !clauses.is_empty() {
220 sql.push_str(" WHERE ");
221 sql.push_str(&clauses.join(" AND "));
222 }
223 sql.push_str(" ORDER BY channel, account_id, sender_id");
224 let rows: Vec<(String, String, String, i64, String, Option<i64>)> =
225 if let Some(c) = channel {
226 sqlx::query_as(&sql).bind(c).fetch_all(&self.pool).await
227 } else {
228 sqlx::query_as(&sql).fetch_all(&self.pool).await
229 }
230 .map_err(|e| PairingError::Storage(e.to_string()))?;
231 let mut out = Vec::with_capacity(rows.len());
232 for (channel, account_id, sender_id, approved_at, approved_via, revoked_at) in rows {
233 let approved_at =
234 DateTime::<Utc>::from_timestamp(approved_at, 0).unwrap_or_else(Utc::now);
235 let revoked_at = revoked_at.and_then(|t| DateTime::<Utc>::from_timestamp(t, 0));
236 out.push(AllowedSender {
237 channel,
238 account_id,
239 sender_id,
240 approved_at,
241 approved_via,
242 revoked_at,
243 });
244 }
245 Ok(out)
246 }
247
248 pub async fn approve(&self, code_value: &str) -> Result<ApprovedRequest, PairingError> {
251 let mut tx = self
252 .pool
253 .begin()
254 .await
255 .map_err(|e| PairingError::Storage(e.to_string()))?;
256 let row: Option<(String, String, String, i64)> = sqlx::query_as(
257 "SELECT channel, account_id, sender_id, created_at FROM pairing_pending WHERE code = ?",
258 )
259 .bind(code_value)
260 .fetch_optional(&mut *tx)
261 .await
262 .map_err(|e| PairingError::Storage(e.to_string()))?;
263 let Some((channel, account_id, sender_id, created_at)) = row else {
264 crate::telemetry::inc_approvals("", "not_found");
265 return Err(PairingError::UnknownCode);
266 };
267 let age = Utc::now().timestamp() - created_at;
269 if age > PENDING_TTL.as_secs() as i64 {
270 crate::telemetry::inc_approvals(&channel, "expired");
271 crate::telemetry::add_codes_expired(1);
272 return Err(PairingError::Expired);
273 }
274 sqlx::query(
275 "INSERT INTO pairing_allow_from(channel, account_id, sender_id, approved_at, approved_via, revoked_at) VALUES(?, ?, ?, ?, 'cli', NULL) ON CONFLICT(channel, account_id, sender_id) DO UPDATE SET revoked_at = NULL, approved_at = excluded.approved_at, approved_via = excluded.approved_via",
276 )
277 .bind(&channel)
278 .bind(&account_id)
279 .bind(&sender_id)
280 .bind(Utc::now().timestamp())
281 .execute(&mut *tx)
282 .await
283 .map_err(|e| PairingError::Storage(e.to_string()))?;
284 sqlx::query("DELETE FROM pairing_pending WHERE code = ?")
285 .bind(code_value)
286 .execute(&mut *tx)
287 .await
288 .map_err(|e| PairingError::Storage(e.to_string()))?;
289 tx.commit()
290 .await
291 .map_err(|e| PairingError::Storage(e.to_string()))?;
292 crate::telemetry::inc_approvals(&channel, "ok");
293 crate::telemetry::dec_requests_pending(&channel);
294 Ok(ApprovedRequest {
295 channel,
296 account_id,
297 sender_id,
298 approved_at: Utc::now(),
299 })
300 }
301
302 pub async fn revoke(&self, channel: &str, sender_id: &str) -> Result<bool, PairingError> {
306 let res = sqlx::query(
307 "UPDATE pairing_allow_from SET revoked_at = ? WHERE channel = ? AND sender_id = ? AND revoked_at IS NULL",
308 )
309 .bind(Utc::now().timestamp())
310 .bind(channel)
311 .bind(sender_id)
312 .execute(&self.pool)
313 .await
314 .map_err(|e| PairingError::Storage(e.to_string()))?;
315 Ok(res.rows_affected() > 0)
316 }
317
318 pub async fn is_allowed(
319 &self,
320 channel: &str,
321 account_id: &str,
322 sender_id: &str,
323 ) -> Result<bool, PairingError> {
324 let row: Option<i64> = sqlx::query_scalar(
325 "SELECT 1 FROM pairing_allow_from WHERE channel = ? AND account_id = ? AND sender_id = ? AND revoked_at IS NULL",
326 )
327 .bind(channel)
328 .bind(account_id)
329 .bind(sender_id)
330 .fetch_optional(&self.pool)
331 .await
332 .map_err(|e| PairingError::Storage(e.to_string()))?;
333 Ok(row.is_some())
334 }
335
336 pub async fn seed(
339 &self,
340 channel: &str,
341 account_id: &str,
342 sender_ids: &[String],
343 ) -> Result<usize, PairingError> {
344 let mut count = 0usize;
345 let now = Utc::now().timestamp();
346 for sender in sender_ids {
347 let res = sqlx::query(
348 "INSERT INTO pairing_allow_from(channel, account_id, sender_id, approved_at, approved_via, revoked_at) VALUES(?, ?, ?, ?, 'seed', NULL) ON CONFLICT(channel, account_id, sender_id) DO UPDATE SET revoked_at = NULL",
349 )
350 .bind(channel)
351 .bind(account_id)
352 .bind(sender)
353 .bind(now)
354 .execute(&self.pool)
355 .await
356 .map_err(|e| PairingError::Storage(e.to_string()))?;
357 count += res.rows_affected() as usize;
358 }
359 Ok(count)
360 }
361
362 #[doc(hidden)]
367 pub fn pool_for_test(&self) -> &SqlitePool {
368 &self.pool
369 }
370
371 pub async fn refresh_pending_gauge(&self) -> Result<(), PairingError> {
377 let rows: Vec<(String, i64)> =
378 sqlx::query_as("SELECT channel, COUNT(*) FROM pairing_pending GROUP BY channel")
379 .fetch_all(&self.pool)
380 .await
381 .map_err(|e| PairingError::Storage(e.to_string()))?;
382 let live: std::collections::HashSet<String> = rows.iter().map(|(c, _)| c.clone()).collect();
383 for prior in crate::telemetry::pending_channels() {
384 if !live.contains(&prior) {
385 crate::telemetry::set_requests_pending(&prior, 0);
386 }
387 }
388 for (channel, count) in rows {
389 crate::telemetry::set_requests_pending(&channel, count);
390 }
391 Ok(())
392 }
393
394 pub async fn purge_expired(&self) -> Result<u64, PairingError> {
395 let cutoff = Utc::now().timestamp() - PENDING_TTL.as_secs() as i64;
396 let by_channel: Vec<(String, i64)> = sqlx::query_as(
399 "SELECT channel, COUNT(*) FROM pairing_pending WHERE created_at < ? GROUP BY channel",
400 )
401 .bind(cutoff)
402 .fetch_all(&self.pool)
403 .await
404 .map_err(|e| PairingError::Storage(e.to_string()))?;
405 let res = sqlx::query("DELETE FROM pairing_pending WHERE created_at < ?")
406 .bind(cutoff)
407 .execute(&self.pool)
408 .await
409 .map_err(|e| PairingError::Storage(e.to_string()))?;
410 let n = res.rows_affected();
411 if n > 0 {
412 crate::telemetry::add_codes_expired(n);
413 for (channel, count) in by_channel {
414 crate::telemetry::sub_requests_pending(&channel, count);
415 }
416 }
417 Ok(n)
418 }
419}