Skip to main content

nexo_pairing/
store.rs

1//! SQLite-backed pairing storage.
2//!
3//! Two tables in one DB file:
4//! - `pairing_pending` — short-lived (TTL 60 min) requests issued via
5//!   the DM challenge flow. Pruned eagerly on insert.
6//! - `pairing_allow_from` — durable per-channel allowlist. Soft-delete
7//!   on revoke (`revoked_at` timestamp) so the operator can audit.
8
9use std::collections::HashSet;
10use std::time::Duration;
11
12use chrono::{DateTime, Utc};
13use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
14use sqlx::SqlitePool;
15
16use crate::code;
17use crate::types::{AllowedSender, ApprovedRequest, PairingError, PendingRequest, UpsertOutcome};
18
19const PENDING_TTL: Duration = Duration::from_secs(60 * 60);
20const MAX_PENDING_PER_ACCOUNT: usize = 3;
21
22pub struct PairingStore {
23    pool: SqlitePool,
24}
25
26impl PairingStore {
27    pub async fn open(path: &str) -> Result<Self, PairingError> {
28        let opts = SqliteConnectOptions::new()
29            .filename(path)
30            .create_if_missing(true);
31        // SQLite's `:memory:` is per-connection, so pin to one
32        // connection in tests; file-backed paths use the normal pool.
33        let max_conns = if path == ":memory:" { 1 } else { 4 };
34        let pool = SqlitePoolOptions::new()
35            .max_connections(max_conns)
36            .connect_with(opts)
37            .await
38            .map_err(|e| PairingError::Storage(e.to_string()))?;
39        sqlx::query(
40            "CREATE TABLE IF NOT EXISTS pairing_pending (\
41                channel    TEXT NOT NULL,\
42                account_id TEXT NOT NULL,\
43                sender_id  TEXT NOT NULL,\
44                code       TEXT NOT NULL,\
45                created_at INTEGER NOT NULL,\
46                meta_json  TEXT NOT NULL DEFAULT '{}',\
47                PRIMARY KEY (channel, account_id, sender_id)\
48            )",
49        )
50        .execute(&pool)
51        .await
52        .map_err(|e| PairingError::Storage(e.to_string()))?;
53        sqlx::query("CREATE INDEX IF NOT EXISTS idx_pairing_pending_code ON pairing_pending(code)")
54            .execute(&pool)
55            .await
56            .map_err(|e| PairingError::Storage(e.to_string()))?;
57        sqlx::query(
58            "CREATE TABLE IF NOT EXISTS pairing_allow_from (\
59                channel       TEXT NOT NULL,\
60                account_id    TEXT NOT NULL,\
61                sender_id     TEXT NOT NULL,\
62                approved_at   INTEGER NOT NULL,\
63                approved_via  TEXT NOT NULL DEFAULT 'cli',\
64                revoked_at    INTEGER,\
65                PRIMARY KEY (channel, account_id, sender_id)\
66            )",
67        )
68        .execute(&pool)
69        .await
70        .map_err(|e| PairingError::Storage(e.to_string()))?;
71        Ok(Self { pool })
72    }
73
74    pub async fn open_memory() -> Result<Self, PairingError> {
75        Self::open(":memory:").await
76    }
77
78    /// Insert (or refresh `created_at` on) a pending request. Enforces
79    /// TTL prune + max-pending per (channel, account). Returns the
80    /// active code (existing or new) and `created=true` when this
81    /// call inserted a fresh row.
82    pub async fn upsert_pending(
83        &self,
84        channel: &str,
85        account_id: &str,
86        sender_id: &str,
87        meta: serde_json::Value,
88    ) -> Result<UpsertOutcome, PairingError> {
89        // Prune expired everywhere first (cheap, O(rows)).
90        self.purge_expired().await?;
91
92        // Already pending for this sender? Refresh `created_at` and
93        // return the existing code so repeated DMs don't keep
94        // generating new codes (matches OpenClaw's `lastSeenAt`
95        // behaviour).
96        let existing: Option<String> = sqlx::query_scalar(
97            "SELECT code FROM pairing_pending WHERE channel = ? AND account_id = ? AND sender_id = ?",
98        )
99        .bind(channel)
100        .bind(account_id)
101        .bind(sender_id)
102        .fetch_optional(&self.pool)
103        .await
104        .map_err(|e| PairingError::Storage(e.to_string()))?;
105        if let Some(code) = existing {
106            return Ok(UpsertOutcome {
107                code,
108                created: false,
109            });
110        }
111
112        // Enforce per-(channel, account) cap before inserting.
113        let count: i64 = sqlx::query_scalar(
114            "SELECT COUNT(*) FROM pairing_pending WHERE channel = ? AND account_id = ?",
115        )
116        .bind(channel)
117        .bind(account_id)
118        .fetch_one(&self.pool)
119        .await
120        .map_err(|e| PairingError::Storage(e.to_string()))?;
121        if count as usize >= MAX_PENDING_PER_ACCOUNT {
122            return Err(PairingError::MaxPending {
123                channel: channel.into(),
124                account_id: account_id.into(),
125            });
126        }
127
128        // Generate a code that does not collide with any *active* code
129        // anywhere in the table.
130        let active_codes: Vec<String> = sqlx::query_scalar("SELECT code FROM pairing_pending")
131            .fetch_all(&self.pool)
132            .await
133            .map_err(|e| PairingError::Storage(e.to_string()))?;
134        let set: HashSet<String> = active_codes.into_iter().collect();
135        let code = code::generate_unique(&set).map_err(PairingError::Invalid)?;
136
137        let now = Utc::now().timestamp();
138        let meta_json =
139            serde_json::to_string(&meta).map_err(|e| PairingError::Storage(e.to_string()))?;
140        sqlx::query(
141            "INSERT INTO pairing_pending(channel, account_id, sender_id, code, created_at, meta_json) VALUES(?, ?, ?, ?, ?, ?)",
142        )
143        .bind(channel)
144        .bind(account_id)
145        .bind(sender_id)
146        .bind(&code)
147        .bind(now)
148        .bind(meta_json)
149        .execute(&self.pool)
150        .await
151        .map_err(|e| PairingError::Storage(e.to_string()))?;
152        crate::telemetry::inc_requests_pending(channel);
153        Ok(UpsertOutcome {
154            code,
155            created: true,
156        })
157    }
158
159    pub async fn list_pending(
160        &self,
161        channel: Option<&str>,
162    ) -> Result<Vec<PendingRequest>, PairingError> {
163        let rows: Vec<(String, String, String, String, i64, String)> = if let Some(c) = channel {
164            sqlx::query_as(
165                "SELECT channel, account_id, sender_id, code, created_at, meta_json FROM pairing_pending WHERE channel = ? ORDER BY created_at",
166            )
167            .bind(c)
168            .fetch_all(&self.pool)
169            .await
170        } else {
171            sqlx::query_as(
172                "SELECT channel, account_id, sender_id, code, created_at, meta_json FROM pairing_pending ORDER BY created_at",
173            )
174            .fetch_all(&self.pool)
175            .await
176        }
177        .map_err(|e| PairingError::Storage(e.to_string()))?;
178        let mut out = Vec::with_capacity(rows.len());
179        for (channel, account_id, sender_id, code, created_at, meta_json) in rows {
180            let meta: serde_json::Value =
181                serde_json::from_str(&meta_json).unwrap_or(serde_json::Value::Null);
182            let created_at =
183                DateTime::<Utc>::from_timestamp(created_at, 0).unwrap_or_else(Utc::now);
184            out.push(PendingRequest {
185                channel,
186                account_id,
187                sender_id,
188                code,
189                created_at,
190                meta,
191            });
192        }
193        Ok(out)
194    }
195
196    /// Dump every row from `pairing_allow_from`. `include_revoked=false`
197    /// hides soft-deleted rows; `true` returns them too with
198    /// `revoked_at` populated. `channel` filters when `Some(_)`. The
199    /// `nexo pair list --all` operator surface relies on this to make
200    /// seeded senders visible (the legacy `list_pending` only shows
201    /// in-flight challenges, which left operators unable to confirm a
202    /// `pair seed` actually landed).
203    pub async fn list_allow(
204        &self,
205        channel: Option<&str>,
206        include_revoked: bool,
207    ) -> Result<Vec<AllowedSender>, PairingError> {
208        let mut sql = String::from(
209            "SELECT channel, account_id, sender_id, approved_at, approved_via, revoked_at \
210             FROM pairing_allow_from",
211        );
212        let mut clauses: Vec<&str> = Vec::new();
213        if !include_revoked {
214            clauses.push("revoked_at IS NULL");
215        }
216        if channel.is_some() {
217            clauses.push("channel = ?");
218        }
219        if !clauses.is_empty() {
220            sql.push_str(" WHERE ");
221            sql.push_str(&clauses.join(" AND "));
222        }
223        sql.push_str(" ORDER BY channel, account_id, sender_id");
224        let rows: Vec<(String, String, String, i64, String, Option<i64>)> =
225            if let Some(c) = channel {
226                sqlx::query_as(&sql).bind(c).fetch_all(&self.pool).await
227            } else {
228                sqlx::query_as(&sql).fetch_all(&self.pool).await
229            }
230            .map_err(|e| PairingError::Storage(e.to_string()))?;
231        let mut out = Vec::with_capacity(rows.len());
232        for (channel, account_id, sender_id, approved_at, approved_via, revoked_at) in rows {
233            let approved_at =
234                DateTime::<Utc>::from_timestamp(approved_at, 0).unwrap_or_else(Utc::now);
235            let revoked_at = revoked_at.and_then(|t| DateTime::<Utc>::from_timestamp(t, 0));
236            out.push(AllowedSender {
237                channel,
238                account_id,
239                sender_id,
240                approved_at,
241                approved_via,
242                revoked_at,
243            });
244        }
245        Ok(out)
246    }
247
248    /// Approve a pending request by its code. Moves the row from
249    /// `pairing_pending` into `pairing_allow_from` atomically.
250    pub async fn approve(&self, code_value: &str) -> Result<ApprovedRequest, PairingError> {
251        let mut tx = self
252            .pool
253            .begin()
254            .await
255            .map_err(|e| PairingError::Storage(e.to_string()))?;
256        let row: Option<(String, String, String, i64)> = sqlx::query_as(
257            "SELECT channel, account_id, sender_id, created_at FROM pairing_pending WHERE code = ?",
258        )
259        .bind(code_value)
260        .fetch_optional(&mut *tx)
261        .await
262        .map_err(|e| PairingError::Storage(e.to_string()))?;
263        let Some((channel, account_id, sender_id, created_at)) = row else {
264            crate::telemetry::inc_approvals("", "not_found");
265            return Err(PairingError::UnknownCode);
266        };
267        // Reject if expired (the prune may not have run since insert).
268        let age = Utc::now().timestamp() - created_at;
269        if age > PENDING_TTL.as_secs() as i64 {
270            crate::telemetry::inc_approvals(&channel, "expired");
271            crate::telemetry::add_codes_expired(1);
272            return Err(PairingError::Expired);
273        }
274        sqlx::query(
275            "INSERT INTO pairing_allow_from(channel, account_id, sender_id, approved_at, approved_via, revoked_at) VALUES(?, ?, ?, ?, 'cli', NULL) ON CONFLICT(channel, account_id, sender_id) DO UPDATE SET revoked_at = NULL, approved_at = excluded.approved_at, approved_via = excluded.approved_via",
276        )
277        .bind(&channel)
278        .bind(&account_id)
279        .bind(&sender_id)
280        .bind(Utc::now().timestamp())
281        .execute(&mut *tx)
282        .await
283        .map_err(|e| PairingError::Storage(e.to_string()))?;
284        sqlx::query("DELETE FROM pairing_pending WHERE code = ?")
285            .bind(code_value)
286            .execute(&mut *tx)
287            .await
288            .map_err(|e| PairingError::Storage(e.to_string()))?;
289        tx.commit()
290            .await
291            .map_err(|e| PairingError::Storage(e.to_string()))?;
292        crate::telemetry::inc_approvals(&channel, "ok");
293        crate::telemetry::dec_requests_pending(&channel);
294        Ok(ApprovedRequest {
295            channel,
296            account_id,
297            sender_id,
298            approved_at: Utc::now(),
299        })
300    }
301
302    /// Soft-delete by setting `revoked_at`. The row stays for audit.
303    /// Returns `true` if a row was updated (caller decides whether to
304    /// surface "already revoked / not found").
305    pub async fn revoke(&self, channel: &str, sender_id: &str) -> Result<bool, PairingError> {
306        let res = sqlx::query(
307            "UPDATE pairing_allow_from SET revoked_at = ? WHERE channel = ? AND sender_id = ? AND revoked_at IS NULL",
308        )
309        .bind(Utc::now().timestamp())
310        .bind(channel)
311        .bind(sender_id)
312        .execute(&self.pool)
313        .await
314        .map_err(|e| PairingError::Storage(e.to_string()))?;
315        Ok(res.rows_affected() > 0)
316    }
317
318    pub async fn is_allowed(
319        &self,
320        channel: &str,
321        account_id: &str,
322        sender_id: &str,
323    ) -> Result<bool, PairingError> {
324        let row: Option<i64> = sqlx::query_scalar(
325            "SELECT 1 FROM pairing_allow_from WHERE channel = ? AND account_id = ? AND sender_id = ? AND revoked_at IS NULL",
326        )
327        .bind(channel)
328        .bind(account_id)
329        .bind(sender_id)
330        .fetch_optional(&self.pool)
331        .await
332        .map_err(|e| PairingError::Storage(e.to_string()))?;
333        Ok(row.is_some())
334    }
335
336    /// Bulk insert (idempotent) — preload allow-from from a known
337    /// list of senders, e.g. when migrating from a non-pairing setup.
338    pub async fn seed(
339        &self,
340        channel: &str,
341        account_id: &str,
342        sender_ids: &[String],
343    ) -> Result<usize, PairingError> {
344        let mut count = 0usize;
345        let now = Utc::now().timestamp();
346        for sender in sender_ids {
347            let res = sqlx::query(
348                "INSERT INTO pairing_allow_from(channel, account_id, sender_id, approved_at, approved_via, revoked_at) VALUES(?, ?, ?, ?, 'seed', NULL) ON CONFLICT(channel, account_id, sender_id) DO UPDATE SET revoked_at = NULL",
349            )
350            .bind(channel)
351            .bind(account_id)
352            .bind(sender)
353            .bind(now)
354            .execute(&self.pool)
355            .await
356            .map_err(|e| PairingError::Storage(e.to_string()))?;
357            count += res.rows_affected() as usize;
358        }
359        Ok(count)
360    }
361
362    /// Test-only access to the underlying pool. Lets integration tests
363    /// in this crate backdate rows or assert raw state without
364    /// duplicating the schema setup. Hidden from rustdoc; do not
365    /// rely on this from production callers.
366    #[doc(hidden)]
367    pub fn pool_for_test(&self) -> &SqlitePool {
368        &self.pool
369    }
370
371    /// Resync the `pairing_requests_pending` gauge from the database.
372    /// Call after process restart (the gauge is in-memory state and
373    /// resets to 0, so without a refresh it under-reports until the
374    /// next `upsert_pending`). Channels that had a value but no longer
375    /// have any pending rows are clamped to 0 to avoid ghost gauges.
376    pub async fn refresh_pending_gauge(&self) -> Result<(), PairingError> {
377        let rows: Vec<(String, i64)> =
378            sqlx::query_as("SELECT channel, COUNT(*) FROM pairing_pending GROUP BY channel")
379                .fetch_all(&self.pool)
380                .await
381                .map_err(|e| PairingError::Storage(e.to_string()))?;
382        let live: std::collections::HashSet<String> = rows.iter().map(|(c, _)| c.clone()).collect();
383        for prior in crate::telemetry::pending_channels() {
384            if !live.contains(&prior) {
385                crate::telemetry::set_requests_pending(&prior, 0);
386            }
387        }
388        for (channel, count) in rows {
389            crate::telemetry::set_requests_pending(&channel, count);
390        }
391        Ok(())
392    }
393
394    pub async fn purge_expired(&self) -> Result<u64, PairingError> {
395        let cutoff = Utc::now().timestamp() - PENDING_TTL.as_secs() as i64;
396        // Count rows about to die per-channel so we can keep the
397        // pending gauge in sync without a follow-up query / refresh.
398        let by_channel: Vec<(String, i64)> = sqlx::query_as(
399            "SELECT channel, COUNT(*) FROM pairing_pending WHERE created_at < ? GROUP BY channel",
400        )
401        .bind(cutoff)
402        .fetch_all(&self.pool)
403        .await
404        .map_err(|e| PairingError::Storage(e.to_string()))?;
405        let res = sqlx::query("DELETE FROM pairing_pending WHERE created_at < ?")
406            .bind(cutoff)
407            .execute(&self.pool)
408            .await
409            .map_err(|e| PairingError::Storage(e.to_string()))?;
410        let n = res.rows_affected();
411        if n > 0 {
412            crate::telemetry::add_codes_expired(n);
413            for (channel, count) in by_channel {
414                crate::telemetry::sub_requests_pending(&channel, count);
415            }
416        }
417        Ok(n)
418    }
419}