Skip to main content

cml_rs/
embedding_store.rs

1//! SQLite-based embedding store with FTS5 hybrid search
2//!
3//! This module provides a high-performance embedding lookup table that combines:
4//! - Full-text search (FTS5) for keyword matching
5//! - Vector similarity search for semantic matching
6//! - Hierarchical parent-child relationships
7//!
8//! The hybrid approach dramatically outperforms either method alone:
9//! - FTS5 provides high precision on exact keywords
10//! - Vector search provides high recall on semantic similarity
11//! - Combined: Best of both worlds
12
13use crate::chunker::Chunk;
14use rusqlite::{params, Connection, Result as SqlResult};
15use std::path::Path;
16
17/// Dimension of embedding vectors (matches sentence-transformers MiniLM-L6-v2).
18pub const EMBEDDING_DIM: usize = 384;
19
20/// SQLite embedding store with FTS5 hybrid search
21pub struct EmbeddingStore {
22    conn: Connection,
23}
24
25impl EmbeddingStore {
26    /// Create a new embedding store (in-memory for testing)
27    pub fn new_in_memory() -> SqlResult<Self> {
28        let conn = Connection::open_in_memory()?;
29        Self::init_schema(&conn)?;
30        Ok(Self { conn })
31    }
32
33    /// Open or create an embedding store from file
34    pub fn open(path: &Path) -> SqlResult<Self> {
35        let conn = Connection::open(path)?;
36        Self::init_schema(&conn)?;
37        Ok(Self { conn })
38    }
39
40    /// Initialize database schema
41    fn init_schema(conn: &Connection) -> SqlResult<()> {
42        // Main chunks table
43        conn.execute(
44            "CREATE TABLE IF NOT EXISTS chunks (
45                id TEXT PRIMARY KEY,
46                parent_id TEXT,
47                content_hash TEXT NOT NULL,
48                profile TEXT NOT NULL,
49                element_type TEXT NOT NULL,
50                content TEXT NOT NULL,
51                token_count INTEGER NOT NULL,
52                metadata JSON,
53                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
54            )",
55            [],
56        )?;
57
58        // Embeddings table (384-dim f32 vectors as BLOB)
59        conn.execute(
60            "CREATE TABLE IF NOT EXISTS embeddings (
61                chunk_id TEXT PRIMARY KEY,
62                embedding BLOB NOT NULL,
63                norm REAL NOT NULL,
64                FOREIGN KEY (chunk_id) REFERENCES chunks(id) ON DELETE CASCADE
65            )",
66            [],
67        )?;
68
69        // FTS5 virtual table for full-text search
70        conn.execute(
71            "CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5(
72                id,
73                content,
74                element_type,
75                metadata,
76                content='chunks',
77                content_rowid='rowid'
78            )",
79            [],
80        )?;
81
82        // Indexes for fast lookups
83        conn.execute(
84            "CREATE INDEX IF NOT EXISTS idx_parent ON chunks(parent_id)",
85            [],
86        )?;
87
88        conn.execute(
89            "CREATE INDEX IF NOT EXISTS idx_profile ON chunks(profile, element_type)",
90            [],
91        )?;
92
93        conn.execute(
94            "CREATE INDEX IF NOT EXISTS idx_content_hash ON chunks(content_hash)",
95            [],
96        )?;
97
98        // Trigger to keep FTS5 in sync
99        conn.execute(
100            "CREATE TRIGGER IF NOT EXISTS chunks_fts_insert AFTER INSERT ON chunks BEGIN
101                INSERT INTO chunks_fts(rowid, id, content, element_type, metadata)
102                VALUES (new.rowid, new.id, new.content, new.element_type, new.metadata);
103            END",
104            [],
105        )?;
106
107        conn.execute(
108            "CREATE TRIGGER IF NOT EXISTS chunks_fts_delete AFTER DELETE ON chunks BEGIN
109                DELETE FROM chunks_fts WHERE rowid = old.rowid;
110            END",
111            [],
112        )?;
113
114        conn.execute(
115            "CREATE TRIGGER IF NOT EXISTS chunks_fts_update AFTER UPDATE ON chunks BEGIN
116                UPDATE chunks_fts SET
117                    id = new.id,
118                    content = new.content,
119                    element_type = new.element_type,
120                    metadata = new.metadata
121                WHERE rowid = new.rowid;
122            END",
123            [],
124        )?;
125
126        Ok(())
127    }
128
129    /// Insert a chunk with its embedding
130    pub fn insert_chunk(&mut self, chunk: &Chunk, embedding: &[f32]) -> SqlResult<()> {
131        if embedding.len() != EMBEDDING_DIM {
132            return Err(rusqlite::Error::InvalidParameterCount(
133                EMBEDDING_DIM,
134                embedding.len(),
135            ));
136        }
137
138        let metadata_json = serde_json::to_string(&chunk.metadata)
139            .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
140
141        // Insert chunk
142        self.conn.execute(
143            "INSERT INTO chunks (id, parent_id, content_hash, profile, element_type, content, token_count, metadata)
144             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
145            params![
146                chunk.id,
147                chunk.parent_id,
148                chunk.content_hash,
149                chunk.profile,
150                chunk.element_type,
151                chunk.content,
152                chunk.token_count,
153                metadata_json,
154            ],
155        )?;
156
157        // Convert embedding to BLOB (4 bytes per f32)
158        let embedding_blob = embedding
159            .iter()
160            .flat_map(|f| f.to_le_bytes())
161            .collect::<Vec<u8>>();
162
163        // Calculate L2 norm
164        let norm = Self::l2_norm(embedding);
165
166        // Insert embedding
167        self.conn.execute(
168            "INSERT INTO embeddings (chunk_id, embedding, norm) VALUES (?1, ?2, ?3)",
169            params![chunk.id, embedding_blob, norm],
170        )?;
171
172        Ok(())
173    }
174
175    /// Get a chunk by ID
176    pub fn get_chunk(&self, id: &str) -> SqlResult<Option<Chunk>> {
177        let mut stmt = self.conn.prepare(
178            "SELECT id, parent_id, content_hash, profile, element_type, content, token_count, metadata
179             FROM chunks WHERE id = ?1",
180        )?;
181
182        let mut rows = stmt.query(params![id])?;
183
184        if let Some(row) = rows.next()? {
185            let metadata_json: String = row.get(7)?;
186            let metadata = serde_json::from_str(&metadata_json).map_err(|e| {
187                rusqlite::Error::FromSqlConversionFailure(
188                    7,
189                    rusqlite::types::Type::Text,
190                    Box::new(e),
191                )
192            })?;
193
194            Ok(Some(Chunk {
195                id: row.get(0)?,
196                parent_id: row.get(1)?,
197                content_hash: row.get(2)?,
198                profile: row.get(3)?,
199                element_type: row.get(4)?,
200                content: row.get(5)?,
201                token_count: row.get(6)?,
202                metadata,
203            }))
204        } else {
205            Ok(None)
206        }
207    }
208
209    /// Get embedding for a chunk
210    pub fn get_embedding(&self, chunk_id: &str) -> SqlResult<Option<Vec<f32>>> {
211        let mut stmt = self
212            .conn
213            .prepare("SELECT embedding FROM embeddings WHERE chunk_id = ?1")?;
214
215        let mut rows = stmt.query(params![chunk_id])?;
216
217        if let Some(row) = rows.next()? {
218            let blob: Vec<u8> = row.get(0)?;
219            let embedding = Self::blob_to_embedding(&blob)?;
220            Ok(Some(embedding))
221        } else {
222            Ok(None)
223        }
224    }
225
226    /// FTS5 keyword search
227    pub fn search_keywords(&self, query: &str, limit: usize) -> SqlResult<Vec<ChunkMatch>> {
228        let mut stmt = self.conn.prepare(
229            "SELECT c.id, c.content, c.element_type, c.profile, rank
230             FROM chunks_fts
231             JOIN chunks c ON chunks_fts.rowid = c.rowid
232             WHERE chunks_fts MATCH ?1
233             ORDER BY rank
234             LIMIT ?2",
235        )?;
236
237        let mut rows = stmt.query(params![query, limit as i64])?;
238        let mut matches = Vec::new();
239
240        while let Some(row) = rows.next()? {
241            matches.push(ChunkMatch {
242                id: row.get(0)?,
243                content: row.get(1)?,
244                element_type: row.get(2)?,
245                profile: row.get(3)?,
246                score: row.get::<_, f64>(4)? as f32,
247                match_type: MatchType::Keyword,
248            });
249        }
250
251        Ok(matches)
252    }
253
254    /// Vector similarity search (brute force for now, fast enough for <100K chunks)
255    pub fn search_similar(
256        &self,
257        query_embedding: &[f32],
258        limit: usize,
259    ) -> SqlResult<Vec<ChunkMatch>> {
260        if query_embedding.len() != EMBEDDING_DIM {
261            return Err(rusqlite::Error::InvalidParameterCount(
262                EMBEDDING_DIM,
263                query_embedding.len(),
264            ));
265        }
266
267        let query_norm = Self::l2_norm(query_embedding);
268
269        let mut stmt = self.conn.prepare(
270            "SELECT c.id, c.content, c.element_type, c.profile, e.embedding, e.norm
271             FROM chunks c
272             JOIN embeddings e ON c.id = e.chunk_id",
273        )?;
274
275        let mut rows = stmt.query([])?;
276        let mut matches = Vec::new();
277
278        while let Some(row) = rows.next()? {
279            let id: String = row.get(0)?;
280            let content: String = row.get(1)?;
281            let element_type: String = row.get(2)?;
282            let profile: String = row.get(3)?;
283            let embedding_blob: Vec<u8> = row.get(4)?;
284            let norm: f32 = row.get(5)?;
285
286            let embedding = Self::blob_to_embedding(&embedding_blob)?;
287
288            // Cosine similarity = dot product / (norm1 * norm2)
289            let dot_product: f32 = query_embedding
290                .iter()
291                .zip(&embedding)
292                .map(|(a, b)| a * b)
293                .sum();
294
295            let similarity = dot_product / (query_norm * norm);
296
297            matches.push(ChunkMatch {
298                id,
299                content,
300                element_type,
301                profile,
302                score: similarity,
303                match_type: MatchType::Vector,
304            });
305        }
306
307        // Sort by similarity (descending) and take top N
308        matches.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
309        matches.truncate(limit);
310
311        Ok(matches)
312    }
313
314    /// Hybrid search: Combine FTS5 + vector similarity
315    pub fn hybrid_search(
316        &self,
317        keywords: &str,
318        query_embedding: &[f32],
319        limit: usize,
320    ) -> SqlResult<Vec<ChunkMatch>> {
321        // Get keyword matches (precision)
322        let keyword_matches = self.search_keywords(keywords, limit * 2)?;
323
324        // Get vector matches (recall)
325        let vector_matches = self.search_similar(query_embedding, limit * 2)?;
326
327        // Combine and rerank
328        let mut combined = Self::merge_and_rerank(keyword_matches, vector_matches);
329        combined.truncate(limit);
330
331        Ok(combined)
332    }
333
334    /// Merge keyword and vector matches, rerank by combined score
335    fn merge_and_rerank(
336        keyword_matches: Vec<ChunkMatch>,
337        vector_matches: Vec<ChunkMatch>,
338    ) -> Vec<ChunkMatch> {
339        use std::collections::HashMap;
340
341        let mut matches_by_id: HashMap<String, ChunkMatch> = HashMap::new();
342        let mut scores: HashMap<String, (f32, f32)> = HashMap::new(); // (keyword_score, vector_score)
343
344        // Collect keyword scores and matches
345        for m in keyword_matches {
346            scores.entry(m.id.clone()).or_insert((0.0, 0.0)).0 = m.score.abs(); // FTS5 rank is negative
347            matches_by_id.insert(m.id.clone(), m);
348        }
349
350        // Collect vector scores and matches
351        for m in vector_matches {
352            scores.entry(m.id.clone()).or_insert((0.0, 0.0)).1 = m.score;
353            matches_by_id.entry(m.id.clone()).or_insert(m);
354        }
355
356        // Rerank: combined_score = 0.3 * keyword + 0.7 * vector (favor semantic)
357        let mut combined: Vec<_> = scores
358            .into_iter()
359            .filter_map(|(id, (kw_score, vec_score))| {
360                let combined_score = 0.3 * kw_score + 0.7 * vec_score;
361                matches_by_id.get(&id).map(|m| {
362                    let mut new_match = m.clone();
363                    new_match.score = combined_score;
364                    new_match.match_type = MatchType::Hybrid;
365                    new_match
366                })
367            })
368            .collect();
369
370        combined.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
371        combined
372    }
373
374    /// Get all child chunks of a parent
375    pub fn get_children(&self, parent_id: &str) -> SqlResult<Vec<Chunk>> {
376        let mut stmt = self.conn.prepare(
377            "SELECT id, parent_id, content_hash, profile, element_type, content, token_count, metadata
378             FROM chunks WHERE parent_id = ?1
379             ORDER BY id",
380        )?;
381
382        let mut rows = stmt.query(params![parent_id])?;
383        let mut children = Vec::new();
384
385        while let Some(row) = rows.next()? {
386            let metadata_json: String = row.get(7)?;
387            let metadata = serde_json::from_str(&metadata_json).map_err(|e| {
388                rusqlite::Error::FromSqlConversionFailure(
389                    7,
390                    rusqlite::types::Type::Text,
391                    Box::new(e),
392                )
393            })?;
394
395            children.push(Chunk {
396                id: row.get(0)?,
397                parent_id: row.get(1)?,
398                content_hash: row.get(2)?,
399                profile: row.get(3)?,
400                element_type: row.get(4)?,
401                content: row.get(5)?,
402                token_count: row.get(6)?,
403                metadata,
404            });
405        }
406
407        Ok(children)
408    }
409
410    /// Count total chunks
411    pub fn count_chunks(&self) -> SqlResult<usize> {
412        let count: i64 = self
413            .conn
414            .query_row("SELECT COUNT(*) FROM chunks", [], |row| row.get(0))?;
415        Ok(count as usize)
416    }
417
418    /// Calculate L2 norm of a vector
419    fn l2_norm(vec: &[f32]) -> f32 {
420        vec.iter().map(|x| x * x).sum::<f32>().sqrt()
421    }
422
423    /// Convert BLOB to f32 vector
424    fn blob_to_embedding(blob: &[u8]) -> SqlResult<Vec<f32>> {
425        if blob.len() != EMBEDDING_DIM * 4 {
426            return Err(rusqlite::Error::InvalidColumnType(
427                0,
428                "Embedding BLOB".to_string(),
429                rusqlite::types::Type::Blob,
430            ));
431        }
432
433        let embedding = blob
434            .chunks_exact(4)
435            .map(|chunk| {
436                let bytes = [chunk[0], chunk[1], chunk[2], chunk[3]];
437                f32::from_le_bytes(bytes)
438            })
439            .collect();
440
441        Ok(embedding)
442    }
443}
444
445/// Search result match
446#[derive(Debug, Clone, PartialEq)]
447pub struct ChunkMatch {
448    pub id: String,
449    pub content: String,
450    pub element_type: String,
451    pub profile: String,
452    pub score: f32,
453    pub match_type: MatchType,
454}
455
456/// Type of search match
457#[derive(Debug, Clone, Copy, PartialEq, Eq)]
458pub enum MatchType {
459    Keyword,
460    Vector,
461    Hybrid,
462}
463
464#[cfg(test)]
465mod tests {
466    use super::*;
467    use crate::id_generator::ElementId;
468    use std::collections::HashMap;
469
470    fn create_test_chunk(id: &str, content: &str) -> Chunk {
471        Chunk {
472            id: id.to_string(),
473            parent_id: None,
474            content_hash: ElementId::new(id, content).content_hash,
475            profile: "code:api".to_string(),
476            element_type: "function".to_string(),
477            content: content.to_string(),
478            token_count: content.len() / 4,
479            metadata: HashMap::new(),
480        }
481    }
482
483    fn create_test_embedding() -> Vec<f32> {
484        vec![0.1; EMBEDDING_DIM]
485    }
486
487    #[test]
488    fn test_create_store() {
489        let store = EmbeddingStore::new_in_memory();
490        assert!(store.is_ok());
491    }
492
493    #[test]
494    fn test_insert_and_get_chunk() {
495        let mut store = EmbeddingStore::new_in_memory().unwrap();
496        let chunk = create_test_chunk("test.id", "Test content");
497        let embedding = create_test_embedding();
498
499        store.insert_chunk(&chunk, &embedding).unwrap();
500
501        let retrieved = store.get_chunk("test.id").unwrap();
502        assert!(retrieved.is_some());
503        assert_eq!(retrieved.unwrap().content, "Test content");
504    }
505
506    #[test]
507    fn test_get_embedding() {
508        let mut store = EmbeddingStore::new_in_memory().unwrap();
509        let chunk = create_test_chunk("test.id", "Test content");
510        let embedding = create_test_embedding();
511
512        store.insert_chunk(&chunk, &embedding).unwrap();
513
514        let retrieved_emb = store.get_embedding("test.id").unwrap();
515        assert!(retrieved_emb.is_some());
516        assert_eq!(retrieved_emb.unwrap().len(), EMBEDDING_DIM);
517    }
518
519    #[test]
520    fn test_fts_search() {
521        let mut store = EmbeddingStore::new_in_memory().unwrap();
522
523        let chunk1 = create_test_chunk("test.1", "Vector push method");
524        let chunk2 = create_test_chunk("test.2", "HashMap insert function");
525        let embedding = create_test_embedding();
526
527        store.insert_chunk(&chunk1, &embedding).unwrap();
528        store.insert_chunk(&chunk2, &embedding).unwrap();
529
530        let results = store.search_keywords("vector", 10).unwrap();
531        assert_eq!(results.len(), 1);
532        assert_eq!(results[0].id, "test.1");
533    }
534
535    #[test]
536    fn test_vector_similarity() {
537        let mut store = EmbeddingStore::new_in_memory().unwrap();
538
539        let chunk = create_test_chunk("test.id", "Test content");
540        let embedding = create_test_embedding();
541
542        store.insert_chunk(&chunk, &embedding).unwrap();
543
544        // Query with same embedding should have similarity ~1.0
545        let results = store.search_similar(&embedding, 10).unwrap();
546        assert_eq!(results.len(), 1);
547        assert!((results[0].score - 1.0).abs() < 0.01);
548    }
549
550    #[test]
551    fn test_hybrid_search() {
552        let mut store = EmbeddingStore::new_in_memory().unwrap();
553
554        let chunk1 = create_test_chunk("test.1", "Vector push method adds items");
555        let chunk2 = create_test_chunk("test.2", "HashMap insert stores key-value pairs");
556        let embedding = create_test_embedding();
557
558        store.insert_chunk(&chunk1, &embedding).unwrap();
559        store.insert_chunk(&chunk2, &embedding).unwrap();
560
561        let results = store.hybrid_search("vector", &embedding, 10).unwrap();
562        assert!(results.len() > 0);
563        assert_eq!(results[0].match_type, MatchType::Hybrid);
564    }
565
566    #[test]
567    fn test_parent_child_relationship() {
568        let mut store = EmbeddingStore::new_in_memory().unwrap();
569
570        let parent = create_test_chunk("parent.id", "Parent content");
571        let mut child = create_test_chunk("parent.id#0", "Child content");
572        child.parent_id = Some("parent.id".to_string());
573
574        let embedding = create_test_embedding();
575
576        store.insert_chunk(&parent, &embedding).unwrap();
577        store.insert_chunk(&child, &embedding).unwrap();
578
579        let children = store.get_children("parent.id").unwrap();
580        assert_eq!(children.len(), 1);
581        assert_eq!(children[0].id, "parent.id#0");
582    }
583
584    #[test]
585    fn test_count_chunks() {
586        let mut store = EmbeddingStore::new_in_memory().unwrap();
587        let embedding = create_test_embedding();
588
589        assert_eq!(store.count_chunks().unwrap(), 0);
590
591        store
592            .insert_chunk(&create_test_chunk("test.1", "Content 1"), &embedding)
593            .unwrap();
594        store
595            .insert_chunk(&create_test_chunk("test.2", "Content 2"), &embedding)
596            .unwrap();
597
598        assert_eq!(store.count_chunks().unwrap(), 2);
599    }
600}