Skip to main content

leann_core/
searcher.rs

1use anyhow::Result;
2use std::collections::HashMap;
3use std::path::Path;
4use std::sync::Arc;
5
6use tracing::warn;
7
8use crate::backend::{self, BackendIndex, PruningStrategy};
9#[cfg(feature = "bm25")]
10use crate::bm25::BM25Scorer;
11use crate::embedding::EmbeddingProvider;
12use crate::hnsw::search::SearchParams;
13use crate::hnsw::simd::{inner_product_distance, l2_distance};
14use crate::index::{DistanceMetric, IndexMeta, IndexPaths};
15#[cfg(feature = "bm25")]
16use crate::passages::Passage;
17use crate::passages::{PassageManager, load_id_map};
18use crate::search_result::SearchResult;
19
20/// Options for opening a LEANN searcher with non-default behavior.
21#[derive(Default)]
22pub struct SearcherOptions {
23    /// Override `recompute_embeddings` from meta.json. `None` = use meta default.
24    pub recompute_embeddings: Option<bool>,
25    /// If true, send a probe embedding request at construction to verify the provider.
26    pub enable_warmup: bool,
27}
28
29/// High-level searcher for LEANN indexes.
30#[allow(dead_code)]
31pub struct LeannSearcher {
32    meta: IndexMeta,
33    passages: PassageManager,
34    index: BackendIndex,
35    id_map: Vec<String>,
36    distance_metric: DistanceMetric,
37    recompute_embeddings: bool,
38    provider: Option<Arc<dyn EmbeddingProvider>>,
39    #[cfg(feature = "bm25")]
40    bm25: Option<BM25Scorer>,
41    meta_path: std::path::PathBuf,
42}
43
44impl LeannSearcher {
45    /// Open an existing LEANN index for searching.
46    pub fn open(index_path: &Path) -> Result<Self> {
47        let index_path = if index_path.is_relative() {
48            std::env::current_dir()?.join(index_path)
49        } else {
50            index_path.to_path_buf()
51        };
52
53        let paths = IndexPaths::new(&index_path);
54        let meta_path = paths.meta_path();
55
56        if !meta_path.exists() {
57            anyhow::bail!("LEANN metadata file not found at {}", meta_path.display());
58        }
59
60        let meta = IndexMeta::load(&meta_path)?;
61        let distance_metric = meta.distance_metric();
62        let recompute = meta.requires_recompute();
63
64        // Load passages
65        let passages = PassageManager::load(&meta.passage_sources, Some(&meta_path))?;
66
67        // Load backend index
68        let index_file = paths.index_file_path();
69        if !index_file.exists() {
70            anyhow::bail!("Index file not found at {}", index_file.display());
71        }
72        let index = backend::read_backend_index(&meta.backend_name, &index_file)?;
73
74        // Load ID map
75        let id_map_path = paths.id_map_path();
76        let id_map = if id_map_path.exists() {
77            load_id_map(&id_map_path)?
78        } else {
79            Vec::new()
80        };
81
82        // Construct embedding provider from meta
83        let provider = Self::create_provider_from_meta(&meta);
84
85        Ok(Self {
86            meta,
87            passages,
88            index,
89            id_map,
90            distance_metric,
91            recompute_embeddings: recompute,
92            provider,
93            #[cfg(feature = "bm25")]
94            bm25: None,
95            meta_path,
96        })
97    }
98
99    /// Open an existing LEANN index with custom options.
100    ///
101    /// This allows overriding `recompute_embeddings` from meta.json and
102    /// optionally warming up the embedding provider at construction time.
103    pub fn open_with_options(index_path: &Path, options: &SearcherOptions) -> Result<Self> {
104        let mut searcher = Self::open(index_path)?;
105
106        // Override recompute_embeddings if explicitly specified
107        if let Some(recompute) = options.recompute_embeddings {
108            searcher.recompute_embeddings = recompute;
109        }
110
111        // Warmup: send a probe embedding request to verify the provider responds
112        if options.enable_warmup {
113            searcher.warmup()?;
114        }
115
116        Ok(searcher)
117    }
118
119    /// Send a probe embedding request to verify the provider is reachable.
120    ///
121    /// This is useful for detecting misconfiguration early (e.g. Ollama not running)
122    /// rather than waiting until the first search call.
123    pub fn warmup(&self) -> Result<()> {
124        if let Some(ref provider) = self.provider {
125            match provider.compute_embeddings(&["__LEANN_WARMUP__".to_string()], None) {
126                Ok(_) => {}
127                Err(e) => {
128                    warn!("Warmup embedding request failed (provider may not be running): {e}");
129                }
130            }
131        }
132        Ok(())
133    }
134
135    /// Construct an embedding provider from index metadata.
136    #[cfg(feature = "embedding-remote")]
137    fn create_provider_from_meta(meta: &IndexMeta) -> Option<Arc<dyn EmbeddingProvider>> {
138        use crate::embedding::{EmbeddingMode, create_embedding_provider};
139
140        let mode = EmbeddingMode::from_str_lossy(&meta.embedding_mode);
141        match create_embedding_provider(&mode, &meta.embedding_model, &meta.embedding_options) {
142            Ok(provider) => Some(Arc::from(provider)),
143            Err(e) => {
144                warn!("Could not create embedding provider from meta: {e}");
145                None
146            }
147        }
148    }
149
150    #[cfg(not(feature = "embedding-remote"))]
151    fn create_provider_from_meta(_meta: &IndexMeta) -> Option<Arc<dyn EmbeddingProvider>> {
152        None
153    }
154
155    /// Search for nearest neighbors.
156    pub fn search(&self, query: &str, top_k: usize) -> Result<Vec<SearchResult>> {
157        self.search_with_params(query, top_k, &SearchConfig::default())
158    }
159
160    /// Search with full configuration.
161    pub fn search_with_params(
162        &self,
163        query: &str,
164        top_k: usize,
165        config: &SearchConfig,
166    ) -> Result<Vec<SearchResult>> {
167        let top_k = top_k.min(self.passages.len());
168
169        // Handle pure BM25 search
170        #[cfg(feature = "bm25")]
171        if config.gemma == 0.0 {
172            let results = self.bm25_search(query, top_k)?;
173            if let Some(ref filters) = config.metadata_filters {
174                return Ok(self.passages.filter_search_results(&results, filters));
175            }
176            return Ok(results);
177        }
178        #[cfg(not(feature = "bm25"))]
179        if config.gemma == 0.0 {
180            anyhow::bail!("BM25 search requires the `bm25` feature");
181        }
182
183        // Handle grep search
184        if config.use_grep {
185            let results = self.grep_search(query, top_k)?;
186            if let Some(ref filters) = config.metadata_filters {
187                return Ok(self.passages.filter_search_results(&results, filters));
188            }
189            return Ok(results);
190        }
191
192        // Vector search requires an embedding provider
193        let results = self.vector_search(query, top_k, config)?;
194        Ok(results)
195    }
196
197    fn vector_search(
198        &self,
199        query: &str,
200        top_k: usize,
201        config: &SearchConfig,
202    ) -> Result<Vec<SearchResult>> {
203        let provider = self.provider.as_ref().ok_or_else(|| {
204            anyhow::anyhow!(
205                "No embedding provider available. Ensure the index was built with a supported \
206                 embedding mode (ollama, openai, gemini) and the `embedding-remote` feature is enabled."
207            )
208        })?;
209
210        // Compute query embedding
211        let query_embedding = provider.compute_embeddings(&[query.to_string()], None)?;
212        let query_vec: Vec<f32> = query_embedding.row(0).to_vec();
213
214        // Normalize for cosine
215        let query_vec = if self.distance_metric == DistanceMetric::Cosine {
216            let norm: f32 = query_vec.iter().map(|x| x * x).sum::<f32>().sqrt();
217            if norm > 0.0 {
218                query_vec.iter().map(|x| x / norm).collect()
219            } else {
220                query_vec
221            }
222        } else {
223            query_vec
224        };
225
226        let pruning_strategy = config
227            .pruning_strategy
228            .as_deref()
229            .map(|s| match s {
230                "local" => PruningStrategy::Local,
231                "proportional" => PruningStrategy::Proportional,
232                _ => PruningStrategy::Global,
233            })
234            .unwrap_or(PruningStrategy::Global);
235
236        let params = SearchParams {
237            ef_search: config.complexity,
238            beam_size: config.beam_width,
239            prune_ratio: config.prune_ratio,
240            recompute_embeddings: self.recompute_embeddings,
241            batch_size: config.batch_size,
242            pruning_strategy,
243            ..Default::default()
244        };
245
246        // Search
247        let (labels, distances) = if self.recompute_embeddings || self.index.is_pruned() {
248            // Recompute: look up passage texts, compute embeddings, compute distances locally
249            let provider = Arc::clone(provider);
250            let passages = &self.passages;
251            let distance_metric = self.distance_metric;
252
253            backend::search_backend_recompute(
254                &self.index,
255                &query_vec,
256                top_k,
257                &params,
258                |node_ids, q, out| {
259                    let mut texts = Vec::new();
260                    let mut found_indices = Vec::new();
261
262                    for (idx, &nid) in node_ids.iter().enumerate() {
263                        if let Ok(passage) = passages.get_passage_by_index(nid)
264                            && !passage.text.is_empty()
265                        {
266                            texts.push(passage.text);
267                            found_indices.push(idx);
268                        }
269                    }
270
271                    for d in out.iter_mut().take(node_ids.len()) {
272                        *d = 1e9;
273                    }
274
275                    if texts.is_empty() {
276                        return;
277                    }
278
279                    if let Ok(embeddings) = provider.compute_embeddings(&texts, None) {
280                        for (i, &original_idx) in found_indices.iter().enumerate() {
281                            let emb = embeddings.row(i);
282                            let emb_slice = emb.as_slice().unwrap();
283                            let dist = match distance_metric {
284                                DistanceMetric::L2 => l2_distance(q, emb_slice),
285                                _ => inner_product_distance(q, emb_slice),
286                            };
287                            out[original_idx] = dist;
288                        }
289                    }
290                },
291            )
292        } else {
293            // Non-recompute: use stored vectors
294            backend::search_backend(&self.index, &query_vec, top_k, &params)
295        };
296
297        // Map labels to passages and enrich results
298        let mut results = Vec::new();
299        for (label, dist) in labels.iter().zip(distances.iter()) {
300            let string_id = self.map_label(*label);
301            match self.passages.get_passage_by_index(*label) {
302                Ok(passage) => {
303                    results.push(SearchResult::with_metadata(
304                        string_id,
305                        *dist as f64,
306                        passage.text,
307                        passage.metadata,
308                    ));
309                }
310                Err(e) => {
311                    warn!("Passage not found for label {}: {}", label, e);
312                }
313            }
314        }
315
316        // Apply metadata filters
317        if let Some(ref filters) = config.metadata_filters {
318            let filtered = self.passages.filter_search_results(&results, filters);
319            return Ok(filtered);
320        }
321
322        // Handle hybrid search
323        #[cfg(feature = "bm25")]
324        if config.gemma < 1.0 {
325            let bm25_results = self.bm25_search(query, top_k)?;
326            let bm25_weight = 1.0 - config.gemma;
327
328            let mut hybrid_scores: HashMap<String, f64> = HashMap::new();
329
330            for r in &results {
331                if let Some(s) = hybrid_scores.get_mut(&r.id) {
332                    *s += config.gemma * r.score;
333                } else {
334                    hybrid_scores.insert(r.id.clone(), config.gemma * r.score);
335                }
336            }
337            for r in &bm25_results {
338                if let Some(s) = hybrid_scores.get_mut(&r.id) {
339                    *s += bm25_weight * r.score;
340                } else {
341                    hybrid_scores.insert(r.id.clone(), bm25_weight * r.score);
342                }
343            }
344
345            let mut sorted: Vec<(String, f64)> = hybrid_scores.into_iter().collect();
346            sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
347            sorted.truncate(top_k);
348
349            // Build lookup for text/metadata to avoid O(k·n) linear scans
350            let result_lookup: HashMap<&str, usize> = results
351                .iter()
352                .enumerate()
353                .map(|(i, r)| (r.id.as_str(), i))
354                .collect();
355
356            let mut hybrid_results = Vec::new();
357            for (id, score) in sorted {
358                let (text, metadata) = match result_lookup.get(id.as_str()) {
359                    Some(&idx) => (results[idx].text.clone(), results[idx].metadata.clone()),
360                    None => (String::new(), HashMap::new()),
361                };
362                hybrid_results.push(SearchResult::with_metadata(id, score, text, metadata));
363            }
364
365            return Ok(hybrid_results);
366        }
367
368        Ok(results)
369    }
370
371    fn map_label(&self, label: usize) -> String {
372        if !self.id_map.is_empty() && label < self.id_map.len() {
373            self.id_map[label].clone()
374        } else {
375            label.to_string()
376        }
377    }
378
379    #[cfg(feature = "bm25")]
380    fn bm25_search(&self, query: &str, top_k: usize) -> Result<Vec<SearchResult>> {
381        let mut scorer = BM25Scorer::default();
382
383        let mut documents = Vec::new();
384        let mut passage_map: HashMap<String, Passage> = HashMap::new();
385        for file_path in self.passages.passage_files() {
386            let file = std::fs::File::open(file_path)?;
387            let reader = std::io::BufReader::new(file);
388            use std::io::BufRead;
389            for line in reader.lines() {
390                let line = line?;
391                if let Ok(passage) = serde_json::from_str::<Passage>(&line) {
392                    documents.push((passage.id.clone(), passage.text.clone()));
393                    passage_map.insert(passage.id.clone(), passage);
394                }
395            }
396        }
397
398        scorer.fit(&documents);
399        let mut results = scorer.search(query, top_k);
400
401        // Enrich results with passage text and metadata
402        for result in &mut results {
403            if let Some(passage) = passage_map.get(&result.id) {
404                result.text.clone_from(&passage.text);
405                result.metadata.clone_from(&passage.metadata);
406            }
407        }
408
409        Ok(results)
410    }
411
412    fn grep_search(&self, query: &str, top_k: usize) -> Result<Vec<SearchResult>> {
413        let pattern = regex::RegexBuilder::new(&regex::escape(query))
414            .case_insensitive(true)
415            .build()?;
416
417        let mut matches = Vec::new();
418        for file_path in self.passages.passage_files() {
419            let file = std::fs::File::open(file_path)?;
420            let reader = std::io::BufReader::new(file);
421            use std::io::BufRead;
422            for line in reader.lines() {
423                let line = line?;
424                if pattern.is_match(&line)
425                    && let Ok(passage) = serde_json::from_str::<crate::passages::Passage>(&line)
426                {
427                    let count = pattern.find_iter(&passage.text).count();
428                    matches.push(SearchResult::with_metadata(
429                        passage.id,
430                        count as f64,
431                        passage.text,
432                        passage.metadata,
433                    ));
434                }
435            }
436        }
437
438        matches.sort_by(|a, b| {
439            b.score
440                .partial_cmp(&a.score)
441                .unwrap_or(std::cmp::Ordering::Equal)
442        });
443        matches.truncate(top_k);
444        Ok(matches)
445    }
446
447    pub fn cleanup(&mut self) {
448        // Cleanup resources (provider is Arc-dropped automatically)
449    }
450}
451
452/// Search configuration options.
453#[derive(Debug, Clone)]
454pub struct SearchConfig {
455    pub complexity: usize,
456    pub beam_width: usize,
457    pub prune_ratio: f64,
458    pub metadata_filters: Option<HashMap<String, HashMap<String, serde_json::Value>>>,
459    pub batch_size: usize,
460    pub use_grep: bool,
461    /// Weight of vector search (0.0 = pure BM25, 1.0 = pure vector).
462    pub gemma: f64,
463    /// Pruning strategy: "global", "local", or "proportional".
464    pub pruning_strategy: Option<String>,
465    /// Provider options (e.g. prompt_template overrides) passed at query time.
466    pub provider_options: Option<HashMap<String, serde_json::Value>>,
467}
468
469impl Default for SearchConfig {
470    fn default() -> Self {
471        Self {
472            complexity: 64,
473            beam_width: 1,
474            prune_ratio: 0.0,
475            metadata_filters: None,
476            batch_size: 0,
477            use_grep: false,
478            gemma: 1.0,
479            pruning_strategy: None,
480            provider_options: None,
481        }
482    }
483}
484
485impl Drop for LeannSearcher {
486    fn drop(&mut self) {
487        self.cleanup();
488    }
489}
490
491#[cfg(test)]
492mod tests {
493    use super::*;
494
495    #[test]
496    fn test_searcher_options_default() {
497        let opts = SearcherOptions::default();
498        assert!(!opts.enable_warmup);
499        assert!(opts.recompute_embeddings.is_none());
500    }
501}