1use anyhow::Result;
2use std::collections::HashMap;
3use std::path::Path;
4use std::sync::Arc;
5
6use tracing::warn;
7
8use crate::backend::{self, BackendIndex, PruningStrategy};
9#[cfg(feature = "bm25")]
10use crate::bm25::BM25Scorer;
11use crate::embedding::EmbeddingProvider;
12use crate::hnsw::search::SearchParams;
13use crate::hnsw::simd::{inner_product_distance, l2_distance};
14use crate::index::{DistanceMetric, IndexMeta, IndexPaths};
15#[cfg(feature = "bm25")]
16use crate::passages::Passage;
17use crate::passages::{PassageManager, load_id_map};
18use crate::search_result::SearchResult;
19
20#[derive(Default)]
22pub struct SearcherOptions {
23 pub recompute_embeddings: Option<bool>,
25 pub enable_warmup: bool,
27}
28
29#[allow(dead_code)]
31pub struct LeannSearcher {
32 meta: IndexMeta,
33 passages: PassageManager,
34 index: BackendIndex,
35 id_map: Vec<String>,
36 distance_metric: DistanceMetric,
37 recompute_embeddings: bool,
38 provider: Option<Arc<dyn EmbeddingProvider>>,
39 #[cfg(feature = "bm25")]
40 bm25: Option<BM25Scorer>,
41 meta_path: std::path::PathBuf,
42}
43
44impl LeannSearcher {
45 pub fn open(index_path: &Path) -> Result<Self> {
47 let index_path = if index_path.is_relative() {
48 std::env::current_dir()?.join(index_path)
49 } else {
50 index_path.to_path_buf()
51 };
52
53 let paths = IndexPaths::new(&index_path);
54 let meta_path = paths.meta_path();
55
56 if !meta_path.exists() {
57 anyhow::bail!("LEANN metadata file not found at {}", meta_path.display());
58 }
59
60 let meta = IndexMeta::load(&meta_path)?;
61 let distance_metric = meta.distance_metric();
62 let recompute = meta.requires_recompute();
63
64 let passages = PassageManager::load(&meta.passage_sources, Some(&meta_path))?;
66
67 let index_file = paths.index_file_path();
69 if !index_file.exists() {
70 anyhow::bail!("Index file not found at {}", index_file.display());
71 }
72 let index = backend::read_backend_index(&meta.backend_name, &index_file)?;
73
74 let id_map_path = paths.id_map_path();
76 let id_map = if id_map_path.exists() {
77 load_id_map(&id_map_path)?
78 } else {
79 Vec::new()
80 };
81
82 let provider = Self::create_provider_from_meta(&meta);
84
85 Ok(Self {
86 meta,
87 passages,
88 index,
89 id_map,
90 distance_metric,
91 recompute_embeddings: recompute,
92 provider,
93 #[cfg(feature = "bm25")]
94 bm25: None,
95 meta_path,
96 })
97 }
98
99 pub fn open_with_options(index_path: &Path, options: &SearcherOptions) -> Result<Self> {
104 let mut searcher = Self::open(index_path)?;
105
106 if let Some(recompute) = options.recompute_embeddings {
108 searcher.recompute_embeddings = recompute;
109 }
110
111 if options.enable_warmup {
113 searcher.warmup()?;
114 }
115
116 Ok(searcher)
117 }
118
119 pub fn warmup(&self) -> Result<()> {
124 if let Some(ref provider) = self.provider {
125 match provider.compute_embeddings(&["__LEANN_WARMUP__".to_string()], None) {
126 Ok(_) => {}
127 Err(e) => {
128 warn!("Warmup embedding request failed (provider may not be running): {e}");
129 }
130 }
131 }
132 Ok(())
133 }
134
135 #[cfg(feature = "embedding-remote")]
137 fn create_provider_from_meta(meta: &IndexMeta) -> Option<Arc<dyn EmbeddingProvider>> {
138 use crate::embedding::{EmbeddingMode, create_embedding_provider};
139
140 let mode = EmbeddingMode::from_str_lossy(&meta.embedding_mode);
141 match create_embedding_provider(&mode, &meta.embedding_model, &meta.embedding_options) {
142 Ok(provider) => Some(Arc::from(provider)),
143 Err(e) => {
144 warn!("Could not create embedding provider from meta: {e}");
145 None
146 }
147 }
148 }
149
150 #[cfg(not(feature = "embedding-remote"))]
151 fn create_provider_from_meta(_meta: &IndexMeta) -> Option<Arc<dyn EmbeddingProvider>> {
152 None
153 }
154
155 pub fn search(&self, query: &str, top_k: usize) -> Result<Vec<SearchResult>> {
157 self.search_with_params(query, top_k, &SearchConfig::default())
158 }
159
160 pub fn search_with_params(
162 &self,
163 query: &str,
164 top_k: usize,
165 config: &SearchConfig,
166 ) -> Result<Vec<SearchResult>> {
167 let top_k = top_k.min(self.passages.len());
168
169 #[cfg(feature = "bm25")]
171 if config.gemma == 0.0 {
172 let results = self.bm25_search(query, top_k)?;
173 if let Some(ref filters) = config.metadata_filters {
174 return Ok(self.passages.filter_search_results(&results, filters));
175 }
176 return Ok(results);
177 }
178 #[cfg(not(feature = "bm25"))]
179 if config.gemma == 0.0 {
180 anyhow::bail!("BM25 search requires the `bm25` feature");
181 }
182
183 if config.use_grep {
185 let results = self.grep_search(query, top_k)?;
186 if let Some(ref filters) = config.metadata_filters {
187 return Ok(self.passages.filter_search_results(&results, filters));
188 }
189 return Ok(results);
190 }
191
192 let results = self.vector_search(query, top_k, config)?;
194 Ok(results)
195 }
196
197 fn vector_search(
198 &self,
199 query: &str,
200 top_k: usize,
201 config: &SearchConfig,
202 ) -> Result<Vec<SearchResult>> {
203 let provider = self.provider.as_ref().ok_or_else(|| {
204 anyhow::anyhow!(
205 "No embedding provider available. Ensure the index was built with a supported \
206 embedding mode (ollama, openai, gemini) and the `embedding-remote` feature is enabled."
207 )
208 })?;
209
210 let query_embedding = provider.compute_embeddings(&[query.to_string()], None)?;
212 let query_vec: Vec<f32> = query_embedding.row(0).to_vec();
213
214 let query_vec = if self.distance_metric == DistanceMetric::Cosine {
216 let norm: f32 = query_vec.iter().map(|x| x * x).sum::<f32>().sqrt();
217 if norm > 0.0 {
218 query_vec.iter().map(|x| x / norm).collect()
219 } else {
220 query_vec
221 }
222 } else {
223 query_vec
224 };
225
226 let pruning_strategy = config
227 .pruning_strategy
228 .as_deref()
229 .map(|s| match s {
230 "local" => PruningStrategy::Local,
231 "proportional" => PruningStrategy::Proportional,
232 _ => PruningStrategy::Global,
233 })
234 .unwrap_or(PruningStrategy::Global);
235
236 let params = SearchParams {
237 ef_search: config.complexity,
238 beam_size: config.beam_width,
239 prune_ratio: config.prune_ratio,
240 recompute_embeddings: self.recompute_embeddings,
241 batch_size: config.batch_size,
242 pruning_strategy,
243 ..Default::default()
244 };
245
246 let (labels, distances) = if self.recompute_embeddings || self.index.is_pruned() {
248 let provider = Arc::clone(provider);
250 let passages = &self.passages;
251 let distance_metric = self.distance_metric;
252
253 backend::search_backend_recompute(
254 &self.index,
255 &query_vec,
256 top_k,
257 ¶ms,
258 |node_ids, q, out| {
259 let mut texts = Vec::new();
260 let mut found_indices = Vec::new();
261
262 for (idx, &nid) in node_ids.iter().enumerate() {
263 if let Ok(passage) = passages.get_passage_by_index(nid)
264 && !passage.text.is_empty()
265 {
266 texts.push(passage.text);
267 found_indices.push(idx);
268 }
269 }
270
271 for d in out.iter_mut().take(node_ids.len()) {
272 *d = 1e9;
273 }
274
275 if texts.is_empty() {
276 return;
277 }
278
279 if let Ok(embeddings) = provider.compute_embeddings(&texts, None) {
280 for (i, &original_idx) in found_indices.iter().enumerate() {
281 let emb = embeddings.row(i);
282 let emb_slice = emb.as_slice().unwrap();
283 let dist = match distance_metric {
284 DistanceMetric::L2 => l2_distance(q, emb_slice),
285 _ => inner_product_distance(q, emb_slice),
286 };
287 out[original_idx] = dist;
288 }
289 }
290 },
291 )
292 } else {
293 backend::search_backend(&self.index, &query_vec, top_k, ¶ms)
295 };
296
297 let mut results = Vec::new();
299 for (label, dist) in labels.iter().zip(distances.iter()) {
300 let string_id = self.map_label(*label);
301 match self.passages.get_passage_by_index(*label) {
302 Ok(passage) => {
303 results.push(SearchResult::with_metadata(
304 string_id,
305 *dist as f64,
306 passage.text,
307 passage.metadata,
308 ));
309 }
310 Err(e) => {
311 warn!("Passage not found for label {}: {}", label, e);
312 }
313 }
314 }
315
316 if let Some(ref filters) = config.metadata_filters {
318 let filtered = self.passages.filter_search_results(&results, filters);
319 return Ok(filtered);
320 }
321
322 #[cfg(feature = "bm25")]
324 if config.gemma < 1.0 {
325 let bm25_results = self.bm25_search(query, top_k)?;
326 let bm25_weight = 1.0 - config.gemma;
327
328 let mut hybrid_scores: HashMap<String, f64> = HashMap::new();
329
330 for r in &results {
331 if let Some(s) = hybrid_scores.get_mut(&r.id) {
332 *s += config.gemma * r.score;
333 } else {
334 hybrid_scores.insert(r.id.clone(), config.gemma * r.score);
335 }
336 }
337 for r in &bm25_results {
338 if let Some(s) = hybrid_scores.get_mut(&r.id) {
339 *s += bm25_weight * r.score;
340 } else {
341 hybrid_scores.insert(r.id.clone(), bm25_weight * r.score);
342 }
343 }
344
345 let mut sorted: Vec<(String, f64)> = hybrid_scores.into_iter().collect();
346 sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
347 sorted.truncate(top_k);
348
349 let result_lookup: HashMap<&str, usize> = results
351 .iter()
352 .enumerate()
353 .map(|(i, r)| (r.id.as_str(), i))
354 .collect();
355
356 let mut hybrid_results = Vec::new();
357 for (id, score) in sorted {
358 let (text, metadata) = match result_lookup.get(id.as_str()) {
359 Some(&idx) => (results[idx].text.clone(), results[idx].metadata.clone()),
360 None => (String::new(), HashMap::new()),
361 };
362 hybrid_results.push(SearchResult::with_metadata(id, score, text, metadata));
363 }
364
365 return Ok(hybrid_results);
366 }
367
368 Ok(results)
369 }
370
371 fn map_label(&self, label: usize) -> String {
372 if !self.id_map.is_empty() && label < self.id_map.len() {
373 self.id_map[label].clone()
374 } else {
375 label.to_string()
376 }
377 }
378
379 #[cfg(feature = "bm25")]
380 fn bm25_search(&self, query: &str, top_k: usize) -> Result<Vec<SearchResult>> {
381 let mut scorer = BM25Scorer::default();
382
383 let mut documents = Vec::new();
384 let mut passage_map: HashMap<String, Passage> = HashMap::new();
385 for file_path in self.passages.passage_files() {
386 let file = std::fs::File::open(file_path)?;
387 let reader = std::io::BufReader::new(file);
388 use std::io::BufRead;
389 for line in reader.lines() {
390 let line = line?;
391 if let Ok(passage) = serde_json::from_str::<Passage>(&line) {
392 documents.push((passage.id.clone(), passage.text.clone()));
393 passage_map.insert(passage.id.clone(), passage);
394 }
395 }
396 }
397
398 scorer.fit(&documents);
399 let mut results = scorer.search(query, top_k);
400
401 for result in &mut results {
403 if let Some(passage) = passage_map.get(&result.id) {
404 result.text.clone_from(&passage.text);
405 result.metadata.clone_from(&passage.metadata);
406 }
407 }
408
409 Ok(results)
410 }
411
412 fn grep_search(&self, query: &str, top_k: usize) -> Result<Vec<SearchResult>> {
413 let pattern = regex::RegexBuilder::new(®ex::escape(query))
414 .case_insensitive(true)
415 .build()?;
416
417 let mut matches = Vec::new();
418 for file_path in self.passages.passage_files() {
419 let file = std::fs::File::open(file_path)?;
420 let reader = std::io::BufReader::new(file);
421 use std::io::BufRead;
422 for line in reader.lines() {
423 let line = line?;
424 if pattern.is_match(&line)
425 && let Ok(passage) = serde_json::from_str::<crate::passages::Passage>(&line)
426 {
427 let count = pattern.find_iter(&passage.text).count();
428 matches.push(SearchResult::with_metadata(
429 passage.id,
430 count as f64,
431 passage.text,
432 passage.metadata,
433 ));
434 }
435 }
436 }
437
438 matches.sort_by(|a, b| {
439 b.score
440 .partial_cmp(&a.score)
441 .unwrap_or(std::cmp::Ordering::Equal)
442 });
443 matches.truncate(top_k);
444 Ok(matches)
445 }
446
447 pub fn cleanup(&mut self) {
448 }
450}
451
452#[derive(Debug, Clone)]
454pub struct SearchConfig {
455 pub complexity: usize,
456 pub beam_width: usize,
457 pub prune_ratio: f64,
458 pub metadata_filters: Option<HashMap<String, HashMap<String, serde_json::Value>>>,
459 pub batch_size: usize,
460 pub use_grep: bool,
461 pub gemma: f64,
463 pub pruning_strategy: Option<String>,
465 pub provider_options: Option<HashMap<String, serde_json::Value>>,
467}
468
469impl Default for SearchConfig {
470 fn default() -> Self {
471 Self {
472 complexity: 64,
473 beam_width: 1,
474 prune_ratio: 0.0,
475 metadata_filters: None,
476 batch_size: 0,
477 use_grep: false,
478 gemma: 1.0,
479 pruning_strategy: None,
480 provider_options: None,
481 }
482 }
483}
484
485impl Drop for LeannSearcher {
486 fn drop(&mut self) {
487 self.cleanup();
488 }
489}
490
491#[cfg(test)]
492mod tests {
493 use super::*;
494
495 #[test]
496 fn test_searcher_options_default() {
497 let opts = SearcherOptions::default();
498 assert!(!opts.enable_warmup);
499 assert!(opts.recompute_embeddings.is_none());
500 }
501}