Skip to main content

nexo_driver_loop/memory/
sqlite_vec.rs

1//! `SqliteVecDecisionMemory` — sqlite-vec backed `DecisionMemory`.
2//!
3//! Schema lives in `driver_decisions` + `driver_decisions_vec` tables
4//! and is created on first `open` (idempotent migration).
5
6use std::sync::Arc;
7
8use async_trait::async_trait;
9use chrono::{TimeZone, Utc};
10use nexo_driver_claude::ClaudeError;
11use nexo_driver_permission::PermissionRequest;
12use nexo_driver_types::{Decision, DecisionChoice, DecisionId, GoalId};
13use nexo_memory::{vector, EmbeddingProvider};
14use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
15use sqlx::SqlitePool;
16use uuid::Uuid;
17
18use crate::memory::prompt::{decision_to_text, request_to_text};
19use crate::memory::trait_def::{DecisionMemory, Namespace};
20
21const SCHEMA_VERSION: i64 = 1;
22
23pub struct SqliteVecDecisionMemory {
24    pool: SqlitePool,
25    embedder: Arc<dyn EmbeddingProvider>,
26    namespace: Namespace,
27    dim: usize,
28}
29
30impl SqliteVecDecisionMemory {
31    pub async fn open(
32        path: &str,
33        embedder: Arc<dyn EmbeddingProvider>,
34    ) -> Result<Self, ClaudeError> {
35        // Register sqlite-vec as an auto-extension. Idempotent.
36        vector::enable();
37
38        let opts = SqliteConnectOptions::new()
39            .filename(path)
40            .create_if_missing(true);
41        let max_conns = if path == ":memory:" { 1 } else { 4 };
42        let pool = SqlitePoolOptions::new()
43            .max_connections(max_conns)
44            .connect_with(opts)
45            .await
46            .map_err(|e| ClaudeError::Binding(e.to_string()))?;
47
48        if path != ":memory:" {
49            sqlx::query("PRAGMA journal_mode = WAL")
50                .execute(&pool)
51                .await
52                .map_err(|e| ClaudeError::Binding(e.to_string()))?;
53            sqlx::query("PRAGMA synchronous = NORMAL")
54                .execute(&pool)
55                .await
56                .map_err(|e| ClaudeError::Binding(e.to_string()))?;
57        }
58
59        let dim = embedder.dimension();
60        Self::migrate(&pool, dim).await?;
61
62        Ok(Self {
63            pool,
64            embedder,
65            namespace: Namespace::Global,
66            dim,
67        })
68    }
69
70    pub async fn open_memory(embedder: Arc<dyn EmbeddingProvider>) -> Result<Self, ClaudeError> {
71        Self::open(":memory:", embedder).await
72    }
73
74    pub fn with_namespace(mut self, ns: Namespace) -> Self {
75        self.namespace = ns;
76        self
77    }
78
79    /// Test helper.
80    #[doc(hidden)]
81    pub fn pool_for_test(&self) -> &SqlitePool {
82        &self.pool
83    }
84
85    /// Test helper.
86    #[doc(hidden)]
87    pub async fn count(&self) -> Result<u64, ClaudeError> {
88        let (n,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM driver_decisions")
89            .fetch_one(&self.pool)
90            .await
91            .map_err(|e| ClaudeError::Binding(e.to_string()))?;
92        Ok(n as u64)
93    }
94
95    async fn migrate(pool: &SqlitePool, dim: usize) -> Result<(), ClaudeError> {
96        sqlx::query(
97            "CREATE TABLE IF NOT EXISTS driver_decisions (\
98                id              TEXT PRIMARY KEY,\
99                goal_id         TEXT NOT NULL,\
100                turn_index      INTEGER NOT NULL,\
101                tool            TEXT NOT NULL,\
102                input_summary   TEXT NOT NULL,\
103                choice_kind     TEXT NOT NULL,\
104                choice_message  TEXT,\
105                rationale       TEXT NOT NULL,\
106                decided_at      INTEGER NOT NULL,\
107                full_input_json TEXT NOT NULL,\
108                schema_version  INTEGER NOT NULL DEFAULT 1\
109            )",
110        )
111        .execute(pool)
112        .await
113        .map_err(|e| ClaudeError::Binding(e.to_string()))?;
114        sqlx::query("CREATE INDEX IF NOT EXISTS idx_dd_goal_id ON driver_decisions(goal_id)")
115            .execute(pool)
116            .await
117            .map_err(|e| ClaudeError::Binding(e.to_string()))?;
118        sqlx::query("CREATE INDEX IF NOT EXISTS idx_dd_decided_at ON driver_decisions(decided_at)")
119            .execute(pool)
120            .await
121            .map_err(|e| ClaudeError::Binding(e.to_string()))?;
122
123        // dim mismatch detection — if vec table exists, compare via a sample row.
124        let exists: Option<(String,)> = sqlx::query_as(
125            "SELECT name FROM sqlite_master \
126             WHERE type='table' AND name='driver_decisions_vec'",
127        )
128        .fetch_optional(pool)
129        .await
130        .map_err(|e| ClaudeError::Binding(e.to_string()))?;
131
132        if exists.is_none() {
133            let sql = format!(
134                "CREATE VIRTUAL TABLE driver_decisions_vec USING vec0(embedding FLOAT[{dim}])"
135            );
136            sqlx::query(&sql)
137                .execute(pool)
138                .await
139                .map_err(|e| ClaudeError::Binding(e.to_string()))?;
140        } else {
141            let sample: Option<(Vec<u8>,)> =
142                sqlx::query_as("SELECT embedding FROM driver_decisions_vec LIMIT 1")
143                    .fetch_optional(pool)
144                    .await
145                    .ok()
146                    .flatten();
147            if let Some((bytes,)) = sample {
148                let existing_dim = bytes.len() / 4;
149                if existing_dim != dim {
150                    return Err(ClaudeError::Binding(format!(
151                        "decision-memory dim mismatch: schema={existing_dim}, embedder={dim}; \
152                         drop the table or reset the file"
153                    )));
154                }
155            }
156        }
157
158        sqlx::query(&format!("PRAGMA user_version = {SCHEMA_VERSION}"))
159            .execute(pool)
160            .await
161            .map_err(|e| ClaudeError::Binding(e.to_string()))?;
162        Ok(())
163    }
164}
165
166fn choice_kind_label(c: &DecisionChoice) -> (&'static str, Option<String>) {
167    match c {
168        DecisionChoice::Allow => ("allow", None),
169        DecisionChoice::Deny { message } => ("deny", Some(message.clone())),
170        DecisionChoice::Observe { note } => ("observe", Some(note.clone())),
171    }
172}
173
174fn parse_choice(kind: &str, message: Option<String>) -> DecisionChoice {
175    match kind {
176        "allow" => DecisionChoice::Allow,
177        "deny" => DecisionChoice::Deny {
178            message: message.unwrap_or_default(),
179        },
180        "observe" => DecisionChoice::Observe {
181            note: message.unwrap_or_default(),
182        },
183        _ => DecisionChoice::Allow,
184    }
185}
186
187#[async_trait]
188impl DecisionMemory for SqliteVecDecisionMemory {
189    async fn record(&self, decision: &Decision) -> Result<(), ClaudeError> {
190        let text = decision_to_text(decision);
191        let mut vecs = match self.embedder.embed(&[text.as_str()]).await {
192            Ok(v) => v,
193            Err(e) => {
194                tracing::warn!(target: "decision-memory", "embed record failed: {e}");
195                return Ok(());
196            }
197        };
198        if vecs.is_empty() {
199            return Ok(());
200        }
201        let v = vecs.remove(0);
202        if v.len() != self.dim {
203            tracing::warn!(
204                target: "decision-memory",
205                "embed dim mismatch: got {}, expected {}",
206                v.len(),
207                self.dim
208            );
209            return Ok(());
210        }
211        let bytes = vector::pack_f32(&v);
212
213        let (choice_kind, choice_message) = choice_kind_label(&decision.choice);
214        let full_input_json =
215            serde_json::to_string(&decision.input).unwrap_or_else(|_| "null".into());
216
217        // Transactional insert — use a manual rowid linkage by reading
218        // last_insert_rowid() in the same transaction.
219        let mut tx = self
220            .pool
221            .begin()
222            .await
223            .map_err(|e| ClaudeError::Binding(e.to_string()))?;
224
225        let inserted = sqlx::query(
226            "INSERT INTO driver_decisions (\
227                id, goal_id, turn_index, tool, input_summary, \
228                choice_kind, choice_message, rationale, decided_at, full_input_json\
229             ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10) \
230             ON CONFLICT(id) DO NOTHING",
231        )
232        .bind(decision.id.0.to_string())
233        .bind(decision.goal_id.0.to_string())
234        .bind(decision.turn_index as i64)
235        .bind(&decision.tool)
236        .bind(&text)
237        .bind(choice_kind)
238        .bind(choice_message)
239        .bind(&decision.rationale)
240        .bind(decision.decided_at.timestamp())
241        .bind(&full_input_json)
242        .execute(&mut *tx)
243        .await
244        .map_err(|e| ClaudeError::Binding(e.to_string()))?;
245
246        if inserted.rows_affected() == 0 {
247            // Duplicate id — keep existing embedding row untouched.
248            tx.commit()
249                .await
250                .map_err(|e| ClaudeError::Binding(e.to_string()))?;
251            return Ok(());
252        }
253
254        let rowid: (i64,) = sqlx::query_as("SELECT rowid FROM driver_decisions WHERE id = ?")
255            .bind(decision.id.0.to_string())
256            .fetch_one(&mut *tx)
257            .await
258            .map_err(|e| ClaudeError::Binding(e.to_string()))?;
259
260        sqlx::query("INSERT INTO driver_decisions_vec(rowid, embedding) VALUES (?, ?)")
261            .bind(rowid.0)
262            .bind(bytes)
263            .execute(&mut *tx)
264            .await
265            .map_err(|e| ClaudeError::Binding(e.to_string()))?;
266
267        tx.commit()
268            .await
269            .map_err(|e| ClaudeError::Binding(e.to_string()))?;
270        Ok(())
271    }
272
273    async fn recall(&self, req: &PermissionRequest, k: usize) -> Vec<Decision> {
274        if k == 0 {
275            return Vec::new();
276        }
277        let text = request_to_text(req);
278        let mut vecs = match self.embedder.embed(&[text.as_str()]).await {
279            Ok(v) => v,
280            Err(e) => {
281                tracing::warn!(target: "decision-memory", "embed recall failed: {e}");
282                return Vec::new();
283            }
284        };
285        if vecs.is_empty() {
286            return Vec::new();
287        }
288        let v = vecs.remove(0);
289        if v.len() != self.dim {
290            return Vec::new();
291        }
292        let bytes = vector::pack_f32(&v);
293
294        let goal_filter: Option<String> = match &self.namespace {
295            Namespace::PerGoal(g) => Some(g.0.to_string()),
296            Namespace::Global => None,
297        };
298
299        let rows = sqlx::query_as::<
300            _,
301            (
302                String,
303                String,
304                i64,
305                String,
306                String,
307                Option<String>,
308                String,
309                i64,
310                String,
311            ),
312        >(
313            "SELECT d.id, d.goal_id, d.turn_index, d.tool, \
314                    d.choice_kind, d.choice_message, d.rationale, \
315                    d.decided_at, d.full_input_json \
316             FROM driver_decisions_vec v \
317             JOIN driver_decisions d ON d.rowid = v.rowid \
318             WHERE v.embedding MATCH ?1 \
319               AND v.k = ?2 \
320               AND (?3 IS NULL OR d.goal_id = ?3) \
321             ORDER BY v.distance",
322        )
323        .bind(bytes)
324        .bind(k as i64)
325        .bind(goal_filter)
326        .fetch_all(&self.pool)
327        .await;
328
329        let rows = match rows {
330            Ok(r) => r,
331            Err(e) => {
332                tracing::warn!(target: "decision-memory", "recall query failed: {e}");
333                return Vec::new();
334            }
335        };
336
337        let mut out = Vec::with_capacity(rows.len());
338        for (
339            id,
340            goal_id,
341            turn_index,
342            tool,
343            choice_kind,
344            choice_msg,
345            rationale,
346            decided_at,
347            input_json,
348        ) in rows
349        {
350            let id = match Uuid::parse_str(&id) {
351                Ok(u) => DecisionId(u),
352                Err(_) => continue,
353            };
354            let goal_id = match Uuid::parse_str(&goal_id) {
355                Ok(u) => GoalId(u),
356                Err(_) => continue,
357            };
358            let input: serde_json::Value =
359                serde_json::from_str(&input_json).unwrap_or(serde_json::Value::Null);
360            let decided_at = Utc
361                .timestamp_opt(decided_at, 0)
362                .single()
363                .unwrap_or_else(Utc::now);
364            out.push(Decision {
365                id,
366                goal_id,
367                turn_index: turn_index as u32,
368                tool,
369                input,
370                choice: parse_choice(&choice_kind, choice_msg),
371                rationale,
372                decided_at,
373            });
374        }
375        out
376    }
377}
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382    use crate::memory::mock::MockEmbedder;
383    use chrono::Utc;
384    use nexo_driver_types::DecisionId;
385    use serde_json::json;
386
387    fn dec(tool: &str, input: serde_json::Value) -> Decision {
388        Decision {
389            id: DecisionId::new(),
390            goal_id: GoalId::new(),
391            turn_index: 0,
392            tool: tool.into(),
393            input,
394            choice: DecisionChoice::Allow,
395            rationale: "ok".into(),
396            decided_at: Utc::now(),
397        }
398    }
399
400    #[tokio::test]
401    async fn open_creates_schema_and_count_zero() {
402        let m = SqliteVecDecisionMemory::open_memory(Arc::new(MockEmbedder::new()))
403            .await
404            .unwrap();
405        assert_eq!(m.count().await.unwrap(), 0);
406    }
407
408    #[tokio::test]
409    async fn record_persists_and_count_increments() {
410        let m = SqliteVecDecisionMemory::open_memory(Arc::new(MockEmbedder::new()))
411            .await
412            .unwrap();
413        m.record(&dec("Edit", json!({"file": "x"}))).await.unwrap();
414        m.record(&dec("Bash", json!({"cmd": "ls"}))).await.unwrap();
415        assert_eq!(m.count().await.unwrap(), 2);
416    }
417
418    #[tokio::test]
419    async fn recall_returns_at_most_k() {
420        let m = SqliteVecDecisionMemory::open_memory(Arc::new(MockEmbedder::new()))
421            .await
422            .unwrap();
423        for i in 0..5 {
424            m.record(&dec("Edit", json!({"file": format!("f{i}.rs")})))
425                .await
426                .unwrap();
427        }
428        let req = PermissionRequest {
429            goal_id: GoalId::new(),
430            tool_use_id: "tu".into(),
431            tool_name: "Edit".into(),
432            input: json!({"file": "f0.rs"}),
433            metadata: serde_json::Map::new(),
434        };
435        let hits = m.recall(&req, 3).await;
436        assert!(hits.len() <= 3);
437        assert!(!hits.is_empty(), "expected at least one hit");
438    }
439
440    #[tokio::test]
441    async fn record_idempotent_on_duplicate_id() {
442        let m = SqliteVecDecisionMemory::open_memory(Arc::new(MockEmbedder::new()))
443            .await
444            .unwrap();
445        let d = dec("Edit", json!({"a": 1}));
446        m.record(&d).await.unwrap();
447        m.record(&d).await.unwrap();
448        assert_eq!(m.count().await.unwrap(), 1);
449    }
450}