use anyhow::Result;
use std::collections::HashMap;
use std::io::Cursor;
use std::path::Path;
use tracing::warn;
use crate::bm25::BM25Scorer;
use crate::embedding::client::EmbeddingClient;
use crate::hnsw::graph::HnswGraph;
use crate::hnsw::io::read_hnsw_index;
use crate::hnsw::search::{SearchParams, search_hnsw_recompute};
use crate::index::{DistanceMetric, IndexMeta, IndexPaths};
use crate::passages::{PassageManager, load_id_map};
use crate::search_result::SearchResult;
pub struct LeannSearcher {
#[allow(dead_code)]
meta: IndexMeta,
passages: PassageManager,
graph: HnswGraph,
id_map: Vec<String>,
distance_metric: DistanceMetric,
recompute_embeddings: bool,
#[allow(dead_code)]
bm25: Option<BM25Scorer>,
#[allow(dead_code)]
meta_path: std::path::PathBuf,
}
impl LeannSearcher {
pub fn open(index_path: &Path) -> Result<Self> {
let index_path = if index_path.is_relative() {
std::env::current_dir()?.join(index_path)
} else {
index_path.to_path_buf()
};
let paths = IndexPaths::new(&index_path);
let meta_path = paths.meta_path();
if !meta_path.exists() {
anyhow::bail!("LEANN metadata file not found at {}", meta_path.display());
}
let meta = IndexMeta::load(&meta_path)?;
let distance_metric = meta.distance_metric();
let recompute = meta.requires_recompute();
let passages = PassageManager::load(&meta.passage_sources, Some(&meta_path))?;
let index_file = paths.index_file_path();
if !index_file.exists() {
anyhow::bail!("HNSW index file not found at {}", index_file.display());
}
let index_data = std::fs::read(&index_file)?;
let mut cursor = Cursor::new(index_data);
let graph = read_hnsw_index(&mut cursor)?;
let id_map_path = paths.id_map_path();
let id_map = if id_map_path.exists() {
load_id_map(&id_map_path)?
} else {
Vec::new()
};
Ok(Self {
meta,
passages,
graph,
id_map,
distance_metric,
recompute_embeddings: recompute,
bm25: None,
meta_path,
})
}
pub fn search(&self, query: &str, top_k: usize) -> Result<Vec<SearchResult>> {
self.search_with_params(query, top_k, &SearchConfig::default())
}
pub fn search_with_params(
&self,
query: &str,
top_k: usize,
config: &SearchConfig,
) -> Result<Vec<SearchResult>> {
let top_k = top_k.min(self.passages.len());
if config.gemma == 0.0 {
let results = self.bm25_search(query, top_k)?;
if let Some(ref filters) = config.metadata_filters {
return Ok(self.passages.filter_search_results(&results, filters));
}
return Ok(results);
}
if config.use_grep {
let results = self.grep_search(query, top_k)?;
if let Some(ref filters) = config.metadata_filters {
return Ok(self.passages.filter_search_results(&results, filters));
}
return Ok(results);
}
let zmq_port = config.zmq_port.unwrap_or(5557);
let client = EmbeddingClient::new(zmq_port);
let query_embedding = client.compute_text_embeddings(&[query.to_string()])?;
let query_vec: Vec<f32> = query_embedding.row(0).to_vec();
let query_vec = if self.distance_metric == DistanceMetric::Cosine {
let norm: f32 = query_vec.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
query_vec.iter().map(|x| x / norm).collect()
} else {
query_vec
}
} else {
query_vec
};
let params = SearchParams {
ef_search: config.complexity,
beam_size: config.beam_width,
prune_ratio: config.prune_ratio,
recompute_embeddings: self.recompute_embeddings,
zmq_port: Some(zmq_port),
batch_size: config.batch_size,
..Default::default()
};
let (labels, distances) = if self.recompute_embeddings {
let client = EmbeddingClient::new(zmq_port);
search_hnsw_recompute(
&self.graph,
&query_vec,
top_k,
¶ms,
|node_ids, q, out| {
let dists = client
.compute_distances(node_ids, q)
.unwrap_or_else(|_| vec![1e9; node_ids.len()]);
out[..dists.len()].copy_from_slice(&dists);
},
)
} else {
let client = EmbeddingClient::new(zmq_port);
search_hnsw_recompute(
&self.graph,
&query_vec,
top_k,
¶ms,
|node_ids, q, out| {
let dists = client
.compute_distances(node_ids, q)
.unwrap_or_else(|_| vec![1e9; node_ids.len()]);
out[..dists.len()].copy_from_slice(&dists);
},
)
};
let mut results = Vec::new();
for (label, dist) in labels.iter().zip(distances.iter()) {
let string_id = self.map_label(*label);
match self.passages.get_passage(&string_id) {
Ok(passage) => {
results.push(SearchResult::with_metadata(
string_id,
*dist as f64,
passage.text,
passage.metadata.clone(),
));
}
Err(e) => {
warn!("Passage not found for ID '{}': {}", string_id, e);
}
}
}
if let Some(ref filters) = config.metadata_filters {
let filtered = self.passages.filter_search_results(&results, filters);
return Ok(filtered);
}
if config.gemma < 1.0 {
let bm25_results = self.bm25_search(query, top_k)?;
let bm25_weight = 1.0 - config.gemma;
let mut hybrid_scores: HashMap<String, f64> = HashMap::new();
for r in &results {
*hybrid_scores.entry(r.id.clone()).or_default() += config.gemma * r.score;
}
for r in &bm25_results {
*hybrid_scores.entry(r.id.clone()).or_default() += bm25_weight * r.score;
}
let mut sorted: Vec<(String, f64)> = hybrid_scores.into_iter().collect();
sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
sorted.truncate(top_k);
let mut hybrid_results = Vec::new();
for (id, score) in sorted {
let text = results
.iter()
.find(|r| r.id == id)
.map(|r| r.text.clone())
.unwrap_or_default();
let metadata = results
.iter()
.find(|r| r.id == id)
.map(|r| r.metadata.clone())
.unwrap_or_default();
hybrid_results.push(SearchResult::with_metadata(id, score, text, metadata));
}
return Ok(hybrid_results);
}
Ok(results)
}
fn map_label(&self, label: usize) -> String {
if !self.id_map.is_empty() && label < self.id_map.len() {
self.id_map[label].clone()
} else {
label.to_string()
}
}
fn bm25_search(&self, query: &str, top_k: usize) -> Result<Vec<SearchResult>> {
let mut scorer = BM25Scorer::default();
let mut documents = Vec::new();
for file_path in self.passages.passage_files() {
let file = std::fs::File::open(file_path)?;
let reader = std::io::BufReader::new(file);
use std::io::BufRead;
for line in reader.lines() {
let line = line?;
if let Ok(passage) = serde_json::from_str::<crate::passages::Passage>(&line) {
documents.push((passage.id, passage.text));
}
}
}
scorer.fit(&documents);
Ok(scorer.search(query, top_k))
}
fn grep_search(&self, query: &str, top_k: usize) -> Result<Vec<SearchResult>> {
let pattern = regex::RegexBuilder::new(®ex::escape(query))
.case_insensitive(true)
.build()?;
let mut matches = Vec::new();
for file_path in self.passages.passage_files() {
let file = std::fs::File::open(file_path)?;
let reader = std::io::BufReader::new(file);
use std::io::BufRead;
for line in reader.lines() {
let line = line?;
if pattern.is_match(&line)
&& let Ok(passage) = serde_json::from_str::<crate::passages::Passage>(&line)
{
let count = pattern.find_iter(&passage.text).count();
matches.push(SearchResult::with_metadata(
passage.id,
count as f64,
passage.text,
passage.metadata,
));
}
}
}
matches.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
matches.truncate(top_k);
Ok(matches)
}
pub fn cleanup(&mut self) {
}
}
#[derive(Debug, Clone)]
pub struct SearchConfig {
pub complexity: usize,
pub beam_width: usize,
pub prune_ratio: f64,
pub metadata_filters: Option<HashMap<String, HashMap<String, serde_json::Value>>>,
pub batch_size: usize,
pub use_grep: bool,
pub gemma: f64,
pub zmq_port: Option<u16>,
}
impl Default for SearchConfig {
fn default() -> Self {
Self {
complexity: 64,
beam_width: 1,
prune_ratio: 0.0,
metadata_filters: None,
batch_size: 0,
use_grep: false,
gemma: 1.0,
zmq_port: None,
}
}
}
impl Drop for LeannSearcher {
fn drop(&mut self) {
self.cleanup();
}
}