1use std::collections::BTreeMap;
4use std::fs::{self, File};
5use std::io::{BufReader, BufWriter, Write};
6use std::path::Path;
7
8use ndarray::{s, Array1, Array2, Axis};
9use serde::{Deserialize, Serialize};
10
11use crate::codec::ResidualCodec;
12use crate::error::{Error, Result};
13use crate::kmeans::{compute_kmeans, ComputeKmeansConfig};
14use crate::utils::{quantile, quantiles};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct IndexConfig {
19 pub nbits: usize,
21 pub batch_size: usize,
23 pub seed: Option<u64>,
25 #[serde(default = "default_kmeans_niters")]
27 pub kmeans_niters: usize,
28 #[serde(default = "default_max_points_per_centroid")]
30 pub max_points_per_centroid: usize,
31 #[serde(default)]
34 pub n_samples_kmeans: Option<usize>,
35 #[serde(default = "default_start_from_scratch")]
39 pub start_from_scratch: usize,
40}
41
42fn default_start_from_scratch() -> usize {
43 999
44}
45
46fn default_kmeans_niters() -> usize {
47 4
48}
49
50fn default_max_points_per_centroid() -> usize {
51 256
52}
53
54impl Default for IndexConfig {
55 fn default() -> Self {
56 Self {
57 nbits: 4,
58 batch_size: 50_000,
59 seed: Some(42),
60 kmeans_niters: 4,
61 max_points_per_centroid: 256,
62 n_samples_kmeans: None,
63 start_from_scratch: 999,
64 }
65 }
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct Metadata {
71 pub num_chunks: usize,
73 pub nbits: usize,
75 pub num_partitions: usize,
77 pub num_embeddings: usize,
79 pub avg_doclen: f64,
81 pub num_documents: usize,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct ChunkMetadata {
88 pub num_documents: usize,
89 pub num_embeddings: usize,
90 #[serde(default)]
91 pub embedding_offset: usize,
92}
93
94pub fn create_index_files(
116 embeddings: &[Array2<f32>],
117 centroids: Array2<f32>,
118 index_path: &str,
119 config: &IndexConfig,
120) -> Result<Metadata> {
121 let index_dir = Path::new(index_path);
122 fs::create_dir_all(index_dir)?;
123
124 let num_documents = embeddings.len();
125 let embedding_dim = centroids.ncols();
126 let num_centroids = centroids.nrows();
127
128 if num_documents == 0 {
129 return Err(Error::IndexCreation("No documents provided".into()));
130 }
131
132 let total_embeddings: usize = embeddings.iter().map(|e| e.nrows()).sum();
134 let avg_doclen = total_embeddings as f64 / num_documents as f64;
135
136 let sample_count = ((16.0 * (120.0 * num_documents as f64).sqrt()) as usize)
138 .min(num_documents)
139 .max(1);
140
141 let mut rng = if let Some(seed) = config.seed {
142 use rand::SeedableRng;
143 rand_chacha::ChaCha8Rng::seed_from_u64(seed)
144 } else {
145 use rand::SeedableRng;
146 rand_chacha::ChaCha8Rng::from_entropy()
147 };
148
149 use rand::seq::SliceRandom;
150 let mut indices: Vec<usize> = (0..num_documents).collect();
151 indices.shuffle(&mut rng);
152 let sample_indices: Vec<usize> = indices.into_iter().take(sample_count).collect();
153
154 let heldout_size = (0.05 * total_embeddings as f64).min(50000.0) as usize;
156 let mut heldout_embeddings: Vec<f32> = Vec::with_capacity(heldout_size * embedding_dim);
157 let mut collected = 0;
158
159 for &idx in sample_indices.iter().rev() {
160 if collected >= heldout_size {
161 break;
162 }
163 let emb = &embeddings[idx];
164 let take = (heldout_size - collected).min(emb.nrows());
165 for row in emb.axis_iter(Axis(0)).take(take) {
166 heldout_embeddings.extend(row.iter());
167 }
168 collected += take;
169 }
170
171 let heldout = Array2::from_shape_vec((collected, embedding_dim), heldout_embeddings)
172 .map_err(|e| Error::IndexCreation(format!("Failed to create heldout array: {}", e)))?;
173
174 let avg_residual = Array1::zeros(embedding_dim);
176 let initial_codec =
177 ResidualCodec::new(config.nbits, centroids.clone(), avg_residual, None, None)?;
178
179 let heldout_codes = initial_codec.compress_into_codes(&heldout);
181
182 let mut residuals = heldout.clone();
184 for i in 0..heldout.nrows() {
185 let centroid = initial_codec.centroids.row(heldout_codes[i]);
186 for j in 0..embedding_dim {
187 residuals[[i, j]] -= centroid[j];
188 }
189 }
190
191 let distances: Array1<f32> = residuals
193 .axis_iter(Axis(0))
194 .map(|row| row.dot(&row).sqrt())
195 .collect();
196 #[allow(unused_variables)]
197 let cluster_threshold = quantile(&distances, 0.75);
198
199 let avg_res_per_dim: Array1<f32> = residuals
201 .axis_iter(Axis(1))
202 .map(|col| col.iter().map(|x| x.abs()).sum::<f32>() / col.len() as f32)
203 .collect();
204
205 let n_options = 1 << config.nbits;
207 let quantile_values: Vec<f64> = (1..n_options)
208 .map(|i| i as f64 / n_options as f64)
209 .collect();
210 let weight_quantile_values: Vec<f64> = (0..n_options)
211 .map(|i| (i as f64 + 0.5) / n_options as f64)
212 .collect();
213
214 let flat_residuals: Array1<f32> = residuals.iter().copied().collect();
216 let bucket_cutoffs = Array1::from_vec(quantiles(&flat_residuals, &quantile_values));
217 let bucket_weights = Array1::from_vec(quantiles(&flat_residuals, &weight_quantile_values));
218
219 let codec = ResidualCodec::new(
220 config.nbits,
221 centroids.clone(),
222 avg_res_per_dim.clone(),
223 Some(bucket_cutoffs.clone()),
224 Some(bucket_weights.clone()),
225 )?;
226
227 use ndarray_npy::WriteNpyExt;
229
230 let centroids_path = index_dir.join("centroids.npy");
231 codec
232 .centroids_view()
233 .to_owned()
234 .write_npy(File::create(¢roids_path)?)?;
235
236 let cutoffs_path = index_dir.join("bucket_cutoffs.npy");
237 bucket_cutoffs.write_npy(File::create(&cutoffs_path)?)?;
238
239 let weights_path = index_dir.join("bucket_weights.npy");
240 bucket_weights.write_npy(File::create(&weights_path)?)?;
241
242 let avg_res_path = index_dir.join("avg_residual.npy");
243 avg_res_per_dim.write_npy(File::create(&avg_res_path)?)?;
244
245 let threshold_path = index_dir.join("cluster_threshold.npy");
246 Array1::from_vec(vec![cluster_threshold]).write_npy(File::create(&threshold_path)?)?;
247
248 let n_chunks = (num_documents as f64 / config.batch_size as f64).ceil() as usize;
250
251 let plan_path = index_dir.join("plan.json");
253 let plan = serde_json::json!({
254 "nbits": config.nbits,
255 "num_chunks": n_chunks,
256 });
257 let mut plan_file = File::create(&plan_path)?;
258 writeln!(plan_file, "{}", serde_json::to_string_pretty(&plan)?)?;
259
260 let mut all_codes: Vec<usize> = Vec::with_capacity(total_embeddings);
261 let mut doc_lengths: Vec<i64> = Vec::with_capacity(num_documents);
262
263 let progress = indicatif::ProgressBar::new(n_chunks as u64);
264 progress.set_message("Creating index...");
265
266 for chunk_idx in 0..n_chunks {
267 let start = chunk_idx * config.batch_size;
268 let end = (start + config.batch_size).min(num_documents);
269 let chunk_docs = &embeddings[start..end];
270
271 let chunk_doclens: Vec<i64> = chunk_docs.iter().map(|d| d.nrows() as i64).collect();
273 let total_tokens: usize = chunk_doclens.iter().sum::<i64>() as usize;
274
275 let mut batch_embeddings = Array2::<f32>::zeros((total_tokens, embedding_dim));
277 let mut offset = 0;
278 for doc in chunk_docs {
279 let n = doc.nrows();
280 batch_embeddings
281 .slice_mut(s![offset..offset + n, ..])
282 .assign(doc);
283 offset += n;
284 }
285
286 let batch_codes = codec.compress_into_codes(&batch_embeddings);
288
289 let mut batch_residuals = batch_embeddings;
291 {
292 use rayon::prelude::*;
293 let centroids = &codec.centroids;
294 batch_residuals
295 .axis_iter_mut(Axis(0))
296 .into_par_iter()
297 .zip(batch_codes.as_slice().unwrap().par_iter())
298 .for_each(|(mut row, &code)| {
299 let centroid = centroids.row(code);
300 row.iter_mut()
301 .zip(centroid.iter())
302 .for_each(|(r, c)| *r -= c);
303 });
304 }
305
306 let batch_packed = codec.quantize_residuals(&batch_residuals)?;
308
309 for &len in &chunk_doclens {
311 doc_lengths.push(len);
312 }
313 all_codes.extend(batch_codes.iter().copied());
314
315 let chunk_meta = ChunkMetadata {
317 num_documents: end - start,
318 num_embeddings: batch_codes.len(),
319 embedding_offset: 0, };
321
322 let chunk_meta_path = index_dir.join(format!("{}.metadata.json", chunk_idx));
323 serde_json::to_writer_pretty(BufWriter::new(File::create(&chunk_meta_path)?), &chunk_meta)?;
324
325 let doclens_path = index_dir.join(format!("doclens.{}.json", chunk_idx));
327 serde_json::to_writer(BufWriter::new(File::create(&doclens_path)?), &chunk_doclens)?;
328
329 let chunk_codes_arr: Array1<i64> = batch_codes.iter().map(|&x| x as i64).collect();
331 let codes_path = index_dir.join(format!("{}.codes.npy", chunk_idx));
332 chunk_codes_arr.write_npy(File::create(&codes_path)?)?;
333
334 let residuals_path = index_dir.join(format!("{}.residuals.npy", chunk_idx));
336 batch_packed.write_npy(File::create(&residuals_path)?)?;
337
338 progress.inc(1);
339 }
340 progress.finish();
341
342 let mut current_offset = 0usize;
344 for chunk_idx in 0..n_chunks {
345 let chunk_meta_path = index_dir.join(format!("{}.metadata.json", chunk_idx));
346 let mut meta: serde_json::Value =
347 serde_json::from_reader(BufReader::new(File::open(&chunk_meta_path)?))?;
348
349 if let Some(obj) = meta.as_object_mut() {
350 obj.insert("embedding_offset".to_string(), current_offset.into());
351 let num_emb = obj["num_embeddings"].as_u64().unwrap_or(0) as usize;
352 current_offset += num_emb;
353 }
354
355 serde_json::to_writer_pretty(BufWriter::new(File::create(&chunk_meta_path)?), &meta)?;
356 }
357
358 let mut code_to_docs: BTreeMap<usize, Vec<i64>> = BTreeMap::new();
360 let mut emb_idx = 0;
361
362 for (doc_id, &len) in doc_lengths.iter().enumerate() {
363 for _ in 0..len {
364 let code = all_codes[emb_idx];
365 code_to_docs.entry(code).or_default().push(doc_id as i64);
366 emb_idx += 1;
367 }
368 }
369
370 let mut ivf_data: Vec<i64> = Vec::new();
372 let mut ivf_lengths: Vec<i32> = vec![0; num_centroids];
373
374 for (centroid_id, ivf_len) in ivf_lengths.iter_mut().enumerate() {
375 if let Some(docs) = code_to_docs.get(¢roid_id) {
376 let mut unique_docs: Vec<i64> = docs.clone();
377 unique_docs.sort_unstable();
378 unique_docs.dedup();
379 *ivf_len = unique_docs.len() as i32;
380 ivf_data.extend(unique_docs);
381 }
382 }
383
384 let ivf = Array1::from_vec(ivf_data);
385 let ivf_lengths = Array1::from_vec(ivf_lengths);
386
387 let ivf_path = index_dir.join("ivf.npy");
388 ivf.write_npy(File::create(&ivf_path)?)?;
389
390 let ivf_lengths_path = index_dir.join("ivf_lengths.npy");
391 ivf_lengths.write_npy(File::create(&ivf_lengths_path)?)?;
392
393 let metadata = Metadata {
395 num_chunks: n_chunks,
396 nbits: config.nbits,
397 num_partitions: num_centroids,
398 num_embeddings: total_embeddings,
399 avg_doclen,
400 num_documents,
401 };
402
403 let metadata_path = index_dir.join("metadata.json");
404 serde_json::to_writer_pretty(BufWriter::new(File::create(&metadata_path)?), &metadata)?;
405
406 Ok(metadata)
407}
408
409pub fn create_index_with_kmeans_files(
424 embeddings: &[Array2<f32>],
425 index_path: &str,
426 config: &IndexConfig,
427) -> Result<Metadata> {
428 if embeddings.is_empty() {
429 return Err(Error::IndexCreation("No documents provided".into()));
430 }
431
432 let kmeans_config = ComputeKmeansConfig {
434 kmeans_niters: config.kmeans_niters,
435 max_points_per_centroid: config.max_points_per_centroid,
436 seed: config.seed.unwrap_or(42),
437 n_samples_kmeans: config.n_samples_kmeans,
438 num_partitions: None, };
440
441 let centroids = compute_kmeans(embeddings, &kmeans_config)?;
443
444 let metadata = create_index_files(embeddings, centroids, index_path, config)?;
446
447 if embeddings.len() <= config.start_from_scratch {
449 let index_dir = std::path::Path::new(index_path);
450 crate::update::save_embeddings_npy(index_dir, embeddings)?;
451 }
452
453 Ok(metadata)
454}
455pub struct MmapIndex {
479 pub path: String,
481 pub metadata: Metadata,
483 pub codec: ResidualCodec,
485 pub ivf: Array1<i64>,
487 pub ivf_lengths: Array1<i32>,
489 pub ivf_offsets: Array1<i64>,
491 pub doc_lengths: Array1<i64>,
493 pub doc_offsets: Array1<usize>,
495 pub mmap_codes: crate::mmap::MmapNpyArray1I64,
497 pub mmap_residuals: crate::mmap::MmapNpyArray2U8,
499}
500
501impl MmapIndex {
502 pub fn load(index_path: &str) -> Result<Self> {
507 use ndarray_npy::ReadNpyExt;
508
509 let index_dir = Path::new(index_path);
510
511 let metadata_path = index_dir.join("metadata.json");
513 let metadata: Metadata = serde_json::from_reader(BufReader::new(
514 File::open(&metadata_path)
515 .map_err(|e| Error::IndexLoad(format!("Failed to open metadata: {}", e)))?,
516 ))?;
517
518 let codec = ResidualCodec::load_mmap_from_dir(index_dir)?;
521
522 let ivf_path = index_dir.join("ivf.npy");
524 let ivf: Array1<i64> = Array1::read_npy(
525 File::open(&ivf_path)
526 .map_err(|e| Error::IndexLoad(format!("Failed to open ivf.npy: {}", e)))?,
527 )
528 .map_err(|e| Error::IndexLoad(format!("Failed to read ivf.npy: {}", e)))?;
529
530 let ivf_lengths_path = index_dir.join("ivf_lengths.npy");
531 let ivf_lengths: Array1<i32> = Array1::read_npy(
532 File::open(&ivf_lengths_path)
533 .map_err(|e| Error::IndexLoad(format!("Failed to open ivf_lengths.npy: {}", e)))?,
534 )
535 .map_err(|e| Error::IndexLoad(format!("Failed to read ivf_lengths.npy: {}", e)))?;
536
537 let num_centroids = ivf_lengths.len();
539 let mut ivf_offsets = Array1::<i64>::zeros(num_centroids + 1);
540 for i in 0..num_centroids {
541 ivf_offsets[i + 1] = ivf_offsets[i] + ivf_lengths[i] as i64;
542 }
543
544 let mut doc_lengths_vec: Vec<i64> = Vec::with_capacity(metadata.num_documents);
546 for chunk_idx in 0..metadata.num_chunks {
547 let doclens_path = index_dir.join(format!("doclens.{}.json", chunk_idx));
548 let chunk_doclens: Vec<i64> =
549 serde_json::from_reader(BufReader::new(File::open(&doclens_path)?))?;
550 doc_lengths_vec.extend(chunk_doclens);
551 }
552 let doc_lengths = Array1::from_vec(doc_lengths_vec);
553
554 let mut doc_offsets = Array1::<usize>::zeros(doc_lengths.len() + 1);
556 for i in 0..doc_lengths.len() {
557 doc_offsets[i + 1] = doc_offsets[i] + doc_lengths[i] as usize;
558 }
559
560 let max_len = doc_lengths.iter().cloned().max().unwrap_or(0) as usize;
562 let last_len = *doc_lengths.last().unwrap_or(&0) as usize;
563 let padding_needed = max_len.saturating_sub(last_len);
564
565 let merged_codes_path =
567 crate::mmap::merge_codes_chunks(index_dir, metadata.num_chunks, padding_needed)?;
568 let merged_residuals_path =
569 crate::mmap::merge_residuals_chunks(index_dir, metadata.num_chunks, padding_needed)?;
570
571 let mmap_codes = crate::mmap::MmapNpyArray1I64::from_npy_file(&merged_codes_path)?;
573 let mmap_residuals = crate::mmap::MmapNpyArray2U8::from_npy_file(&merged_residuals_path)?;
574
575 Ok(Self {
576 path: index_path.to_string(),
577 metadata,
578 codec,
579 ivf,
580 ivf_lengths,
581 ivf_offsets,
582 doc_lengths,
583 doc_offsets,
584 mmap_codes,
585 mmap_residuals,
586 })
587 }
588
589 pub fn get_candidates(&self, centroid_indices: &[usize]) -> Vec<i64> {
591 let mut candidates: Vec<i64> = Vec::new();
592
593 for &idx in centroid_indices {
594 if idx < self.ivf_lengths.len() {
595 let start = self.ivf_offsets[idx] as usize;
596 let len = self.ivf_lengths[idx] as usize;
597 candidates.extend(self.ivf.slice(s![start..start + len]).iter());
598 }
599 }
600
601 candidates.sort_unstable();
602 candidates.dedup();
603 candidates
604 }
605
606 pub fn get_document_embeddings(&self, doc_id: usize) -> Result<Array2<f32>> {
608 if doc_id >= self.doc_lengths.len() {
609 return Err(Error::Search(format!("Invalid document ID: {}", doc_id)));
610 }
611
612 let start = self.doc_offsets[doc_id];
613 let end = self.doc_offsets[doc_id + 1];
614
615 let codes_slice = self.mmap_codes.slice(start, end);
617 let residuals_view = self.mmap_residuals.slice_rows(start, end);
618
619 let codes: Array1<usize> = Array1::from_iter(codes_slice.iter().map(|&c| c as usize));
621
622 let residuals = residuals_view.to_owned();
624
625 self.codec.decompress(&residuals, &codes.view())
627 }
628
629 pub fn get_document_codes(&self, doc_ids: &[usize]) -> Vec<Vec<i64>> {
631 doc_ids
632 .iter()
633 .map(|&doc_id| {
634 if doc_id >= self.doc_lengths.len() {
635 return vec![];
636 }
637 let start = self.doc_offsets[doc_id];
638 let end = self.doc_offsets[doc_id + 1];
639 self.mmap_codes.slice(start, end).to_vec()
640 })
641 .collect()
642 }
643
644 pub fn decompress_documents(&self, doc_ids: &[usize]) -> Result<(Array2<f32>, Vec<usize>)> {
646 let mut total_tokens = 0usize;
648 let mut lengths = Vec::with_capacity(doc_ids.len());
649 for &doc_id in doc_ids {
650 if doc_id >= self.doc_lengths.len() {
651 lengths.push(0);
652 } else {
653 let len = self.doc_offsets[doc_id + 1] - self.doc_offsets[doc_id];
654 lengths.push(len);
655 total_tokens += len;
656 }
657 }
658
659 if total_tokens == 0 {
660 return Ok((Array2::zeros((0, self.codec.embedding_dim())), lengths));
661 }
662
663 let packed_dim = self.mmap_residuals.ncols();
665 let mut all_codes = Vec::with_capacity(total_tokens);
666 let mut all_residuals = Array2::<u8>::zeros((total_tokens, packed_dim));
667 let mut offset = 0;
668
669 for &doc_id in doc_ids {
670 if doc_id >= self.doc_lengths.len() {
671 continue;
672 }
673 let start = self.doc_offsets[doc_id];
674 let end = self.doc_offsets[doc_id + 1];
675 let len = end - start;
676
677 let codes_slice = self.mmap_codes.slice(start, end);
679 all_codes.extend(codes_slice.iter().map(|&c| c as usize));
680
681 let residuals_view = self.mmap_residuals.slice_rows(start, end);
683 all_residuals
684 .slice_mut(s![offset..offset + len, ..])
685 .assign(&residuals_view);
686 offset += len;
687 }
688
689 let codes_arr = Array1::from_vec(all_codes);
690 let embeddings = self.codec.decompress(&all_residuals, &codes_arr.view())?;
691
692 Ok((embeddings, lengths))
693 }
694
695 pub fn search(
707 &self,
708 query: &Array2<f32>,
709 params: &crate::search::SearchParameters,
710 subset: Option<&[i64]>,
711 ) -> Result<crate::search::SearchResult> {
712 crate::search::search_one_mmap(self, query, params, subset)
713 }
714
715 pub fn search_batch(
728 &self,
729 queries: &[Array2<f32>],
730 params: &crate::search::SearchParameters,
731 parallel: bool,
732 subset: Option<&[i64]>,
733 ) -> Result<Vec<crate::search::SearchResult>> {
734 crate::search::search_many_mmap(self, queries, params, parallel, subset)
735 }
736
737 pub fn num_documents(&self) -> usize {
739 self.doc_lengths.len()
740 }
741
742 pub fn num_embeddings(&self) -> usize {
744 self.metadata.num_embeddings
745 }
746
747 pub fn num_partitions(&self) -> usize {
749 self.metadata.num_partitions
750 }
751
752 pub fn avg_doclen(&self) -> f64 {
754 self.metadata.avg_doclen
755 }
756
757 pub fn embedding_dim(&self) -> usize {
759 self.codec.embedding_dim()
760 }
761
762 pub fn reconstruct(&self, doc_ids: &[i64]) -> Result<Vec<Array2<f32>>> {
788 crate::embeddings::reconstruct_embeddings(self, doc_ids)
789 }
790
791 pub fn reconstruct_single(&self, doc_id: i64) -> Result<Array2<f32>> {
803 crate::embeddings::reconstruct_single(self, doc_id)
804 }
805
806 pub fn create_with_kmeans(
826 embeddings: &[Array2<f32>],
827 index_path: &str,
828 config: &IndexConfig,
829 ) -> Result<Self> {
830 create_index_with_kmeans_files(embeddings, index_path, config)?;
832
833 Self::load(index_path)
835 }
836
837 pub fn update(
865 &mut self,
866 embeddings: &[Array2<f32>],
867 config: &crate::update::UpdateConfig,
868 ) -> Result<Vec<i64>> {
869 use crate::codec::ResidualCodec;
870 use crate::update::{
871 clear_buffer, clear_embeddings_npy, embeddings_npy_exists, load_buffer,
872 load_buffer_info, load_cluster_threshold, load_embeddings_npy, save_buffer,
873 update_centroids, update_index,
874 };
875
876 let path_str = self.path.clone();
877 let index_path = std::path::Path::new(&path_str);
878 let num_new_docs = embeddings.len();
879
880 if self.metadata.num_documents <= config.start_from_scratch {
884 let existing_embeddings = load_embeddings_npy(index_path)?;
886 let start_doc_id = existing_embeddings.len() as i64;
888
889 let combined_embeddings: Vec<Array2<f32>> = existing_embeddings
891 .into_iter()
892 .chain(embeddings.iter().cloned())
893 .collect();
894
895 let index_config = IndexConfig {
897 nbits: self.metadata.nbits,
898 batch_size: config.batch_size,
899 seed: Some(config.seed),
900 kmeans_niters: config.kmeans_niters,
901 max_points_per_centroid: config.max_points_per_centroid,
902 n_samples_kmeans: config.n_samples_kmeans,
903 start_from_scratch: config.start_from_scratch,
904 };
905
906 *self = Self::create_with_kmeans(&combined_embeddings, &path_str, &index_config)?;
908
909 if combined_embeddings.len() > config.start_from_scratch
911 && embeddings_npy_exists(index_path)
912 {
913 clear_embeddings_npy(index_path)?;
914 }
915
916 return Ok((start_doc_id..start_doc_id + num_new_docs as i64).collect());
918 }
919
920 let buffer = load_buffer(index_path)?;
922 let buffer_len = buffer.len();
923 let total_new = embeddings.len() + buffer_len;
924
925 let start_doc_id: i64;
927
928 let mut codec = ResidualCodec::load_from_dir(index_path)?;
930
931 if total_new >= config.buffer_size {
933 let num_buffered = load_buffer_info(index_path)?;
937
938 if num_buffered > 0 && self.metadata.num_documents >= num_buffered {
940 let start_del_idx = self.metadata.num_documents - num_buffered;
941 let docs_to_delete: Vec<i64> = (start_del_idx..self.metadata.num_documents)
942 .map(|i| i as i64)
943 .collect();
944 crate::delete::delete_from_index_keep_buffer(&docs_to_delete, &path_str)?;
945 let metadata_path = index_path.join("metadata.json");
947 self.metadata = serde_json::from_reader(std::io::BufReader::new(
948 std::fs::File::open(&metadata_path)?,
949 ))?;
950 }
951
952 start_doc_id = (self.metadata.num_documents + buffer_len) as i64;
954
955 let combined: Vec<Array2<f32>> = buffer
957 .into_iter()
958 .chain(embeddings.iter().cloned())
959 .collect();
960
961 if let Ok(cluster_threshold) = load_cluster_threshold(index_path) {
963 let new_centroids =
964 update_centroids(index_path, &combined, cluster_threshold, config)?;
965 if new_centroids > 0 {
966 codec = ResidualCodec::load_from_dir(index_path)?;
968 }
969 }
970
971 clear_buffer(index_path)?;
973
974 update_index(&combined, &path_str, &codec, Some(config.batch_size), true)?;
976 } else {
977 start_doc_id = self.metadata.num_documents as i64;
980
981 let combined_buffer: Vec<Array2<f32>> = buffer
983 .into_iter()
984 .chain(embeddings.iter().cloned())
985 .collect();
986 save_buffer(index_path, &combined_buffer)?;
987
988 update_index(
990 embeddings,
991 &path_str,
992 &codec,
993 Some(config.batch_size),
994 false,
995 )?;
996 }
997
998 *self = Self::load(&path_str)?;
1000
1001 Ok((start_doc_id..start_doc_id + num_new_docs as i64).collect())
1003 }
1004
1005 pub fn update_with_metadata(
1017 &mut self,
1018 embeddings: &[Array2<f32>],
1019 config: &crate::update::UpdateConfig,
1020 metadata: Option<&[serde_json::Value]>,
1021 ) -> Result<Vec<i64>> {
1022 if let Some(meta) = metadata {
1024 if meta.len() != embeddings.len() {
1025 return Err(Error::Config(format!(
1026 "Metadata length ({}) must match embeddings length ({})",
1027 meta.len(),
1028 embeddings.len()
1029 )));
1030 }
1031 }
1032
1033 let doc_ids = self.update(embeddings, config)?;
1035
1036 if let Some(meta) = metadata {
1038 crate::filtering::update(&self.path, meta, &doc_ids)?;
1039 }
1040
1041 Ok(doc_ids)
1042 }
1043
1044 pub fn update_or_create(
1057 embeddings: &[Array2<f32>],
1058 index_path: &str,
1059 index_config: &IndexConfig,
1060 update_config: &crate::update::UpdateConfig,
1061 ) -> Result<(Self, Vec<i64>)> {
1062 let index_dir = std::path::Path::new(index_path);
1063 let metadata_path = index_dir.join("metadata.json");
1064
1065 if metadata_path.exists() {
1066 let mut index = Self::load(index_path)?;
1068 let doc_ids = index.update(embeddings, update_config)?;
1069 Ok((index, doc_ids))
1070 } else {
1071 let num_docs = embeddings.len();
1073 let index = Self::create_with_kmeans(embeddings, index_path, index_config)?;
1074 let doc_ids: Vec<i64> = (0..num_docs as i64).collect();
1075 Ok((index, doc_ids))
1076 }
1077 }
1078
1079 pub fn delete(&mut self, doc_ids: &[i64]) -> Result<usize> {
1089 self.delete_with_options(doc_ids, true)
1090 }
1091
1092 pub fn delete_with_options(&mut self, doc_ids: &[i64], delete_metadata: bool) -> Result<usize> {
1103 let path = self.path.clone();
1104
1105 let deleted = crate::delete::delete_from_index(doc_ids, &path)?;
1107
1108 if delete_metadata && deleted > 0 {
1110 let index_path = std::path::Path::new(&path);
1111 let db_path = index_path.join("metadata.db");
1112 if db_path.exists() {
1113 crate::filtering::delete(&path, doc_ids)?;
1114 }
1115 }
1116
1117 *self = Self::load(&path)?;
1119
1120 Ok(deleted)
1121 }
1122}
1123
1124#[cfg(test)]
1125mod tests {
1126 use super::*;
1127
1128 #[test]
1129 fn test_index_config_default() {
1130 let config = IndexConfig::default();
1131 assert_eq!(config.nbits, 4);
1132 assert_eq!(config.batch_size, 50_000);
1133 assert_eq!(config.seed, Some(42));
1134 }
1135
1136 #[test]
1137 fn test_update_or_create_new_index() {
1138 use ndarray::Array2;
1139 use tempfile::tempdir;
1140
1141 let temp_dir = tempdir().unwrap();
1142 let index_path = temp_dir.path().to_str().unwrap();
1143
1144 let mut embeddings: Vec<Array2<f32>> = Vec::new();
1146 for i in 0..5 {
1147 let mut doc = Array2::<f32>::zeros((5, 32));
1148 for j in 0..5 {
1149 for k in 0..32 {
1150 doc[[j, k]] = (i as f32 * 0.1) + (j as f32 * 0.01) + (k as f32 * 0.001);
1151 }
1152 }
1153 for mut row in doc.rows_mut() {
1155 let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
1156 if norm > 0.0 {
1157 row.iter_mut().for_each(|x| *x /= norm);
1158 }
1159 }
1160 embeddings.push(doc);
1161 }
1162
1163 let index_config = IndexConfig {
1164 nbits: 2,
1165 batch_size: 50,
1166 seed: Some(42),
1167 kmeans_niters: 2,
1168 ..Default::default()
1169 };
1170 let update_config = crate::update::UpdateConfig::default();
1171
1172 let (index, doc_ids) =
1174 MmapIndex::update_or_create(&embeddings, index_path, &index_config, &update_config)
1175 .expect("Failed to create index");
1176
1177 assert_eq!(index.metadata.num_documents, 5);
1178 assert_eq!(doc_ids, vec![0, 1, 2, 3, 4]);
1179
1180 assert!(temp_dir.path().join("metadata.json").exists());
1182 assert!(temp_dir.path().join("centroids.npy").exists());
1183 }
1184
1185 #[test]
1186 fn test_update_or_create_existing_index() {
1187 use ndarray::Array2;
1188 use tempfile::tempdir;
1189
1190 let temp_dir = tempdir().unwrap();
1191 let index_path = temp_dir.path().to_str().unwrap();
1192
1193 let create_embeddings = |count: usize, offset: usize| -> Vec<Array2<f32>> {
1195 let mut embeddings = Vec::new();
1196 for i in 0..count {
1197 let mut doc = Array2::<f32>::zeros((5, 32));
1198 for j in 0..5 {
1199 for k in 0..32 {
1200 doc[[j, k]] =
1201 ((i + offset) as f32 * 0.1) + (j as f32 * 0.01) + (k as f32 * 0.001);
1202 }
1203 }
1204 for mut row in doc.rows_mut() {
1205 let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
1206 if norm > 0.0 {
1207 row.iter_mut().for_each(|x| *x /= norm);
1208 }
1209 }
1210 embeddings.push(doc);
1211 }
1212 embeddings
1213 };
1214
1215 let index_config = IndexConfig {
1216 nbits: 2,
1217 batch_size: 50,
1218 seed: Some(42),
1219 kmeans_niters: 2,
1220 ..Default::default()
1221 };
1222 let update_config = crate::update::UpdateConfig::default();
1223
1224 let embeddings1 = create_embeddings(5, 0);
1226 let (index1, doc_ids1) =
1227 MmapIndex::update_or_create(&embeddings1, index_path, &index_config, &update_config)
1228 .expect("Failed to create index");
1229 assert_eq!(index1.metadata.num_documents, 5);
1230 assert_eq!(doc_ids1, vec![0, 1, 2, 3, 4]);
1231
1232 let embeddings2 = create_embeddings(3, 5);
1234 let (index2, doc_ids2) =
1235 MmapIndex::update_or_create(&embeddings2, index_path, &index_config, &update_config)
1236 .expect("Failed to update index");
1237 assert_eq!(index2.metadata.num_documents, 8);
1238 assert_eq!(doc_ids2, vec![5, 6, 7]);
1239 }
1240}