plexus_substrate/activations/claudecode_loopback/
storage.rs1use super::types::{ApprovalId, ApprovalRequest, ApprovalStatus, LoopbackError};
2use crate::activations::storage::init_sqlite_pool;
3use crate::activation_db_path_from_module;
4use serde_json::Value;
5use sqlx::{sqlite::SqlitePool, Row};
6use std::collections::HashMap;
7use std::path::PathBuf;
8use std::sync::{Arc, RwLock};
9use std::time::{SystemTime, UNIX_EPOCH};
10use tokio::sync::Notify;
11use uuid::Uuid;
12
13#[derive(Debug, Clone)]
14pub struct LoopbackStorageConfig {
15 pub db_path: PathBuf,
16}
17
18impl Default for LoopbackStorageConfig {
19 fn default() -> Self {
20 Self {
21 db_path: activation_db_path_from_module!("loopback.db"),
22 }
23 }
24}
25
26pub struct LoopbackStorage {
27 pool: SqlitePool,
28 tool_session_map: RwLock<HashMap<String, String>>,
31 session_notifiers: Arc<RwLock<HashMap<String, Arc<Notify>>>>,
34 session_parents: RwLock<HashMap<String, String>>,
37 session_children: RwLock<HashMap<String, Vec<String>>>,
40}
41
42impl LoopbackStorage {
43 pub async fn new(config: LoopbackStorageConfig) -> Result<Self, String> {
44 let pool = init_sqlite_pool(config.db_path).await?;
45
46 let storage = Self {
47 pool,
48 tool_session_map: RwLock::new(HashMap::new()),
49 session_notifiers: Arc::new(RwLock::new(HashMap::new())),
50 session_parents: RwLock::new(HashMap::new()),
51 session_children: RwLock::new(HashMap::new()),
52 };
53 storage.run_migrations().await?;
54 Ok(storage)
55 }
56
57 pub fn register_tool_session(&self, tool_use_id: &str, session_id: &str) {
60 if let Ok(mut map) = self.tool_session_map.write() {
61 map.insert(tool_use_id.to_string(), session_id.to_string());
62 }
63 }
64
65 pub fn lookup_session_by_tool(&self, tool_use_id: &str) -> Option<String> {
68 self.tool_session_map.read().ok()?.get(tool_use_id).cloned()
69 }
70
71 pub fn remove_tool_mapping(&self, tool_use_id: &str) {
73 if let Ok(mut map) = self.tool_session_map.write() {
74 map.remove(tool_use_id);
75 }
76 }
77
78 async fn run_migrations(&self) -> Result<(), LoopbackError> {
79 sqlx::query(r#"
80 CREATE TABLE IF NOT EXISTS loopback_approvals (
81 id TEXT PRIMARY KEY,
82 session_id TEXT NOT NULL,
83 tool_name TEXT NOT NULL,
84 tool_use_id TEXT NOT NULL,
85 input TEXT NOT NULL,
86 status TEXT NOT NULL DEFAULT 'pending',
87 response_message TEXT,
88 created_at INTEGER NOT NULL,
89 resolved_at INTEGER
90 );
91 CREATE INDEX IF NOT EXISTS idx_loopback_session ON loopback_approvals(session_id);
92 CREATE INDEX IF NOT EXISTS idx_loopback_status ON loopback_approvals(status);
93 "#)
94 .execute(&self.pool)
95 .await
96 .map_err(|e| LoopbackError::Storage { operation: "migration", detail: e.to_string() })?;
97 Ok(())
98 }
99
100 pub async fn create_approval(
101 &self,
102 session_id: &str,
103 tool_name: &str,
104 tool_use_id: &str,
105 input: &Value,
106 ) -> Result<ApprovalRequest, LoopbackError> {
107 let id = Uuid::new_v4();
108 let now = current_timestamp();
109 let input_json = serde_json::to_string(input)
110 .map_err(|e| LoopbackError::Serialization { detail: e.to_string() })?;
111
112 sqlx::query(
113 "INSERT INTO loopback_approvals (id, session_id, tool_name, tool_use_id, input, status, created_at)
114 VALUES (?, ?, ?, ?, ?, 'pending', ?)"
115 )
116 .bind(id.to_string())
117 .bind(session_id)
118 .bind(tool_name)
119 .bind(tool_use_id)
120 .bind(&input_json)
121 .bind(now)
122 .execute(&self.pool)
123 .await
124 .map_err(|e| LoopbackError::Storage { operation: "create_approval", detail: e.to_string() })?;
125
126 self.notify_session(session_id);
128
129 Ok(ApprovalRequest {
130 id,
131 session_id: session_id.to_string(),
132 tool_name: tool_name.to_string(),
133 tool_use_id: tool_use_id.to_string(),
134 input: input.clone(),
135 status: ApprovalStatus::Pending,
136 response_message: None,
137 created_at: now,
138 resolved_at: None,
139 })
140 }
141
142 pub async fn get_approval(&self, id: &ApprovalId) -> Result<ApprovalRequest, LoopbackError> {
143 let row = sqlx::query(
144 "SELECT id, session_id, tool_name, tool_use_id, input, status, response_message, created_at, resolved_at
145 FROM loopback_approvals WHERE id = ?"
146 )
147 .bind(id.to_string())
148 .fetch_optional(&self.pool)
149 .await
150 .map_err(|e| LoopbackError::Storage { operation: "get_approval", detail: e.to_string() })?
151 .ok_or_else(|| LoopbackError::ApprovalNotFound { id: id.to_string() })?;
152
153 self.row_to_approval(row)
154 }
155
156 pub async fn resolve_approval(
157 &self,
158 id: &ApprovalId,
159 approved: bool,
160 message: Option<String>,
161 ) -> Result<(), LoopbackError> {
162 let now = current_timestamp();
163 let status = if approved { "approved" } else { "denied" };
164
165 let result = sqlx::query(
166 "UPDATE loopback_approvals SET status = ?, response_message = ?, resolved_at = ? WHERE id = ?"
167 )
168 .bind(status)
169 .bind(&message)
170 .bind(now)
171 .bind(id.to_string())
172 .execute(&self.pool)
173 .await
174 .map_err(|e| LoopbackError::Storage { operation: "resolve_approval", detail: e.to_string() })?;
175
176 if result.rows_affected() == 0 {
177 return Err(LoopbackError::ApprovalNotFound { id: id.to_string() });
178 }
179 Ok(())
180 }
181
182 pub async fn get_pending_approvals(&self, session_id: &str) -> Vec<ApprovalRequest> {
184 let rows = sqlx::query(
185 "SELECT * FROM loopback_approvals WHERE session_id = ? AND status = 'pending'"
186 )
187 .bind(session_id)
188 .fetch_all(&self.pool)
189 .await;
190
191 match rows {
192 Ok(rows) => rows.into_iter().filter_map(|row| self.row_to_approval(row).ok()).collect(),
193 Err(_) => vec![],
194 }
195 }
196
197 pub async fn list_pending(&self, session_id: Option<&str>) -> Result<Vec<ApprovalRequest>, LoopbackError> {
198 let rows = if let Some(sid) = session_id {
199 let mut session_ids = vec![sid.to_string()];
201 if let Ok(children) = self.session_children.read() {
202 if let Some(child_ids) = children.get(sid) {
203 session_ids.extend(child_ids.iter().cloned());
204 }
205 }
206
207 if session_ids.len() == 1 {
208 sqlx::query(
209 "SELECT id, session_id, tool_name, tool_use_id, input, status, response_message, created_at, resolved_at
210 FROM loopback_approvals WHERE session_id = ? AND status = 'pending' ORDER BY created_at"
211 )
212 .bind(&session_ids[0])
213 .fetch_all(&self.pool)
214 .await
215 } else {
216 let placeholders = session_ids.iter().map(|_| "?").collect::<Vec<_>>().join(", ");
218 let query_str = format!(
219 "SELECT id, session_id, tool_name, tool_use_id, input, status, response_message, created_at, resolved_at
220 FROM loopback_approvals WHERE session_id IN ({}) AND status = 'pending' ORDER BY created_at",
221 placeholders
222 );
223 let mut q = sqlx::query(&query_str);
224 for sid in &session_ids {
225 q = q.bind(sid);
226 }
227 q.fetch_all(&self.pool).await
228 }
229 } else {
230 sqlx::query(
231 "SELECT id, session_id, tool_name, tool_use_id, input, status, response_message, created_at, resolved_at
232 FROM loopback_approvals WHERE status = 'pending' ORDER BY created_at"
233 )
234 .fetch_all(&self.pool)
235 .await
236 }
237 .map_err(|e| LoopbackError::Storage { operation: "list_pending", detail: e.to_string() })?;
238
239 rows.into_iter().map(|r| self.row_to_approval(r)).collect()
240 }
241
242 fn row_to_approval(&self, row: sqlx::sqlite::SqliteRow) -> Result<ApprovalRequest, LoopbackError> {
243 let id_str: String = row.get("id");
244 let input_json: String = row.get("input");
245 let status_str: String = row.get("status");
246
247 let status = match status_str.as_str() {
248 "pending" => ApprovalStatus::Pending,
249 "approved" => ApprovalStatus::Approved,
250 "denied" => ApprovalStatus::Denied,
251 "timed_out" => ApprovalStatus::TimedOut,
252 _ => ApprovalStatus::Pending,
253 };
254
255 Ok(ApprovalRequest {
256 id: Uuid::parse_str(&id_str).map_err(|e| LoopbackError::InvalidData { detail: format!("Invalid UUID '{}': {}", id_str, e) })?,
257 session_id: row.get("session_id"),
258 tool_name: row.get("tool_name"),
259 tool_use_id: row.get("tool_use_id"),
260 input: serde_json::from_str(&input_json).unwrap_or(Value::Null),
261 status,
262 response_message: row.get("response_message"),
263 created_at: row.get("created_at"),
264 resolved_at: row.get("resolved_at"),
265 })
266 }
267
268 pub fn get_or_create_notifier(&self, session_id: &str) -> Arc<Notify> {
271 let mut notifiers = self.session_notifiers.write().unwrap();
272 notifiers
273 .entry(session_id.to_string())
274 .or_insert_with(|| Arc::new(Notify::new()))
275 .clone()
276 }
277
278 pub fn register_session_parent(&self, child_session_id: &str, parent_session_id: &str) {
282 if let Ok(mut map) = self.session_parents.write() {
283 map.insert(child_session_id.to_string(), parent_session_id.to_string());
284 }
285 if let Ok(mut map) = self.session_children.write() {
286 map.entry(parent_session_id.to_string())
287 .or_default()
288 .push(child_session_id.to_string());
289 }
290 }
291
292 fn notify_session(&self, session_id: &str) {
298 if let Ok(notifiers) = self.session_notifiers.read() {
299 if let Some(notifier) = notifiers.get(session_id) {
300 notifier.notify_one();
301 }
302 }
303 if let Ok(parents) = self.session_parents.read() {
305 if let Some(parent_id) = parents.get(session_id) {
306 if let Ok(notifiers) = self.session_notifiers.read() {
307 if let Some(notifier) = notifiers.get(parent_id.as_str()) {
308 notifier.notify_one();
309 }
310 }
311 }
312 }
313 }
314
315 pub fn remove_notifier(&self, session_id: &str) {
317 if let Ok(mut notifiers) = self.session_notifiers.write() {
318 notifiers.remove(session_id);
319 }
320 }
321}
322
323fn current_timestamp() -> i64 {
324 SystemTime::now()
325 .duration_since(UNIX_EPOCH)
326 .unwrap()
327 .as_secs() as i64
328}