1use crate::chunker::Chunk;
14use rusqlite::{params, Connection, Result as SqlResult};
15use std::path::Path;
16
17pub const EMBEDDING_DIM: usize = 384;
19
20pub struct EmbeddingStore {
22 conn: Connection,
23}
24
25impl EmbeddingStore {
26 pub fn new_in_memory() -> SqlResult<Self> {
28 let conn = Connection::open_in_memory()?;
29 Self::init_schema(&conn)?;
30 Ok(Self { conn })
31 }
32
33 pub fn open(path: &Path) -> SqlResult<Self> {
35 let conn = Connection::open(path)?;
36 Self::init_schema(&conn)?;
37 Ok(Self { conn })
38 }
39
40 fn init_schema(conn: &Connection) -> SqlResult<()> {
42 conn.execute(
44 "CREATE TABLE IF NOT EXISTS chunks (
45 id TEXT PRIMARY KEY,
46 parent_id TEXT,
47 content_hash TEXT NOT NULL,
48 profile TEXT NOT NULL,
49 element_type TEXT NOT NULL,
50 content TEXT NOT NULL,
51 token_count INTEGER NOT NULL,
52 metadata JSON,
53 created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
54 )",
55 [],
56 )?;
57
58 conn.execute(
60 "CREATE TABLE IF NOT EXISTS embeddings (
61 chunk_id TEXT PRIMARY KEY,
62 embedding BLOB NOT NULL,
63 norm REAL NOT NULL,
64 FOREIGN KEY (chunk_id) REFERENCES chunks(id) ON DELETE CASCADE
65 )",
66 [],
67 )?;
68
69 conn.execute(
71 "CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5(
72 id,
73 content,
74 element_type,
75 metadata,
76 content='chunks',
77 content_rowid='rowid'
78 )",
79 [],
80 )?;
81
82 conn.execute(
84 "CREATE INDEX IF NOT EXISTS idx_parent ON chunks(parent_id)",
85 [],
86 )?;
87
88 conn.execute(
89 "CREATE INDEX IF NOT EXISTS idx_profile ON chunks(profile, element_type)",
90 [],
91 )?;
92
93 conn.execute(
94 "CREATE INDEX IF NOT EXISTS idx_content_hash ON chunks(content_hash)",
95 [],
96 )?;
97
98 conn.execute(
100 "CREATE TRIGGER IF NOT EXISTS chunks_fts_insert AFTER INSERT ON chunks BEGIN
101 INSERT INTO chunks_fts(rowid, id, content, element_type, metadata)
102 VALUES (new.rowid, new.id, new.content, new.element_type, new.metadata);
103 END",
104 [],
105 )?;
106
107 conn.execute(
108 "CREATE TRIGGER IF NOT EXISTS chunks_fts_delete AFTER DELETE ON chunks BEGIN
109 DELETE FROM chunks_fts WHERE rowid = old.rowid;
110 END",
111 [],
112 )?;
113
114 conn.execute(
115 "CREATE TRIGGER IF NOT EXISTS chunks_fts_update AFTER UPDATE ON chunks BEGIN
116 UPDATE chunks_fts SET
117 id = new.id,
118 content = new.content,
119 element_type = new.element_type,
120 metadata = new.metadata
121 WHERE rowid = new.rowid;
122 END",
123 [],
124 )?;
125
126 Ok(())
127 }
128
129 pub fn insert_chunk(&mut self, chunk: &Chunk, embedding: &[f32]) -> SqlResult<()> {
131 if embedding.len() != EMBEDDING_DIM {
132 return Err(rusqlite::Error::InvalidParameterCount(
133 EMBEDDING_DIM,
134 embedding.len(),
135 ));
136 }
137
138 let metadata_json = serde_json::to_string(&chunk.metadata)
139 .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
140
141 self.conn.execute(
143 "INSERT INTO chunks (id, parent_id, content_hash, profile, element_type, content, token_count, metadata)
144 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
145 params![
146 chunk.id,
147 chunk.parent_id,
148 chunk.content_hash,
149 chunk.profile,
150 chunk.element_type,
151 chunk.content,
152 chunk.token_count,
153 metadata_json,
154 ],
155 )?;
156
157 let embedding_blob = embedding
159 .iter()
160 .flat_map(|f| f.to_le_bytes())
161 .collect::<Vec<u8>>();
162
163 let norm = Self::l2_norm(embedding);
165
166 self.conn.execute(
168 "INSERT INTO embeddings (chunk_id, embedding, norm) VALUES (?1, ?2, ?3)",
169 params![chunk.id, embedding_blob, norm],
170 )?;
171
172 Ok(())
173 }
174
175 pub fn get_chunk(&self, id: &str) -> SqlResult<Option<Chunk>> {
177 let mut stmt = self.conn.prepare(
178 "SELECT id, parent_id, content_hash, profile, element_type, content, token_count, metadata
179 FROM chunks WHERE id = ?1",
180 )?;
181
182 let mut rows = stmt.query(params![id])?;
183
184 if let Some(row) = rows.next()? {
185 let metadata_json: String = row.get(7)?;
186 let metadata = serde_json::from_str(&metadata_json).map_err(|e| {
187 rusqlite::Error::FromSqlConversionFailure(
188 7,
189 rusqlite::types::Type::Text,
190 Box::new(e),
191 )
192 })?;
193
194 Ok(Some(Chunk {
195 id: row.get(0)?,
196 parent_id: row.get(1)?,
197 content_hash: row.get(2)?,
198 profile: row.get(3)?,
199 element_type: row.get(4)?,
200 content: row.get(5)?,
201 token_count: row.get(6)?,
202 metadata,
203 }))
204 } else {
205 Ok(None)
206 }
207 }
208
209 pub fn get_embedding(&self, chunk_id: &str) -> SqlResult<Option<Vec<f32>>> {
211 let mut stmt = self
212 .conn
213 .prepare("SELECT embedding FROM embeddings WHERE chunk_id = ?1")?;
214
215 let mut rows = stmt.query(params![chunk_id])?;
216
217 if let Some(row) = rows.next()? {
218 let blob: Vec<u8> = row.get(0)?;
219 let embedding = Self::blob_to_embedding(&blob)?;
220 Ok(Some(embedding))
221 } else {
222 Ok(None)
223 }
224 }
225
226 pub fn search_keywords(&self, query: &str, limit: usize) -> SqlResult<Vec<ChunkMatch>> {
228 let mut stmt = self.conn.prepare(
229 "SELECT c.id, c.content, c.element_type, c.profile, rank
230 FROM chunks_fts
231 JOIN chunks c ON chunks_fts.rowid = c.rowid
232 WHERE chunks_fts MATCH ?1
233 ORDER BY rank
234 LIMIT ?2",
235 )?;
236
237 let mut rows = stmt.query(params![query, limit as i64])?;
238 let mut matches = Vec::new();
239
240 while let Some(row) = rows.next()? {
241 matches.push(ChunkMatch {
242 id: row.get(0)?,
243 content: row.get(1)?,
244 element_type: row.get(2)?,
245 profile: row.get(3)?,
246 score: row.get::<_, f64>(4)? as f32,
247 match_type: MatchType::Keyword,
248 });
249 }
250
251 Ok(matches)
252 }
253
254 pub fn search_similar(
256 &self,
257 query_embedding: &[f32],
258 limit: usize,
259 ) -> SqlResult<Vec<ChunkMatch>> {
260 if query_embedding.len() != EMBEDDING_DIM {
261 return Err(rusqlite::Error::InvalidParameterCount(
262 EMBEDDING_DIM,
263 query_embedding.len(),
264 ));
265 }
266
267 let query_norm = Self::l2_norm(query_embedding);
268
269 let mut stmt = self.conn.prepare(
270 "SELECT c.id, c.content, c.element_type, c.profile, e.embedding, e.norm
271 FROM chunks c
272 JOIN embeddings e ON c.id = e.chunk_id",
273 )?;
274
275 let mut rows = stmt.query([])?;
276 let mut matches = Vec::new();
277
278 while let Some(row) = rows.next()? {
279 let id: String = row.get(0)?;
280 let content: String = row.get(1)?;
281 let element_type: String = row.get(2)?;
282 let profile: String = row.get(3)?;
283 let embedding_blob: Vec<u8> = row.get(4)?;
284 let norm: f32 = row.get(5)?;
285
286 let embedding = Self::blob_to_embedding(&embedding_blob)?;
287
288 let dot_product: f32 = query_embedding
290 .iter()
291 .zip(&embedding)
292 .map(|(a, b)| a * b)
293 .sum();
294
295 let similarity = dot_product / (query_norm * norm);
296
297 matches.push(ChunkMatch {
298 id,
299 content,
300 element_type,
301 profile,
302 score: similarity,
303 match_type: MatchType::Vector,
304 });
305 }
306
307 matches.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
309 matches.truncate(limit);
310
311 Ok(matches)
312 }
313
314 pub fn hybrid_search(
316 &self,
317 keywords: &str,
318 query_embedding: &[f32],
319 limit: usize,
320 ) -> SqlResult<Vec<ChunkMatch>> {
321 let keyword_matches = self.search_keywords(keywords, limit * 2)?;
323
324 let vector_matches = self.search_similar(query_embedding, limit * 2)?;
326
327 let mut combined = Self::merge_and_rerank(keyword_matches, vector_matches);
329 combined.truncate(limit);
330
331 Ok(combined)
332 }
333
334 fn merge_and_rerank(
336 keyword_matches: Vec<ChunkMatch>,
337 vector_matches: Vec<ChunkMatch>,
338 ) -> Vec<ChunkMatch> {
339 use std::collections::HashMap;
340
341 let mut matches_by_id: HashMap<String, ChunkMatch> = HashMap::new();
342 let mut scores: HashMap<String, (f32, f32)> = HashMap::new(); for m in keyword_matches {
346 scores.entry(m.id.clone()).or_insert((0.0, 0.0)).0 = m.score.abs(); matches_by_id.insert(m.id.clone(), m);
348 }
349
350 for m in vector_matches {
352 scores.entry(m.id.clone()).or_insert((0.0, 0.0)).1 = m.score;
353 matches_by_id.entry(m.id.clone()).or_insert(m);
354 }
355
356 let mut combined: Vec<_> = scores
358 .into_iter()
359 .filter_map(|(id, (kw_score, vec_score))| {
360 let combined_score = 0.3 * kw_score + 0.7 * vec_score;
361 matches_by_id.get(&id).map(|m| {
362 let mut new_match = m.clone();
363 new_match.score = combined_score;
364 new_match.match_type = MatchType::Hybrid;
365 new_match
366 })
367 })
368 .collect();
369
370 combined.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
371 combined
372 }
373
374 pub fn get_children(&self, parent_id: &str) -> SqlResult<Vec<Chunk>> {
376 let mut stmt = self.conn.prepare(
377 "SELECT id, parent_id, content_hash, profile, element_type, content, token_count, metadata
378 FROM chunks WHERE parent_id = ?1
379 ORDER BY id",
380 )?;
381
382 let mut rows = stmt.query(params![parent_id])?;
383 let mut children = Vec::new();
384
385 while let Some(row) = rows.next()? {
386 let metadata_json: String = row.get(7)?;
387 let metadata = serde_json::from_str(&metadata_json).map_err(|e| {
388 rusqlite::Error::FromSqlConversionFailure(
389 7,
390 rusqlite::types::Type::Text,
391 Box::new(e),
392 )
393 })?;
394
395 children.push(Chunk {
396 id: row.get(0)?,
397 parent_id: row.get(1)?,
398 content_hash: row.get(2)?,
399 profile: row.get(3)?,
400 element_type: row.get(4)?,
401 content: row.get(5)?,
402 token_count: row.get(6)?,
403 metadata,
404 });
405 }
406
407 Ok(children)
408 }
409
410 pub fn count_chunks(&self) -> SqlResult<usize> {
412 let count: i64 = self
413 .conn
414 .query_row("SELECT COUNT(*) FROM chunks", [], |row| row.get(0))?;
415 Ok(count as usize)
416 }
417
418 fn l2_norm(vec: &[f32]) -> f32 {
420 vec.iter().map(|x| x * x).sum::<f32>().sqrt()
421 }
422
423 fn blob_to_embedding(blob: &[u8]) -> SqlResult<Vec<f32>> {
425 if blob.len() != EMBEDDING_DIM * 4 {
426 return Err(rusqlite::Error::InvalidColumnType(
427 0,
428 "Embedding BLOB".to_string(),
429 rusqlite::types::Type::Blob,
430 ));
431 }
432
433 let embedding = blob
434 .chunks_exact(4)
435 .map(|chunk| {
436 let bytes = [chunk[0], chunk[1], chunk[2], chunk[3]];
437 f32::from_le_bytes(bytes)
438 })
439 .collect();
440
441 Ok(embedding)
442 }
443}
444
445#[derive(Debug, Clone, PartialEq)]
447pub struct ChunkMatch {
448 pub id: String,
449 pub content: String,
450 pub element_type: String,
451 pub profile: String,
452 pub score: f32,
453 pub match_type: MatchType,
454}
455
456#[derive(Debug, Clone, Copy, PartialEq, Eq)]
458pub enum MatchType {
459 Keyword,
460 Vector,
461 Hybrid,
462}
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467 use crate::id_generator::ElementId;
468 use std::collections::HashMap;
469
470 fn create_test_chunk(id: &str, content: &str) -> Chunk {
471 Chunk {
472 id: id.to_string(),
473 parent_id: None,
474 content_hash: ElementId::new(id, content).content_hash,
475 profile: "code:api".to_string(),
476 element_type: "function".to_string(),
477 content: content.to_string(),
478 token_count: content.len() / 4,
479 metadata: HashMap::new(),
480 }
481 }
482
483 fn create_test_embedding() -> Vec<f32> {
484 vec![0.1; EMBEDDING_DIM]
485 }
486
487 #[test]
488 fn test_create_store() {
489 let store = EmbeddingStore::new_in_memory();
490 assert!(store.is_ok());
491 }
492
493 #[test]
494 fn test_insert_and_get_chunk() {
495 let mut store = EmbeddingStore::new_in_memory().unwrap();
496 let chunk = create_test_chunk("test.id", "Test content");
497 let embedding = create_test_embedding();
498
499 store.insert_chunk(&chunk, &embedding).unwrap();
500
501 let retrieved = store.get_chunk("test.id").unwrap();
502 assert!(retrieved.is_some());
503 assert_eq!(retrieved.unwrap().content, "Test content");
504 }
505
506 #[test]
507 fn test_get_embedding() {
508 let mut store = EmbeddingStore::new_in_memory().unwrap();
509 let chunk = create_test_chunk("test.id", "Test content");
510 let embedding = create_test_embedding();
511
512 store.insert_chunk(&chunk, &embedding).unwrap();
513
514 let retrieved_emb = store.get_embedding("test.id").unwrap();
515 assert!(retrieved_emb.is_some());
516 assert_eq!(retrieved_emb.unwrap().len(), EMBEDDING_DIM);
517 }
518
519 #[test]
520 fn test_fts_search() {
521 let mut store = EmbeddingStore::new_in_memory().unwrap();
522
523 let chunk1 = create_test_chunk("test.1", "Vector push method");
524 let chunk2 = create_test_chunk("test.2", "HashMap insert function");
525 let embedding = create_test_embedding();
526
527 store.insert_chunk(&chunk1, &embedding).unwrap();
528 store.insert_chunk(&chunk2, &embedding).unwrap();
529
530 let results = store.search_keywords("vector", 10).unwrap();
531 assert_eq!(results.len(), 1);
532 assert_eq!(results[0].id, "test.1");
533 }
534
535 #[test]
536 fn test_vector_similarity() {
537 let mut store = EmbeddingStore::new_in_memory().unwrap();
538
539 let chunk = create_test_chunk("test.id", "Test content");
540 let embedding = create_test_embedding();
541
542 store.insert_chunk(&chunk, &embedding).unwrap();
543
544 let results = store.search_similar(&embedding, 10).unwrap();
546 assert_eq!(results.len(), 1);
547 assert!((results[0].score - 1.0).abs() < 0.01);
548 }
549
550 #[test]
551 fn test_hybrid_search() {
552 let mut store = EmbeddingStore::new_in_memory().unwrap();
553
554 let chunk1 = create_test_chunk("test.1", "Vector push method adds items");
555 let chunk2 = create_test_chunk("test.2", "HashMap insert stores key-value pairs");
556 let embedding = create_test_embedding();
557
558 store.insert_chunk(&chunk1, &embedding).unwrap();
559 store.insert_chunk(&chunk2, &embedding).unwrap();
560
561 let results = store.hybrid_search("vector", &embedding, 10).unwrap();
562 assert!(results.len() > 0);
563 assert_eq!(results[0].match_type, MatchType::Hybrid);
564 }
565
566 #[test]
567 fn test_parent_child_relationship() {
568 let mut store = EmbeddingStore::new_in_memory().unwrap();
569
570 let parent = create_test_chunk("parent.id", "Parent content");
571 let mut child = create_test_chunk("parent.id#0", "Child content");
572 child.parent_id = Some("parent.id".to_string());
573
574 let embedding = create_test_embedding();
575
576 store.insert_chunk(&parent, &embedding).unwrap();
577 store.insert_chunk(&child, &embedding).unwrap();
578
579 let children = store.get_children("parent.id").unwrap();
580 assert_eq!(children.len(), 1);
581 assert_eq!(children[0].id, "parent.id#0");
582 }
583
584 #[test]
585 fn test_count_chunks() {
586 let mut store = EmbeddingStore::new_in_memory().unwrap();
587 let embedding = create_test_embedding();
588
589 assert_eq!(store.count_chunks().unwrap(), 0);
590
591 store
592 .insert_chunk(&create_test_chunk("test.1", "Content 1"), &embedding)
593 .unwrap();
594 store
595 .insert_chunk(&create_test_chunk("test.2", "Content 2"), &embedding)
596 .unwrap();
597
598 assert_eq!(store.count_chunks().unwrap(), 2);
599 }
600}