use std::collections::BTreeMap;
use std::fs::{self, File};
use std::io::{BufReader, BufWriter, Write};
use std::path::Path;
use ndarray::{s, Array1, Array2, Axis};
use serde::{Deserialize, Serialize};
use crate::codec::ResidualCodec;
use crate::error::{Error, Result};
use crate::kmeans::{compute_kmeans, ComputeKmeansConfig};
use crate::utils::{quantile, quantiles};
fn compress_and_residuals_cpu(
embeddings: &Array2<f32>,
codec: &ResidualCodec,
) -> (Array1<usize>, Array2<f32>) {
use rayon::prelude::*;
let codes = codec.compress_into_codes_cpu(embeddings);
let mut residuals = embeddings.clone();
let centroids = &codec.centroids;
residuals
.axis_iter_mut(Axis(0))
.into_par_iter()
.zip(codes.as_slice().unwrap().par_iter())
.for_each(|(mut row, &code)| {
let centroid = centroids.row(code);
row.iter_mut()
.zip(centroid.iter())
.for_each(|(r, c)| *r -= c);
});
(codes, residuals)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IndexConfig {
pub nbits: usize,
pub batch_size: usize,
pub seed: Option<u64>,
#[serde(default = "default_kmeans_niters")]
pub kmeans_niters: usize,
#[serde(default = "default_max_points_per_centroid")]
pub max_points_per_centroid: usize,
#[serde(default)]
pub n_samples_kmeans: Option<usize>,
#[serde(default = "default_start_from_scratch")]
pub start_from_scratch: usize,
#[serde(default)]
pub force_cpu: bool,
#[serde(default)]
pub fts_tokenizer: crate::text_search::FtsTokenizer,
}
fn default_start_from_scratch() -> usize {
999
}
fn default_kmeans_niters() -> usize {
4
}
fn default_max_points_per_centroid() -> usize {
256
}
impl Default for IndexConfig {
fn default() -> Self {
Self {
nbits: 4,
batch_size: 50_000,
seed: Some(42),
kmeans_niters: 4,
max_points_per_centroid: 256,
n_samples_kmeans: None,
start_from_scratch: 999,
force_cpu: false,
fts_tokenizer: crate::text_search::FtsTokenizer::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Metadata {
pub num_chunks: usize,
pub nbits: usize,
pub num_partitions: usize,
pub num_embeddings: usize,
pub avg_doclen: f64,
#[serde(default)]
pub num_documents: usize,
#[serde(default)]
pub embedding_dim: usize,
#[serde(default)]
pub next_plaid_compatible: bool,
}
impl Metadata {
pub fn load_from_path(index_path: &Path) -> Result<Self> {
let metadata_path = index_path.join("metadata.json");
let mut metadata: Metadata = serde_json::from_reader(BufReader::new(
File::open(&metadata_path)
.map_err(|e| Error::IndexLoad(format!("Failed to open metadata: {}", e)))?,
))?;
if metadata.num_documents == 0 {
let mut total_docs = 0usize;
for chunk_idx in 0..metadata.num_chunks {
let doclens_path = index_path.join(format!("doclens.{}.json", chunk_idx));
if let Ok(file) = File::open(&doclens_path) {
if let Ok(chunk_doclens) =
serde_json::from_reader::<_, Vec<i64>>(BufReader::new(file))
{
total_docs += chunk_doclens.len();
}
}
}
metadata.num_documents = total_docs;
}
Ok(metadata)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunkMetadata {
pub num_documents: usize,
pub num_embeddings: usize,
#[serde(default)]
pub embedding_offset: usize,
}
#[derive(Debug, Clone)]
pub struct EncodedIndexChunk {
pub codes: Array1<i64>,
pub residuals: Array2<u8>,
pub doclens: Vec<i64>,
}
pub struct PreparedCodecArtifacts {
pub codec: ResidualCodec,
pub cluster_threshold: f32,
pub bucket_cutoffs: Array1<f32>,
pub bucket_weights: Array1<f32>,
pub avg_res_per_dim: Array1<f32>,
}
pub fn prepare_codec_artifacts(
embeddings: &[Array2<f32>],
centroids: Array2<f32>,
config: &IndexConfig,
) -> Result<PreparedCodecArtifacts> {
let embedding_dim = centroids.ncols();
let total_embeddings: usize = embeddings.iter().map(|e| e.nrows()).sum();
let num_documents = embeddings.len();
if num_documents == 0 {
return Err(Error::IndexCreation("No documents provided".into()));
}
let sample_count = ((16.0 * (120.0 * num_documents as f64).sqrt()) as usize)
.min(num_documents)
.max(1);
let mut rng = if let Some(seed) = config.seed {
use rand::SeedableRng;
rand_chacha::ChaCha8Rng::seed_from_u64(seed)
} else {
use rand::SeedableRng;
rand_chacha::ChaCha8Rng::from_entropy()
};
use rand::seq::SliceRandom;
let mut indices: Vec<usize> = (0..num_documents).collect();
indices.shuffle(&mut rng);
let sample_indices: Vec<usize> = indices.into_iter().take(sample_count).collect();
let heldout_size = (0.05 * total_embeddings as f64).min(50000.0) as usize;
let mut heldout_embeddings: Vec<f32> = Vec::with_capacity(heldout_size * embedding_dim);
let mut collected = 0;
for &idx in sample_indices.iter().rev() {
if collected >= heldout_size {
break;
}
let emb = &embeddings[idx];
let take = (heldout_size - collected).min(emb.nrows());
for row in emb.axis_iter(Axis(0)).take(take) {
heldout_embeddings.extend(row.iter());
}
collected += take;
}
let heldout = Array2::from_shape_vec((collected, embedding_dim), heldout_embeddings)
.map_err(|e| Error::IndexCreation(format!("Failed to create heldout array: {}", e)))?;
let avg_residual = Array1::zeros(embedding_dim);
let initial_codec =
ResidualCodec::new(config.nbits, centroids.clone(), avg_residual, None, None)?;
let heldout_codes = if config.force_cpu {
initial_codec.compress_into_codes_cpu(&heldout)
} else {
initial_codec.compress_into_codes(&heldout)
};
let mut residuals = heldout.clone();
for i in 0..heldout.nrows() {
let centroid = initial_codec.centroids.row(heldout_codes[i]);
for j in 0..embedding_dim {
residuals[[i, j]] -= centroid[j];
}
}
let distances: Array1<f32> = residuals
.axis_iter(Axis(0))
.map(|row| row.dot(&row).sqrt())
.collect();
let cluster_threshold = quantile(&distances, 0.75);
let avg_res_per_dim: Array1<f32> = residuals
.axis_iter(Axis(1))
.map(|col| col.iter().map(|x| x.abs()).sum::<f32>() / col.len() as f32)
.collect();
let n_options = 1 << config.nbits;
let quantile_values: Vec<f64> = (1..n_options)
.map(|i| i as f64 / n_options as f64)
.collect();
let weight_quantile_values: Vec<f64> = (0..n_options)
.map(|i| (i as f64 + 0.5) / n_options as f64)
.collect();
let flat_residuals: Array1<f32> = residuals.iter().copied().collect();
let bucket_cutoffs = Array1::from_vec(quantiles(&flat_residuals, &quantile_values));
let bucket_weights = Array1::from_vec(quantiles(&flat_residuals, &weight_quantile_values));
let codec = ResidualCodec::new(
config.nbits,
centroids,
avg_res_per_dim.clone(),
Some(bucket_cutoffs.clone()),
Some(bucket_weights.clone()),
)?;
Ok(PreparedCodecArtifacts {
codec,
cluster_threshold,
bucket_cutoffs,
bucket_weights,
avg_res_per_dim,
})
}
pub fn encode_index_chunk(
embeddings: &[Array2<f32>],
codec: &ResidualCodec,
force_cpu: bool,
) -> Result<EncodedIndexChunk> {
let embedding_dim = codec.embedding_dim();
let packed_dim = embedding_dim * codec.nbits / 8;
let doclens: Vec<i64> = embeddings.iter().map(|d| d.nrows() as i64).collect();
let total_tokens: usize = doclens.iter().sum::<i64>() as usize;
#[cfg(not(feature = "cuda"))]
let _ = force_cpu;
let mut batch_embeddings = Array2::<f32>::zeros((total_tokens, embedding_dim));
let mut offset = 0;
for doc in embeddings {
let n = doc.nrows();
batch_embeddings
.slice_mut(s![offset..offset + n, ..])
.assign(doc);
offset += n;
}
let (batch_codes, batch_residuals) = {
#[cfg(feature = "cuda")]
{
let force_gpu = crate::is_force_gpu();
if !force_cpu {
if let Some(ctx) = crate::cuda::get_global_context() {
match crate::cuda::compress_and_residuals_cuda_batched(
&ctx,
&batch_embeddings.view(),
&codec.centroids_view(),
None,
) {
Ok(result) => result,
Err(e) => {
if force_gpu {
panic!(
"FORCE_GPU is set but CUDA compress_and_residuals failed: {}",
e
);
}
println!(
"[next-plaid] CUDA compress_and_residuals failed: {}, falling back to CPU",
e
);
compress_and_residuals_cpu(&batch_embeddings, codec)
}
}
} else if force_gpu {
panic!("FORCE_GPU is set but CUDA context is unavailable");
} else {
compress_and_residuals_cpu(&batch_embeddings, codec)
}
} else {
compress_and_residuals_cpu(&batch_embeddings, codec)
}
}
#[cfg(not(feature = "cuda"))]
{
compress_and_residuals_cpu(&batch_embeddings, codec)
}
};
let batch_packed = codec.quantize_residuals(&batch_residuals)?;
let (raw_residuals, residuals_offset) = batch_packed.into_raw_vec_and_offset();
if residuals_offset != Some(0) {
return Err(Error::Shape(format!(
"Unexpected residual packing offset: {:?}",
residuals_offset
)));
}
let residuals = Array2::from_shape_vec((batch_codes.len(), packed_dim), raw_residuals)
.map_err(|e| Error::Shape(format!("Failed to reshape residuals: {}", e)))?;
let codes: Array1<i64> = batch_codes.iter().map(|&x| x as i64).collect();
Ok(EncodedIndexChunk {
codes,
residuals,
doclens,
})
}
pub fn write_index_from_encoded_chunks(
chunks: &[EncodedIndexChunk],
codec_artifacts: &PreparedCodecArtifacts,
index_path: &str,
config: &IndexConfig,
) -> Result<Metadata> {
use ndarray_npy::WriteNpyExt;
let index_dir = Path::new(index_path);
fs::create_dir_all(index_dir)?;
let embedding_dim = codec_artifacts.codec.embedding_dim();
let num_centroids = codec_artifacts.codec.num_centroids();
let total_embeddings: usize = chunks.iter().map(|c| c.codes.len()).sum();
let num_documents: usize = chunks.iter().map(|c| c.doclens.len()).sum();
let avg_doclen = if num_documents > 0 {
total_embeddings as f64 / num_documents as f64
} else {
0.0
};
let centroids_path = index_dir.join("centroids.npy");
codec_artifacts
.codec
.centroids_view()
.to_owned()
.write_npy(File::create(¢roids_path)?)?;
codec_artifacts
.bucket_cutoffs
.write_npy(File::create(index_dir.join("bucket_cutoffs.npy"))?)?;
codec_artifacts
.bucket_weights
.write_npy(File::create(index_dir.join("bucket_weights.npy"))?)?;
codec_artifacts
.avg_res_per_dim
.write_npy(File::create(index_dir.join("avg_residual.npy"))?)?;
Array1::from_vec(vec![codec_artifacts.cluster_threshold])
.write_npy(File::create(index_dir.join("cluster_threshold.npy"))?)?;
let n_chunks = chunks.len();
let plan = serde_json::json!({
"nbits": config.nbits,
"num_chunks": n_chunks,
});
let mut plan_file = File::create(index_dir.join("plan.json"))?;
writeln!(plan_file, "{}", serde_json::to_string_pretty(&plan)?)?;
let mut all_codes: Vec<usize> = Vec::with_capacity(total_embeddings);
let mut doc_lengths: Vec<i64> = Vec::with_capacity(num_documents);
let mut current_offset = 0usize;
for (chunk_idx, chunk) in chunks.iter().enumerate() {
let chunk_meta = ChunkMetadata {
num_documents: chunk.doclens.len(),
num_embeddings: chunk.codes.len(),
embedding_offset: current_offset,
};
current_offset += chunk.codes.len();
serde_json::to_writer_pretty(
BufWriter::new(File::create(
index_dir.join(format!("{}.metadata.json", chunk_idx)),
)?),
&chunk_meta,
)?;
serde_json::to_writer(
BufWriter::new(File::create(
index_dir.join(format!("doclens.{}.json", chunk_idx)),
)?),
&chunk.doclens,
)?;
chunk.codes.write_npy(File::create(
index_dir.join(format!("{}.codes.npy", chunk_idx)),
)?)?;
chunk.residuals.write_npy(File::create(
index_dir.join(format!("{}.residuals.npy", chunk_idx)),
)?)?;
doc_lengths.extend_from_slice(&chunk.doclens);
all_codes.extend(chunk.codes.iter().map(|&x| x as usize));
}
let mut code_to_docs: BTreeMap<usize, Vec<i64>> = BTreeMap::new();
let mut emb_idx = 0;
for (doc_id, &len) in doc_lengths.iter().enumerate() {
for _ in 0..len {
let code = all_codes[emb_idx];
code_to_docs.entry(code).or_default().push(doc_id as i64);
emb_idx += 1;
}
}
let mut ivf_data: Vec<i64> = Vec::new();
let mut ivf_lengths: Vec<i32> = vec![0; num_centroids];
for (centroid_id, ivf_len) in ivf_lengths.iter_mut().enumerate() {
if let Some(docs) = code_to_docs.get(¢roid_id) {
let mut unique_docs = docs.clone();
unique_docs.sort_unstable();
unique_docs.dedup();
*ivf_len = unique_docs.len() as i32;
ivf_data.extend(unique_docs);
}
}
Array1::from_vec(ivf_data).write_npy(File::create(index_dir.join("ivf.npy"))?)?;
Array1::from_vec(ivf_lengths).write_npy(File::create(index_dir.join("ivf_lengths.npy"))?)?;
let metadata = Metadata {
num_chunks: n_chunks,
nbits: config.nbits,
num_partitions: num_centroids,
num_embeddings: total_embeddings,
avg_doclen,
num_documents,
embedding_dim,
next_plaid_compatible: true,
};
serde_json::to_writer_pretty(
BufWriter::new(File::create(index_dir.join("metadata.json"))?),
&metadata,
)?;
Ok(metadata)
}
pub fn create_index_files(
embeddings: &[Array2<f32>],
centroids: Array2<f32>,
index_path: &str,
config: &IndexConfig,
) -> Result<Metadata> {
let index_dir = Path::new(index_path);
fs::create_dir_all(index_dir)?;
let num_documents = embeddings.len();
let embedding_dim = centroids.ncols();
let num_centroids = centroids.nrows();
if num_documents == 0 {
return Err(Error::IndexCreation("No documents provided".into()));
}
let total_embeddings: usize = embeddings.iter().map(|e| e.nrows()).sum();
let avg_doclen = total_embeddings as f64 / num_documents as f64;
let sample_count = ((16.0 * (120.0 * num_documents as f64).sqrt()) as usize)
.min(num_documents)
.max(1);
let mut rng = if let Some(seed) = config.seed {
use rand::SeedableRng;
rand_chacha::ChaCha8Rng::seed_from_u64(seed)
} else {
use rand::SeedableRng;
rand_chacha::ChaCha8Rng::from_entropy()
};
use rand::seq::SliceRandom;
let mut indices: Vec<usize> = (0..num_documents).collect();
indices.shuffle(&mut rng);
let sample_indices: Vec<usize> = indices.into_iter().take(sample_count).collect();
let heldout_size = (0.05 * total_embeddings as f64).min(50000.0) as usize;
let mut heldout_embeddings: Vec<f32> = Vec::with_capacity(heldout_size * embedding_dim);
let mut collected = 0;
for &idx in sample_indices.iter().rev() {
if collected >= heldout_size {
break;
}
let emb = &embeddings[idx];
let take = (heldout_size - collected).min(emb.nrows());
for row in emb.axis_iter(Axis(0)).take(take) {
heldout_embeddings.extend(row.iter());
}
collected += take;
}
let heldout = Array2::from_shape_vec((collected, embedding_dim), heldout_embeddings)
.map_err(|e| Error::IndexCreation(format!("Failed to create heldout array: {}", e)))?;
let avg_residual = Array1::zeros(embedding_dim);
let initial_codec =
ResidualCodec::new(config.nbits, centroids.clone(), avg_residual, None, None)?;
let heldout_codes = if config.force_cpu {
initial_codec.compress_into_codes_cpu(&heldout)
} else {
initial_codec.compress_into_codes(&heldout)
};
let mut residuals = heldout.clone();
for i in 0..heldout.nrows() {
let centroid = initial_codec.centroids.row(heldout_codes[i]);
for j in 0..embedding_dim {
residuals[[i, j]] -= centroid[j];
}
}
let distances: Array1<f32> = residuals
.axis_iter(Axis(0))
.map(|row| row.dot(&row).sqrt())
.collect();
#[allow(unused_variables)]
let cluster_threshold = quantile(&distances, 0.75);
let avg_res_per_dim: Array1<f32> = residuals
.axis_iter(Axis(1))
.map(|col| col.iter().map(|x| x.abs()).sum::<f32>() / col.len() as f32)
.collect();
let n_options = 1 << config.nbits;
let quantile_values: Vec<f64> = (1..n_options)
.map(|i| i as f64 / n_options as f64)
.collect();
let weight_quantile_values: Vec<f64> = (0..n_options)
.map(|i| (i as f64 + 0.5) / n_options as f64)
.collect();
let flat_residuals: Array1<f32> = residuals.iter().copied().collect();
let bucket_cutoffs = Array1::from_vec(quantiles(&flat_residuals, &quantile_values));
let bucket_weights = Array1::from_vec(quantiles(&flat_residuals, &weight_quantile_values));
let codec = ResidualCodec::new(
config.nbits,
centroids.clone(),
avg_res_per_dim.clone(),
Some(bucket_cutoffs.clone()),
Some(bucket_weights.clone()),
)?;
use ndarray_npy::WriteNpyExt;
let centroids_path = index_dir.join("centroids.npy");
codec
.centroids_view()
.to_owned()
.write_npy(File::create(¢roids_path)?)?;
let cutoffs_path = index_dir.join("bucket_cutoffs.npy");
bucket_cutoffs.write_npy(File::create(&cutoffs_path)?)?;
let weights_path = index_dir.join("bucket_weights.npy");
bucket_weights.write_npy(File::create(&weights_path)?)?;
let avg_res_path = index_dir.join("avg_residual.npy");
avg_res_per_dim.write_npy(File::create(&avg_res_path)?)?;
let threshold_path = index_dir.join("cluster_threshold.npy");
Array1::from_vec(vec![cluster_threshold]).write_npy(File::create(&threshold_path)?)?;
let n_chunks = (num_documents as f64 / config.batch_size as f64).ceil() as usize;
let plan_path = index_dir.join("plan.json");
let plan = serde_json::json!({
"nbits": config.nbits,
"num_chunks": n_chunks,
});
let mut plan_file = File::create(&plan_path)?;
writeln!(plan_file, "{}", serde_json::to_string_pretty(&plan)?)?;
let mut all_codes: Vec<usize> = Vec::with_capacity(total_embeddings);
let mut doc_lengths: Vec<i64> = Vec::with_capacity(num_documents);
for chunk_idx in 0..n_chunks {
let start = chunk_idx * config.batch_size;
let end = (start + config.batch_size).min(num_documents);
let chunk_docs = &embeddings[start..end];
let chunk_doclens: Vec<i64> = chunk_docs.iter().map(|d| d.nrows() as i64).collect();
let total_tokens: usize = chunk_doclens.iter().sum::<i64>() as usize;
let mut batch_embeddings = Array2::<f32>::zeros((total_tokens, embedding_dim));
let mut offset = 0;
for doc in chunk_docs {
let n = doc.nrows();
batch_embeddings
.slice_mut(s![offset..offset + n, ..])
.assign(doc);
offset += n;
}
let (batch_codes, batch_residuals) = {
#[cfg(feature = "cuda")]
{
let force_gpu = crate::is_force_gpu();
if !config.force_cpu {
if let Some(ctx) = crate::cuda::get_global_context() {
match crate::cuda::compress_and_residuals_cuda_batched(
&ctx,
&batch_embeddings.view(),
&codec.centroids_view(),
None,
) {
Ok(result) => result,
Err(e) => {
if force_gpu {
panic!("FORCE_GPU is set but CUDA compress_and_residuals failed: {}", e);
}
eprintln!(
"[next-plaid] CUDA compress_and_residuals failed: {}, falling back to CPU",
e
);
compress_and_residuals_cpu(&batch_embeddings, &codec)
}
}
} else if force_gpu {
panic!("FORCE_GPU is set but CUDA context is unavailable");
} else {
compress_and_residuals_cpu(&batch_embeddings, &codec)
}
} else {
compress_and_residuals_cpu(&batch_embeddings, &codec)
}
}
#[cfg(not(feature = "cuda"))]
{
compress_and_residuals_cpu(&batch_embeddings, &codec)
}
};
let batch_packed = codec.quantize_residuals(&batch_residuals)?;
for &len in &chunk_doclens {
doc_lengths.push(len);
}
all_codes.extend(batch_codes.iter().copied());
let chunk_meta = ChunkMetadata {
num_documents: end - start,
num_embeddings: batch_codes.len(),
embedding_offset: 0, };
let chunk_meta_path = index_dir.join(format!("{}.metadata.json", chunk_idx));
serde_json::to_writer_pretty(BufWriter::new(File::create(&chunk_meta_path)?), &chunk_meta)?;
let doclens_path = index_dir.join(format!("doclens.{}.json", chunk_idx));
serde_json::to_writer(BufWriter::new(File::create(&doclens_path)?), &chunk_doclens)?;
let chunk_codes_arr: Array1<i64> = batch_codes.iter().map(|&x| x as i64).collect();
let codes_path = index_dir.join(format!("{}.codes.npy", chunk_idx));
chunk_codes_arr.write_npy(File::create(&codes_path)?)?;
let residuals_path = index_dir.join(format!("{}.residuals.npy", chunk_idx));
batch_packed.write_npy(File::create(&residuals_path)?)?;
}
let mut current_offset = 0usize;
for chunk_idx in 0..n_chunks {
let chunk_meta_path = index_dir.join(format!("{}.metadata.json", chunk_idx));
let mut meta: serde_json::Value =
serde_json::from_reader(BufReader::new(File::open(&chunk_meta_path)?))?;
if let Some(obj) = meta.as_object_mut() {
obj.insert("embedding_offset".to_string(), current_offset.into());
let num_emb = obj["num_embeddings"].as_u64().unwrap_or(0) as usize;
current_offset += num_emb;
}
serde_json::to_writer_pretty(BufWriter::new(File::create(&chunk_meta_path)?), &meta)?;
}
let mut code_to_docs: BTreeMap<usize, Vec<i64>> = BTreeMap::new();
let mut emb_idx = 0;
for (doc_id, &len) in doc_lengths.iter().enumerate() {
for _ in 0..len {
let code = all_codes[emb_idx];
code_to_docs.entry(code).or_default().push(doc_id as i64);
emb_idx += 1;
}
}
let mut ivf_data: Vec<i64> = Vec::new();
let mut ivf_lengths: Vec<i32> = vec![0; num_centroids];
for (centroid_id, ivf_len) in ivf_lengths.iter_mut().enumerate() {
if let Some(docs) = code_to_docs.get(¢roid_id) {
let mut unique_docs: Vec<i64> = docs.clone();
unique_docs.sort_unstable();
unique_docs.dedup();
*ivf_len = unique_docs.len() as i32;
ivf_data.extend(unique_docs);
}
}
let ivf = Array1::from_vec(ivf_data);
let ivf_lengths = Array1::from_vec(ivf_lengths);
let ivf_path = index_dir.join("ivf.npy");
ivf.write_npy(File::create(&ivf_path)?)?;
let ivf_lengths_path = index_dir.join("ivf_lengths.npy");
ivf_lengths.write_npy(File::create(&ivf_lengths_path)?)?;
let metadata = Metadata {
num_chunks: n_chunks,
nbits: config.nbits,
num_partitions: num_centroids,
num_embeddings: total_embeddings,
avg_doclen,
num_documents,
embedding_dim,
next_plaid_compatible: true, };
let metadata_path = index_dir.join("metadata.json");
serde_json::to_writer_pretty(BufWriter::new(File::create(&metadata_path)?), &metadata)?;
Ok(metadata)
}
pub fn create_index_with_kmeans_files(
embeddings: &[Array2<f32>],
index_path: &str,
config: &IndexConfig,
) -> Result<Metadata> {
if embeddings.is_empty() {
return Err(Error::IndexCreation("No documents provided".into()));
}
#[cfg(feature = "cuda")]
if !config.force_cpu {
if crate::is_force_gpu() {
crate::cuda::get_global_context()
.expect("FORCE_GPU is set but CUDA context failed to initialize");
} else {
let _ = crate::cuda::get_global_context();
}
}
let kmeans_config = ComputeKmeansConfig {
kmeans_niters: config.kmeans_niters,
max_points_per_centroid: config.max_points_per_centroid,
seed: config.seed.unwrap_or(42),
n_samples_kmeans: config.n_samples_kmeans,
num_partitions: None, force_cpu: config.force_cpu,
};
let centroids = compute_kmeans(embeddings, &kmeans_config)?;
let metadata = create_index_files(embeddings, centroids, index_path, config)?;
if embeddings.len() <= config.start_from_scratch {
let index_dir = std::path::Path::new(index_path);
crate::update::save_embeddings_npy(index_dir, embeddings)?;
}
Ok(metadata)
}
pub struct MmapIndex {
pub path: String,
pub metadata: Metadata,
pub codec: ResidualCodec,
pub ivf: Array1<i64>,
pub ivf_lengths: Array1<i32>,
pub ivf_offsets: Array1<i64>,
pub doc_lengths: Array1<i64>,
pub doc_offsets: Array1<usize>,
pub mmap_codes: crate::mmap::MmapNpyArray1I64,
pub mmap_residuals: crate::mmap::MmapNpyArray2U8,
}
impl MmapIndex {
pub fn load(index_path: &str) -> Result<Self> {
use ndarray_npy::ReadNpyExt;
let index_dir = Path::new(index_path);
let mut metadata = Metadata::load_from_path(index_dir)?;
if !metadata.next_plaid_compatible {
eprintln!("Checking index format compatibility...");
let converted = crate::mmap::convert_fastplaid_to_nextplaid(index_dir)?;
if converted {
eprintln!("Index converted to next-plaid compatible format.");
let merged_codes = index_dir.join("merged_codes.npy");
let merged_residuals = index_dir.join("merged_residuals.npy");
let codes_manifest = index_dir.join("merged_codes.manifest.json");
let residuals_manifest = index_dir.join("merged_residuals.manifest.json");
for path in [
&merged_codes,
&merged_residuals,
&codes_manifest,
&residuals_manifest,
] {
if path.exists() {
let _ = fs::remove_file(path);
}
}
}
metadata.next_plaid_compatible = true;
let metadata_path = index_dir.join("metadata.json");
let file = File::create(&metadata_path)
.map_err(|e| Error::IndexLoad(format!("Failed to update metadata: {}", e)))?;
serde_json::to_writer_pretty(BufWriter::new(file), &metadata)?;
eprintln!("Metadata updated with next_plaid_compatible: true");
}
let codec = ResidualCodec::load_mmap_from_dir(index_dir)?;
let ivf_path = index_dir.join("ivf.npy");
let ivf: Array1<i64> = Array1::read_npy(
File::open(&ivf_path)
.map_err(|e| Error::IndexLoad(format!("Failed to open ivf.npy: {}", e)))?,
)
.map_err(|e| Error::IndexLoad(format!("Failed to read ivf.npy: {}", e)))?;
let ivf_lengths_path = index_dir.join("ivf_lengths.npy");
let ivf_lengths: Array1<i32> = Array1::read_npy(
File::open(&ivf_lengths_path)
.map_err(|e| Error::IndexLoad(format!("Failed to open ivf_lengths.npy: {}", e)))?,
)
.map_err(|e| Error::IndexLoad(format!("Failed to read ivf_lengths.npy: {}", e)))?;
let num_centroids = ivf_lengths.len();
let mut ivf_offsets = Array1::<i64>::zeros(num_centroids + 1);
for i in 0..num_centroids {
ivf_offsets[i + 1] = ivf_offsets[i] + ivf_lengths[i] as i64;
}
let mut doc_lengths_vec: Vec<i64> = Vec::with_capacity(metadata.num_documents);
for chunk_idx in 0..metadata.num_chunks {
let doclens_path = index_dir.join(format!("doclens.{}.json", chunk_idx));
let chunk_doclens: Vec<i64> =
serde_json::from_reader(BufReader::new(File::open(&doclens_path)?))?;
doc_lengths_vec.extend(chunk_doclens);
}
let doc_lengths = Array1::from_vec(doc_lengths_vec);
let mut doc_offsets = Array1::<usize>::zeros(doc_lengths.len() + 1);
for i in 0..doc_lengths.len() {
doc_offsets[i + 1] = doc_offsets[i] + doc_lengths[i] as usize;
}
let max_len = doc_lengths.iter().cloned().max().unwrap_or(0) as usize;
let last_len = *doc_lengths.last().unwrap_or(&0) as usize;
let padding_needed = max_len.saturating_sub(last_len);
let merged_codes_path =
crate::mmap::merge_codes_chunks(index_dir, metadata.num_chunks, padding_needed)?;
let merged_residuals_path =
crate::mmap::merge_residuals_chunks(index_dir, metadata.num_chunks, padding_needed)?;
let (mmap_codes, mmap_residuals) = (
crate::mmap::MmapNpyArray1I64::from_npy_file(&merged_codes_path)?,
crate::mmap::MmapNpyArray2U8::from_npy_file(&merged_residuals_path)?,
);
Ok(Self {
path: index_path.to_string(),
metadata,
codec,
ivf,
ivf_lengths,
ivf_offsets,
doc_lengths,
doc_offsets,
mmap_codes,
mmap_residuals,
})
}
pub fn get_candidates(&self, centroid_indices: &[usize]) -> Vec<i64> {
let mut candidates: Vec<i64> = Vec::new();
for &idx in centroid_indices {
if idx < self.ivf_lengths.len() {
let start = self.ivf_offsets[idx] as usize;
let len = self.ivf_lengths[idx] as usize;
candidates.extend(self.ivf.slice(s![start..start + len]).iter());
}
}
candidates.sort_unstable();
candidates.dedup();
candidates
}
pub fn get_document_embeddings(&self, doc_id: usize) -> Result<Array2<f32>> {
if doc_id >= self.doc_lengths.len() {
return Err(Error::Search(format!("Invalid document ID: {}", doc_id)));
}
let start = self.doc_offsets[doc_id];
let end = self.doc_offsets[doc_id + 1];
let codes_slice = self.mmap_codes.slice(start, end);
let residuals_view = self.mmap_residuals.slice_rows(start, end);
let codes: Array1<usize> = Array1::from_iter(codes_slice.iter().map(|&c| c as usize));
let residuals = residuals_view.to_owned();
self.codec.decompress(&residuals, &codes.view())
}
pub fn get_document_codes(&self, doc_ids: &[usize]) -> Vec<Vec<i64>> {
doc_ids
.iter()
.map(|&doc_id| {
if doc_id >= self.doc_lengths.len() {
return vec![];
}
let start = self.doc_offsets[doc_id];
let end = self.doc_offsets[doc_id + 1];
self.mmap_codes.slice(start, end).to_vec()
})
.collect()
}
pub fn decompress_documents(&self, doc_ids: &[usize]) -> Result<(Array2<f32>, Vec<usize>)> {
let mut total_tokens = 0usize;
let mut lengths = Vec::with_capacity(doc_ids.len());
for &doc_id in doc_ids {
if doc_id >= self.doc_lengths.len() {
lengths.push(0);
} else {
let len = self.doc_offsets[doc_id + 1] - self.doc_offsets[doc_id];
lengths.push(len);
total_tokens += len;
}
}
if total_tokens == 0 {
return Ok((Array2::zeros((0, self.codec.embedding_dim())), lengths));
}
let packed_dim = self.mmap_residuals.ncols();
let mut all_codes = Vec::with_capacity(total_tokens);
let mut all_residuals = Array2::<u8>::zeros((total_tokens, packed_dim));
let mut offset = 0;
for &doc_id in doc_ids {
if doc_id >= self.doc_lengths.len() {
continue;
}
let start = self.doc_offsets[doc_id];
let end = self.doc_offsets[doc_id + 1];
let len = end - start;
let codes_slice = self.mmap_codes.slice(start, end);
all_codes.extend(codes_slice.iter().map(|&c| c as usize));
let residuals_view = self.mmap_residuals.slice_rows(start, end);
all_residuals
.slice_mut(s![offset..offset + len, ..])
.assign(&residuals_view);
offset += len;
}
let codes_arr = Array1::from_vec(all_codes);
let embeddings = self.codec.decompress(&all_residuals, &codes_arr.view())?;
Ok((embeddings, lengths))
}
pub fn search(
&self,
query: &Array2<f32>,
params: &crate::search::SearchParameters,
subset: Option<&[i64]>,
) -> Result<crate::search::SearchResult> {
crate::search::search_one_mmap(self, query, params, subset)
}
pub fn search_batch(
&self,
queries: &[Array2<f32>],
params: &crate::search::SearchParameters,
parallel: bool,
subset: Option<&[i64]>,
) -> Result<Vec<crate::search::SearchResult>> {
crate::search::search_many_mmap(self, queries, params, parallel, subset)
}
pub fn num_documents(&self) -> usize {
self.doc_lengths.len()
}
pub fn num_embeddings(&self) -> usize {
self.metadata.num_embeddings
}
pub fn num_partitions(&self) -> usize {
self.metadata.num_partitions
}
pub fn avg_doclen(&self) -> f64 {
self.metadata.avg_doclen
}
pub fn embedding_dim(&self) -> usize {
self.codec.embedding_dim()
}
fn release_mmaps(&mut self) {
self.mmap_codes = crate::mmap::MmapNpyArray1I64::empty();
self.mmap_residuals = crate::mmap::MmapNpyArray2U8::empty();
self.codec.centroids = crate::codec::CentroidStore::Owned(Array2::zeros((0, 0)));
}
pub fn reconstruct(&self, doc_ids: &[i64]) -> Result<Vec<Array2<f32>>> {
crate::embeddings::reconstruct_embeddings(self, doc_ids)
}
pub fn reconstruct_single(&self, doc_id: i64) -> Result<Array2<f32>> {
crate::embeddings::reconstruct_single(self, doc_id)
}
pub fn create_with_kmeans(
embeddings: &[Array2<f32>],
index_path: &str,
config: &IndexConfig,
) -> Result<Self> {
create_index_with_kmeans_files(embeddings, index_path, config)?;
Self::load(index_path)
}
pub fn update(
&mut self,
embeddings: &[Array2<f32>],
config: &crate::update::UpdateConfig,
) -> Result<Vec<i64>> {
use crate::codec::ResidualCodec;
use crate::update::{
clear_buffer, clear_embeddings_npy, embeddings_npy_exists, load_buffer,
load_buffer_info, load_cluster_threshold, load_embeddings_npy, save_buffer,
update_centroids, update_index,
};
let path_str = self.path.clone();
let index_path = std::path::Path::new(&path_str);
let num_new_docs = embeddings.len();
self.release_mmaps();
if self.metadata.num_documents <= config.start_from_scratch {
let existing_embeddings = load_embeddings_npy(index_path)?;
if existing_embeddings.len() == self.metadata.num_documents {
let start_doc_id = existing_embeddings.len() as i64;
let combined_embeddings: Vec<Array2<f32>> = existing_embeddings
.into_iter()
.chain(embeddings.iter().cloned())
.collect();
let index_config = IndexConfig {
nbits: self.metadata.nbits,
batch_size: config.batch_size,
seed: Some(config.seed),
kmeans_niters: config.kmeans_niters,
max_points_per_centroid: config.max_points_per_centroid,
n_samples_kmeans: config.n_samples_kmeans,
start_from_scratch: config.start_from_scratch,
force_cpu: config.force_cpu,
..Default::default()
};
*self = Self::create_with_kmeans(&combined_embeddings, &path_str, &index_config)?;
if combined_embeddings.len() > config.start_from_scratch
&& embeddings_npy_exists(index_path)
{
clear_embeddings_npy(index_path)?;
}
return Ok((start_doc_id..start_doc_id + num_new_docs as i64).collect());
}
}
let buffer = load_buffer(index_path)?;
let buffer_len = buffer.len();
let total_new = embeddings.len() + buffer_len;
let start_doc_id: i64;
let mut codec = ResidualCodec::load_from_dir(index_path)?;
if total_new >= config.buffer_size {
let num_buffered = load_buffer_info(index_path)?;
if num_buffered > 0 && self.metadata.num_documents >= num_buffered {
let start_del_idx = self.metadata.num_documents - num_buffered;
let docs_to_delete: Vec<i64> = (start_del_idx..self.metadata.num_documents)
.map(|i| i as i64)
.collect();
crate::delete::delete_from_index_keep_buffer(&docs_to_delete, &path_str)?;
self.metadata = Metadata::load_from_path(index_path)?;
}
start_doc_id = (self.metadata.num_documents + buffer_len) as i64;
let combined: Vec<Array2<f32>> = buffer
.into_iter()
.chain(embeddings.iter().cloned())
.collect();
if let Ok(cluster_threshold) = load_cluster_threshold(index_path) {
let new_centroids =
update_centroids(index_path, &combined, cluster_threshold, config)?;
if new_centroids > 0 {
codec = ResidualCodec::load_from_dir(index_path)?;
}
}
clear_buffer(index_path)?;
update_index(
&combined,
&path_str,
&codec,
Some(config.batch_size),
true,
config.force_cpu,
)?;
} else {
start_doc_id = self.metadata.num_documents as i64;
let combined_buffer: Vec<Array2<f32>> = buffer
.into_iter()
.chain(embeddings.iter().cloned())
.collect();
save_buffer(index_path, &combined_buffer)?;
update_index(
embeddings,
&path_str,
&codec,
Some(config.batch_size),
false,
config.force_cpu,
)?;
}
*self = Self::load(&path_str)?;
Ok((start_doc_id..start_doc_id + num_new_docs as i64).collect())
}
pub fn update_with_metadata(
&mut self,
embeddings: &[Array2<f32>],
config: &crate::update::UpdateConfig,
metadata: Option<&[serde_json::Value]>,
) -> Result<Vec<i64>> {
if let Some(meta) = metadata {
if meta.len() != embeddings.len() {
return Err(Error::Config(format!(
"Metadata length ({}) must match embeddings length ({})",
meta.len(),
embeddings.len()
)));
}
}
let doc_ids = self.update(embeddings, config)?;
if let Some(meta) = metadata {
crate::filtering::update(&self.path, meta, &doc_ids)?;
}
Ok(doc_ids)
}
pub fn update_or_create(
embeddings: &[Array2<f32>],
index_path: &str,
index_config: &IndexConfig,
update_config: &crate::update::UpdateConfig,
) -> Result<(Self, Vec<i64>)> {
let index_dir = std::path::Path::new(index_path);
let metadata_path = index_dir.join("metadata.json");
if metadata_path.exists() {
let mut index = Self::load(index_path)?;
let doc_ids = index.update(embeddings, update_config)?;
Ok((index, doc_ids))
} else {
let num_docs = embeddings.len();
let index = Self::create_with_kmeans(embeddings, index_path, index_config)?;
let doc_ids: Vec<i64> = (0..num_docs as i64).collect();
Ok((index, doc_ids))
}
}
pub fn update_or_create_with_metadata(
embeddings: &[Array2<f32>],
index_path: &str,
index_config: &IndexConfig,
update_config: &crate::update::UpdateConfig,
metadata: Option<&[serde_json::Value]>,
) -> Result<(Self, Vec<i64>)> {
if let Some(meta) = metadata {
if meta.len() != embeddings.len() {
return Err(Error::Config(format!(
"Metadata length ({}) must match embeddings length ({})",
meta.len(),
embeddings.len()
)));
}
}
let index_dir = std::path::Path::new(index_path);
let metadata_json_path = index_dir.join("metadata.json");
let (index, doc_ids) = if metadata_json_path.exists() {
let mut index = Self::load(index_path)?;
let doc_ids = index.update(embeddings, update_config)?;
(index, doc_ids)
} else {
let num_docs = embeddings.len();
let index = Self::create_with_kmeans(embeddings, index_path, index_config)?;
let doc_ids: Vec<i64> = (0..num_docs as i64).collect();
(index, doc_ids)
};
if let Some(meta) = metadata {
if crate::filtering::exists(index_path) {
crate::filtering::update(index_path, meta, &doc_ids)?;
} else {
crate::filtering::create(index_path, meta, &doc_ids)?;
}
crate::text_search::index(index_path, meta, &doc_ids, &index_config.fts_tokenizer)?;
}
Ok((index, doc_ids))
}
pub fn reload(&mut self) -> Result<()> {
let path = self.path.clone();
self.release_mmaps();
*self = Self::load(&path)?;
Ok(())
}
pub fn delete(&mut self, doc_ids: &[i64]) -> Result<usize> {
self.delete_with_options(doc_ids, true)
}
pub fn delete_with_options(&mut self, doc_ids: &[i64], delete_metadata: bool) -> Result<usize> {
let path = self.path.clone();
self.release_mmaps();
let deleted = crate::delete::delete_from_index(doc_ids, &path)?;
if delete_metadata && deleted > 0 {
let index_path = std::path::Path::new(&path);
let db_path = index_path.join("metadata.db");
if db_path.exists() {
crate::filtering::delete(&path, doc_ids)?;
crate::text_search::rebuild(&path)?;
}
}
Ok(deleted)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_index_config_default() {
let config = IndexConfig::default();
assert_eq!(config.nbits, 4);
assert_eq!(config.batch_size, 50_000);
assert_eq!(config.seed, Some(42));
}
#[test]
fn test_update_or_create_new_index() {
use ndarray::Array2;
use tempfile::tempdir;
let temp_dir = tempdir().unwrap();
let index_path = temp_dir.path().to_str().unwrap();
let mut embeddings: Vec<Array2<f32>> = Vec::new();
for i in 0..5 {
let mut doc = Array2::<f32>::zeros((5, 32));
for j in 0..5 {
for k in 0..32 {
doc[[j, k]] = (i as f32 * 0.1) + (j as f32 * 0.01) + (k as f32 * 0.001);
}
}
for mut row in doc.rows_mut() {
let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
row.iter_mut().for_each(|x| *x /= norm);
}
}
embeddings.push(doc);
}
let index_config = IndexConfig {
nbits: 2,
batch_size: 50,
seed: Some(42),
kmeans_niters: 2,
..Default::default()
};
let update_config = crate::update::UpdateConfig::default();
let (index, doc_ids) =
MmapIndex::update_or_create(&embeddings, index_path, &index_config, &update_config)
.expect("Failed to create index");
assert_eq!(index.metadata.num_documents, 5);
assert_eq!(doc_ids, vec![0, 1, 2, 3, 4]);
assert!(temp_dir.path().join("metadata.json").exists());
assert!(temp_dir.path().join("centroids.npy").exists());
}
#[test]
fn test_update_or_create_existing_index() {
use ndarray::Array2;
use tempfile::tempdir;
let temp_dir = tempdir().unwrap();
let index_path = temp_dir.path().to_str().unwrap();
let create_embeddings = |count: usize, offset: usize| -> Vec<Array2<f32>> {
let mut embeddings = Vec::new();
for i in 0..count {
let mut doc = Array2::<f32>::zeros((5, 32));
for j in 0..5 {
for k in 0..32 {
doc[[j, k]] =
((i + offset) as f32 * 0.1) + (j as f32 * 0.01) + (k as f32 * 0.001);
}
}
for mut row in doc.rows_mut() {
let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
row.iter_mut().for_each(|x| *x /= norm);
}
}
embeddings.push(doc);
}
embeddings
};
let index_config = IndexConfig {
nbits: 2,
batch_size: 50,
seed: Some(42),
kmeans_niters: 2,
..Default::default()
};
let update_config = crate::update::UpdateConfig::default();
let embeddings1 = create_embeddings(5, 0);
let (index1, doc_ids1) =
MmapIndex::update_or_create(&embeddings1, index_path, &index_config, &update_config)
.expect("Failed to create index");
assert_eq!(index1.metadata.num_documents, 5);
assert_eq!(doc_ids1, vec![0, 1, 2, 3, 4]);
drop(index1);
let embeddings2 = create_embeddings(3, 5);
let (index2, doc_ids2) =
MmapIndex::update_or_create(&embeddings2, index_path, &index_config, &update_config)
.expect("Failed to update index");
assert_eq!(index2.metadata.num_documents, 8);
assert_eq!(doc_ids2, vec![5, 6, 7]);
}
}