plexus_substrate/activations/claudecode_loopback/
storage.rs1use super::types::{ApprovalId, ApprovalRequest, ApprovalStatus};
2use serde_json::Value;
3use sqlx::{sqlite::{SqliteConnectOptions, SqlitePool}, ConnectOptions, Row};
4use std::collections::HashMap;
5use std::path::PathBuf;
6use std::sync::RwLock;
7use std::time::{SystemTime, UNIX_EPOCH};
8use uuid::Uuid;
9
10#[derive(Debug, Clone)]
11pub struct LoopbackStorageConfig {
12 pub db_path: PathBuf,
13}
14
15impl Default for LoopbackStorageConfig {
16 fn default() -> Self {
17 Self {
18 db_path: PathBuf::from("loopback.db"),
19 }
20 }
21}
22
23pub struct LoopbackStorage {
24 pool: SqlitePool,
25 tool_session_map: RwLock<HashMap<String, String>>,
28}
29
30impl LoopbackStorage {
31 pub async fn new(config: LoopbackStorageConfig) -> Result<Self, String> {
32 let db_url = format!("sqlite:{}?mode=rwc", config.db_path.display());
33 let options: SqliteConnectOptions = db_url.parse()
34 .map_err(|e| format!("Failed to parse DB URL: {}", e))?;
35 let options = options.disable_statement_logging();
36
37 let pool = SqlitePool::connect_with(options)
38 .await
39 .map_err(|e| format!("Failed to connect: {}", e))?;
40
41 let storage = Self {
42 pool,
43 tool_session_map: RwLock::new(HashMap::new()),
44 };
45 storage.run_migrations().await?;
46 Ok(storage)
47 }
48
49 pub fn register_tool_session(&self, tool_use_id: &str, session_id: &str) {
52 if let Ok(mut map) = self.tool_session_map.write() {
53 map.insert(tool_use_id.to_string(), session_id.to_string());
54 }
55 }
56
57 pub fn lookup_session_by_tool(&self, tool_use_id: &str) -> Option<String> {
60 self.tool_session_map.read().ok()?.get(tool_use_id).cloned()
61 }
62
63 pub fn remove_tool_mapping(&self, tool_use_id: &str) {
65 if let Ok(mut map) = self.tool_session_map.write() {
66 map.remove(tool_use_id);
67 }
68 }
69
70 async fn run_migrations(&self) -> Result<(), String> {
71 sqlx::query(r#"
72 CREATE TABLE IF NOT EXISTS loopback_approvals (
73 id TEXT PRIMARY KEY,
74 session_id TEXT NOT NULL,
75 tool_name TEXT NOT NULL,
76 tool_use_id TEXT NOT NULL,
77 input TEXT NOT NULL,
78 status TEXT NOT NULL DEFAULT 'pending',
79 response_message TEXT,
80 created_at INTEGER NOT NULL,
81 resolved_at INTEGER
82 );
83 CREATE INDEX IF NOT EXISTS idx_loopback_session ON loopback_approvals(session_id);
84 CREATE INDEX IF NOT EXISTS idx_loopback_status ON loopback_approvals(status);
85 "#)
86 .execute(&self.pool)
87 .await
88 .map_err(|e| format!("Migration failed: {}", e))?;
89 Ok(())
90 }
91
92 pub async fn create_approval(
93 &self,
94 session_id: &str,
95 tool_name: &str,
96 tool_use_id: &str,
97 input: &Value,
98 ) -> Result<ApprovalRequest, String> {
99 let id = Uuid::new_v4();
100 let now = current_timestamp();
101 let input_json = serde_json::to_string(input)
102 .map_err(|e| format!("Failed to serialize input: {}", e))?;
103
104 sqlx::query(
105 "INSERT INTO loopback_approvals (id, session_id, tool_name, tool_use_id, input, status, created_at)
106 VALUES (?, ?, ?, ?, ?, 'pending', ?)"
107 )
108 .bind(id.to_string())
109 .bind(session_id)
110 .bind(tool_name)
111 .bind(tool_use_id)
112 .bind(&input_json)
113 .bind(now)
114 .execute(&self.pool)
115 .await
116 .map_err(|e| format!("Failed to create approval: {}", e))?;
117
118 Ok(ApprovalRequest {
119 id,
120 session_id: session_id.to_string(),
121 tool_name: tool_name.to_string(),
122 tool_use_id: tool_use_id.to_string(),
123 input: input.clone(),
124 status: ApprovalStatus::Pending,
125 response_message: None,
126 created_at: now,
127 resolved_at: None,
128 })
129 }
130
131 pub async fn get_approval(&self, id: &ApprovalId) -> Result<ApprovalRequest, String> {
132 let row = sqlx::query(
133 "SELECT id, session_id, tool_name, tool_use_id, input, status, response_message, created_at, resolved_at
134 FROM loopback_approvals WHERE id = ?"
135 )
136 .bind(id.to_string())
137 .fetch_optional(&self.pool)
138 .await
139 .map_err(|e| format!("Failed to fetch approval: {}", e))?
140 .ok_or_else(|| format!("Approval not found: {}", id))?;
141
142 self.row_to_approval(row)
143 }
144
145 pub async fn resolve_approval(
146 &self,
147 id: &ApprovalId,
148 approved: bool,
149 message: Option<String>,
150 ) -> Result<(), String> {
151 let now = current_timestamp();
152 let status = if approved { "approved" } else { "denied" };
153
154 let result = sqlx::query(
155 "UPDATE loopback_approvals SET status = ?, response_message = ?, resolved_at = ? WHERE id = ?"
156 )
157 .bind(status)
158 .bind(&message)
159 .bind(now)
160 .bind(id.to_string())
161 .execute(&self.pool)
162 .await
163 .map_err(|e| format!("Failed to resolve approval: {}", e))?;
164
165 if result.rows_affected() == 0 {
166 return Err(format!("Approval not found: {}", id));
167 }
168 Ok(())
169 }
170
171 pub async fn list_pending(&self, session_id: Option<&str>) -> Result<Vec<ApprovalRequest>, String> {
172 let rows = if let Some(sid) = session_id {
173 sqlx::query(
174 "SELECT id, session_id, tool_name, tool_use_id, input, status, response_message, created_at, resolved_at
175 FROM loopback_approvals WHERE session_id = ? AND status = 'pending' ORDER BY created_at"
176 )
177 .bind(sid)
178 .fetch_all(&self.pool)
179 .await
180 } else {
181 sqlx::query(
182 "SELECT id, session_id, tool_name, tool_use_id, input, status, response_message, created_at, resolved_at
183 FROM loopback_approvals WHERE status = 'pending' ORDER BY created_at"
184 )
185 .fetch_all(&self.pool)
186 .await
187 }
188 .map_err(|e| format!("Failed to list pending: {}", e))?;
189
190 rows.into_iter().map(|r| self.row_to_approval(r)).collect()
191 }
192
193 fn row_to_approval(&self, row: sqlx::sqlite::SqliteRow) -> Result<ApprovalRequest, String> {
194 let id_str: String = row.get("id");
195 let input_json: String = row.get("input");
196 let status_str: String = row.get("status");
197
198 let status = match status_str.as_str() {
199 "pending" => ApprovalStatus::Pending,
200 "approved" => ApprovalStatus::Approved,
201 "denied" => ApprovalStatus::Denied,
202 "timed_out" => ApprovalStatus::TimedOut,
203 _ => ApprovalStatus::Pending,
204 };
205
206 Ok(ApprovalRequest {
207 id: Uuid::parse_str(&id_str).map_err(|e| format!("Invalid UUID: {}", e))?,
208 session_id: row.get("session_id"),
209 tool_name: row.get("tool_name"),
210 tool_use_id: row.get("tool_use_id"),
211 input: serde_json::from_str(&input_json).unwrap_or(Value::Null),
212 status,
213 response_message: row.get("response_message"),
214 created_at: row.get("created_at"),
215 resolved_at: row.get("resolved_at"),
216 })
217 }
218}
219
220fn current_timestamp() -> i64 {
221 SystemTime::now()
222 .duration_since(UNIX_EPOCH)
223 .unwrap()
224 .as_secs() as i64
225}