plexus_substrate/activations/claudecode_loopback/
storage.rs1use super::types::{ApprovalId, ApprovalRequest, ApprovalStatus};
2use serde_json::Value;
3use sqlx::{sqlite::{SqliteConnectOptions, SqlitePool}, ConnectOptions, Row};
4use std::collections::HashMap;
5use std::path::PathBuf;
6use std::sync::RwLock;
7use std::time::{SystemTime, UNIX_EPOCH};
8use uuid::Uuid;
9
10#[derive(Debug, Clone)]
11pub struct LoopbackStorageConfig {
12 pub db_path: PathBuf,
13}
14
15impl Default for LoopbackStorageConfig {
16 fn default() -> Self {
17 Self {
18 db_path: PathBuf::from("loopback.db"),
19 }
20 }
21}
22
23pub struct LoopbackStorage {
24 pool: SqlitePool,
25 tool_session_map: RwLock<HashMap<String, String>>,
28}
29
30impl LoopbackStorage {
31 pub async fn new(config: LoopbackStorageConfig) -> Result<Self, String> {
32 let db_url = format!("sqlite:{}?mode=rwc", config.db_path.display());
33 let options: SqliteConnectOptions = db_url.parse()
34 .map_err(|e| format!("Failed to parse DB URL: {}", e))?;
35 let options = options.disable_statement_logging();
36
37 let pool = SqlitePool::connect_with(options)
38 .await
39 .map_err(|e| format!("Failed to connect: {}", e))?;
40
41 let storage = Self {
42 pool,
43 tool_session_map: RwLock::new(HashMap::new()),
44 };
45 storage.run_migrations().await?;
46 Ok(storage)
47 }
48
49 pub fn register_tool_session(&self, tool_use_id: &str, session_id: &str) {
52 if let Ok(mut map) = self.tool_session_map.write() {
53 map.insert(tool_use_id.to_string(), session_id.to_string());
54 }
55 }
56
57 pub fn lookup_session_by_tool(&self, tool_use_id: &str) -> Option<String> {
60 self.tool_session_map.read().ok()?.get(tool_use_id).cloned()
61 }
62
63 pub fn remove_tool_mapping(&self, tool_use_id: &str) {
65 if let Ok(mut map) = self.tool_session_map.write() {
66 map.remove(tool_use_id);
67 }
68 }
69
70 async fn run_migrations(&self) -> Result<(), String> {
71 sqlx::query(r#"
72 CREATE TABLE IF NOT EXISTS loopback_approvals (
73 id TEXT PRIMARY KEY,
74 session_id TEXT NOT NULL,
75 tool_name TEXT NOT NULL,
76 tool_use_id TEXT NOT NULL,
77 input TEXT NOT NULL,
78 status TEXT NOT NULL DEFAULT 'pending',
79 response_message TEXT,
80 created_at INTEGER NOT NULL,
81 resolved_at INTEGER
82 );
83 CREATE INDEX IF NOT EXISTS idx_loopback_session ON loopback_approvals(session_id);
84 CREATE INDEX IF NOT EXISTS idx_loopback_status ON loopback_approvals(status);
85 "#)
86 .execute(&self.pool)
87 .await
88 .map_err(|e| format!("Migration failed: {}", e))?;
89 Ok(())
90 }
91
92 pub async fn create_approval(
93 &self,
94 session_id: &str,
95 tool_name: &str,
96 tool_use_id: &str,
97 input: &Value,
98 ) -> Result<ApprovalRequest, String> {
99 let id = Uuid::new_v4();
100 let now = current_timestamp();
101 let input_json = serde_json::to_string(input)
102 .map_err(|e| format!("Failed to serialize input: {}", e))?;
103
104 sqlx::query(
105 "INSERT INTO loopback_approvals (id, session_id, tool_name, tool_use_id, input, status, created_at)
106 VALUES (?, ?, ?, ?, ?, 'pending', ?)"
107 )
108 .bind(id.to_string())
109 .bind(session_id)
110 .bind(tool_name)
111 .bind(tool_use_id)
112 .bind(&input_json)
113 .bind(now)
114 .execute(&self.pool)
115 .await
116 .map_err(|e| format!("Failed to create approval: {}", e))?;
117
118 Ok(ApprovalRequest {
119 id,
120 session_id: session_id.to_string(),
121 tool_name: tool_name.to_string(),
122 tool_use_id: tool_use_id.to_string(),
123 input: input.clone(),
124 status: ApprovalStatus::Pending,
125 response_message: None,
126 created_at: now,
127 resolved_at: None,
128 })
129 }
130
131 pub async fn get_approval(&self, id: &ApprovalId) -> Result<ApprovalRequest, String> {
132 let row = sqlx::query(
133 "SELECT id, session_id, tool_name, tool_use_id, input, status, response_message, created_at, resolved_at
134 FROM loopback_approvals WHERE id = ?"
135 )
136 .bind(id.to_string())
137 .fetch_optional(&self.pool)
138 .await
139 .map_err(|e| format!("Failed to fetch approval: {}", e))?
140 .ok_or_else(|| format!("Approval not found: {}", id))?;
141
142 self.row_to_approval(row)
143 }
144
145 pub async fn resolve_approval(
146 &self,
147 id: &ApprovalId,
148 approved: bool,
149 message: Option<String>,
150 ) -> Result<(), String> {
151 let now = current_timestamp();
152 let status = if approved { "approved" } else { "denied" };
153
154 let result = sqlx::query(
155 "UPDATE loopback_approvals SET status = ?, response_message = ?, resolved_at = ? WHERE id = ?"
156 )
157 .bind(status)
158 .bind(&message)
159 .bind(now)
160 .bind(id.to_string())
161 .execute(&self.pool)
162 .await
163 .map_err(|e| format!("Failed to resolve approval: {}", e))?;
164
165 if result.rows_affected() == 0 {
166 return Err(format!("Approval not found: {}", id));
167 }
168 Ok(())
169 }
170
171 pub async fn get_pending_approvals(&self, session_id: &str) -> Vec<ApprovalRequest> {
173 let rows = sqlx::query(
174 "SELECT * FROM loopback_approvals WHERE session_id = ? AND status = 'pending'"
175 )
176 .bind(session_id)
177 .fetch_all(&self.pool)
178 .await;
179
180 match rows {
181 Ok(rows) => rows.into_iter().filter_map(|row| self.row_to_approval(row).ok()).collect(),
182 Err(_) => vec![],
183 }
184 }
185
186 pub async fn list_pending(&self, session_id: Option<&str>) -> Result<Vec<ApprovalRequest>, String> {
187 let rows = if let Some(sid) = session_id {
188 sqlx::query(
189 "SELECT id, session_id, tool_name, tool_use_id, input, status, response_message, created_at, resolved_at
190 FROM loopback_approvals WHERE session_id = ? AND status = 'pending' ORDER BY created_at"
191 )
192 .bind(sid)
193 .fetch_all(&self.pool)
194 .await
195 } else {
196 sqlx::query(
197 "SELECT id, session_id, tool_name, tool_use_id, input, status, response_message, created_at, resolved_at
198 FROM loopback_approvals WHERE status = 'pending' ORDER BY created_at"
199 )
200 .fetch_all(&self.pool)
201 .await
202 }
203 .map_err(|e| format!("Failed to list pending: {}", e))?;
204
205 rows.into_iter().map(|r| self.row_to_approval(r)).collect()
206 }
207
208 fn row_to_approval(&self, row: sqlx::sqlite::SqliteRow) -> Result<ApprovalRequest, String> {
209 let id_str: String = row.get("id");
210 let input_json: String = row.get("input");
211 let status_str: String = row.get("status");
212
213 let status = match status_str.as_str() {
214 "pending" => ApprovalStatus::Pending,
215 "approved" => ApprovalStatus::Approved,
216 "denied" => ApprovalStatus::Denied,
217 "timed_out" => ApprovalStatus::TimedOut,
218 _ => ApprovalStatus::Pending,
219 };
220
221 Ok(ApprovalRequest {
222 id: Uuid::parse_str(&id_str).map_err(|e| format!("Invalid UUID: {}", e))?,
223 session_id: row.get("session_id"),
224 tool_name: row.get("tool_name"),
225 tool_use_id: row.get("tool_use_id"),
226 input: serde_json::from_str(&input_json).unwrap_or(Value::Null),
227 status,
228 response_message: row.get("response_message"),
229 created_at: row.get("created_at"),
230 resolved_at: row.get("resolved_at"),
231 })
232 }
233}
234
235fn current_timestamp() -> i64 {
236 SystemTime::now()
237 .duration_since(UNIX_EPOCH)
238 .unwrap()
239 .as_secs() as i64
240}