Skip to main content

plexus_substrate/activations/claudecode_loopback/
storage.rs

1use super::types::{ApprovalId, ApprovalRequest, ApprovalStatus, LoopbackError};
2use crate::activations::storage::init_sqlite_pool;
3use crate::activation_db_path_from_module;
4use serde_json::Value;
5use sqlx::{sqlite::SqlitePool, Row};
6use std::collections::HashMap;
7use std::path::PathBuf;
8use std::sync::{Arc, RwLock};
9use std::time::{SystemTime, UNIX_EPOCH};
10use tokio::sync::Notify;
11use uuid::Uuid;
12
13#[derive(Debug, Clone)]
14pub struct LoopbackStorageConfig {
15    pub db_path: PathBuf,
16}
17
18impl Default for LoopbackStorageConfig {
19    fn default() -> Self {
20        Self {
21            db_path: activation_db_path_from_module!("loopback.db"),
22        }
23    }
24}
25
26pub struct LoopbackStorage {
27    pool: SqlitePool,
28    /// Maps tool_use_id -> session_id for correlation
29    /// This allows loopback_permit to find the session_id when called via MCP
30    tool_session_map: RwLock<HashMap<String, String>>,
31    /// Maps session_id -> Notify for blocking wait on new approvals
32    /// Allows wait_for_approval to block until an approval arrives for that session
33    session_notifiers: Arc<RwLock<HashMap<String, Arc<Notify>>>>,
34    /// Maps child_session_id -> parent_session_id
35    /// When a child session gets an approval, the parent is also notified
36    session_parents: RwLock<HashMap<String, String>>,
37    /// Maps parent_session_id -> [child_session_id]
38    /// Allows list_pending to include child session approvals when querying by parent
39    session_children: RwLock<HashMap<String, Vec<String>>>,
40}
41
42impl LoopbackStorage {
43    pub async fn new(config: LoopbackStorageConfig) -> Result<Self, String> {
44        let pool = init_sqlite_pool(config.db_path).await?;
45
46        let storage = Self {
47            pool,
48            tool_session_map: RwLock::new(HashMap::new()),
49            session_notifiers: Arc::new(RwLock::new(HashMap::new())),
50            session_parents: RwLock::new(HashMap::new()),
51            session_children: RwLock::new(HashMap::new()),
52        };
53        storage.run_migrations().await?;
54        Ok(storage)
55    }
56
57    /// Register a tool_use_id -> session_id mapping
58    /// Called by the background task when it sees a ToolUse event
59    pub fn register_tool_session(&self, tool_use_id: &str, session_id: &str) {
60        if let Ok(mut map) = self.tool_session_map.write() {
61            map.insert(tool_use_id.to_string(), session_id.to_string());
62        }
63    }
64
65    /// Lookup session_id by tool_use_id
66    /// Called by loopback_permit to find the correct session_id
67    pub fn lookup_session_by_tool(&self, tool_use_id: &str) -> Option<String> {
68        self.tool_session_map.read().ok()?.get(tool_use_id).cloned()
69    }
70
71    /// Remove a tool_use_id mapping (called after approval is resolved)
72    pub fn remove_tool_mapping(&self, tool_use_id: &str) {
73        if let Ok(mut map) = self.tool_session_map.write() {
74            map.remove(tool_use_id);
75        }
76    }
77
78    async fn run_migrations(&self) -> Result<(), LoopbackError> {
79        sqlx::query(r#"
80            CREATE TABLE IF NOT EXISTS loopback_approvals (
81                id TEXT PRIMARY KEY,
82                session_id TEXT NOT NULL,
83                tool_name TEXT NOT NULL,
84                tool_use_id TEXT NOT NULL,
85                input TEXT NOT NULL,
86                status TEXT NOT NULL DEFAULT 'pending',
87                response_message TEXT,
88                created_at INTEGER NOT NULL,
89                resolved_at INTEGER
90            );
91            CREATE INDEX IF NOT EXISTS idx_loopback_session ON loopback_approvals(session_id);
92            CREATE INDEX IF NOT EXISTS idx_loopback_status ON loopback_approvals(status);
93        "#)
94        .execute(&self.pool)
95        .await
96        .map_err(|e| LoopbackError::Storage { operation: "migration", detail: e.to_string() })?;
97        Ok(())
98    }
99
100    pub async fn create_approval(
101        &self,
102        session_id: &str,
103        tool_name: &str,
104        tool_use_id: &str,
105        input: &Value,
106    ) -> Result<ApprovalRequest, LoopbackError> {
107        let id = Uuid::new_v4();
108        let now = current_timestamp();
109        let input_json = serde_json::to_string(input)
110            .map_err(|e| LoopbackError::Serialization { detail: e.to_string() })?;
111
112        sqlx::query(
113            "INSERT INTO loopback_approvals (id, session_id, tool_name, tool_use_id, input, status, created_at)
114             VALUES (?, ?, ?, ?, ?, 'pending', ?)"
115        )
116        .bind(id.to_string())
117        .bind(session_id)
118        .bind(tool_name)
119        .bind(tool_use_id)
120        .bind(&input_json)
121        .bind(now)
122        .execute(&self.pool)
123        .await
124        .map_err(|e| LoopbackError::Storage { operation: "create_approval", detail: e.to_string() })?;
125
126        // Notify any waiters that a new approval has arrived
127        self.notify_session(session_id);
128
129        Ok(ApprovalRequest {
130            id,
131            session_id: session_id.to_string(),
132            tool_name: tool_name.to_string(),
133            tool_use_id: tool_use_id.to_string(),
134            input: input.clone(),
135            status: ApprovalStatus::Pending,
136            response_message: None,
137            created_at: now,
138            resolved_at: None,
139        })
140    }
141
142    pub async fn get_approval(&self, id: &ApprovalId) -> Result<ApprovalRequest, LoopbackError> {
143        let row = sqlx::query(
144            "SELECT id, session_id, tool_name, tool_use_id, input, status, response_message, created_at, resolved_at
145             FROM loopback_approvals WHERE id = ?"
146        )
147        .bind(id.to_string())
148        .fetch_optional(&self.pool)
149        .await
150        .map_err(|e| LoopbackError::Storage { operation: "get_approval", detail: e.to_string() })?
151        .ok_or_else(|| LoopbackError::ApprovalNotFound { id: id.to_string() })?;
152
153        self.row_to_approval(row)
154    }
155
156    pub async fn resolve_approval(
157        &self,
158        id: &ApprovalId,
159        approved: bool,
160        message: Option<String>,
161    ) -> Result<(), LoopbackError> {
162        let now = current_timestamp();
163        let status = if approved { "approved" } else { "denied" };
164
165        let result = sqlx::query(
166            "UPDATE loopback_approvals SET status = ?, response_message = ?, resolved_at = ? WHERE id = ?"
167        )
168        .bind(status)
169        .bind(&message)
170        .bind(now)
171        .bind(id.to_string())
172        .execute(&self.pool)
173        .await
174        .map_err(|e| LoopbackError::Storage { operation: "resolve_approval", detail: e.to_string() })?;
175
176        if result.rows_affected() == 0 {
177            return Err(LoopbackError::ApprovalNotFound { id: id.to_string() });
178        }
179        Ok(())
180    }
181
182    /// Get all pending approvals for a session
183    pub async fn get_pending_approvals(&self, session_id: &str) -> Vec<ApprovalRequest> {
184        let rows = sqlx::query(
185            "SELECT * FROM loopback_approvals WHERE session_id = ? AND status = 'pending'"
186        )
187        .bind(session_id)
188        .fetch_all(&self.pool)
189        .await;
190
191        match rows {
192            Ok(rows) => rows.into_iter().filter_map(|row| self.row_to_approval(row).ok()).collect(),
193            Err(_) => vec![],
194        }
195    }
196
197    pub async fn list_pending(&self, session_id: Option<&str>) -> Result<Vec<ApprovalRequest>, LoopbackError> {
198        let rows = if let Some(sid) = session_id {
199            // Collect all session IDs to query: the given one plus any registered children
200            let mut session_ids = vec![sid.to_string()];
201            if let Ok(children) = self.session_children.read() {
202                if let Some(child_ids) = children.get(sid) {
203                    session_ids.extend(child_ids.iter().cloned());
204                }
205            }
206
207            if session_ids.len() == 1 {
208                sqlx::query(
209                    "SELECT id, session_id, tool_name, tool_use_id, input, status, response_message, created_at, resolved_at
210                     FROM loopback_approvals WHERE session_id = ? AND status = 'pending' ORDER BY created_at"
211                )
212                .bind(&session_ids[0])
213                .fetch_all(&self.pool)
214                .await
215            } else {
216                // Build IN clause for multiple session IDs
217                let placeholders = session_ids.iter().map(|_| "?").collect::<Vec<_>>().join(", ");
218                let query_str = format!(
219                    "SELECT id, session_id, tool_name, tool_use_id, input, status, response_message, created_at, resolved_at
220                     FROM loopback_approvals WHERE session_id IN ({}) AND status = 'pending' ORDER BY created_at",
221                    placeholders
222                );
223                let mut q = sqlx::query(&query_str);
224                for sid in &session_ids {
225                    q = q.bind(sid);
226                }
227                q.fetch_all(&self.pool).await
228            }
229        } else {
230            sqlx::query(
231                "SELECT id, session_id, tool_name, tool_use_id, input, status, response_message, created_at, resolved_at
232                 FROM loopback_approvals WHERE status = 'pending' ORDER BY created_at"
233            )
234            .fetch_all(&self.pool)
235            .await
236        }
237        .map_err(|e| LoopbackError::Storage { operation: "list_pending", detail: e.to_string() })?;
238
239        rows.into_iter().map(|r| self.row_to_approval(r)).collect()
240    }
241
242    fn row_to_approval(&self, row: sqlx::sqlite::SqliteRow) -> Result<ApprovalRequest, LoopbackError> {
243        let id_str: String = row.get("id");
244        let input_json: String = row.get("input");
245        let status_str: String = row.get("status");
246
247        let status = match status_str.as_str() {
248            "pending" => ApprovalStatus::Pending,
249            "approved" => ApprovalStatus::Approved,
250            "denied" => ApprovalStatus::Denied,
251            "timed_out" => ApprovalStatus::TimedOut,
252            _ => ApprovalStatus::Pending,
253        };
254
255        Ok(ApprovalRequest {
256            id: Uuid::parse_str(&id_str).map_err(|e| LoopbackError::InvalidData { detail: format!("Invalid UUID '{}': {}", id_str, e) })?,
257            session_id: row.get("session_id"),
258            tool_name: row.get("tool_name"),
259            tool_use_id: row.get("tool_use_id"),
260            input: serde_json::from_str(&input_json).unwrap_or(Value::Null),
261            status,
262            response_message: row.get("response_message"),
263            created_at: row.get("created_at"),
264            resolved_at: row.get("resolved_at"),
265        })
266    }
267
268    /// Get or create a notifier for a session
269    /// This allows multiple wait_for_approval calls to wait on the same session
270    pub fn get_or_create_notifier(&self, session_id: &str) -> Arc<Notify> {
271        let mut notifiers = self.session_notifiers.write().unwrap();
272        notifiers
273            .entry(session_id.to_string())
274            .or_insert_with(|| Arc::new(Notify::new()))
275            .clone()
276    }
277
278    /// Register a parent session for a child session.
279    /// When the child gets an approval, the parent notifier is also woken.
280    /// Also registers the inverse mapping so list_pending can find child approvals.
281    pub fn register_session_parent(&self, child_session_id: &str, parent_session_id: &str) {
282        if let Ok(mut map) = self.session_parents.write() {
283            map.insert(child_session_id.to_string(), parent_session_id.to_string());
284        }
285        if let Ok(mut map) = self.session_children.write() {
286            map.entry(parent_session_id.to_string())
287                .or_default()
288                .push(child_session_id.to_string());
289        }
290    }
291
292    /// Notify waiters on a session that a new approval has arrived.
293    /// Uses notify_one() so the permit is stored even if no task is currently
294    /// suspended in notified() — preventing lost wakeups when the auto-approver
295    /// is busy processing a previous batch.
296    /// Also notifies the parent session if one is registered.
297    fn notify_session(&self, session_id: &str) {
298        if let Ok(notifiers) = self.session_notifiers.read() {
299            if let Some(notifier) = notifiers.get(session_id) {
300                notifier.notify_one();
301            }
302        }
303        // Propagate to parent (e.g., Orcha session waiting on any child approval)
304        if let Ok(parents) = self.session_parents.read() {
305            if let Some(parent_id) = parents.get(session_id) {
306                if let Ok(notifiers) = self.session_notifiers.read() {
307                    if let Some(notifier) = notifiers.get(parent_id.as_str()) {
308                        notifier.notify_one();
309                    }
310                }
311            }
312        }
313    }
314
315    /// Clean up notifier for a session (optional, for resource cleanup)
316    pub fn remove_notifier(&self, session_id: &str) {
317        if let Ok(mut notifiers) = self.session_notifiers.write() {
318            notifiers.remove(session_id);
319        }
320    }
321}
322
323fn current_timestamp() -> i64 {
324    SystemTime::now()
325        .duration_since(UNIX_EPOCH)
326        .unwrap()
327        .as_secs() as i64
328}