cml_rs/
embedding_store.rs

1//! SQLite-based embedding store with FTS5 hybrid search
2//!
3//! This module provides a high-performance embedding lookup table that combines:
4//! - Full-text search (FTS5) for keyword matching
5//! - Vector similarity search for semantic matching
6//! - Hierarchical parent-child relationships
7//!
8//! The hybrid approach dramatically outperforms either method alone:
9//! - FTS5 provides high precision on exact keywords
10//! - Vector search provides high recall on semantic similarity
11//! - Combined: Best of both worlds
12
13use rusqlite::{params, Connection, Result as SqlResult};
14use crate::chunker::Chunk;
15use std::collections::HashMap;
16use std::path::Path;
17
18/// Dimension of embedding vectors (matches sentence-transformers MiniLM-L6-v2).
19pub const EMBEDDING_DIM: usize = 384;
20
21/// SQLite embedding store with FTS5 hybrid search
22pub struct EmbeddingStore {
23    conn: Connection,
24}
25
26impl EmbeddingStore {
27    /// Create a new embedding store (in-memory for testing)
28    pub fn new_in_memory() -> SqlResult<Self> {
29        let conn = Connection::open_in_memory()?;
30        Self::init_schema(&conn)?;
31        Ok(Self { conn })
32    }
33
34    /// Open or create an embedding store from file
35    pub fn open(path: &Path) -> SqlResult<Self> {
36        let conn = Connection::open(path)?;
37        Self::init_schema(&conn)?;
38        Ok(Self { conn })
39    }
40
41    /// Initialize database schema
42    fn init_schema(conn: &Connection) -> SqlResult<()> {
43        // Main chunks table
44        conn.execute(
45            "CREATE TABLE IF NOT EXISTS chunks (
46                id TEXT PRIMARY KEY,
47                parent_id TEXT,
48                content_hash TEXT NOT NULL,
49                profile TEXT NOT NULL,
50                element_type TEXT NOT NULL,
51                content TEXT NOT NULL,
52                token_count INTEGER NOT NULL,
53                metadata JSON,
54                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
55            )",
56            [],
57        )?;
58
59        // Embeddings table (384-dim f32 vectors as BLOB)
60        conn.execute(
61            "CREATE TABLE IF NOT EXISTS embeddings (
62                chunk_id TEXT PRIMARY KEY,
63                embedding BLOB NOT NULL,
64                norm REAL NOT NULL,
65                FOREIGN KEY (chunk_id) REFERENCES chunks(id) ON DELETE CASCADE
66            )",
67            [],
68        )?;
69
70        // FTS5 virtual table for full-text search
71        conn.execute(
72            "CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5(
73                id,
74                content,
75                element_type,
76                metadata,
77                content='chunks',
78                content_rowid='rowid'
79            )",
80            [],
81        )?;
82
83        // Indexes for fast lookups
84        conn.execute(
85            "CREATE INDEX IF NOT EXISTS idx_parent ON chunks(parent_id)",
86            [],
87        )?;
88
89        conn.execute(
90            "CREATE INDEX IF NOT EXISTS idx_profile ON chunks(profile, element_type)",
91            [],
92        )?;
93
94        conn.execute(
95            "CREATE INDEX IF NOT EXISTS idx_content_hash ON chunks(content_hash)",
96            [],
97        )?;
98
99        // Trigger to keep FTS5 in sync
100        conn.execute(
101            "CREATE TRIGGER IF NOT EXISTS chunks_fts_insert AFTER INSERT ON chunks BEGIN
102                INSERT INTO chunks_fts(rowid, id, content, element_type, metadata)
103                VALUES (new.rowid, new.id, new.content, new.element_type, new.metadata);
104            END",
105            [],
106        )?;
107
108        conn.execute(
109            "CREATE TRIGGER IF NOT EXISTS chunks_fts_delete AFTER DELETE ON chunks BEGIN
110                DELETE FROM chunks_fts WHERE rowid = old.rowid;
111            END",
112            [],
113        )?;
114
115        conn.execute(
116            "CREATE TRIGGER IF NOT EXISTS chunks_fts_update AFTER UPDATE ON chunks BEGIN
117                UPDATE chunks_fts SET
118                    id = new.id,
119                    content = new.content,
120                    element_type = new.element_type,
121                    metadata = new.metadata
122                WHERE rowid = new.rowid;
123            END",
124            [],
125        )?;
126
127        Ok(())
128    }
129
130    /// Insert a chunk with its embedding
131    pub fn insert_chunk(&mut self, chunk: &Chunk, embedding: &[f32]) -> SqlResult<()> {
132        if embedding.len() != EMBEDDING_DIM {
133            return Err(rusqlite::Error::InvalidParameterCount(
134                EMBEDDING_DIM,
135                embedding.len(),
136            ));
137        }
138
139        let metadata_json = serde_json::to_string(&chunk.metadata)
140            .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
141
142        // Insert chunk
143        self.conn.execute(
144            "INSERT INTO chunks (id, parent_id, content_hash, profile, element_type, content, token_count, metadata)
145             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
146            params![
147                chunk.id,
148                chunk.parent_id,
149                chunk.content_hash,
150                chunk.profile,
151                chunk.element_type,
152                chunk.content,
153                chunk.token_count,
154                metadata_json,
155            ],
156        )?;
157
158        // Convert embedding to BLOB (4 bytes per f32)
159        let embedding_blob = embedding
160            .iter()
161            .flat_map(|f| f.to_le_bytes())
162            .collect::<Vec<u8>>();
163
164        // Calculate L2 norm
165        let norm = Self::l2_norm(embedding);
166
167        // Insert embedding
168        self.conn.execute(
169            "INSERT INTO embeddings (chunk_id, embedding, norm) VALUES (?1, ?2, ?3)",
170            params![chunk.id, embedding_blob, norm],
171        )?;
172
173        Ok(())
174    }
175
176    /// Get a chunk by ID
177    pub fn get_chunk(&self, id: &str) -> SqlResult<Option<Chunk>> {
178        let mut stmt = self.conn.prepare(
179            "SELECT id, parent_id, content_hash, profile, element_type, content, token_count, metadata
180             FROM chunks WHERE id = ?1",
181        )?;
182
183        let mut rows = stmt.query(params![id])?;
184
185        if let Some(row) = rows.next()? {
186            let metadata_json: String = row.get(7)?;
187            let metadata = serde_json::from_str(&metadata_json).map_err(|e| {
188                rusqlite::Error::FromSqlConversionFailure(
189                    7,
190                    rusqlite::types::Type::Text,
191                    Box::new(e),
192                )
193            })?;
194
195            Ok(Some(Chunk {
196                id: row.get(0)?,
197                parent_id: row.get(1)?,
198                content_hash: row.get(2)?,
199                profile: row.get(3)?,
200                element_type: row.get(4)?,
201                content: row.get(5)?,
202                token_count: row.get(6)?,
203                metadata,
204            }))
205        } else {
206            Ok(None)
207        }
208    }
209
210    /// Get embedding for a chunk
211    pub fn get_embedding(&self, chunk_id: &str) -> SqlResult<Option<Vec<f32>>> {
212        let mut stmt = self
213            .conn
214            .prepare("SELECT embedding FROM embeddings WHERE chunk_id = ?1")?;
215
216        let mut rows = stmt.query(params![chunk_id])?;
217
218        if let Some(row) = rows.next()? {
219            let blob: Vec<u8> = row.get(0)?;
220            let embedding = Self::blob_to_embedding(&blob)?;
221            Ok(Some(embedding))
222        } else {
223            Ok(None)
224        }
225    }
226
227    /// FTS5 keyword search
228    pub fn search_keywords(&self, query: &str, limit: usize) -> SqlResult<Vec<ChunkMatch>> {
229        let mut stmt = self.conn.prepare(
230            "SELECT c.id, c.content, c.element_type, c.profile, rank
231             FROM chunks_fts
232             JOIN chunks c ON chunks_fts.rowid = c.rowid
233             WHERE chunks_fts MATCH ?1
234             ORDER BY rank
235             LIMIT ?2",
236        )?;
237
238        let mut rows = stmt.query(params![query, limit as i64])?;
239        let mut matches = Vec::new();
240
241        while let Some(row) = rows.next()? {
242            matches.push(ChunkMatch {
243                id: row.get(0)?,
244                content: row.get(1)?,
245                element_type: row.get(2)?,
246                profile: row.get(3)?,
247                score: row.get::<_, f64>(4)? as f32,
248                match_type: MatchType::Keyword,
249            });
250        }
251
252        Ok(matches)
253    }
254
255    /// Vector similarity search (brute force for now, fast enough for <100K chunks)
256    pub fn search_similar(
257        &self,
258        query_embedding: &[f32],
259        limit: usize,
260    ) -> SqlResult<Vec<ChunkMatch>> {
261        if query_embedding.len() != EMBEDDING_DIM {
262            return Err(rusqlite::Error::InvalidParameterCount(
263                EMBEDDING_DIM,
264                query_embedding.len(),
265            ));
266        }
267
268        let query_norm = Self::l2_norm(query_embedding);
269
270        let mut stmt = self.conn.prepare(
271            "SELECT c.id, c.content, c.element_type, c.profile, e.embedding, e.norm
272             FROM chunks c
273             JOIN embeddings e ON c.id = e.chunk_id",
274        )?;
275
276        let mut rows = stmt.query([])?;
277        let mut matches = Vec::new();
278
279        while let Some(row) = rows.next()? {
280            let id: String = row.get(0)?;
281            let content: String = row.get(1)?;
282            let element_type: String = row.get(2)?;
283            let profile: String = row.get(3)?;
284            let embedding_blob: Vec<u8> = row.get(4)?;
285            let norm: f32 = row.get(5)?;
286
287            let embedding = Self::blob_to_embedding(&embedding_blob)?;
288
289            // Cosine similarity = dot product / (norm1 * norm2)
290            let dot_product: f32 = query_embedding
291                .iter()
292                .zip(&embedding)
293                .map(|(a, b)| a * b)
294                .sum();
295
296            let similarity = dot_product / (query_norm * norm);
297
298            matches.push(ChunkMatch {
299                id,
300                content,
301                element_type,
302                profile,
303                score: similarity,
304                match_type: MatchType::Vector,
305            });
306        }
307
308        // Sort by similarity (descending) and take top N
309        matches.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
310        matches.truncate(limit);
311
312        Ok(matches)
313    }
314
315    /// Hybrid search: Combine FTS5 + vector similarity
316    pub fn hybrid_search(
317        &self,
318        keywords: &str,
319        query_embedding: &[f32],
320        limit: usize,
321    ) -> SqlResult<Vec<ChunkMatch>> {
322        // Get keyword matches (precision)
323        let keyword_matches = self.search_keywords(keywords, limit * 2)?;
324
325        // Get vector matches (recall)
326        let vector_matches = self.search_similar(query_embedding, limit * 2)?;
327
328        // Combine and rerank
329        let mut combined = Self::merge_and_rerank(keyword_matches, vector_matches);
330        combined.truncate(limit);
331
332        Ok(combined)
333    }
334
335    /// Merge keyword and vector matches, rerank by combined score
336    fn merge_and_rerank(
337        keyword_matches: Vec<ChunkMatch>,
338        vector_matches: Vec<ChunkMatch>,
339    ) -> Vec<ChunkMatch> {
340        use std::collections::HashMap;
341
342        let mut matches_by_id: HashMap<String, ChunkMatch> = HashMap::new();
343        let mut scores: HashMap<String, (f32, f32)> = HashMap::new(); // (keyword_score, vector_score)
344
345        // Collect keyword scores and matches
346        for m in keyword_matches {
347            scores.entry(m.id.clone()).or_insert((0.0, 0.0)).0 = m.score.abs(); // FTS5 rank is negative
348            matches_by_id.insert(m.id.clone(), m);
349        }
350
351        // Collect vector scores and matches
352        for m in vector_matches {
353            scores.entry(m.id.clone()).or_insert((0.0, 0.0)).1 = m.score;
354            matches_by_id.entry(m.id.clone()).or_insert(m);
355        }
356
357        // Rerank: combined_score = 0.3 * keyword + 0.7 * vector (favor semantic)
358        let mut combined: Vec<_> = scores
359            .into_iter()
360            .filter_map(|(id, (kw_score, vec_score))| {
361                let combined_score = 0.3 * kw_score + 0.7 * vec_score;
362                matches_by_id.get(&id).map(|m| {
363                    let mut new_match = m.clone();
364                    new_match.score = combined_score;
365                    new_match.match_type = MatchType::Hybrid;
366                    new_match
367                })
368            })
369            .collect();
370
371        combined.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
372        combined
373    }
374
375    /// Get all child chunks of a parent
376    pub fn get_children(&self, parent_id: &str) -> SqlResult<Vec<Chunk>> {
377        let mut stmt = self.conn.prepare(
378            "SELECT id, parent_id, content_hash, profile, element_type, content, token_count, metadata
379             FROM chunks WHERE parent_id = ?1
380             ORDER BY id",
381        )?;
382
383        let mut rows = stmt.query(params![parent_id])?;
384        let mut children = Vec::new();
385
386        while let Some(row) = rows.next()? {
387            let metadata_json: String = row.get(7)?;
388            let metadata = serde_json::from_str(&metadata_json).map_err(|e| {
389                rusqlite::Error::FromSqlConversionFailure(
390                    7,
391                    rusqlite::types::Type::Text,
392                    Box::new(e),
393                )
394            })?;
395
396            children.push(Chunk {
397                id: row.get(0)?,
398                parent_id: row.get(1)?,
399                content_hash: row.get(2)?,
400                profile: row.get(3)?,
401                element_type: row.get(4)?,
402                content: row.get(5)?,
403                token_count: row.get(6)?,
404                metadata,
405            });
406        }
407
408        Ok(children)
409    }
410
411    /// Count total chunks
412    pub fn count_chunks(&self) -> SqlResult<usize> {
413        let count: i64 = self
414            .conn
415            .query_row("SELECT COUNT(*) FROM chunks", [], |row| row.get(0))?;
416        Ok(count as usize)
417    }
418
419    /// Calculate L2 norm of a vector
420    fn l2_norm(vec: &[f32]) -> f32 {
421        vec.iter().map(|x| x * x).sum::<f32>().sqrt()
422    }
423
424    /// Convert BLOB to f32 vector
425    fn blob_to_embedding(blob: &[u8]) -> SqlResult<Vec<f32>> {
426        if blob.len() != EMBEDDING_DIM * 4 {
427            return Err(rusqlite::Error::InvalidColumnType(
428                0,
429                "Embedding BLOB".to_string(),
430                rusqlite::types::Type::Blob,
431            ));
432        }
433
434        let embedding = blob
435            .chunks_exact(4)
436            .map(|chunk| {
437                let bytes = [chunk[0], chunk[1], chunk[2], chunk[3]];
438                f32::from_le_bytes(bytes)
439            })
440            .collect();
441
442        Ok(embedding)
443    }
444}
445
446/// Search result match
447#[derive(Debug, Clone, PartialEq)]
448pub struct ChunkMatch {
449    pub id: String,
450    pub content: String,
451    pub element_type: String,
452    pub profile: String,
453    pub score: f32,
454    pub match_type: MatchType,
455}
456
457/// Type of search match
458#[derive(Debug, Clone, Copy, PartialEq, Eq)]
459pub enum MatchType {
460    Keyword,
461    Vector,
462    Hybrid,
463}
464
465#[cfg(test)]
466mod tests {
467    use super::*;
468    use crate::id_generator::ElementId;
469    use std::collections::HashMap;
470
471    fn create_test_chunk(id: &str, content: &str) -> Chunk {
472        Chunk {
473            id: id.to_string(),
474            parent_id: None,
475            content_hash: ElementId::new(id, content).content_hash,
476            profile: "code:api".to_string(),
477            element_type: "function".to_string(),
478            content: content.to_string(),
479            token_count: content.len() / 4,
480            metadata: HashMap::new(),
481        }
482    }
483
484    fn create_test_embedding() -> Vec<f32> {
485        vec![0.1; EMBEDDING_DIM]
486    }
487
488    #[test]
489    fn test_create_store() {
490        let store = EmbeddingStore::new_in_memory();
491        assert!(store.is_ok());
492    }
493
494    #[test]
495    fn test_insert_and_get_chunk() {
496        let mut store = EmbeddingStore::new_in_memory().unwrap();
497        let chunk = create_test_chunk("test.id", "Test content");
498        let embedding = create_test_embedding();
499
500        store.insert_chunk(&chunk, &embedding).unwrap();
501
502        let retrieved = store.get_chunk("test.id").unwrap();
503        assert!(retrieved.is_some());
504        assert_eq!(retrieved.unwrap().content, "Test content");
505    }
506
507    #[test]
508    fn test_get_embedding() {
509        let mut store = EmbeddingStore::new_in_memory().unwrap();
510        let chunk = create_test_chunk("test.id", "Test content");
511        let embedding = create_test_embedding();
512
513        store.insert_chunk(&chunk, &embedding).unwrap();
514
515        let retrieved_emb = store.get_embedding("test.id").unwrap();
516        assert!(retrieved_emb.is_some());
517        assert_eq!(retrieved_emb.unwrap().len(), EMBEDDING_DIM);
518    }
519
520    #[test]
521    fn test_fts_search() {
522        let mut store = EmbeddingStore::new_in_memory().unwrap();
523
524        let chunk1 = create_test_chunk("test.1", "Vector push method");
525        let chunk2 = create_test_chunk("test.2", "HashMap insert function");
526        let embedding = create_test_embedding();
527
528        store.insert_chunk(&chunk1, &embedding).unwrap();
529        store.insert_chunk(&chunk2, &embedding).unwrap();
530
531        let results = store.search_keywords("vector", 10).unwrap();
532        assert_eq!(results.len(), 1);
533        assert_eq!(results[0].id, "test.1");
534    }
535
536    #[test]
537    fn test_vector_similarity() {
538        let mut store = EmbeddingStore::new_in_memory().unwrap();
539
540        let chunk = create_test_chunk("test.id", "Test content");
541        let embedding = create_test_embedding();
542
543        store.insert_chunk(&chunk, &embedding).unwrap();
544
545        // Query with same embedding should have similarity ~1.0
546        let results = store.search_similar(&embedding, 10).unwrap();
547        assert_eq!(results.len(), 1);
548        assert!((results[0].score - 1.0).abs() < 0.01);
549    }
550
551    #[test]
552    fn test_hybrid_search() {
553        let mut store = EmbeddingStore::new_in_memory().unwrap();
554
555        let chunk1 = create_test_chunk("test.1", "Vector push method adds items");
556        let chunk2 = create_test_chunk("test.2", "HashMap insert stores key-value pairs");
557        let embedding = create_test_embedding();
558
559        store.insert_chunk(&chunk1, &embedding).unwrap();
560        store.insert_chunk(&chunk2, &embedding).unwrap();
561
562        let results = store.hybrid_search("vector", &embedding, 10).unwrap();
563        assert!(results.len() > 0);
564        assert_eq!(results[0].match_type, MatchType::Hybrid);
565    }
566
567    #[test]
568    fn test_parent_child_relationship() {
569        let mut store = EmbeddingStore::new_in_memory().unwrap();
570
571        let parent = create_test_chunk("parent.id", "Parent content");
572        let mut child = create_test_chunk("parent.id#0", "Child content");
573        child.parent_id = Some("parent.id".to_string());
574
575        let embedding = create_test_embedding();
576
577        store.insert_chunk(&parent, &embedding).unwrap();
578        store.insert_chunk(&child, &embedding).unwrap();
579
580        let children = store.get_children("parent.id").unwrap();
581        assert_eq!(children.len(), 1);
582        assert_eq!(children[0].id, "parent.id#0");
583    }
584
585    #[test]
586    fn test_count_chunks() {
587        let mut store = EmbeddingStore::new_in_memory().unwrap();
588        let embedding = create_test_embedding();
589
590        assert_eq!(store.count_chunks().unwrap(), 0);
591
592        store
593            .insert_chunk(&create_test_chunk("test.1", "Content 1"), &embedding)
594            .unwrap();
595        store
596            .insert_chunk(&create_test_chunk("test.2", "Content 2"), &embedding)
597            .unwrap();
598
599        assert_eq!(store.count_chunks().unwrap(), 2);
600    }
601}