agentroot_core/index/
embedder.rs

1//! Embedding pipeline with smart cache invalidation
2
3use super::ast_chunker::{compute_chunk_hash, SemanticChunk, SemanticChunker};
4use super::chunker::{chunk_by_chars, CHUNK_OVERLAP_CHARS, CHUNK_SIZE_CHARS};
5use crate::db::{CacheLookupResult, Database};
6use crate::error::Result;
7use crate::llm::Embedder;
8use std::path::Path;
9
10const BATCH_SIZE: usize = 32;
11
12/// Embedding progress
13#[derive(Debug, Clone)]
14pub struct EmbedProgress {
15    pub total_docs: usize,
16    pub processed_docs: usize,
17    pub total_chunks: usize,
18    pub processed_chunks: usize,
19    pub cached_chunks: usize,
20    pub computed_chunks: usize,
21}
22
23/// Embedding statistics
24#[derive(Debug, Clone, Default)]
25pub struct EmbedStats {
26    pub total_documents: usize,
27    pub embedded_documents: usize,
28    pub total_chunks: usize,
29    pub embedded_chunks: usize,
30    pub cached_chunks: usize,
31    pub computed_chunks: usize,
32}
33
34impl EmbedStats {
35    pub fn cache_hit_rate(&self) -> f64 {
36        if self.embedded_chunks == 0 {
37            return 0.0;
38        }
39        self.cached_chunks as f64 / self.embedded_chunks as f64 * 100.0
40    }
41}
42
43/// Chunk ready for embedding with cache metadata
44struct ChunkToEmbed {
45    seq: u32,
46    text: String,
47    position: usize,
48    chunk_hash: String,
49    cached_embedding: Option<Vec<f32>>,
50}
51
52/// Generate embeddings for documents with smart caching
53pub async fn embed_documents(
54    db: &Database,
55    embedder: &dyn Embedder,
56    model: &str,
57    force: bool,
58    progress: Option<Box<dyn Fn(EmbedProgress) + Send + Sync>>,
59) -> Result<EmbedStats> {
60    let docs = if force {
61        db.get_all_content_with_paths()?
62    } else {
63        db.get_content_needing_embedding_with_paths()?
64    };
65
66    if docs.is_empty() {
67        return Ok(EmbedStats::default());
68    }
69
70    let dimensions = embedder.dimensions();
71    db.ensure_vec_table(dimensions)?;
72
73    // Check model compatibility once upfront
74    let cache_enabled = !force && db.check_model_compatibility(model, dimensions)?;
75    db.register_model(model, dimensions)?;
76
77    let total_docs = docs.len();
78    let mut stats = EmbedStats {
79        total_documents: total_docs,
80        ..Default::default()
81    };
82
83    let chunker = SemanticChunker::new();
84
85    for (doc_idx, (hash, content, path)) in docs.iter().enumerate() {
86        let title = db.get_document_title_by_hash(hash)?;
87
88        // Use semantic chunking if we have a file path
89        let semantic_chunks = if let Some(p) = path {
90            chunker.chunk(content, Path::new(p))?
91        } else {
92            fallback_to_semantic_chunks(content)
93        };
94
95        stats.total_chunks += semantic_chunks.len();
96
97        // Prepare chunks with cache lookups
98        let mut chunks_to_embed: Vec<ChunkToEmbed> = Vec::new();
99
100        for (seq, chunk) in semantic_chunks.iter().enumerate() {
101            let formatted_text = format_doc_for_embedding(&chunk.text, title.as_deref());
102
103            // Try to find cached embedding (using fast lookup since we checked compatibility upfront)
104            let cached = if cache_enabled {
105                match db.get_cached_embedding_fast(&chunk.chunk_hash, model)? {
106                    CacheLookupResult::Hit(emb) => Some(emb),
107                    CacheLookupResult::Miss | CacheLookupResult::ModelMismatch => None,
108                }
109            } else {
110                None
111            };
112
113            chunks_to_embed.push(ChunkToEmbed {
114                seq: seq as u32,
115                text: formatted_text,
116                position: chunk.position,
117                chunk_hash: chunk.chunk_hash.clone(),
118                cached_embedding: cached,
119            });
120        }
121
122        // Separate cached from needing computation
123        let (cached, to_compute): (Vec<_>, Vec<_>) = chunks_to_embed
124            .into_iter()
125            .partition(|c| c.cached_embedding.is_some());
126
127        // Store cached embeddings
128        for chunk in cached {
129            let embedding = chunk.cached_embedding.unwrap();
130            db.insert_chunk_embedding(
131                hash,
132                chunk.seq,
133                chunk.position,
134                &chunk.chunk_hash,
135                model,
136                &embedding,
137            )?;
138            stats.embedded_chunks += 1;
139            stats.cached_chunks += 1;
140        }
141
142        // Batch embed new chunks
143        for batch in to_compute.chunks(BATCH_SIZE) {
144            let texts: Vec<String> = batch.iter().map(|c| c.text.clone()).collect();
145            let embeddings = embedder.embed_batch(&texts).await?;
146
147            for (chunk, embedding) in batch.iter().zip(embeddings.iter()) {
148                db.insert_chunk_embedding(
149                    hash,
150                    chunk.seq,
151                    chunk.position,
152                    &chunk.chunk_hash,
153                    model,
154                    embedding,
155                )?;
156                stats.embedded_chunks += 1;
157                stats.computed_chunks += 1;
158            }
159        }
160
161        stats.embedded_documents += 1;
162
163        if let Some(ref cb) = progress {
164            cb(EmbedProgress {
165                total_docs,
166                processed_docs: doc_idx + 1,
167                total_chunks: stats.total_chunks,
168                processed_chunks: stats.embedded_chunks,
169                cached_chunks: stats.cached_chunks,
170                computed_chunks: stats.computed_chunks,
171            });
172        }
173    }
174
175    Ok(stats)
176}
177
178/// Fallback: convert character-based chunks to semantic chunks with hashes
179fn fallback_to_semantic_chunks(content: &str) -> Vec<SemanticChunk> {
180    let char_chunks = chunk_by_chars(content, CHUNK_SIZE_CHARS, CHUNK_OVERLAP_CHARS);
181
182    char_chunks
183        .into_iter()
184        .map(|c| {
185            let chunk_hash = compute_chunk_hash(&c.text, "", "");
186            SemanticChunk {
187                text: c.text,
188                chunk_type: super::ast_chunker::ChunkType::Text,
189                chunk_hash,
190                position: c.position,
191                token_count: c.token_count,
192                metadata: super::ast_chunker::ChunkMetadata::default(),
193            }
194        })
195        .collect()
196}
197
198fn format_doc_for_embedding(text: &str, title: Option<&str>) -> String {
199    format!("title: {} | text: {}", title.unwrap_or("none"), text)
200}
201
202impl Database {
203    /// Get all content hashes and content
204    pub fn get_all_content(&self) -> Result<Vec<(String, String)>> {
205        let mut stmt = self.conn.prepare(
206            "SELECT c.hash, c.doc FROM content c
207             JOIN documents d ON d.hash = c.hash AND d.active = 1",
208        )?;
209        let results = stmt
210            .query_map([], |row| Ok((row.get(0)?, row.get(1)?)))?
211            .collect::<std::result::Result<Vec<_>, _>>()?;
212        Ok(results)
213    }
214
215    /// Get all content with file paths
216    pub fn get_all_content_with_paths(&self) -> Result<Vec<(String, String, Option<String>)>> {
217        let mut stmt = self.conn.prepare(
218            "SELECT c.hash, c.doc, d.path FROM content c
219             JOIN documents d ON d.hash = c.hash AND d.active = 1
220             GROUP BY c.hash",
221        )?;
222        let results = stmt
223            .query_map([], |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)))?
224            .collect::<std::result::Result<Vec<_>, _>>()?;
225        Ok(results)
226    }
227
228    /// Get content needing embedding with file paths
229    pub fn get_content_needing_embedding_with_paths(
230        &self,
231    ) -> Result<Vec<(String, String, Option<String>)>> {
232        let mut stmt = self.conn.prepare(
233            "SELECT c.hash, c.doc, d.path FROM content c
234             JOIN documents d ON d.hash = c.hash AND d.active = 1
235             WHERE c.hash NOT IN (SELECT DISTINCT hash FROM content_vectors)
236             GROUP BY c.hash",
237        )?;
238        let results = stmt
239            .query_map([], |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)))?
240            .collect::<std::result::Result<Vec<_>, _>>()?;
241        Ok(results)
242    }
243
244    /// Get document title by hash
245    pub fn get_document_title_by_hash(&self, hash: &str) -> Result<Option<String>> {
246        let result = self.conn.query_row(
247            "SELECT title FROM documents WHERE hash = ?1 AND active = 1 LIMIT 1",
248            rusqlite::params![hash],
249            |row| row.get(0),
250        );
251        match result {
252            Ok(title) => Ok(Some(title)),
253            Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
254            Err(e) => Err(e.into()),
255        }
256    }
257}