agentroot_core/db/
vectors.rs

1//! Vector storage operations
2//!
3//! Stores embeddings as BLOBs and computes cosine similarity in Rust.
4
5use super::Database;
6use crate::error::Result;
7use chrono::Utc;
8use rusqlite::params;
9
10/// Result of looking up a cached embedding
11#[derive(Debug, Clone)]
12pub enum CacheLookupResult {
13    /// Cache hit with the embedding
14    Hit(Vec<f32>),
15    /// Cache miss - need to compute
16    Miss,
17    /// Model dimensions changed - need to recompute
18    ModelMismatch,
19}
20
21impl Database {
22    /// Ensure vector storage table exists
23    pub fn ensure_vec_table(&self, _dimensions: usize) -> Result<()> {
24        self.conn.execute(
25            "CREATE TABLE IF NOT EXISTS embeddings (
26                hash_seq TEXT PRIMARY KEY,
27                embedding BLOB NOT NULL
28            )",
29            [],
30        )?;
31        // Note: No index needed on hash_seq - SQLite automatically indexes PRIMARY KEY
32        Ok(())
33    }
34
35    /// Insert embedding for a document chunk
36    pub fn insert_embedding(
37        &self,
38        hash: &str,
39        seq: u32,
40        pos: usize,
41        model: &str,
42        embedding: &[f32],
43    ) -> Result<()> {
44        let now = Utc::now().to_rfc3339();
45        let hash_seq = format!("{}_{}", hash, seq);
46        let embedding_bytes = embedding_to_bytes(embedding);
47
48        self.conn.execute("BEGIN IMMEDIATE", [])?;
49        let result = (|| {
50            self.conn.execute(
51                "INSERT OR REPLACE INTO content_vectors (hash, seq, pos, model, created_at)
52                 VALUES (?1, ?2, ?3, ?4, ?5)",
53                params![hash, seq, pos, model, now],
54            )?;
55            self.conn.execute(
56                "INSERT OR REPLACE INTO embeddings (hash_seq, embedding) VALUES (?1, ?2)",
57                params![hash_seq, embedding_bytes],
58            )?;
59            Ok(())
60        })();
61
62        if result.is_ok() {
63            self.conn.execute("COMMIT", [])?;
64        } else {
65            let _ = self.conn.execute("ROLLBACK", []);
66        }
67        result
68    }
69
70    /// Check if vector index exists and has data
71    pub fn has_vector_index(&self) -> bool {
72        self.conn
73            .query_row("SELECT COUNT(*) FROM content_vectors", [], |row| {
74                row.get::<_, i64>(0)
75            })
76            .map(|count| count > 0)
77            .unwrap_or(false)
78    }
79
80    /// Get all embeddings for similarity search
81    pub fn get_all_embeddings(&self) -> Result<Vec<(String, Vec<f32>)>> {
82        let mut stmt = self
83            .conn
84            .prepare("SELECT hash_seq, embedding FROM embeddings")?;
85
86        let results = stmt
87            .query_map([], |row| {
88                let hash_seq: String = row.get(0)?;
89                let embedding_bytes: Vec<u8> = row.get(1)?;
90                let embedding = bytes_to_embedding(&embedding_bytes);
91                Ok((hash_seq, embedding))
92            })?
93            .collect::<std::result::Result<Vec<_>, _>>()?;
94
95        Ok(results)
96    }
97
98    /// Get embeddings for specific hashes (for filtered search)
99    pub fn get_embeddings_for_collection(
100        &self,
101        collection: &str,
102    ) -> Result<Vec<(String, Vec<f32>)>> {
103        let mut stmt = self.conn.prepare(
104            "SELECT e.hash_seq, e.embedding
105             FROM embeddings e
106             JOIN content_vectors cv ON e.hash_seq = cv.hash || '_' || cv.seq
107             JOIN documents d ON d.hash = cv.hash AND d.active = 1
108             WHERE d.collection = ?1",
109        )?;
110
111        let results = stmt
112            .query_map(params![collection], |row| {
113                let hash_seq: String = row.get(0)?;
114                let embedding_bytes: Vec<u8> = row.get(1)?;
115                let embedding = bytes_to_embedding(&embedding_bytes);
116                Ok((hash_seq, embedding))
117            })?
118            .collect::<std::result::Result<Vec<_>, _>>()?;
119
120        Ok(results)
121    }
122
123    /// Get hashes that need embedding
124    pub fn get_hashes_needing_embedding(&self) -> Result<Vec<(String, String)>> {
125        let mut stmt = self.conn.prepare(
126            "SELECT c.hash, c.doc FROM content c
127             JOIN documents d ON d.hash = c.hash AND d.active = 1
128             WHERE c.hash NOT IN (SELECT DISTINCT hash FROM content_vectors)",
129        )?;
130
131        let results = stmt
132            .query_map([], |row| Ok((row.get(0)?, row.get(1)?)))?
133            .collect::<std::result::Result<Vec<_>, _>>()?;
134
135        Ok(results)
136    }
137
138    /// Count hashes needing embedding
139    pub fn count_hashes_needing_embedding(&self) -> Result<usize> {
140        let count: i64 = self.conn.query_row(
141            "SELECT COUNT(DISTINCT c.hash) FROM content c
142             JOIN documents d ON d.hash = c.hash AND d.active = 1
143             WHERE c.hash NOT IN (SELECT DISTINCT hash FROM content_vectors)",
144            [],
145            |row| row.get(0),
146        )?;
147        Ok(count as usize)
148    }
149
150    /// Delete embeddings for a hash
151    pub fn delete_embeddings(&self, hash: &str) -> Result<usize> {
152        let pattern = format!("{}_*", hash);
153
154        self.conn.execute("BEGIN IMMEDIATE", [])?;
155        let result = (|| {
156            self.conn
157                .execute("DELETE FROM content_vectors WHERE hash = ?1", params![hash])?;
158            // Use GLOB instead of LIKE to avoid issues with special characters.
159            // GLOB uses * and ? as wildcards, which won't appear in SHA-256 hex hashes.
160            let rows = self.conn.execute(
161                "DELETE FROM embeddings WHERE hash_seq GLOB ?1",
162                params![pattern],
163            )?;
164            Ok(rows)
165        })();
166
167        if result.is_ok() {
168            self.conn.execute("COMMIT", [])?;
169        } else {
170            let _ = self.conn.execute("ROLLBACK", []);
171        }
172        result
173    }
174
175    /// Get all hashes for embedding (for force re-embedding)
176    pub fn get_all_hashes_for_embedding(&self) -> Result<Vec<(String, String)>> {
177        let mut stmt = self.conn.prepare(
178            "SELECT c.hash, c.doc FROM content c
179             JOIN documents d ON d.hash = c.hash AND d.active = 1",
180        )?;
181
182        let results = stmt
183            .query_map([], |row| Ok((row.get(0)?, row.get(1)?)))?
184            .collect::<std::result::Result<Vec<_>, _>>()?;
185
186        Ok(results)
187    }
188
189    /// Check if model dimensions are compatible with expected dimensions
190    pub fn check_model_compatibility(&self, model: &str, expected_dims: usize) -> Result<bool> {
191        match self.get_model_dimensions(model)? {
192            Some(stored_dims) => Ok(stored_dims == expected_dims),
193            None => Ok(true), // No stored model = compatible (will be registered)
194        }
195    }
196
197    /// Look up a cached embedding by chunk hash (performs dimension check)
198    pub fn get_cached_embedding(
199        &self,
200        chunk_hash: &str,
201        model: &str,
202        expected_dims: usize,
203    ) -> Result<CacheLookupResult> {
204        if !self.check_model_compatibility(model, expected_dims)? {
205            return Ok(CacheLookupResult::ModelMismatch);
206        }
207        self.get_cached_embedding_fast(chunk_hash, model)
208    }
209
210    /// Look up a cached embedding by chunk hash (skips dimension check - caller must verify compatibility)
211    pub fn get_cached_embedding_fast(
212        &self,
213        chunk_hash: &str,
214        model: &str,
215    ) -> Result<CacheLookupResult> {
216        let result = self.conn.query_row(
217            "SELECT embedding FROM chunk_embeddings WHERE chunk_hash = ?1 AND model = ?2",
218            params![chunk_hash, model],
219            |row| {
220                let bytes: Vec<u8> = row.get(0)?;
221                Ok(bytes_to_embedding(&bytes))
222            },
223        );
224
225        match result {
226            Ok(embedding) => Ok(CacheLookupResult::Hit(embedding)),
227            Err(rusqlite::Error::QueryReturnedNoRows) => Ok(CacheLookupResult::Miss),
228            Err(e) => Err(e.into()),
229        }
230    }
231
232    /// Insert a chunk embedding with cache support
233    pub fn insert_chunk_embedding(
234        &self,
235        doc_hash: &str,
236        seq: u32,
237        pos: usize,
238        chunk_hash: &str,
239        model: &str,
240        embedding: &[f32],
241    ) -> Result<()> {
242        let now = Utc::now().to_rfc3339();
243        let hash_seq = format!("{}_{}", doc_hash, seq);
244        let embedding_bytes = embedding_to_bytes(embedding);
245
246        self.conn.execute("BEGIN IMMEDIATE", [])?;
247        let result = (|| {
248            self.conn.execute(
249                "INSERT OR REPLACE INTO content_vectors (hash, seq, pos, model, chunk_hash, created_at)
250                 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
251                params![doc_hash, seq, pos, model, chunk_hash, now],
252            )?;
253            self.conn.execute(
254                "INSERT OR REPLACE INTO embeddings (hash_seq, embedding) VALUES (?1, ?2)",
255                params![hash_seq, embedding_bytes],
256            )?;
257            self.conn.execute(
258                "INSERT OR REPLACE INTO chunk_embeddings (chunk_hash, model, embedding, created_at)
259                 VALUES (?1, ?2, ?3, ?4)",
260                params![chunk_hash, model, &embedding_bytes, now],
261            )?;
262            Ok(())
263        })();
264
265        if result.is_ok() {
266            self.conn.execute("COMMIT", [])?;
267        } else {
268            let _ = self.conn.execute("ROLLBACK", []);
269        }
270        result
271    }
272
273    /// Get chunk hashes for a document
274    pub fn get_chunk_hashes_for_doc(&self, doc_hash: &str) -> Result<Vec<(u32, String)>> {
275        let mut stmt = self.conn.prepare(
276            "SELECT seq, chunk_hash FROM content_vectors WHERE hash = ?1 AND chunk_hash IS NOT NULL"
277        )?;
278
279        let results = stmt
280            .query_map(params![doc_hash], |row| Ok((row.get(0)?, row.get(1)?)))?
281            .collect::<std::result::Result<Vec<_>, _>>()?;
282
283        Ok(results)
284    }
285
286    /// Clean up orphaned chunk embeddings (not referenced by any document)
287    pub fn cleanup_orphaned_chunk_embeddings(&self) -> Result<usize> {
288        let count = self.conn.execute(
289            "DELETE FROM chunk_embeddings WHERE chunk_hash NOT IN (
290                SELECT DISTINCT chunk_hash FROM content_vectors WHERE chunk_hash IS NOT NULL
291            )",
292            [],
293        )?;
294        Ok(count)
295    }
296
297    /// Register model with its dimensions
298    pub fn register_model(&self, model: &str, dimensions: usize) -> Result<()> {
299        let now = Utc::now().to_rfc3339();
300
301        self.conn.execute(
302            "INSERT INTO model_metadata (model, dimensions, created_at, last_used_at)
303             VALUES (?1, ?2, ?3, ?3)
304             ON CONFLICT(model) DO UPDATE SET last_used_at = ?3",
305            params![model, dimensions as i64, now],
306        )?;
307
308        Ok(())
309    }
310
311    /// Get stored model dimensions
312    pub fn get_model_dimensions(&self, model: &str) -> Result<Option<usize>> {
313        let result = self.conn.query_row(
314            "SELECT dimensions FROM model_metadata WHERE model = ?1",
315            params![model],
316            |row| row.get::<_, i64>(0),
317        );
318
319        match result {
320            Ok(dims) => Ok(Some(dims as usize)),
321            Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
322            Err(e) => Err(e.into()),
323        }
324    }
325
326    /// Count cached chunk embeddings
327    pub fn count_cached_embeddings(&self, model: &str) -> Result<usize> {
328        let count: i64 = self.conn.query_row(
329            "SELECT COUNT(*) FROM chunk_embeddings WHERE model = ?1",
330            params![model],
331            |row| row.get(0),
332        )?;
333        Ok(count as usize)
334    }
335}
336
337/// Convert f32 embedding to bytes (little-endian)
338pub fn embedding_to_bytes(embedding: &[f32]) -> Vec<u8> {
339    embedding.iter().flat_map(|f| f.to_le_bytes()).collect()
340}
341
342/// Convert bytes to f32 embedding
343pub fn bytes_to_embedding(bytes: &[u8]) -> Vec<f32> {
344    bytes
345        .chunks_exact(4)
346        .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
347        .collect()
348}
349
350/// Compute cosine similarity between two embeddings
351pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
352    if a.len() != b.len() || a.is_empty() {
353        return 0.0;
354    }
355
356    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
357    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
358    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
359
360    if norm_a == 0.0 || norm_b == 0.0 {
361        return 0.0;
362    }
363
364    dot / (norm_a * norm_b)
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370
371    #[test]
372    fn test_embedding_roundtrip() {
373        let original = vec![1.0f32, 2.0, 3.0, -1.5];
374        let bytes = embedding_to_bytes(&original);
375        let restored = bytes_to_embedding(&bytes);
376        assert_eq!(original, restored);
377    }
378
379    #[test]
380    fn test_cosine_similarity_identical() {
381        let a = vec![1.0, 0.0, 0.0];
382        let b = vec![1.0, 0.0, 0.0];
383        let sim = cosine_similarity(&a, &b);
384        assert!((sim - 1.0).abs() < 0.0001);
385    }
386
387    #[test]
388    fn test_cosine_similarity_orthogonal() {
389        let a = vec![1.0, 0.0, 0.0];
390        let b = vec![0.0, 1.0, 0.0];
391        let sim = cosine_similarity(&a, &b);
392        assert!(sim.abs() < 0.0001);
393    }
394}