Skip to main content

plexus_substrate/activations/claudecode_loopback/
storage.rs

1use super::types::{ApprovalId, ApprovalRequest, ApprovalStatus};
2use serde_json::Value;
3use sqlx::{sqlite::{SqliteConnectOptions, SqlitePool}, ConnectOptions, Row};
4use std::collections::HashMap;
5use std::path::PathBuf;
6use std::sync::RwLock;
7use std::time::{SystemTime, UNIX_EPOCH};
8use uuid::Uuid;
9
10#[derive(Debug, Clone)]
11pub struct LoopbackStorageConfig {
12    pub db_path: PathBuf,
13}
14
15impl Default for LoopbackStorageConfig {
16    fn default() -> Self {
17        Self {
18            db_path: PathBuf::from("loopback.db"),
19        }
20    }
21}
22
23pub struct LoopbackStorage {
24    pool: SqlitePool,
25    /// Maps tool_use_id -> session_id for correlation
26    /// This allows loopback_permit to find the session_id when called via MCP
27    tool_session_map: RwLock<HashMap<String, String>>,
28}
29
30impl LoopbackStorage {
31    pub async fn new(config: LoopbackStorageConfig) -> Result<Self, String> {
32        let db_url = format!("sqlite:{}?mode=rwc", config.db_path.display());
33        let options: SqliteConnectOptions = db_url.parse()
34            .map_err(|e| format!("Failed to parse DB URL: {}", e))?;
35        let options = options.disable_statement_logging();
36
37        let pool = SqlitePool::connect_with(options)
38            .await
39            .map_err(|e| format!("Failed to connect: {}", e))?;
40
41        let storage = Self {
42            pool,
43            tool_session_map: RwLock::new(HashMap::new()),
44        };
45        storage.run_migrations().await?;
46        Ok(storage)
47    }
48
49    /// Register a tool_use_id -> session_id mapping
50    /// Called by the background task when it sees a ToolUse event
51    pub fn register_tool_session(&self, tool_use_id: &str, session_id: &str) {
52        if let Ok(mut map) = self.tool_session_map.write() {
53            map.insert(tool_use_id.to_string(), session_id.to_string());
54        }
55    }
56
57    /// Lookup session_id by tool_use_id
58    /// Called by loopback_permit to find the correct session_id
59    pub fn lookup_session_by_tool(&self, tool_use_id: &str) -> Option<String> {
60        self.tool_session_map.read().ok()?.get(tool_use_id).cloned()
61    }
62
63    /// Remove a tool_use_id mapping (called after approval is resolved)
64    pub fn remove_tool_mapping(&self, tool_use_id: &str) {
65        if let Ok(mut map) = self.tool_session_map.write() {
66            map.remove(tool_use_id);
67        }
68    }
69
70    async fn run_migrations(&self) -> Result<(), String> {
71        sqlx::query(r#"
72            CREATE TABLE IF NOT EXISTS loopback_approvals (
73                id TEXT PRIMARY KEY,
74                session_id TEXT NOT NULL,
75                tool_name TEXT NOT NULL,
76                tool_use_id TEXT NOT NULL,
77                input TEXT NOT NULL,
78                status TEXT NOT NULL DEFAULT 'pending',
79                response_message TEXT,
80                created_at INTEGER NOT NULL,
81                resolved_at INTEGER
82            );
83            CREATE INDEX IF NOT EXISTS idx_loopback_session ON loopback_approvals(session_id);
84            CREATE INDEX IF NOT EXISTS idx_loopback_status ON loopback_approvals(status);
85        "#)
86        .execute(&self.pool)
87        .await
88        .map_err(|e| format!("Migration failed: {}", e))?;
89        Ok(())
90    }
91
92    pub async fn create_approval(
93        &self,
94        session_id: &str,
95        tool_name: &str,
96        tool_use_id: &str,
97        input: &Value,
98    ) -> Result<ApprovalRequest, String> {
99        let id = Uuid::new_v4();
100        let now = current_timestamp();
101        let input_json = serde_json::to_string(input)
102            .map_err(|e| format!("Failed to serialize input: {}", e))?;
103
104        sqlx::query(
105            "INSERT INTO loopback_approvals (id, session_id, tool_name, tool_use_id, input, status, created_at)
106             VALUES (?, ?, ?, ?, ?, 'pending', ?)"
107        )
108        .bind(id.to_string())
109        .bind(session_id)
110        .bind(tool_name)
111        .bind(tool_use_id)
112        .bind(&input_json)
113        .bind(now)
114        .execute(&self.pool)
115        .await
116        .map_err(|e| format!("Failed to create approval: {}", e))?;
117
118        Ok(ApprovalRequest {
119            id,
120            session_id: session_id.to_string(),
121            tool_name: tool_name.to_string(),
122            tool_use_id: tool_use_id.to_string(),
123            input: input.clone(),
124            status: ApprovalStatus::Pending,
125            response_message: None,
126            created_at: now,
127            resolved_at: None,
128        })
129    }
130
131    pub async fn get_approval(&self, id: &ApprovalId) -> Result<ApprovalRequest, String> {
132        let row = sqlx::query(
133            "SELECT id, session_id, tool_name, tool_use_id, input, status, response_message, created_at, resolved_at
134             FROM loopback_approvals WHERE id = ?"
135        )
136        .bind(id.to_string())
137        .fetch_optional(&self.pool)
138        .await
139        .map_err(|e| format!("Failed to fetch approval: {}", e))?
140        .ok_or_else(|| format!("Approval not found: {}", id))?;
141
142        self.row_to_approval(row)
143    }
144
145    pub async fn resolve_approval(
146        &self,
147        id: &ApprovalId,
148        approved: bool,
149        message: Option<String>,
150    ) -> Result<(), String> {
151        let now = current_timestamp();
152        let status = if approved { "approved" } else { "denied" };
153
154        let result = sqlx::query(
155            "UPDATE loopback_approvals SET status = ?, response_message = ?, resolved_at = ? WHERE id = ?"
156        )
157        .bind(status)
158        .bind(&message)
159        .bind(now)
160        .bind(id.to_string())
161        .execute(&self.pool)
162        .await
163        .map_err(|e| format!("Failed to resolve approval: {}", e))?;
164
165        if result.rows_affected() == 0 {
166            return Err(format!("Approval not found: {}", id));
167        }
168        Ok(())
169    }
170
171    pub async fn list_pending(&self, session_id: Option<&str>) -> Result<Vec<ApprovalRequest>, String> {
172        let rows = if let Some(sid) = session_id {
173            sqlx::query(
174                "SELECT id, session_id, tool_name, tool_use_id, input, status, response_message, created_at, resolved_at
175                 FROM loopback_approvals WHERE session_id = ? AND status = 'pending' ORDER BY created_at"
176            )
177            .bind(sid)
178            .fetch_all(&self.pool)
179            .await
180        } else {
181            sqlx::query(
182                "SELECT id, session_id, tool_name, tool_use_id, input, status, response_message, created_at, resolved_at
183                 FROM loopback_approvals WHERE status = 'pending' ORDER BY created_at"
184            )
185            .fetch_all(&self.pool)
186            .await
187        }
188        .map_err(|e| format!("Failed to list pending: {}", e))?;
189
190        rows.into_iter().map(|r| self.row_to_approval(r)).collect()
191    }
192
193    fn row_to_approval(&self, row: sqlx::sqlite::SqliteRow) -> Result<ApprovalRequest, String> {
194        let id_str: String = row.get("id");
195        let input_json: String = row.get("input");
196        let status_str: String = row.get("status");
197
198        let status = match status_str.as_str() {
199            "pending" => ApprovalStatus::Pending,
200            "approved" => ApprovalStatus::Approved,
201            "denied" => ApprovalStatus::Denied,
202            "timed_out" => ApprovalStatus::TimedOut,
203            _ => ApprovalStatus::Pending,
204        };
205
206        Ok(ApprovalRequest {
207            id: Uuid::parse_str(&id_str).map_err(|e| format!("Invalid UUID: {}", e))?,
208            session_id: row.get("session_id"),
209            tool_name: row.get("tool_name"),
210            tool_use_id: row.get("tool_use_id"),
211            input: serde_json::from_str(&input_json).unwrap_or(Value::Null),
212            status,
213            response_message: row.get("response_message"),
214            created_at: row.get("created_at"),
215            resolved_at: row.get("resolved_at"),
216        })
217    }
218}
219
220fn current_timestamp() -> i64 {
221    SystemTime::now()
222        .duration_since(UNIX_EPOCH)
223        .unwrap()
224        .as_secs() as i64
225}