use anyhow::Result;
use ndarray::Array2;
use std::collections::HashMap;
use std::path::Path;
use tracing::info;
use crate::embedding::EmbeddingProvider;
use crate::hnsw::build::build_hnsw_with_threads;
use crate::hnsw::csr::convert_to_csr;
use crate::hnsw::graph::{HnswConfig, VectorStorage};
use crate::hnsw::io::{write_hnsw_compact, write_hnsw_standard};
use crate::hnsw::simd::normalize_l2_inplace;
use crate::index::{DistanceMetric, IndexMeta, IndexPaths, PassageSource};
use crate::passages::{Passage, write_id_map, write_passages};
pub struct LeannBuilder {
embedding_model: String,
dimensions: Option<usize>,
embedding_mode: String,
config: HnswConfig,
num_threads: usize,
chunks: Vec<Passage>,
embedding_options: HashMap<String, serde_json::Value>,
}
impl LeannBuilder {
pub fn new(embedding_model: &str, dimensions: Option<usize>, embedding_mode: &str) -> Self {
Self {
embedding_model: embedding_model.to_string(),
dimensions,
embedding_mode: embedding_mode.to_string(),
config: HnswConfig::default(),
num_threads: std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1),
chunks: Vec::new(),
embedding_options: HashMap::new(),
}
}
pub fn with_config(mut self, config: HnswConfig) -> Self {
self.config = config;
self
}
pub fn with_m(mut self, m: usize) -> Self {
self.config.m = m;
self
}
pub fn with_ef_construction(mut self, ef: usize) -> Self {
self.config.ef_construction = ef;
self
}
pub fn with_distance_metric(mut self, metric: DistanceMetric) -> Self {
self.config.distance_metric = metric;
self
}
pub fn with_compact(mut self, compact: bool) -> Self {
self.config.is_compact = compact;
self
}
pub fn with_recompute(mut self, recompute: bool) -> Self {
self.config.is_recompute = recompute;
self
}
pub fn with_num_threads(mut self, n: usize) -> Self {
self.num_threads = n.max(1);
self
}
pub fn with_embedding_options(mut self, options: HashMap<String, serde_json::Value>) -> Self {
self.embedding_options = options;
self
}
pub fn add_text(&mut self, text: &str, metadata: HashMap<String, serde_json::Value>) {
let id = metadata
.get("id")
.and_then(|v| v.as_str())
.map(String::from)
.unwrap_or_else(|| self.chunks.len().to_string());
self.chunks.push(Passage {
id,
text: text.to_string(),
metadata,
});
}
pub fn build_index(
&mut self,
index_path: &Path,
provider: &dyn EmbeddingProvider,
) -> Result<()> {
if self.chunks.is_empty() {
anyhow::bail!("No chunks added");
}
self.chunks.retain(|c| !c.text.trim().is_empty());
if self.chunks.is_empty() {
anyhow::bail!("All provided chunks are empty or invalid");
}
if self.dimensions.is_none() {
let dummy = provider.compute_embeddings(&["dummy".to_string()])?;
self.dimensions = Some(dummy.ncols());
}
let dimensions = self.dimensions.unwrap();
let paths = IndexPaths::new(index_path);
std::fs::create_dir_all(&paths.base_dir)?;
let texts: Vec<String> = self.chunks.iter().map(|c| c.text.clone()).collect();
let ids: Vec<String> = self.chunks.iter().map(|c| c.id.clone()).collect();
let (write_result, embed_result) = std::thread::scope(|s| {
let passages_path = paths.passages_path();
let offset_path = paths.offset_path();
let id_map_path = paths.id_map_path();
let chunks = &self.chunks;
let ids_ref = &ids;
let writer = s.spawn(move || -> Result<()> {
info!("Writing {} passages to disk", chunks.len());
write_passages(chunks, &passages_path, &offset_path)?;
write_id_map(ids_ref, &id_map_path)?;
info!("Passage writing complete");
Ok(())
});
info!("Computing embeddings for {} chunks", texts.len());
let emb = provider.compute_embeddings(&texts);
let wr = writer.join().expect("passage writer thread panicked");
(wr, emb)
});
write_result?;
let mut embeddings = embed_result?;
if self.config.distance_metric == DistanceMetric::Cosine {
normalize_l2_inplace(&mut embeddings);
}
info!(
"Building HNSW graph (M={}, efConstruction={})",
self.config.m, self.config.ef_construction
);
let mut graph = build_hnsw_with_threads(&embeddings, &self.config, self.num_threads)?;
if !self.config.is_recompute {
let flat: Vec<f32> = embeddings.iter().copied().collect();
let storage_bytes = flat
.iter()
.flat_map(|f| f.to_le_bytes())
.collect::<Vec<u8>>();
let fourcc = match self.config.distance_metric {
DistanceMetric::L2 => u32::from_le_bytes(*b"IxFl"),
_ => u32::from_le_bytes(*b"IxFI"),
};
graph.vector_storage = VectorStorage::Raw {
fourcc,
data: storage_bytes,
};
}
let graph = if self.config.is_compact {
info!("Converting to compact CSR format");
convert_to_csr(&graph)?
} else {
graph
};
let index_file = paths.index_file_path();
let mut file = std::fs::File::create(&index_file)?;
if graph.is_compact() {
write_hnsw_compact(&mut file, &graph)?;
} else {
write_hnsw_standard(&mut file, &graph)?;
}
let meta = IndexMeta {
version: "1.0".to_string(),
backend_name: "hnsw".to_string(),
embedding_model: self.embedding_model.clone(),
dimensions,
backend_kwargs: {
let mut kwargs = HashMap::new();
kwargs.insert("M".to_string(), serde_json::json!(self.config.m));
kwargs.insert(
"efConstruction".to_string(),
serde_json::json!(self.config.ef_construction),
);
kwargs.insert(
"distance_metric".to_string(),
serde_json::json!(match self.config.distance_metric {
DistanceMetric::L2 => "l2",
DistanceMetric::Cosine => "cosine",
DistanceMetric::Mips => "mips",
}),
);
kwargs.insert(
"is_compact".to_string(),
serde_json::json!(self.config.is_compact),
);
kwargs.insert(
"is_recompute".to_string(),
serde_json::json!(self.config.is_recompute),
);
kwargs
},
embedding_mode: self.embedding_mode.clone(),
passage_sources: vec![PassageSource {
source_type: "jsonl".to_string(),
path: paths
.passages_path()
.file_name()
.unwrap()
.to_string_lossy()
.to_string(),
index_path: paths
.offset_path()
.file_name()
.unwrap()
.to_string_lossy()
.to_string(),
path_relative: Some(
paths
.passages_path()
.file_name()
.unwrap()
.to_string_lossy()
.to_string(),
),
index_path_relative: Some(
paths
.offset_path()
.file_name()
.unwrap()
.to_string_lossy()
.to_string(),
),
}],
embedding_options: self.embedding_options.clone(),
is_compact: Some(self.config.is_compact),
is_pruned: Some(self.config.is_recompute),
total_passages: Some(self.chunks.len()),
built_from_precomputed_embeddings: None,
embeddings_source: None,
};
meta.save(&paths.meta_path())?;
info!("Index built successfully at {}", index_path.display());
Ok(())
}
pub fn build_index_from_embeddings(
&mut self,
index_path: &Path,
ids: &[String],
embeddings: &Array2<f32>,
) -> Result<()> {
let dimensions = embeddings.ncols();
self.dimensions = Some(dimensions);
if ids.len() != embeddings.nrows() {
anyhow::bail!(
"Mismatch: {} IDs vs {} embeddings",
ids.len(),
embeddings.nrows()
);
}
if self.chunks.is_empty() {
for id in ids {
self.add_text(&format!("Document {}", id), {
let mut m = HashMap::new();
m.insert("id".to_string(), serde_json::json!(id));
m.insert("from_embeddings".to_string(), serde_json::json!(true));
m
});
}
}
let paths = IndexPaths::new(index_path);
std::fs::create_dir_all(&paths.base_dir)?;
write_passages(&self.chunks, &paths.passages_path(), &paths.offset_path())?;
write_id_map(ids, &paths.id_map_path())?;
let mut emb = embeddings.to_owned();
if self.config.distance_metric == DistanceMetric::Cosine {
normalize_l2_inplace(&mut emb);
}
let graph = build_hnsw_with_threads(&emb, &self.config, self.num_threads)?;
if self.config.is_compact {
let csr = convert_to_csr(&graph)?;
let mut file = std::fs::File::create(paths.index_file_path())?;
write_hnsw_compact(&mut file, &csr)?;
} else {
let mut file = std::fs::File::create(paths.index_file_path())?;
write_hnsw_standard(&mut file, &graph)?;
}
let meta = IndexMeta {
version: "1.0".to_string(),
backend_name: "hnsw".to_string(),
embedding_model: self.embedding_model.clone(),
dimensions,
backend_kwargs: HashMap::new(),
embedding_mode: self.embedding_mode.clone(),
passage_sources: vec![PassageSource {
source_type: "jsonl".to_string(),
path: paths
.passages_path()
.file_name()
.unwrap()
.to_string_lossy()
.to_string(),
index_path: paths
.offset_path()
.file_name()
.unwrap()
.to_string_lossy()
.to_string(),
path_relative: None,
index_path_relative: None,
}],
embedding_options: HashMap::new(),
is_compact: Some(self.config.is_compact),
is_pruned: Some(self.config.is_recompute),
total_passages: Some(self.chunks.len()),
built_from_precomputed_embeddings: Some(true),
embeddings_source: None,
};
meta.save(&paths.meta_path())?;
Ok(())
}
}