1use rusqlite::{params, Connection};
7use serde::{Deserialize, Serialize};
8use std::path::Path;
9
10use crate::db::{apply_pragmas, DbResult};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct SemanticResult {
15 pub file_path: String,
16 pub chunk_text: String,
17 pub score: f32,
18 pub rank: usize,
19}
20
21pub struct VectorStore {
23 conn: Connection,
24}
25
26impl VectorStore {
27 pub fn open(db_path: &Path) -> DbResult<Self> {
30 if let Some(parent) = db_path.parent() {
31 std::fs::create_dir_all(parent)?;
32 }
33
34 let conn = Connection::open(db_path)?;
35 apply_pragmas(&conn)?;
36
37 conn.execute_batch(
39 "CREATE TABLE IF NOT EXISTS embeddings (
40 id INTEGER PRIMARY KEY AUTOINCREMENT,
41 file_path TEXT NOT NULL,
42 chunk_text TEXT NOT NULL,
43 embedding BLOB NOT NULL,
44 updated_at TEXT NOT NULL DEFAULT (datetime('now')),
45 UNIQUE(file_path, chunk_text)
46 );
47 CREATE INDEX IF NOT EXISTS idx_embeddings_file_path ON embeddings(file_path);",
48 )?;
49
50 Ok(Self { conn })
51 }
52
53 pub fn from_connection(conn: Connection) -> Self {
56 Self { conn }
57 }
58
59 pub fn store_embedding(
61 &self,
62 file_path: &str,
63 chunk_text: &str,
64 embedding: &[f32],
65 ) -> DbResult<()> {
66 let blob = embedding_to_blob(embedding);
67
68 self.conn.execute(
69 "INSERT INTO embeddings (file_path, chunk_text, embedding, updated_at)
70 VALUES (?1, ?2, ?3, datetime('now'))
71 ON CONFLICT(file_path, chunk_text) DO UPDATE SET
72 embedding = excluded.embedding,
73 updated_at = datetime('now')",
74 params![file_path, chunk_text, blob],
75 )?;
76
77 Ok(())
78 }
79
80 pub fn search_similar(
82 &self,
83 query_embedding: &[f32],
84 limit: usize,
85 ) -> DbResult<Vec<SemanticResult>> {
86 let mut stmt = self.conn.prepare(
87 "SELECT file_path, chunk_text, embedding FROM embeddings",
88 )?;
89
90 let rows = stmt.query_map([], |row| {
91 let file_path: String = row.get(0)?;
92 let chunk_text: String = row.get(1)?;
93 let blob: Vec<u8> = row.get(2)?;
94 Ok((file_path, chunk_text, blob))
95 })?;
96
97 let mut scored: Vec<(String, String, f32)> = Vec::new();
98
99 for row in rows {
100 let (file_path, chunk_text, blob) = row?;
101 let stored_embedding = blob_to_embedding(&blob);
102 let score = cosine_similarity(query_embedding, &stored_embedding);
103 scored.push((file_path, chunk_text, score));
104 }
105
106 scored.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
108
109 let results: Vec<SemanticResult> = scored
110 .into_iter()
111 .take(limit)
112 .enumerate()
113 .map(|(rank, (file_path, chunk_text, score))| SemanticResult {
114 file_path,
115 chunk_text,
116 score,
117 rank: rank + 1,
118 })
119 .collect();
120
121 Ok(results)
122 }
123
124 pub fn clear_embeddings(&self) -> DbResult<()> {
126 self.conn.execute("DELETE FROM embeddings", [])?;
127 Ok(())
128 }
129
130 pub fn get_embedding_count(&self) -> DbResult<usize> {
132 let count: i64 = self.conn.query_row(
133 "SELECT COUNT(*) FROM embeddings",
134 [],
135 |row| row.get(0),
136 )?;
137 Ok(count as usize)
138 }
139
140 pub fn get_indexed_file_count(&self) -> DbResult<usize> {
142 let count: i64 = self.conn.query_row(
143 "SELECT COUNT(DISTINCT file_path) FROM embeddings",
144 [],
145 |row| row.get(0),
146 )?;
147 Ok(count as usize)
148 }
149
150 pub fn delete_file_embeddings(&self, file_path: &str) -> DbResult<()> {
152 self.conn.execute(
153 "DELETE FROM embeddings WHERE file_path = ?1",
154 params![file_path],
155 )?;
156 Ok(())
157 }
158}
159
160pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
164 if a.len() != b.len() || a.is_empty() {
165 return 0.0;
166 }
167
168 let mut dot = 0.0f32;
169 let mut norm_a = 0.0f32;
170 let mut norm_b = 0.0f32;
171
172 for (x, y) in a.iter().zip(b.iter()) {
173 dot += x * y;
174 norm_a += x * x;
175 norm_b += y * y;
176 }
177
178 let denom = norm_a.sqrt() * norm_b.sqrt();
179 if denom == 0.0 {
180 0.0
181 } else {
182 dot / denom
183 }
184}
185
186fn embedding_to_blob(embedding: &[f32]) -> Vec<u8> {
188 let mut bytes = Vec::with_capacity(embedding.len() * 4);
189 for &val in embedding {
190 bytes.extend_from_slice(&val.to_le_bytes());
191 }
192 bytes
193}
194
195fn blob_to_embedding(blob: &[u8]) -> Vec<f32> {
197 blob.chunks_exact(4)
198 .map(|chunk| {
199 let arr: [u8; 4] = [chunk[0], chunk[1], chunk[2], chunk[3]];
200 f32::from_le_bytes(arr)
201 })
202 .collect()
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208
209 #[test]
210 fn test_cosine_similarity_identical() {
211 let a = vec![1.0, 0.0, 0.0];
212 let b = vec![1.0, 0.0, 0.0];
213 let sim = cosine_similarity(&a, &b);
214 assert!((sim - 1.0).abs() < 1e-6, "Identical vectors should have similarity 1.0, got {}", sim);
215 }
216
217 #[test]
218 fn test_cosine_similarity_orthogonal() {
219 let a = vec![1.0, 0.0, 0.0];
220 let b = vec![0.0, 1.0, 0.0];
221 let sim = cosine_similarity(&a, &b);
222 assert!(sim.abs() < 1e-6, "Orthogonal vectors should have similarity 0.0, got {}", sim);
223 }
224
225 #[test]
226 fn test_cosine_similarity_opposite() {
227 let a = vec![1.0, 0.0, 0.0];
228 let b = vec![-1.0, 0.0, 0.0];
229 let sim = cosine_similarity(&a, &b);
230 assert!((sim + 1.0).abs() < 1e-6, "Opposite vectors should have similarity -1.0, got {}", sim);
231 }
232
233 #[test]
234 fn test_cosine_similarity_empty() {
235 let a: Vec<f32> = vec![];
236 let b: Vec<f32> = vec![];
237 assert_eq!(cosine_similarity(&a, &b), 0.0);
238 }
239
240 #[test]
241 fn test_cosine_similarity_different_lengths() {
242 let a = vec![1.0, 0.0];
243 let b = vec![1.0, 0.0, 0.0];
244 assert_eq!(cosine_similarity(&a, &b), 0.0);
245 }
246
247 #[test]
248 fn test_cosine_similarity_real_vectors() {
249 let a = vec![0.5, 0.3, 0.8, 0.1];
250 let b = vec![0.4, 0.35, 0.75, 0.15];
251 let sim = cosine_similarity(&a, &b);
252 assert!(sim > 0.99, "Very similar vectors should have high similarity, got {}", sim);
253 }
254
255 #[test]
256 fn test_embedding_blob_roundtrip() {
257 let original = vec![0.1, 0.2, -0.3, 1.5, 0.0, -0.001];
258 let blob = embedding_to_blob(&original);
259 let recovered = blob_to_embedding(&blob);
260 assert_eq!(original.len(), recovered.len());
261 for (a, b) in original.iter().zip(recovered.iter()) {
262 assert!((a - b).abs() < 1e-7, "Values should match: {} vs {}", a, b);
263 }
264 }
265
266 #[test]
267 fn test_embedding_blob_size() {
268 let embedding = vec![0.0f32; 768]; let blob = embedding_to_blob(&embedding);
270 assert_eq!(blob.len(), 768 * 4); }
272
273 #[test]
274 fn test_vector_store_open_and_count() {
275 let dir = tempfile::tempdir().unwrap();
276 let db_path = dir.path().join("test_embeddings.db");
277 let store = VectorStore::open(&db_path).unwrap();
278 assert_eq!(store.get_embedding_count().unwrap(), 0);
279 }
280
281 #[test]
282 fn test_vector_store_store_and_retrieve() {
283 let dir = tempfile::tempdir().unwrap();
284 let db_path = dir.path().join("test_embeddings.db");
285 let store = VectorStore::open(&db_path).unwrap();
286
287 let embedding = vec![0.1, 0.2, 0.3, 0.4];
288 store.store_embedding("docs/readme.md", "# Introduction", &embedding).unwrap();
289
290 assert_eq!(store.get_embedding_count().unwrap(), 1);
291 assert_eq!(store.get_indexed_file_count().unwrap(), 1);
292 }
293
294 #[test]
295 fn test_vector_store_upsert() {
296 let dir = tempfile::tempdir().unwrap();
297 let db_path = dir.path().join("test_embeddings.db");
298 let store = VectorStore::open(&db_path).unwrap();
299
300 let embedding1 = vec![0.1, 0.2, 0.3, 0.4];
301 let embedding2 = vec![0.5, 0.6, 0.7, 0.8];
302
303 store.store_embedding("docs/readme.md", "# Introduction", &embedding1).unwrap();
304 store.store_embedding("docs/readme.md", "# Introduction", &embedding2).unwrap();
305
306 assert_eq!(store.get_embedding_count().unwrap(), 1);
308 }
309
310 #[test]
311 fn test_vector_store_search_ranking() {
312 let dir = tempfile::tempdir().unwrap();
313 let db_path = dir.path().join("test_embeddings.db");
314 let store = VectorStore::open(&db_path).unwrap();
315
316 let embedding_a = vec![1.0, 0.0, 0.0, 0.0]; let embedding_b = vec![0.0, 1.0, 0.0, 0.0]; let embedding_c = vec![0.9, 0.1, 0.0, 0.0]; store.store_embedding("file_a.md", "File A content", &embedding_a).unwrap();
322 store.store_embedding("file_b.md", "File B content", &embedding_b).unwrap();
323 store.store_embedding("file_c.md", "File C content", &embedding_c).unwrap();
324
325 let query = vec![1.0, 0.0, 0.0, 0.0];
327 let results = store.search_similar(&query, 3).unwrap();
328
329 assert_eq!(results.len(), 3);
330 assert_eq!(results[0].file_path, "file_a.md");
332 assert_eq!(results[0].rank, 1);
333 assert!((results[0].score - 1.0).abs() < 1e-6);
334
335 assert_eq!(results[1].file_path, "file_c.md");
337 assert_eq!(results[1].rank, 2);
338
339 assert_eq!(results[2].file_path, "file_b.md");
341 assert_eq!(results[2].rank, 3);
342 }
343
344 #[test]
345 fn test_vector_store_search_limit() {
346 let dir = tempfile::tempdir().unwrap();
347 let db_path = dir.path().join("test_embeddings.db");
348 let store = VectorStore::open(&db_path).unwrap();
349
350 for i in 0..10 {
351 let embedding = vec![i as f32, 0.0, 0.0, 1.0];
352 store.store_embedding(&format!("file_{}.md", i), &format!("Content {}", i), &embedding).unwrap();
353 }
354
355 let query = vec![5.0, 0.0, 0.0, 1.0];
356 let results = store.search_similar(&query, 3).unwrap();
357 assert_eq!(results.len(), 3);
358 }
359
360 #[test]
361 fn test_vector_store_clear() {
362 let dir = tempfile::tempdir().unwrap();
363 let db_path = dir.path().join("test_embeddings.db");
364 let store = VectorStore::open(&db_path).unwrap();
365
366 store.store_embedding("file.md", "chunk1", &[0.1, 0.2]).unwrap();
367 store.store_embedding("file.md", "chunk2", &[0.3, 0.4]).unwrap();
368 assert_eq!(store.get_embedding_count().unwrap(), 2);
369
370 store.clear_embeddings().unwrap();
371 assert_eq!(store.get_embedding_count().unwrap(), 0);
372 }
373
374 #[test]
375 fn test_vector_store_delete_file() {
376 let dir = tempfile::tempdir().unwrap();
377 let db_path = dir.path().join("test_embeddings.db");
378 let store = VectorStore::open(&db_path).unwrap();
379
380 store.store_embedding("file_a.md", "chunk1", &[0.1, 0.2]).unwrap();
381 store.store_embedding("file_a.md", "chunk2", &[0.3, 0.4]).unwrap();
382 store.store_embedding("file_b.md", "chunk1", &[0.5, 0.6]).unwrap();
383
384 assert_eq!(store.get_embedding_count().unwrap(), 3);
385
386 store.delete_file_embeddings("file_a.md").unwrap();
387 assert_eq!(store.get_embedding_count().unwrap(), 1);
388 assert_eq!(store.get_indexed_file_count().unwrap(), 1);
389 }
390
391 #[test]
392 fn test_vector_store_multiple_chunks_per_file() {
393 let dir = tempfile::tempdir().unwrap();
394 let db_path = dir.path().join("test_embeddings.db");
395 let store = VectorStore::open(&db_path).unwrap();
396
397 store.store_embedding("readme.md", "# Introduction\nWelcome", &[0.1, 0.2, 0.3]).unwrap();
398 store.store_embedding("readme.md", "## Setup\nRun npm install", &[0.4, 0.5, 0.6]).unwrap();
399 store.store_embedding("readme.md", "## Usage\nRun npm start", &[0.7, 0.8, 0.9]).unwrap();
400
401 assert_eq!(store.get_embedding_count().unwrap(), 3);
402 assert_eq!(store.get_indexed_file_count().unwrap(), 1);
403 }
404
405 #[test]
406 fn test_search_empty_store() {
407 let dir = tempfile::tempdir().unwrap();
408 let db_path = dir.path().join("test_embeddings.db");
409 let store = VectorStore::open(&db_path).unwrap();
410
411 let results = store.search_similar(&[0.1, 0.2], 5).unwrap();
412 assert!(results.is_empty());
413 }
414}