1use super::ast_chunker::{compute_chunk_hash, SemanticChunk, SemanticChunker};
4use super::chunker::{chunk_by_chars, CHUNK_OVERLAP_CHARS, CHUNK_SIZE_CHARS};
5use crate::db::{CacheLookupResult, Database};
6use crate::error::Result;
7use crate::llm::Embedder;
8use std::path::Path;
9
10const BATCH_SIZE: usize = 32;
11
12#[derive(Debug, Clone)]
14pub struct EmbedProgress {
15 pub total_docs: usize,
16 pub processed_docs: usize,
17 pub total_chunks: usize,
18 pub processed_chunks: usize,
19 pub cached_chunks: usize,
20 pub computed_chunks: usize,
21}
22
23#[derive(Debug, Clone, Default)]
25pub struct EmbedStats {
26 pub total_documents: usize,
27 pub embedded_documents: usize,
28 pub total_chunks: usize,
29 pub embedded_chunks: usize,
30 pub cached_chunks: usize,
31 pub computed_chunks: usize,
32}
33
34impl EmbedStats {
35 pub fn cache_hit_rate(&self) -> f64 {
36 if self.embedded_chunks == 0 {
37 return 0.0;
38 }
39 self.cached_chunks as f64 / self.embedded_chunks as f64 * 100.0
40 }
41}
42
43struct ChunkToEmbed {
45 seq: u32,
46 text: String,
47 position: usize,
48 chunk_hash: String,
49 cached_embedding: Option<Vec<f32>>,
50}
51
52pub async fn embed_documents(
54 db: &Database,
55 embedder: &dyn Embedder,
56 model: &str,
57 force: bool,
58 progress: Option<Box<dyn Fn(EmbedProgress) + Send + Sync>>,
59) -> Result<EmbedStats> {
60 let docs = if force {
61 db.get_all_content_with_paths()?
62 } else {
63 db.get_content_needing_embedding_with_paths()?
64 };
65
66 if docs.is_empty() {
67 return Ok(EmbedStats::default());
68 }
69
70 let dimensions = embedder.dimensions();
71 db.ensure_vec_table(dimensions)?;
72
73 let cache_enabled = !force && db.check_model_compatibility(model, dimensions)?;
75 db.register_model(model, dimensions)?;
76
77 let total_docs = docs.len();
78 let mut stats = EmbedStats {
79 total_documents: total_docs,
80 ..Default::default()
81 };
82
83 let chunker = SemanticChunker::new();
84
85 for (doc_idx, (hash, content, path)) in docs.iter().enumerate() {
86 let title = db.get_document_title_by_hash(hash)?;
87
88 let semantic_chunks = if let Some(p) = path {
90 chunker.chunk(content, Path::new(p))?
91 } else {
92 fallback_to_semantic_chunks(content)
93 };
94
95 stats.total_chunks += semantic_chunks.len();
96
97 let mut chunks_to_embed: Vec<ChunkToEmbed> = Vec::new();
99
100 for (seq, chunk) in semantic_chunks.iter().enumerate() {
101 let formatted_text = format_doc_for_embedding(&chunk.text, title.as_deref());
102
103 let cached = if cache_enabled {
105 match db.get_cached_embedding_fast(&chunk.chunk_hash, model)? {
106 CacheLookupResult::Hit(emb) => Some(emb),
107 CacheLookupResult::Miss | CacheLookupResult::ModelMismatch => None,
108 }
109 } else {
110 None
111 };
112
113 chunks_to_embed.push(ChunkToEmbed {
114 seq: seq as u32,
115 text: formatted_text,
116 position: chunk.position,
117 chunk_hash: chunk.chunk_hash.clone(),
118 cached_embedding: cached,
119 });
120 }
121
122 let (cached, to_compute): (Vec<_>, Vec<_>) = chunks_to_embed
124 .into_iter()
125 .partition(|c| c.cached_embedding.is_some());
126
127 for chunk in cached {
129 let embedding = chunk.cached_embedding.unwrap();
130 db.insert_chunk_embedding(
131 hash,
132 chunk.seq,
133 chunk.position,
134 &chunk.chunk_hash,
135 model,
136 &embedding,
137 )?;
138 stats.embedded_chunks += 1;
139 stats.cached_chunks += 1;
140 }
141
142 for batch in to_compute.chunks(BATCH_SIZE) {
144 let texts: Vec<String> = batch.iter().map(|c| c.text.clone()).collect();
145 let embeddings = embedder.embed_batch(&texts).await?;
146
147 for (chunk, embedding) in batch.iter().zip(embeddings.iter()) {
148 db.insert_chunk_embedding(
149 hash,
150 chunk.seq,
151 chunk.position,
152 &chunk.chunk_hash,
153 model,
154 embedding,
155 )?;
156 stats.embedded_chunks += 1;
157 stats.computed_chunks += 1;
158 }
159 }
160
161 stats.embedded_documents += 1;
162
163 if let Some(ref cb) = progress {
164 cb(EmbedProgress {
165 total_docs,
166 processed_docs: doc_idx + 1,
167 total_chunks: stats.total_chunks,
168 processed_chunks: stats.embedded_chunks,
169 cached_chunks: stats.cached_chunks,
170 computed_chunks: stats.computed_chunks,
171 });
172 }
173 }
174
175 Ok(stats)
176}
177
178fn fallback_to_semantic_chunks(content: &str) -> Vec<SemanticChunk> {
180 let char_chunks = chunk_by_chars(content, CHUNK_SIZE_CHARS, CHUNK_OVERLAP_CHARS);
181
182 char_chunks
183 .into_iter()
184 .map(|c| {
185 let chunk_hash = compute_chunk_hash(&c.text, "", "");
186 SemanticChunk {
187 text: c.text,
188 chunk_type: super::ast_chunker::ChunkType::Text,
189 chunk_hash,
190 position: c.position,
191 token_count: c.token_count,
192 metadata: super::ast_chunker::ChunkMetadata::default(),
193 }
194 })
195 .collect()
196}
197
198fn format_doc_for_embedding(text: &str, title: Option<&str>) -> String {
199 format!("title: {} | text: {}", title.unwrap_or("none"), text)
200}
201
202impl Database {
203 pub fn get_all_content(&self) -> Result<Vec<(String, String)>> {
205 let mut stmt = self.conn.prepare(
206 "SELECT c.hash, c.doc FROM content c
207 JOIN documents d ON d.hash = c.hash AND d.active = 1",
208 )?;
209 let results = stmt
210 .query_map([], |row| Ok((row.get(0)?, row.get(1)?)))?
211 .collect::<std::result::Result<Vec<_>, _>>()?;
212 Ok(results)
213 }
214
215 pub fn get_all_content_with_paths(&self) -> Result<Vec<(String, String, Option<String>)>> {
217 let mut stmt = self.conn.prepare(
218 "SELECT c.hash, c.doc, d.path FROM content c
219 JOIN documents d ON d.hash = c.hash AND d.active = 1
220 GROUP BY c.hash",
221 )?;
222 let results = stmt
223 .query_map([], |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)))?
224 .collect::<std::result::Result<Vec<_>, _>>()?;
225 Ok(results)
226 }
227
228 pub fn get_content_needing_embedding_with_paths(
230 &self,
231 ) -> Result<Vec<(String, String, Option<String>)>> {
232 let mut stmt = self.conn.prepare(
233 "SELECT c.hash, c.doc, d.path FROM content c
234 JOIN documents d ON d.hash = c.hash AND d.active = 1
235 WHERE c.hash NOT IN (SELECT DISTINCT hash FROM content_vectors)
236 GROUP BY c.hash",
237 )?;
238 let results = stmt
239 .query_map([], |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)))?
240 .collect::<std::result::Result<Vec<_>, _>>()?;
241 Ok(results)
242 }
243
244 pub fn get_document_title_by_hash(&self, hash: &str) -> Result<Option<String>> {
246 let result = self.conn.query_row(
247 "SELECT title FROM documents WHERE hash = ?1 AND active = 1 LIMIT 1",
248 rusqlite::params![hash],
249 |row| row.get(0),
250 );
251 match result {
252 Ok(title) => Ok(Some(title)),
253 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
254 Err(e) => Err(e.into()),
255 }
256 }
257}