1use rusqlite::{params, Connection, Result as SqlResult};
14use crate::chunker::Chunk;
15use std::collections::HashMap;
16use std::path::Path;
17
18pub const EMBEDDING_DIM: usize = 384;
20
21pub struct EmbeddingStore {
23 conn: Connection,
24}
25
26impl EmbeddingStore {
27 pub fn new_in_memory() -> SqlResult<Self> {
29 let conn = Connection::open_in_memory()?;
30 Self::init_schema(&conn)?;
31 Ok(Self { conn })
32 }
33
34 pub fn open(path: &Path) -> SqlResult<Self> {
36 let conn = Connection::open(path)?;
37 Self::init_schema(&conn)?;
38 Ok(Self { conn })
39 }
40
41 fn init_schema(conn: &Connection) -> SqlResult<()> {
43 conn.execute(
45 "CREATE TABLE IF NOT EXISTS chunks (
46 id TEXT PRIMARY KEY,
47 parent_id TEXT,
48 content_hash TEXT NOT NULL,
49 profile TEXT NOT NULL,
50 element_type TEXT NOT NULL,
51 content TEXT NOT NULL,
52 token_count INTEGER NOT NULL,
53 metadata JSON,
54 created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
55 )",
56 [],
57 )?;
58
59 conn.execute(
61 "CREATE TABLE IF NOT EXISTS embeddings (
62 chunk_id TEXT PRIMARY KEY,
63 embedding BLOB NOT NULL,
64 norm REAL NOT NULL,
65 FOREIGN KEY (chunk_id) REFERENCES chunks(id) ON DELETE CASCADE
66 )",
67 [],
68 )?;
69
70 conn.execute(
72 "CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5(
73 id,
74 content,
75 element_type,
76 metadata,
77 content='chunks',
78 content_rowid='rowid'
79 )",
80 [],
81 )?;
82
83 conn.execute(
85 "CREATE INDEX IF NOT EXISTS idx_parent ON chunks(parent_id)",
86 [],
87 )?;
88
89 conn.execute(
90 "CREATE INDEX IF NOT EXISTS idx_profile ON chunks(profile, element_type)",
91 [],
92 )?;
93
94 conn.execute(
95 "CREATE INDEX IF NOT EXISTS idx_content_hash ON chunks(content_hash)",
96 [],
97 )?;
98
99 conn.execute(
101 "CREATE TRIGGER IF NOT EXISTS chunks_fts_insert AFTER INSERT ON chunks BEGIN
102 INSERT INTO chunks_fts(rowid, id, content, element_type, metadata)
103 VALUES (new.rowid, new.id, new.content, new.element_type, new.metadata);
104 END",
105 [],
106 )?;
107
108 conn.execute(
109 "CREATE TRIGGER IF NOT EXISTS chunks_fts_delete AFTER DELETE ON chunks BEGIN
110 DELETE FROM chunks_fts WHERE rowid = old.rowid;
111 END",
112 [],
113 )?;
114
115 conn.execute(
116 "CREATE TRIGGER IF NOT EXISTS chunks_fts_update AFTER UPDATE ON chunks BEGIN
117 UPDATE chunks_fts SET
118 id = new.id,
119 content = new.content,
120 element_type = new.element_type,
121 metadata = new.metadata
122 WHERE rowid = new.rowid;
123 END",
124 [],
125 )?;
126
127 Ok(())
128 }
129
130 pub fn insert_chunk(&mut self, chunk: &Chunk, embedding: &[f32]) -> SqlResult<()> {
132 if embedding.len() != EMBEDDING_DIM {
133 return Err(rusqlite::Error::InvalidParameterCount(
134 EMBEDDING_DIM,
135 embedding.len(),
136 ));
137 }
138
139 let metadata_json = serde_json::to_string(&chunk.metadata)
140 .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
141
142 self.conn.execute(
144 "INSERT INTO chunks (id, parent_id, content_hash, profile, element_type, content, token_count, metadata)
145 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
146 params![
147 chunk.id,
148 chunk.parent_id,
149 chunk.content_hash,
150 chunk.profile,
151 chunk.element_type,
152 chunk.content,
153 chunk.token_count,
154 metadata_json,
155 ],
156 )?;
157
158 let embedding_blob = embedding
160 .iter()
161 .flat_map(|f| f.to_le_bytes())
162 .collect::<Vec<u8>>();
163
164 let norm = Self::l2_norm(embedding);
166
167 self.conn.execute(
169 "INSERT INTO embeddings (chunk_id, embedding, norm) VALUES (?1, ?2, ?3)",
170 params![chunk.id, embedding_blob, norm],
171 )?;
172
173 Ok(())
174 }
175
176 pub fn get_chunk(&self, id: &str) -> SqlResult<Option<Chunk>> {
178 let mut stmt = self.conn.prepare(
179 "SELECT id, parent_id, content_hash, profile, element_type, content, token_count, metadata
180 FROM chunks WHERE id = ?1",
181 )?;
182
183 let mut rows = stmt.query(params![id])?;
184
185 if let Some(row) = rows.next()? {
186 let metadata_json: String = row.get(7)?;
187 let metadata = serde_json::from_str(&metadata_json).map_err(|e| {
188 rusqlite::Error::FromSqlConversionFailure(
189 7,
190 rusqlite::types::Type::Text,
191 Box::new(e),
192 )
193 })?;
194
195 Ok(Some(Chunk {
196 id: row.get(0)?,
197 parent_id: row.get(1)?,
198 content_hash: row.get(2)?,
199 profile: row.get(3)?,
200 element_type: row.get(4)?,
201 content: row.get(5)?,
202 token_count: row.get(6)?,
203 metadata,
204 }))
205 } else {
206 Ok(None)
207 }
208 }
209
210 pub fn get_embedding(&self, chunk_id: &str) -> SqlResult<Option<Vec<f32>>> {
212 let mut stmt = self
213 .conn
214 .prepare("SELECT embedding FROM embeddings WHERE chunk_id = ?1")?;
215
216 let mut rows = stmt.query(params![chunk_id])?;
217
218 if let Some(row) = rows.next()? {
219 let blob: Vec<u8> = row.get(0)?;
220 let embedding = Self::blob_to_embedding(&blob)?;
221 Ok(Some(embedding))
222 } else {
223 Ok(None)
224 }
225 }
226
227 pub fn search_keywords(&self, query: &str, limit: usize) -> SqlResult<Vec<ChunkMatch>> {
229 let mut stmt = self.conn.prepare(
230 "SELECT c.id, c.content, c.element_type, c.profile, rank
231 FROM chunks_fts
232 JOIN chunks c ON chunks_fts.rowid = c.rowid
233 WHERE chunks_fts MATCH ?1
234 ORDER BY rank
235 LIMIT ?2",
236 )?;
237
238 let mut rows = stmt.query(params![query, limit as i64])?;
239 let mut matches = Vec::new();
240
241 while let Some(row) = rows.next()? {
242 matches.push(ChunkMatch {
243 id: row.get(0)?,
244 content: row.get(1)?,
245 element_type: row.get(2)?,
246 profile: row.get(3)?,
247 score: row.get::<_, f64>(4)? as f32,
248 match_type: MatchType::Keyword,
249 });
250 }
251
252 Ok(matches)
253 }
254
255 pub fn search_similar(
257 &self,
258 query_embedding: &[f32],
259 limit: usize,
260 ) -> SqlResult<Vec<ChunkMatch>> {
261 if query_embedding.len() != EMBEDDING_DIM {
262 return Err(rusqlite::Error::InvalidParameterCount(
263 EMBEDDING_DIM,
264 query_embedding.len(),
265 ));
266 }
267
268 let query_norm = Self::l2_norm(query_embedding);
269
270 let mut stmt = self.conn.prepare(
271 "SELECT c.id, c.content, c.element_type, c.profile, e.embedding, e.norm
272 FROM chunks c
273 JOIN embeddings e ON c.id = e.chunk_id",
274 )?;
275
276 let mut rows = stmt.query([])?;
277 let mut matches = Vec::new();
278
279 while let Some(row) = rows.next()? {
280 let id: String = row.get(0)?;
281 let content: String = row.get(1)?;
282 let element_type: String = row.get(2)?;
283 let profile: String = row.get(3)?;
284 let embedding_blob: Vec<u8> = row.get(4)?;
285 let norm: f32 = row.get(5)?;
286
287 let embedding = Self::blob_to_embedding(&embedding_blob)?;
288
289 let dot_product: f32 = query_embedding
291 .iter()
292 .zip(&embedding)
293 .map(|(a, b)| a * b)
294 .sum();
295
296 let similarity = dot_product / (query_norm * norm);
297
298 matches.push(ChunkMatch {
299 id,
300 content,
301 element_type,
302 profile,
303 score: similarity,
304 match_type: MatchType::Vector,
305 });
306 }
307
308 matches.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
310 matches.truncate(limit);
311
312 Ok(matches)
313 }
314
315 pub fn hybrid_search(
317 &self,
318 keywords: &str,
319 query_embedding: &[f32],
320 limit: usize,
321 ) -> SqlResult<Vec<ChunkMatch>> {
322 let keyword_matches = self.search_keywords(keywords, limit * 2)?;
324
325 let vector_matches = self.search_similar(query_embedding, limit * 2)?;
327
328 let mut combined = Self::merge_and_rerank(keyword_matches, vector_matches);
330 combined.truncate(limit);
331
332 Ok(combined)
333 }
334
335 fn merge_and_rerank(
337 keyword_matches: Vec<ChunkMatch>,
338 vector_matches: Vec<ChunkMatch>,
339 ) -> Vec<ChunkMatch> {
340 use std::collections::HashMap;
341
342 let mut matches_by_id: HashMap<String, ChunkMatch> = HashMap::new();
343 let mut scores: HashMap<String, (f32, f32)> = HashMap::new(); for m in keyword_matches {
347 scores.entry(m.id.clone()).or_insert((0.0, 0.0)).0 = m.score.abs(); matches_by_id.insert(m.id.clone(), m);
349 }
350
351 for m in vector_matches {
353 scores.entry(m.id.clone()).or_insert((0.0, 0.0)).1 = m.score;
354 matches_by_id.entry(m.id.clone()).or_insert(m);
355 }
356
357 let mut combined: Vec<_> = scores
359 .into_iter()
360 .filter_map(|(id, (kw_score, vec_score))| {
361 let combined_score = 0.3 * kw_score + 0.7 * vec_score;
362 matches_by_id.get(&id).map(|m| {
363 let mut new_match = m.clone();
364 new_match.score = combined_score;
365 new_match.match_type = MatchType::Hybrid;
366 new_match
367 })
368 })
369 .collect();
370
371 combined.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
372 combined
373 }
374
375 pub fn get_children(&self, parent_id: &str) -> SqlResult<Vec<Chunk>> {
377 let mut stmt = self.conn.prepare(
378 "SELECT id, parent_id, content_hash, profile, element_type, content, token_count, metadata
379 FROM chunks WHERE parent_id = ?1
380 ORDER BY id",
381 )?;
382
383 let mut rows = stmt.query(params![parent_id])?;
384 let mut children = Vec::new();
385
386 while let Some(row) = rows.next()? {
387 let metadata_json: String = row.get(7)?;
388 let metadata = serde_json::from_str(&metadata_json).map_err(|e| {
389 rusqlite::Error::FromSqlConversionFailure(
390 7,
391 rusqlite::types::Type::Text,
392 Box::new(e),
393 )
394 })?;
395
396 children.push(Chunk {
397 id: row.get(0)?,
398 parent_id: row.get(1)?,
399 content_hash: row.get(2)?,
400 profile: row.get(3)?,
401 element_type: row.get(4)?,
402 content: row.get(5)?,
403 token_count: row.get(6)?,
404 metadata,
405 });
406 }
407
408 Ok(children)
409 }
410
411 pub fn count_chunks(&self) -> SqlResult<usize> {
413 let count: i64 = self
414 .conn
415 .query_row("SELECT COUNT(*) FROM chunks", [], |row| row.get(0))?;
416 Ok(count as usize)
417 }
418
419 fn l2_norm(vec: &[f32]) -> f32 {
421 vec.iter().map(|x| x * x).sum::<f32>().sqrt()
422 }
423
424 fn blob_to_embedding(blob: &[u8]) -> SqlResult<Vec<f32>> {
426 if blob.len() != EMBEDDING_DIM * 4 {
427 return Err(rusqlite::Error::InvalidColumnType(
428 0,
429 "Embedding BLOB".to_string(),
430 rusqlite::types::Type::Blob,
431 ));
432 }
433
434 let embedding = blob
435 .chunks_exact(4)
436 .map(|chunk| {
437 let bytes = [chunk[0], chunk[1], chunk[2], chunk[3]];
438 f32::from_le_bytes(bytes)
439 })
440 .collect();
441
442 Ok(embedding)
443 }
444}
445
446#[derive(Debug, Clone, PartialEq)]
448pub struct ChunkMatch {
449 pub id: String,
450 pub content: String,
451 pub element_type: String,
452 pub profile: String,
453 pub score: f32,
454 pub match_type: MatchType,
455}
456
457#[derive(Debug, Clone, Copy, PartialEq, Eq)]
459pub enum MatchType {
460 Keyword,
461 Vector,
462 Hybrid,
463}
464
465#[cfg(test)]
466mod tests {
467 use super::*;
468 use crate::id_generator::ElementId;
469 use std::collections::HashMap;
470
471 fn create_test_chunk(id: &str, content: &str) -> Chunk {
472 Chunk {
473 id: id.to_string(),
474 parent_id: None,
475 content_hash: ElementId::new(id, content).content_hash,
476 profile: "code:api".to_string(),
477 element_type: "function".to_string(),
478 content: content.to_string(),
479 token_count: content.len() / 4,
480 metadata: HashMap::new(),
481 }
482 }
483
484 fn create_test_embedding() -> Vec<f32> {
485 vec![0.1; EMBEDDING_DIM]
486 }
487
488 #[test]
489 fn test_create_store() {
490 let store = EmbeddingStore::new_in_memory();
491 assert!(store.is_ok());
492 }
493
494 #[test]
495 fn test_insert_and_get_chunk() {
496 let mut store = EmbeddingStore::new_in_memory().unwrap();
497 let chunk = create_test_chunk("test.id", "Test content");
498 let embedding = create_test_embedding();
499
500 store.insert_chunk(&chunk, &embedding).unwrap();
501
502 let retrieved = store.get_chunk("test.id").unwrap();
503 assert!(retrieved.is_some());
504 assert_eq!(retrieved.unwrap().content, "Test content");
505 }
506
507 #[test]
508 fn test_get_embedding() {
509 let mut store = EmbeddingStore::new_in_memory().unwrap();
510 let chunk = create_test_chunk("test.id", "Test content");
511 let embedding = create_test_embedding();
512
513 store.insert_chunk(&chunk, &embedding).unwrap();
514
515 let retrieved_emb = store.get_embedding("test.id").unwrap();
516 assert!(retrieved_emb.is_some());
517 assert_eq!(retrieved_emb.unwrap().len(), EMBEDDING_DIM);
518 }
519
520 #[test]
521 fn test_fts_search() {
522 let mut store = EmbeddingStore::new_in_memory().unwrap();
523
524 let chunk1 = create_test_chunk("test.1", "Vector push method");
525 let chunk2 = create_test_chunk("test.2", "HashMap insert function");
526 let embedding = create_test_embedding();
527
528 store.insert_chunk(&chunk1, &embedding).unwrap();
529 store.insert_chunk(&chunk2, &embedding).unwrap();
530
531 let results = store.search_keywords("vector", 10).unwrap();
532 assert_eq!(results.len(), 1);
533 assert_eq!(results[0].id, "test.1");
534 }
535
536 #[test]
537 fn test_vector_similarity() {
538 let mut store = EmbeddingStore::new_in_memory().unwrap();
539
540 let chunk = create_test_chunk("test.id", "Test content");
541 let embedding = create_test_embedding();
542
543 store.insert_chunk(&chunk, &embedding).unwrap();
544
545 let results = store.search_similar(&embedding, 10).unwrap();
547 assert_eq!(results.len(), 1);
548 assert!((results[0].score - 1.0).abs() < 0.01);
549 }
550
551 #[test]
552 fn test_hybrid_search() {
553 let mut store = EmbeddingStore::new_in_memory().unwrap();
554
555 let chunk1 = create_test_chunk("test.1", "Vector push method adds items");
556 let chunk2 = create_test_chunk("test.2", "HashMap insert stores key-value pairs");
557 let embedding = create_test_embedding();
558
559 store.insert_chunk(&chunk1, &embedding).unwrap();
560 store.insert_chunk(&chunk2, &embedding).unwrap();
561
562 let results = store.hybrid_search("vector", &embedding, 10).unwrap();
563 assert!(results.len() > 0);
564 assert_eq!(results[0].match_type, MatchType::Hybrid);
565 }
566
567 #[test]
568 fn test_parent_child_relationship() {
569 let mut store = EmbeddingStore::new_in_memory().unwrap();
570
571 let parent = create_test_chunk("parent.id", "Parent content");
572 let mut child = create_test_chunk("parent.id#0", "Child content");
573 child.parent_id = Some("parent.id".to_string());
574
575 let embedding = create_test_embedding();
576
577 store.insert_chunk(&parent, &embedding).unwrap();
578 store.insert_chunk(&child, &embedding).unwrap();
579
580 let children = store.get_children("parent.id").unwrap();
581 assert_eq!(children.len(), 1);
582 assert_eq!(children[0].id, "parent.id#0");
583 }
584
585 #[test]
586 fn test_count_chunks() {
587 let mut store = EmbeddingStore::new_in_memory().unwrap();
588 let embedding = create_test_embedding();
589
590 assert_eq!(store.count_chunks().unwrap(), 0);
591
592 store
593 .insert_chunk(&create_test_chunk("test.1", "Content 1"), &embedding)
594 .unwrap();
595 store
596 .insert_chunk(&create_test_chunk("test.2", "Content 2"), &embedding)
597 .unwrap();
598
599 assert_eq!(store.count_chunks().unwrap(), 2);
600 }
601}