use anyhow::Result;
use ndarray::Array2;
use std::collections::HashMap;
use std::path::Path;
use tracing::info;
use crate::backend::{self, BackendConfig};
use crate::embedding::{EmbeddingMode, EmbeddingProvider};
use crate::hnsw::simd::normalize_l2_inplace;
use crate::index::{DistanceMetric, IndexMeta, IndexPaths, PassageSource};
use crate::passages::{Passage, write_id_map, write_passages};
pub fn is_normalized_embeddings_model(embedding_model: &str, embedding_mode: &str) -> bool {
let model = embedding_model.to_lowercase();
let mode = embedding_mode.to_lowercase();
const KNOWN_MODELS: &[(&str, &str)] = &[
("openai", "text-embedding-ada-002"),
("openai", "text-embedding-3-small"),
("openai", "text-embedding-3-large"),
("voyage", "voyage-2"),
("voyage", "voyage-3"),
("voyage", "voyage-large-2"),
("voyage", "voyage-multilingual-2"),
("voyage", "voyage-code-2"),
("cohere", "embed-english-v3.0"),
("cohere", "embed-multilingual-v3.0"),
("cohere", "embed-english-light-v3.0"),
("cohere", "embed-multilingual-light-v3.0"),
];
for &(known_mode, known_model) in KNOWN_MODELS {
if (mode == known_mode && model == known_model)
|| (mode.contains(known_mode) && model.contains(known_model))
{
return true;
}
}
if (mode.contains("openai") || model.contains("openai"))
&& ["text-embedding", "ada", "3-small", "3-large"]
.iter()
.any(|p| model.contains(p))
{
return true;
}
if mode.contains("voyage") || model.contains("voyage") {
return true;
}
if (mode.contains("cohere") || model.contains("cohere")) && model.contains("embed") {
return true;
}
false
}
pub struct LeannBuilder {
embedding_model: String,
dimensions: Option<usize>,
embedding_mode: String,
backend_config: BackendConfig,
chunks: Vec<Passage>,
embedding_options: HashMap<String, serde_json::Value>,
distance_metric_auto: bool,
}
impl LeannBuilder {
pub fn new(embedding_model: &str, dimensions: Option<usize>, embedding_mode: &str) -> Self {
let mut backend_config = BackendConfig::hnsw_default();
let mut distance_metric_auto = false;
if is_normalized_embeddings_model(embedding_model, embedding_mode) {
info!(
"Detected normalized embeddings model '{}' (mode '{}'). \
Auto-setting distance_metric=cosine for optimal performance.",
embedding_model, embedding_mode
);
backend_config.set_distance_metric(DistanceMetric::Cosine);
distance_metric_auto = true;
}
Self {
embedding_model: embedding_model.to_string(),
dimensions,
embedding_mode: embedding_mode.to_string(),
backend_config,
chunks: Vec::new(),
embedding_options: HashMap::new(),
distance_metric_auto,
}
}
pub fn with_backend(mut self, name: &str) -> Result<Self> {
self.backend_config = BackendConfig::from_name(name)?;
Ok(self)
}
pub fn with_m(mut self, m: usize) -> Self {
self.backend_config.set_m(m);
self
}
pub fn with_ef_construction(mut self, ef: usize) -> Self {
self.backend_config.set_ef_construction(ef);
self
}
pub fn with_distance_metric(mut self, metric: DistanceMetric) -> Self {
self.backend_config.set_distance_metric(metric);
self.distance_metric_auto = false;
self
}
pub fn with_compact(mut self, compact: bool) -> Self {
self.backend_config.set_compact(compact);
self
}
pub fn with_recompute(mut self, recompute: bool) -> Self {
self.backend_config.set_recompute(recompute);
self
}
pub fn with_num_threads(mut self, n: usize) -> Self {
self.backend_config.set_num_threads(n);
self
}
pub fn with_embedding_options(mut self, options: HashMap<String, serde_json::Value>) -> Self {
self.embedding_options = options;
self
}
pub fn add_text(&mut self, text: &str, metadata: HashMap<String, serde_json::Value>) {
let id = metadata
.get("id")
.and_then(|v| v.as_str())
.map(String::from)
.unwrap_or_else(|| self.chunks.len().to_string());
self.chunks.push(Passage {
id,
text: text.to_string(),
metadata,
});
}
pub fn build_index(
&mut self,
index_path: &Path,
provider: &dyn EmbeddingProvider,
) -> Result<()> {
if self.chunks.is_empty() {
anyhow::bail!("No chunks added");
}
self.chunks.retain(|c| !c.text.trim().is_empty());
if self.chunks.is_empty() {
anyhow::bail!("All provided chunks are empty or invalid");
}
if self.dimensions.is_none() {
let dummy = provider.compute_embeddings(&["dummy".to_string()])?;
self.dimensions = Some(dummy.ncols());
}
let dimensions = self.dimensions.unwrap();
let paths = IndexPaths::new(index_path);
std::fs::create_dir_all(&paths.base_dir)?;
let texts: Vec<String> = self.chunks.iter().map(|c| c.text.clone()).collect();
let ids: Vec<String> = self.chunks.iter().map(|c| c.id.clone()).collect();
let (write_result, embed_result) = std::thread::scope(|s| {
let passages_path = paths.passages_path();
let offset_path = paths.offset_path();
let id_map_path = paths.id_map_path();
let chunks = &self.chunks;
let ids_ref = &ids;
let writer = s.spawn(move || -> Result<()> {
info!("Writing {} passages to disk", chunks.len());
write_passages(chunks, &passages_path, &offset_path)?;
write_id_map(ids_ref, &id_map_path)?;
info!("Passage writing complete");
Ok(())
});
info!("Computing embeddings for {} chunks", texts.len());
let emb = provider.compute_embeddings(&texts);
let wr = writer.join().expect("passage writer thread panicked");
(wr, emb)
});
write_result?;
let mut embeddings = embed_result?;
if self.backend_config.distance_metric() == DistanceMetric::Cosine {
normalize_l2_inplace(&mut embeddings);
}
backend::build_backend(&self.backend_config, &embeddings, &paths.index_file_path())?;
let meta = IndexMeta {
version: "1.0".to_string(),
backend_name: self.backend_config.name().to_string(),
embedding_model: self.embedding_model.clone(),
dimensions,
backend_kwargs: self.backend_config.to_backend_kwargs(),
embedding_mode: self.embedding_mode.clone(),
passage_sources: vec![PassageSource {
source_type: "jsonl".to_string(),
path: paths
.passages_path()
.file_name()
.unwrap()
.to_string_lossy()
.to_string(),
index_path: paths
.offset_path()
.file_name()
.unwrap()
.to_string_lossy()
.to_string(),
path_relative: Some(
paths
.passages_path()
.file_name()
.unwrap()
.to_string_lossy()
.to_string(),
),
index_path_relative: Some(
paths
.offset_path()
.file_name()
.unwrap()
.to_string_lossy()
.to_string(),
),
}],
embedding_options: self.embedding_options.clone(),
is_compact: Some(self.backend_config.is_compact()),
is_pruned: Some(self.backend_config.is_recompute()),
total_passages: Some(self.chunks.len()),
built_from_precomputed_embeddings: None,
embeddings_source: None,
};
meta.save(&paths.meta_path())?;
info!("Index built successfully at {}", index_path.display());
Ok(())
}
pub fn build_index_from_embeddings(
&mut self,
index_path: &Path,
ids: &[String],
embeddings: &Array2<f32>,
) -> Result<()> {
let dimensions = embeddings.ncols();
self.dimensions = Some(dimensions);
if ids.len() != embeddings.nrows() {
anyhow::bail!(
"Mismatch: {} IDs vs {} embeddings",
ids.len(),
embeddings.nrows()
);
}
if self.chunks.is_empty() {
for id in ids {
self.add_text(&format!("Document {}", id), {
let mut m = HashMap::new();
m.insert("id".to_string(), serde_json::json!(id));
m.insert("from_embeddings".to_string(), serde_json::json!(true));
m
});
}
}
let paths = IndexPaths::new(index_path);
std::fs::create_dir_all(&paths.base_dir)?;
write_passages(&self.chunks, &paths.passages_path(), &paths.offset_path())?;
write_id_map(ids, &paths.id_map_path())?;
let mut emb = embeddings.to_owned();
if self.backend_config.distance_metric() == DistanceMetric::Cosine {
normalize_l2_inplace(&mut emb);
}
backend::build_backend(&self.backend_config, &emb, &paths.index_file_path())?;
let meta = IndexMeta {
version: "1.0".to_string(),
backend_name: self.backend_config.name().to_string(),
embedding_model: self.embedding_model.clone(),
dimensions,
backend_kwargs: HashMap::new(),
embedding_mode: self.embedding_mode.clone(),
passage_sources: vec![PassageSource {
source_type: "jsonl".to_string(),
path: paths
.passages_path()
.file_name()
.unwrap()
.to_string_lossy()
.to_string(),
index_path: paths
.offset_path()
.file_name()
.unwrap()
.to_string_lossy()
.to_string(),
path_relative: None,
index_path_relative: None,
}],
embedding_options: HashMap::new(),
is_compact: Some(self.backend_config.is_compact()),
is_pruned: Some(self.backend_config.is_recompute()),
total_passages: Some(self.chunks.len()),
built_from_precomputed_embeddings: Some(true),
embeddings_source: None,
};
meta.save(&paths.meta_path())?;
Ok(())
}
#[cfg(feature = "embedding-remote")]
pub fn create_embedding_provider(&self) -> Result<Box<dyn EmbeddingProvider>> {
let mode = EmbeddingMode::from_str_lossy(&self.embedding_mode);
crate::embedding::create_embedding_provider(
&mode,
&self.embedding_model,
&self.embedding_options,
)
}
}