1use std::sync::Arc;
7
8use async_trait::async_trait;
9use chrono::{TimeZone, Utc};
10use nexo_driver_claude::ClaudeError;
11use nexo_driver_permission::PermissionRequest;
12use nexo_driver_types::{Decision, DecisionChoice, DecisionId, GoalId};
13use nexo_memory::{vector, EmbeddingProvider};
14use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
15use sqlx::SqlitePool;
16use uuid::Uuid;
17
18use crate::memory::prompt::{decision_to_text, request_to_text};
19use crate::memory::trait_def::{DecisionMemory, Namespace};
20
21const SCHEMA_VERSION: i64 = 1;
22
23pub struct SqliteVecDecisionMemory {
24 pool: SqlitePool,
25 embedder: Arc<dyn EmbeddingProvider>,
26 namespace: Namespace,
27 dim: usize,
28}
29
30impl SqliteVecDecisionMemory {
31 pub async fn open(
32 path: &str,
33 embedder: Arc<dyn EmbeddingProvider>,
34 ) -> Result<Self, ClaudeError> {
35 vector::enable();
37
38 let opts = SqliteConnectOptions::new()
39 .filename(path)
40 .create_if_missing(true);
41 let max_conns = if path == ":memory:" { 1 } else { 4 };
42 let pool = SqlitePoolOptions::new()
43 .max_connections(max_conns)
44 .connect_with(opts)
45 .await
46 .map_err(|e| ClaudeError::Binding(e.to_string()))?;
47
48 if path != ":memory:" {
49 sqlx::query("PRAGMA journal_mode = WAL")
50 .execute(&pool)
51 .await
52 .map_err(|e| ClaudeError::Binding(e.to_string()))?;
53 sqlx::query("PRAGMA synchronous = NORMAL")
54 .execute(&pool)
55 .await
56 .map_err(|e| ClaudeError::Binding(e.to_string()))?;
57 }
58
59 let dim = embedder.dimension();
60 Self::migrate(&pool, dim).await?;
61
62 Ok(Self {
63 pool,
64 embedder,
65 namespace: Namespace::Global,
66 dim,
67 })
68 }
69
70 pub async fn open_memory(embedder: Arc<dyn EmbeddingProvider>) -> Result<Self, ClaudeError> {
71 Self::open(":memory:", embedder).await
72 }
73
74 pub fn with_namespace(mut self, ns: Namespace) -> Self {
75 self.namespace = ns;
76 self
77 }
78
79 #[doc(hidden)]
81 pub fn pool_for_test(&self) -> &SqlitePool {
82 &self.pool
83 }
84
85 #[doc(hidden)]
87 pub async fn count(&self) -> Result<u64, ClaudeError> {
88 let (n,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM driver_decisions")
89 .fetch_one(&self.pool)
90 .await
91 .map_err(|e| ClaudeError::Binding(e.to_string()))?;
92 Ok(n as u64)
93 }
94
95 async fn migrate(pool: &SqlitePool, dim: usize) -> Result<(), ClaudeError> {
96 sqlx::query(
97 "CREATE TABLE IF NOT EXISTS driver_decisions (\
98 id TEXT PRIMARY KEY,\
99 goal_id TEXT NOT NULL,\
100 turn_index INTEGER NOT NULL,\
101 tool TEXT NOT NULL,\
102 input_summary TEXT NOT NULL,\
103 choice_kind TEXT NOT NULL,\
104 choice_message TEXT,\
105 rationale TEXT NOT NULL,\
106 decided_at INTEGER NOT NULL,\
107 full_input_json TEXT NOT NULL,\
108 schema_version INTEGER NOT NULL DEFAULT 1\
109 )",
110 )
111 .execute(pool)
112 .await
113 .map_err(|e| ClaudeError::Binding(e.to_string()))?;
114 sqlx::query("CREATE INDEX IF NOT EXISTS idx_dd_goal_id ON driver_decisions(goal_id)")
115 .execute(pool)
116 .await
117 .map_err(|e| ClaudeError::Binding(e.to_string()))?;
118 sqlx::query("CREATE INDEX IF NOT EXISTS idx_dd_decided_at ON driver_decisions(decided_at)")
119 .execute(pool)
120 .await
121 .map_err(|e| ClaudeError::Binding(e.to_string()))?;
122
123 let exists: Option<(String,)> = sqlx::query_as(
125 "SELECT name FROM sqlite_master \
126 WHERE type='table' AND name='driver_decisions_vec'",
127 )
128 .fetch_optional(pool)
129 .await
130 .map_err(|e| ClaudeError::Binding(e.to_string()))?;
131
132 if exists.is_none() {
133 let sql = format!(
134 "CREATE VIRTUAL TABLE driver_decisions_vec USING vec0(embedding FLOAT[{dim}])"
135 );
136 sqlx::query(&sql)
137 .execute(pool)
138 .await
139 .map_err(|e| ClaudeError::Binding(e.to_string()))?;
140 } else {
141 let sample: Option<(Vec<u8>,)> =
142 sqlx::query_as("SELECT embedding FROM driver_decisions_vec LIMIT 1")
143 .fetch_optional(pool)
144 .await
145 .ok()
146 .flatten();
147 if let Some((bytes,)) = sample {
148 let existing_dim = bytes.len() / 4;
149 if existing_dim != dim {
150 return Err(ClaudeError::Binding(format!(
151 "decision-memory dim mismatch: schema={existing_dim}, embedder={dim}; \
152 drop the table or reset the file"
153 )));
154 }
155 }
156 }
157
158 sqlx::query(&format!("PRAGMA user_version = {SCHEMA_VERSION}"))
159 .execute(pool)
160 .await
161 .map_err(|e| ClaudeError::Binding(e.to_string()))?;
162 Ok(())
163 }
164}
165
166fn choice_kind_label(c: &DecisionChoice) -> (&'static str, Option<String>) {
167 match c {
168 DecisionChoice::Allow => ("allow", None),
169 DecisionChoice::Deny { message } => ("deny", Some(message.clone())),
170 DecisionChoice::Observe { note } => ("observe", Some(note.clone())),
171 }
172}
173
174fn parse_choice(kind: &str, message: Option<String>) -> DecisionChoice {
175 match kind {
176 "allow" => DecisionChoice::Allow,
177 "deny" => DecisionChoice::Deny {
178 message: message.unwrap_or_default(),
179 },
180 "observe" => DecisionChoice::Observe {
181 note: message.unwrap_or_default(),
182 },
183 _ => DecisionChoice::Allow,
184 }
185}
186
187#[async_trait]
188impl DecisionMemory for SqliteVecDecisionMemory {
189 async fn record(&self, decision: &Decision) -> Result<(), ClaudeError> {
190 let text = decision_to_text(decision);
191 let mut vecs = match self.embedder.embed(&[text.as_str()]).await {
192 Ok(v) => v,
193 Err(e) => {
194 tracing::warn!(target: "decision-memory", "embed record failed: {e}");
195 return Ok(());
196 }
197 };
198 if vecs.is_empty() {
199 return Ok(());
200 }
201 let v = vecs.remove(0);
202 if v.len() != self.dim {
203 tracing::warn!(
204 target: "decision-memory",
205 "embed dim mismatch: got {}, expected {}",
206 v.len(),
207 self.dim
208 );
209 return Ok(());
210 }
211 let bytes = vector::pack_f32(&v);
212
213 let (choice_kind, choice_message) = choice_kind_label(&decision.choice);
214 let full_input_json =
215 serde_json::to_string(&decision.input).unwrap_or_else(|_| "null".into());
216
217 let mut tx = self
220 .pool
221 .begin()
222 .await
223 .map_err(|e| ClaudeError::Binding(e.to_string()))?;
224
225 let inserted = sqlx::query(
226 "INSERT INTO driver_decisions (\
227 id, goal_id, turn_index, tool, input_summary, \
228 choice_kind, choice_message, rationale, decided_at, full_input_json\
229 ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10) \
230 ON CONFLICT(id) DO NOTHING",
231 )
232 .bind(decision.id.0.to_string())
233 .bind(decision.goal_id.0.to_string())
234 .bind(decision.turn_index as i64)
235 .bind(&decision.tool)
236 .bind(&text)
237 .bind(choice_kind)
238 .bind(choice_message)
239 .bind(&decision.rationale)
240 .bind(decision.decided_at.timestamp())
241 .bind(&full_input_json)
242 .execute(&mut *tx)
243 .await
244 .map_err(|e| ClaudeError::Binding(e.to_string()))?;
245
246 if inserted.rows_affected() == 0 {
247 tx.commit()
249 .await
250 .map_err(|e| ClaudeError::Binding(e.to_string()))?;
251 return Ok(());
252 }
253
254 let rowid: (i64,) = sqlx::query_as("SELECT rowid FROM driver_decisions WHERE id = ?")
255 .bind(decision.id.0.to_string())
256 .fetch_one(&mut *tx)
257 .await
258 .map_err(|e| ClaudeError::Binding(e.to_string()))?;
259
260 sqlx::query("INSERT INTO driver_decisions_vec(rowid, embedding) VALUES (?, ?)")
261 .bind(rowid.0)
262 .bind(bytes)
263 .execute(&mut *tx)
264 .await
265 .map_err(|e| ClaudeError::Binding(e.to_string()))?;
266
267 tx.commit()
268 .await
269 .map_err(|e| ClaudeError::Binding(e.to_string()))?;
270 Ok(())
271 }
272
273 async fn recall(&self, req: &PermissionRequest, k: usize) -> Vec<Decision> {
274 if k == 0 {
275 return Vec::new();
276 }
277 let text = request_to_text(req);
278 let mut vecs = match self.embedder.embed(&[text.as_str()]).await {
279 Ok(v) => v,
280 Err(e) => {
281 tracing::warn!(target: "decision-memory", "embed recall failed: {e}");
282 return Vec::new();
283 }
284 };
285 if vecs.is_empty() {
286 return Vec::new();
287 }
288 let v = vecs.remove(0);
289 if v.len() != self.dim {
290 return Vec::new();
291 }
292 let bytes = vector::pack_f32(&v);
293
294 let goal_filter: Option<String> = match &self.namespace {
295 Namespace::PerGoal(g) => Some(g.0.to_string()),
296 Namespace::Global => None,
297 };
298
299 let rows = sqlx::query_as::<
300 _,
301 (
302 String,
303 String,
304 i64,
305 String,
306 String,
307 Option<String>,
308 String,
309 i64,
310 String,
311 ),
312 >(
313 "SELECT d.id, d.goal_id, d.turn_index, d.tool, \
314 d.choice_kind, d.choice_message, d.rationale, \
315 d.decided_at, d.full_input_json \
316 FROM driver_decisions_vec v \
317 JOIN driver_decisions d ON d.rowid = v.rowid \
318 WHERE v.embedding MATCH ?1 \
319 AND v.k = ?2 \
320 AND (?3 IS NULL OR d.goal_id = ?3) \
321 ORDER BY v.distance",
322 )
323 .bind(bytes)
324 .bind(k as i64)
325 .bind(goal_filter)
326 .fetch_all(&self.pool)
327 .await;
328
329 let rows = match rows {
330 Ok(r) => r,
331 Err(e) => {
332 tracing::warn!(target: "decision-memory", "recall query failed: {e}");
333 return Vec::new();
334 }
335 };
336
337 let mut out = Vec::with_capacity(rows.len());
338 for (
339 id,
340 goal_id,
341 turn_index,
342 tool,
343 choice_kind,
344 choice_msg,
345 rationale,
346 decided_at,
347 input_json,
348 ) in rows
349 {
350 let id = match Uuid::parse_str(&id) {
351 Ok(u) => DecisionId(u),
352 Err(_) => continue,
353 };
354 let goal_id = match Uuid::parse_str(&goal_id) {
355 Ok(u) => GoalId(u),
356 Err(_) => continue,
357 };
358 let input: serde_json::Value =
359 serde_json::from_str(&input_json).unwrap_or(serde_json::Value::Null);
360 let decided_at = Utc
361 .timestamp_opt(decided_at, 0)
362 .single()
363 .unwrap_or_else(Utc::now);
364 out.push(Decision {
365 id,
366 goal_id,
367 turn_index: turn_index as u32,
368 tool,
369 input,
370 choice: parse_choice(&choice_kind, choice_msg),
371 rationale,
372 decided_at,
373 });
374 }
375 out
376 }
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382 use crate::memory::mock::MockEmbedder;
383 use chrono::Utc;
384 use nexo_driver_types::DecisionId;
385 use serde_json::json;
386
387 fn dec(tool: &str, input: serde_json::Value) -> Decision {
388 Decision {
389 id: DecisionId::new(),
390 goal_id: GoalId::new(),
391 turn_index: 0,
392 tool: tool.into(),
393 input,
394 choice: DecisionChoice::Allow,
395 rationale: "ok".into(),
396 decided_at: Utc::now(),
397 }
398 }
399
400 #[tokio::test]
401 async fn open_creates_schema_and_count_zero() {
402 let m = SqliteVecDecisionMemory::open_memory(Arc::new(MockEmbedder::new()))
403 .await
404 .unwrap();
405 assert_eq!(m.count().await.unwrap(), 0);
406 }
407
408 #[tokio::test]
409 async fn record_persists_and_count_increments() {
410 let m = SqliteVecDecisionMemory::open_memory(Arc::new(MockEmbedder::new()))
411 .await
412 .unwrap();
413 m.record(&dec("Edit", json!({"file": "x"}))).await.unwrap();
414 m.record(&dec("Bash", json!({"cmd": "ls"}))).await.unwrap();
415 assert_eq!(m.count().await.unwrap(), 2);
416 }
417
418 #[tokio::test]
419 async fn recall_returns_at_most_k() {
420 let m = SqliteVecDecisionMemory::open_memory(Arc::new(MockEmbedder::new()))
421 .await
422 .unwrap();
423 for i in 0..5 {
424 m.record(&dec("Edit", json!({"file": format!("f{i}.rs")})))
425 .await
426 .unwrap();
427 }
428 let req = PermissionRequest {
429 goal_id: GoalId::new(),
430 tool_use_id: "tu".into(),
431 tool_name: "Edit".into(),
432 input: json!({"file": "f0.rs"}),
433 metadata: serde_json::Map::new(),
434 };
435 let hits = m.recall(&req, 3).await;
436 assert!(hits.len() <= 3);
437 assert!(!hits.is_empty(), "expected at least one hit");
438 }
439
440 #[tokio::test]
441 async fn record_idempotent_on_duplicate_id() {
442 let m = SqliteVecDecisionMemory::open_memory(Arc::new(MockEmbedder::new()))
443 .await
444 .unwrap();
445 let d = dec("Edit", json!({"a": 1}));
446 m.record(&d).await.unwrap();
447 m.record(&d).await.unwrap();
448 assert_eq!(m.count().await.unwrap(), 1);
449 }
450}