use std::collections::{BTreeMap, HashSet};
use std::fs::File;
use std::io::{BufReader, BufWriter};
use std::path::Path;
use ndarray::{Array1, Array2};
use crate::error::Error;
use crate::error::Result;
use crate::index::Metadata;
pub fn delete_from_index(doc_ids: &[i64], index_path: &str) -> Result<usize> {
delete_from_index_impl(doc_ids, index_path, true)
}
pub fn delete_from_index_keep_buffer(doc_ids: &[i64], index_path: &str) -> Result<usize> {
delete_from_index_impl(doc_ids, index_path, false)
}
fn delete_from_index_impl(doc_ids: &[i64], index_path: &str, clean_buffer: bool) -> Result<usize> {
use ndarray_npy::{ReadNpyExt, WriteNpyExt};
let index_dir = Path::new(index_path);
let metadata_path = index_dir.join("metadata.json");
let metadata = Metadata::load_from_path(index_dir)?;
let original_num_documents = metadata.num_documents;
let num_chunks = metadata.num_chunks;
let nbits = metadata.nbits;
let num_partitions = metadata.num_partitions;
let ids_to_delete: HashSet<i64> = doc_ids.iter().copied().collect();
let mut final_num_documents: usize = 0;
let mut total_embeddings: usize = 0;
let mut current_doc_offset: i64 = 0;
let mut docs_actually_deleted: usize = 0;
for chunk_idx in 0..num_chunks {
let doclens_path = index_dir.join(format!("doclens.{}.json", chunk_idx));
let doclens: Vec<i64> = serde_json::from_reader(BufReader::new(
File::open(&doclens_path)
.map_err(|e| Error::Delete(format!("Failed to open doclens: {}", e)))?,
))?;
let mut new_doclens: Vec<i64> = Vec::new();
let mut embs_to_keep_mask: Vec<bool> = Vec::new();
for (i, &len) in doclens.iter().enumerate() {
let doc_id = current_doc_offset + i as i64;
if !ids_to_delete.contains(&doc_id) {
new_doclens.push(len);
embs_to_keep_mask.extend(std::iter::repeat_n(true, len as usize));
} else {
docs_actually_deleted += 1;
embs_to_keep_mask.extend(std::iter::repeat_n(false, len as usize));
}
}
final_num_documents += new_doclens.len();
if new_doclens.len() < doclens.len() {
serde_json::to_writer(BufWriter::new(File::create(&doclens_path)?), &new_doclens)?;
let codes_path = index_dir.join(format!("{}.codes.npy", chunk_idx));
let codes: Array1<i64> = Array1::read_npy(
File::open(&codes_path)
.map_err(|e| Error::Delete(format!("Failed to open codes: {}", e)))?,
)?;
let new_codes: Array1<i64> = codes
.iter()
.zip(embs_to_keep_mask.iter())
.filter_map(|(&code, &keep)| if keep { Some(code) } else { None })
.collect();
new_codes.write_npy(File::create(&codes_path)?)?;
let residuals_path = index_dir.join(format!("{}.residuals.npy", chunk_idx));
let residuals: Array2<u8> = Array2::read_npy(
File::open(&residuals_path)
.map_err(|e| Error::Delete(format!("Failed to open residuals: {}", e)))?,
)?;
let packed_dim = residuals.ncols();
let kept_count = embs_to_keep_mask.iter().filter(|&&k| k).count();
let mut new_residuals = Array2::<u8>::zeros((kept_count, packed_dim));
let mut new_idx = 0;
for (old_idx, &keep) in embs_to_keep_mask.iter().enumerate() {
if keep {
new_residuals
.row_mut(new_idx)
.assign(&residuals.row(old_idx));
new_idx += 1;
}
}
new_residuals.write_npy(File::create(&residuals_path)?)?;
let chunk_meta_path = index_dir.join(format!("{}.metadata.json", chunk_idx));
let mut chunk_meta: serde_json::Value = serde_json::from_reader(BufReader::new(
File::open(&chunk_meta_path)
.map_err(|e| Error::Delete(format!("Failed to open chunk metadata: {}", e)))?,
))?;
if let Some(obj) = chunk_meta.as_object_mut() {
obj.insert("num_documents".to_string(), new_doclens.len().into());
obj.insert("num_embeddings".to_string(), new_codes.len().into());
}
serde_json::to_writer_pretty(
BufWriter::new(File::create(&chunk_meta_path)?),
&chunk_meta,
)?;
}
total_embeddings += new_doclens.iter().sum::<i64>() as usize;
current_doc_offset += doclens.len() as i64;
}
let mut all_codes: Vec<i64> = Vec::with_capacity(total_embeddings);
for chunk_idx in 0..num_chunks {
let codes_path = index_dir.join(format!("{}.codes.npy", chunk_idx));
let chunk_codes: Array1<i64> =
Array1::read_npy(File::open(&codes_path).map_err(|e| {
Error::Delete(format!("Failed to read codes for IVF rebuild: {}", e))
})?)?;
all_codes.extend(chunk_codes.iter());
}
let mut code_to_docs: BTreeMap<usize, Vec<i64>> = BTreeMap::new();
let mut emb_idx = 0;
let mut doc_id: i64 = 0;
for chunk_idx in 0..num_chunks {
let doclens_path = index_dir.join(format!("doclens.{}.json", chunk_idx));
let doclens: Vec<i64> =
serde_json::from_reader(BufReader::new(File::open(&doclens_path)?))?;
for &len in &doclens {
for _ in 0..len {
if emb_idx < all_codes.len() {
let code = all_codes[emb_idx] as usize;
code_to_docs.entry(code).or_default().push(doc_id);
}
emb_idx += 1;
}
doc_id += 1;
}
}
let mut ivf_data: Vec<i64> = Vec::new();
let mut ivf_lengths: Vec<i32> = vec![0; num_partitions];
for (centroid_id, ivf_len) in ivf_lengths.iter_mut().enumerate() {
if let Some(docs) = code_to_docs.get(¢roid_id) {
let mut unique_docs: Vec<i64> = docs.clone();
unique_docs.sort_unstable();
unique_docs.dedup();
*ivf_len = unique_docs.len() as i32;
ivf_data.extend(unique_docs);
}
}
let ivf = Array1::from_vec(ivf_data);
let ivf_lengths = Array1::from_vec(ivf_lengths);
ivf.write_npy(File::create(index_dir.join("ivf.npy"))?)?;
ivf_lengths.write_npy(File::create(index_dir.join("ivf_lengths.npy"))?)?;
let final_avg_doclen = if final_num_documents > 0 {
total_embeddings as f64 / final_num_documents as f64
} else {
0.0
};
let final_metadata = Metadata {
num_chunks,
nbits,
num_partitions,
num_embeddings: total_embeddings,
avg_doclen: final_avg_doclen,
num_documents: final_num_documents,
embedding_dim: metadata.embedding_dim,
next_plaid_compatible: metadata.next_plaid_compatible,
};
serde_json::to_writer_pretty(
BufWriter::new(File::create(&metadata_path)?),
&final_metadata,
)?;
crate::mmap::clear_merged_files(index_dir)?;
if clean_buffer {
clean_embeddings_files(index_dir, &ids_to_delete, original_num_documents)?;
}
Ok(docs_actually_deleted)
}
fn clean_embeddings_files(
index_dir: &Path,
ids_to_delete: &HashSet<i64>,
original_num_documents: usize,
) -> Result<()> {
use ndarray_npy::{ReadNpyExt, WriteNpyExt};
let emb_path = index_dir.join("embeddings.npy");
let emb_lengths_path = index_dir.join("embeddings_lengths.json");
if emb_path.exists() && emb_lengths_path.exists() {
let flat: Array2<f32> = Array2::read_npy(File::open(&emb_path)?)?;
let lengths: Vec<i64> =
serde_json::from_reader(BufReader::new(File::open(&emb_lengths_path)?))?;
let dim = flat.ncols();
let mut new_embeddings: Vec<f32> = Vec::new();
let mut new_lengths: Vec<i64> = Vec::new();
let mut offset = 0;
for (doc_id, &len) in lengths.iter().enumerate() {
let len_usize = len as usize;
if !ids_to_delete.contains(&(doc_id as i64)) {
for row_idx in offset..offset + len_usize {
if row_idx < flat.nrows() {
new_embeddings.extend(flat.row(row_idx).iter());
}
}
new_lengths.push(len);
}
offset += len_usize;
}
if !new_lengths.is_empty() {
let new_total_rows = new_embeddings.len() / dim;
let new_flat = Array2::from_shape_vec((new_total_rows, dim), new_embeddings)
.map_err(|e| Error::Delete(format!("Failed to reshape embeddings: {}", e)))?;
new_flat.write_npy(File::create(&emb_path)?)?;
serde_json::to_writer(
BufWriter::new(File::create(&emb_lengths_path)?),
&new_lengths,
)?;
} else {
std::fs::remove_file(&emb_path).ok();
std::fs::remove_file(&emb_lengths_path).ok();
}
}
let buffer_path = index_dir.join("buffer.npy");
let buffer_lengths_path = index_dir.join("buffer_lengths.json");
let buffer_info_path = index_dir.join("buffer_info.json");
if buffer_path.exists() && buffer_lengths_path.exists() {
let flat: Array2<f32> = Array2::read_npy(File::open(&buffer_path)?)?;
let lengths: Vec<i64> =
serde_json::from_reader(BufReader::new(File::open(&buffer_lengths_path)?))?;
let dim = flat.ncols();
let mut new_embeddings: Vec<f32> = Vec::new();
let mut new_lengths: Vec<i64> = Vec::new();
let mut offset = 0;
let buffer_len = lengths.len();
let buffer_start_doc_id = (original_num_documents as i64) - (buffer_len as i64);
for (i, &len) in lengths.iter().enumerate() {
let len_usize = len as usize;
let doc_id = buffer_start_doc_id + i as i64;
if !ids_to_delete.contains(&doc_id) {
for row_idx in offset..offset + len_usize {
if row_idx < flat.nrows() {
new_embeddings.extend(flat.row(row_idx).iter());
}
}
new_lengths.push(len);
}
offset += len_usize;
}
if !new_lengths.is_empty() {
let new_total_rows = new_embeddings.len() / dim;
let new_flat = Array2::from_shape_vec((new_total_rows, dim), new_embeddings)
.map_err(|e| Error::Delete(format!("Failed to reshape buffer: {}", e)))?;
new_flat.write_npy(File::create(&buffer_path)?)?;
serde_json::to_writer(
BufWriter::new(File::create(&buffer_lengths_path)?),
&new_lengths,
)?;
let buffer_info = serde_json::json!({ "num_docs": new_lengths.len() });
serde_json::to_writer(
BufWriter::new(File::create(&buffer_info_path)?),
&buffer_info,
)?;
} else {
std::fs::remove_file(&buffer_path).ok();
std::fs::remove_file(&buffer_lengths_path).ok();
std::fs::remove_file(&buffer_info_path).ok();
}
}
Ok(())
}
#[cfg(test)]
mod tests {
#[allow(unused_imports)]
use super::*;
#[test]
fn test_delete_from_index() {
use crate::index::{IndexConfig, MmapIndex};
use ndarray::Array2;
use tempfile::tempdir;
let temp_dir = tempdir().unwrap();
let index_path = temp_dir.path().to_str().unwrap();
let mut embeddings: Vec<Array2<f32>> = Vec::new();
for i in 0..10 {
let num_tokens = 5 + (i % 3); let mut doc = Array2::<f32>::zeros((num_tokens, 64));
for j in 0..num_tokens {
for k in 0..64 {
doc[[j, k]] = (i as f32 * 0.1) + (j as f32 * 0.01) + (k as f32 * 0.001);
}
}
for mut row in doc.rows_mut() {
let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
row.iter_mut().for_each(|x| *x /= norm);
}
}
embeddings.push(doc);
}
let config = IndexConfig {
nbits: 2,
batch_size: 50,
seed: Some(42),
kmeans_niters: 2,
max_points_per_centroid: 256,
n_samples_kmeans: None,
start_from_scratch: 999,
force_cpu: false,
};
let index = MmapIndex::create_with_kmeans(&embeddings, index_path, &config).unwrap();
let original_num_docs = index.metadata.num_documents;
assert_eq!(original_num_docs, 10);
let deleted = delete_from_index(&[2, 5, 7], index_path).unwrap();
assert_eq!(deleted, 3);
let index_after = MmapIndex::load(index_path).unwrap();
assert_eq!(index_after.metadata.num_documents, 7);
let num_docs = index_after.metadata.num_documents as i64;
for &doc_id in index_after.ivf.iter() {
assert!(
doc_id >= 0 && doc_id < num_docs,
"Invalid doc ID {} in IVF (should be in range [0, {}))",
doc_id,
num_docs
);
}
let query = embeddings[0].clone(); let results = index_after
.search(&query, &crate::search::SearchParameters::default(), None)
.unwrap();
assert!(
!results.passage_ids.is_empty(),
"Search should return results"
);
}
#[test]
fn test_delete_nonexistent_docs() {
use crate::index::{IndexConfig, MmapIndex};
use ndarray::Array2;
use tempfile::tempdir;
let temp_dir = tempdir().unwrap();
let index_path = temp_dir.path().to_str().unwrap();
let mut embeddings: Vec<Array2<f32>> = Vec::new();
for i in 0..5 {
let mut doc = Array2::<f32>::zeros((5, 32));
for j in 0..5 {
for k in 0..32 {
doc[[j, k]] = (i as f32 + j as f32 + k as f32) * 0.01;
}
}
for mut row in doc.rows_mut() {
let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
row.iter_mut().for_each(|x| *x /= norm);
}
}
embeddings.push(doc);
}
let config = IndexConfig {
nbits: 2,
batch_size: 50,
seed: Some(42),
kmeans_niters: 2,
max_points_per_centroid: 256,
n_samples_kmeans: None,
start_from_scratch: 999,
force_cpu: false,
};
MmapIndex::create_with_kmeans(&embeddings, index_path, &config).unwrap();
let deleted = delete_from_index(&[2, 100, 200], index_path).unwrap();
assert_eq!(deleted, 1);
let index_after = MmapIndex::load(index_path).unwrap();
assert_eq!(index_after.metadata.num_documents, 4);
}
}