1use std::collections::HashMap;
10use std::fs;
11use std::fs::File;
12use std::io::{BufReader, BufWriter};
13use std::path::Path;
14
15use serde::{Deserialize, Serialize};
16
17use ndarray::{s, Array1, Array2, Axis};
18use rayon::prelude::*;
19
20use crate::codec::ResidualCodec;
21use crate::error::Error;
22use crate::error::Result;
23use crate::index::Metadata;
24use crate::kmeans::compute_kmeans;
25use crate::kmeans::ComputeKmeansConfig;
26use crate::utils::quantile;
27
28const DEFAULT_BATCH_SIZE: usize = 50_000;
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct UpdateConfig {
34 pub batch_size: usize,
36 pub kmeans_niters: usize,
38 pub max_points_per_centroid: usize,
40 pub n_samples_kmeans: Option<usize>,
42 pub seed: u64,
44 pub start_from_scratch: usize,
46 pub buffer_size: usize,
48}
49
50impl Default for UpdateConfig {
51 fn default() -> Self {
52 Self {
53 batch_size: DEFAULT_BATCH_SIZE,
54 kmeans_niters: 4,
55 max_points_per_centroid: 256,
56 n_samples_kmeans: None,
57 seed: 42,
58 start_from_scratch: 999,
59 buffer_size: 100,
60 }
61 }
62}
63
64impl UpdateConfig {
65 pub fn to_kmeans_config(&self) -> ComputeKmeansConfig {
67 ComputeKmeansConfig {
68 kmeans_niters: self.kmeans_niters,
69 max_points_per_centroid: self.max_points_per_centroid,
70 seed: self.seed,
71 n_samples_kmeans: self.n_samples_kmeans,
72 num_partitions: None,
73 }
74 }
75}
76
77pub fn load_buffer(index_path: &Path) -> Result<Vec<Array2<f32>>> {
86 use ndarray_npy::ReadNpyExt;
87
88 let buffer_path = index_path.join("buffer.npy");
89 let lengths_path = index_path.join("buffer_lengths.json");
90
91 if !buffer_path.exists() {
92 return Ok(Vec::new());
93 }
94
95 let flat: Array2<f32> = match Array2::read_npy(File::open(&buffer_path)?) {
97 Ok(arr) => arr,
98 Err(_) => return Ok(Vec::new()),
99 };
100
101 if lengths_path.exists() {
103 let lengths: Vec<i64> =
104 serde_json::from_reader(BufReader::new(File::open(&lengths_path)?))?;
105
106 let mut result = Vec::with_capacity(lengths.len());
107 let mut offset = 0;
108
109 for &len in &lengths {
110 let len_usize = len as usize;
111 if offset + len_usize > flat.nrows() {
112 break;
113 }
114 let doc_emb = flat.slice(s![offset..offset + len_usize, ..]).to_owned();
115 result.push(doc_emb);
116 offset += len_usize;
117 }
118
119 return Ok(result);
120 }
121
122 Ok(vec![flat])
124}
125
126pub fn save_buffer(index_path: &Path, embeddings: &[Array2<f32>]) -> Result<()> {
130 use ndarray_npy::WriteNpyExt;
131
132 let buffer_path = index_path.join("buffer.npy");
133
134 if embeddings.is_empty() {
137 return Ok(());
138 }
139
140 let dim = embeddings[0].ncols();
141 let total_rows: usize = embeddings.iter().map(|e| e.nrows()).sum();
142
143 let mut flat = Array2::<f32>::zeros((total_rows, dim));
144 let mut offset = 0;
145 let mut lengths = Vec::new();
146
147 for emb in embeddings {
148 let n = emb.nrows();
149 flat.slice_mut(s![offset..offset + n, ..]).assign(emb);
150 lengths.push(n as i64);
151 offset += n;
152 }
153
154 flat.write_npy(File::create(&buffer_path)?)?;
155
156 let lengths_path = index_path.join("buffer_lengths.json");
158 serde_json::to_writer(BufWriter::new(File::create(&lengths_path)?), &lengths)?;
159
160 let info_path = index_path.join("buffer_info.json");
162 let buffer_info = serde_json::json!({ "num_docs": embeddings.len() });
163 serde_json::to_writer(BufWriter::new(File::create(&info_path)?), &buffer_info)?;
164
165 Ok(())
166}
167
168pub fn load_buffer_info(index_path: &Path) -> Result<usize> {
172 let info_path = index_path.join("buffer_info.json");
173 if !info_path.exists() {
174 return Ok(0);
175 }
176
177 let info: serde_json::Value = serde_json::from_reader(BufReader::new(File::open(&info_path)?))?;
178
179 Ok(info.get("num_docs").and_then(|v| v.as_u64()).unwrap_or(0) as usize)
180}
181
182pub fn clear_buffer(index_path: &Path) -> Result<()> {
184 let buffer_path = index_path.join("buffer.npy");
185 let lengths_path = index_path.join("buffer_lengths.json");
186 let info_path = index_path.join("buffer_info.json");
187
188 if buffer_path.exists() {
189 fs::remove_file(&buffer_path)?;
190 }
191 if lengths_path.exists() {
192 fs::remove_file(&lengths_path)?;
193 }
194 if info_path.exists() {
195 fs::remove_file(&info_path)?;
196 }
197
198 Ok(())
199}
200
201pub fn load_embeddings_npy(index_path: &Path) -> Result<Vec<Array2<f32>>> {
206 use ndarray_npy::ReadNpyExt;
207
208 let emb_path = index_path.join("embeddings.npy");
209 let lengths_path = index_path.join("embeddings_lengths.json");
210
211 if !emb_path.exists() {
212 return Ok(Vec::new());
213 }
214
215 let flat: Array2<f32> = Array2::read_npy(File::open(&emb_path)?)?;
217
218 if lengths_path.exists() {
220 let lengths: Vec<i64> =
221 serde_json::from_reader(BufReader::new(File::open(&lengths_path)?))?;
222
223 let mut result = Vec::with_capacity(lengths.len());
224 let mut offset = 0;
225
226 for &len in &lengths {
227 let len_usize = len as usize;
228 if offset + len_usize > flat.nrows() {
229 break;
230 }
231 let doc_emb = flat.slice(s![offset..offset + len_usize, ..]).to_owned();
232 result.push(doc_emb);
233 offset += len_usize;
234 }
235
236 return Ok(result);
237 }
238
239 Ok(vec![flat])
241}
242
243pub fn save_embeddings_npy(index_path: &Path, embeddings: &[Array2<f32>]) -> Result<()> {
249 use ndarray_npy::WriteNpyExt;
250
251 if embeddings.is_empty() {
252 return Ok(());
253 }
254
255 let dim = embeddings[0].ncols();
256 let total_rows: usize = embeddings.iter().map(|e| e.nrows()).sum();
257
258 let mut flat = Array2::<f32>::zeros((total_rows, dim));
259 let mut offset = 0;
260 let mut lengths = Vec::with_capacity(embeddings.len());
261
262 for emb in embeddings {
263 let n = emb.nrows();
264 flat.slice_mut(s![offset..offset + n, ..]).assign(emb);
265 lengths.push(n as i64);
266 offset += n;
267 }
268
269 let emb_path = index_path.join("embeddings.npy");
271 flat.write_npy(File::create(&emb_path)?)?;
272
273 let lengths_path = index_path.join("embeddings_lengths.json");
275 serde_json::to_writer(BufWriter::new(File::create(&lengths_path)?), &lengths)?;
276
277 Ok(())
278}
279
280pub fn clear_embeddings_npy(index_path: &Path) -> Result<()> {
282 let emb_path = index_path.join("embeddings.npy");
283 let lengths_path = index_path.join("embeddings_lengths.json");
284
285 if emb_path.exists() {
286 fs::remove_file(&emb_path)?;
287 }
288 if lengths_path.exists() {
289 fs::remove_file(&lengths_path)?;
290 }
291 Ok(())
292}
293
294pub fn embeddings_npy_exists(index_path: &Path) -> bool {
296 index_path.join("embeddings.npy").exists()
297}
298
299pub fn load_cluster_threshold(index_path: &Path) -> Result<f32> {
305 use ndarray_npy::ReadNpyExt;
306
307 let thresh_path = index_path.join("cluster_threshold.npy");
308 if !thresh_path.exists() {
309 return Err(Error::Update("cluster_threshold.npy not found".into()));
310 }
311
312 let arr: Array1<f32> = Array1::read_npy(File::open(&thresh_path)?)?;
313 Ok(arr[0])
314}
315
316pub fn update_cluster_threshold(
318 index_path: &Path,
319 new_residual_norms: &Array1<f32>,
320 old_total_embeddings: usize,
321) -> Result<()> {
322 use ndarray_npy::{ReadNpyExt, WriteNpyExt};
323
324 let new_count = new_residual_norms.len();
325 if new_count == 0 {
326 return Ok(());
327 }
328
329 let new_threshold = quantile(new_residual_norms, 0.75);
330
331 let thresh_path = index_path.join("cluster_threshold.npy");
332 let final_threshold = if thresh_path.exists() {
333 let old_arr: Array1<f32> = Array1::read_npy(File::open(&thresh_path)?)?;
334 let old_threshold = old_arr[0];
335 let total = old_total_embeddings + new_count;
336 (old_threshold * old_total_embeddings as f32 + new_threshold * new_count as f32)
337 / total as f32
338 } else {
339 new_threshold
340 };
341
342 Array1::from_vec(vec![final_threshold]).write_npy(File::create(&thresh_path)?)?;
343
344 Ok(())
345}
346
347fn find_outliers(
358 flat_embeddings: &Array2<f32>,
359 centroids: &Array2<f32>,
360 threshold_sq: f32,
361) -> Vec<usize> {
362 let n = flat_embeddings.nrows();
363 let k = centroids.nrows();
364
365 if n == 0 || k == 0 {
366 return Vec::new();
367 }
368
369 let emb_norms_sq: Vec<f32> = flat_embeddings
371 .axis_iter(Axis(0))
372 .into_par_iter()
373 .map(|row| row.dot(&row))
374 .collect();
375
376 let centroid_norms_sq: Vec<f32> = centroids
377 .axis_iter(Axis(0))
378 .into_par_iter()
379 .map(|row| row.dot(&row))
380 .collect();
381
382 let batch_size = (2 * 1024 * 1024 * 1024 / (k * std::mem::size_of::<f32>())).clamp(1, 4096);
386
387 let mut outlier_indices = Vec::new();
388
389 for batch_start in (0..n).step_by(batch_size) {
390 let batch_end = (batch_start + batch_size).min(n);
391 let batch = flat_embeddings.slice(s![batch_start..batch_end, ..]);
392
393 let dot_products = batch.dot(¢roids.t());
395
396 for (batch_idx, row) in dot_products.axis_iter(Axis(0)).enumerate() {
398 let global_idx = batch_start + batch_idx;
399 let emb_norm_sq = emb_norms_sq[global_idx];
400
401 let min_dist_sq = row
404 .iter()
405 .zip(centroid_norms_sq.iter())
406 .map(|(&dot, &c_norm_sq)| emb_norm_sq + c_norm_sq - 2.0 * dot)
407 .fold(f32::INFINITY, f32::min);
408
409 if min_dist_sq > threshold_sq {
410 outlier_indices.push(global_idx);
411 }
412 }
413 }
414
415 outlier_indices
416}
417
418pub fn update_centroids(
430 index_path: &Path,
431 new_embeddings: &[Array2<f32>],
432 cluster_threshold: f32,
433 config: &UpdateConfig,
434) -> Result<usize> {
435 use ndarray_npy::{ReadNpyExt, WriteNpyExt};
436
437 let centroids_path = index_path.join("centroids.npy");
438 if !centroids_path.exists() {
439 return Ok(0);
440 }
441
442 let existing_centroids: Array2<f32> = Array2::read_npy(File::open(¢roids_path)?)?;
444
445 let dim = existing_centroids.ncols();
447 let total_tokens: usize = new_embeddings.iter().map(|e| e.nrows()).sum();
448
449 if total_tokens == 0 {
450 return Ok(0);
451 }
452
453 let mut flat_embeddings = Array2::<f32>::zeros((total_tokens, dim));
454 let mut offset = 0;
455
456 for emb in new_embeddings {
457 let n = emb.nrows();
458 flat_embeddings
459 .slice_mut(s![offset..offset + n, ..])
460 .assign(emb);
461 offset += n;
462 }
463
464 let threshold_sq = cluster_threshold * cluster_threshold;
466 let outlier_indices = find_outliers(&flat_embeddings, &existing_centroids, threshold_sq);
467
468 let num_outliers = outlier_indices.len();
469 if num_outliers == 0 {
470 return Ok(0);
471 }
472
473 let mut outliers = Array2::<f32>::zeros((num_outliers, dim));
475 for (i, &idx) in outlier_indices.iter().enumerate() {
476 outliers.row_mut(i).assign(&flat_embeddings.row(idx));
477 }
478
479 let target_k =
482 ((num_outliers as f64 / config.max_points_per_centroid as f64).ceil() as usize).max(1) * 4;
483 let k_update = target_k.min(num_outliers); let kmeans_config = ComputeKmeansConfig {
487 kmeans_niters: config.kmeans_niters,
488 max_points_per_centroid: config.max_points_per_centroid,
489 seed: config.seed,
490 n_samples_kmeans: config.n_samples_kmeans,
491 num_partitions: Some(k_update),
492 };
493
494 let outlier_docs: Vec<Array2<f32>> = outlier_indices
496 .iter()
497 .map(|&idx| flat_embeddings.slice(s![idx..idx + 1, ..]).to_owned())
498 .collect();
499
500 let new_centroids = compute_kmeans(&outlier_docs, &kmeans_config)?;
501 let k_new = new_centroids.nrows();
502
503 let new_num_centroids = existing_centroids.nrows() + k_new;
505 let mut final_centroids = Array2::<f32>::zeros((new_num_centroids, dim));
506 final_centroids
507 .slice_mut(s![..existing_centroids.nrows(), ..])
508 .assign(&existing_centroids);
509 final_centroids
510 .slice_mut(s![existing_centroids.nrows().., ..])
511 .assign(&new_centroids);
512
513 final_centroids.write_npy(File::create(¢roids_path)?)?;
515
516 let ivf_lengths_path = index_path.join("ivf_lengths.npy");
518 if ivf_lengths_path.exists() {
519 let old_lengths: Array1<i32> = Array1::read_npy(File::open(&ivf_lengths_path)?)?;
520 let mut new_lengths = Array1::<i32>::zeros(new_num_centroids);
521 new_lengths
522 .slice_mut(s![..old_lengths.len()])
523 .assign(&old_lengths);
524 new_lengths.write_npy(File::create(&ivf_lengths_path)?)?;
525 }
526
527 let meta_path = index_path.join("metadata.json");
529 if meta_path.exists() {
530 let mut meta: serde_json::Value =
531 serde_json::from_reader(BufReader::new(File::open(&meta_path)?))?;
532
533 if let Some(obj) = meta.as_object_mut() {
534 obj.insert("num_partitions".to_string(), new_num_centroids.into());
535 }
536
537 serde_json::to_writer_pretty(BufWriter::new(File::create(&meta_path)?), &meta)?;
538 }
539
540 Ok(k_new)
541}
542
543pub fn update_index(
561 embeddings: &[Array2<f32>],
562 index_path: &str,
563 codec: &ResidualCodec,
564 batch_size: Option<usize>,
565 update_threshold: bool,
566) -> Result<usize> {
567 let batch_size = batch_size.unwrap_or(DEFAULT_BATCH_SIZE);
568 let index_dir = Path::new(index_path);
569
570 let metadata_path = index_dir.join("metadata.json");
572 let metadata = Metadata::load_from_path(index_dir)?;
573
574 let num_existing_chunks = metadata.num_chunks;
575 let old_num_documents = metadata.num_documents;
576 let old_total_embeddings = metadata.num_embeddings;
577 let num_centroids = codec.num_centroids();
578 let embedding_dim = codec.embedding_dim();
579 let nbits = metadata.nbits;
580
581 let mut start_chunk_idx = num_existing_chunks;
583 let mut append_to_last = false;
584 let mut current_emb_offset = old_total_embeddings;
585
586 if start_chunk_idx > 0 {
588 let last_idx = start_chunk_idx - 1;
589 let last_meta_path = index_dir.join(format!("{}.metadata.json", last_idx));
590
591 if last_meta_path.exists() {
592 let last_meta: serde_json::Value =
593 serde_json::from_reader(BufReader::new(File::open(&last_meta_path).map_err(
594 |e| Error::IndexLoad(format!("Failed to open chunk metadata: {}", e)),
595 )?))?;
596
597 if let Some(nd) = last_meta.get("num_documents").and_then(|x| x.as_u64()) {
598 if nd < 2000 {
599 start_chunk_idx = last_idx;
600 append_to_last = true;
601
602 if let Some(off) = last_meta.get("embedding_offset").and_then(|x| x.as_u64()) {
603 current_emb_offset = off as usize;
604 } else {
605 let embs_in_last = last_meta
606 .get("num_embeddings")
607 .and_then(|x| x.as_u64())
608 .unwrap_or(0) as usize;
609 current_emb_offset = old_total_embeddings - embs_in_last;
610 }
611 }
612 }
613 }
614 }
615
616 let num_new_documents = embeddings.len();
618 let n_new_chunks = (num_new_documents as f64 / batch_size as f64).ceil() as usize;
619
620 let mut new_codes_accumulated: Vec<Vec<usize>> = Vec::new();
621 let mut new_doclens_accumulated: Vec<i64> = Vec::new();
622 let mut all_residual_norms: Vec<f32> = Vec::new();
623
624 let packed_dim = embedding_dim * nbits / 8;
625
626 for i in 0..n_new_chunks {
627 let global_chunk_idx = start_chunk_idx + i;
628 let chk_offset = i * batch_size;
629 let chk_end = (chk_offset + batch_size).min(num_new_documents);
630 let chunk_docs = &embeddings[chk_offset..chk_end];
631
632 let mut chk_doclens: Vec<i64> = chunk_docs.iter().map(|d| d.nrows() as i64).collect();
634 let total_tokens: usize = chk_doclens.iter().sum::<i64>() as usize;
635
636 let mut batch_embeddings = ndarray::Array2::<f32>::zeros((total_tokens, embedding_dim));
638 let mut offset = 0;
639 for doc in chunk_docs {
640 let n = doc.nrows();
641 batch_embeddings
642 .slice_mut(s![offset..offset + n, ..])
643 .assign(doc);
644 offset += n;
645 }
646
647 let batch_codes = codec.compress_into_codes(&batch_embeddings);
649
650 let mut batch_residuals = batch_embeddings;
652 {
653 let centroids = &codec.centroids;
654 batch_residuals
655 .axis_iter_mut(Axis(0))
656 .into_par_iter()
657 .zip(batch_codes.as_slice().unwrap().par_iter())
658 .for_each(|(mut row, &code)| {
659 let centroid = centroids.row(code);
660 row.iter_mut()
661 .zip(centroid.iter())
662 .for_each(|(r, c)| *r -= c);
663 });
664 }
665
666 if update_threshold {
668 for row in batch_residuals.axis_iter(Axis(0)) {
669 let norm = row.dot(&row).sqrt();
670 all_residual_norms.push(norm);
671 }
672 }
673
674 let batch_packed = codec.quantize_residuals(&batch_residuals)?;
676
677 let mut chk_codes_list: Vec<usize> = batch_codes.iter().copied().collect();
679 let mut chk_residuals_list: Vec<u8> = batch_packed.iter().copied().collect();
680
681 let mut code_offset = 0;
683 for &len in &chk_doclens {
684 let len_usize = len as usize;
685 let codes: Vec<usize> = batch_codes
686 .slice(s![code_offset..code_offset + len_usize])
687 .iter()
688 .copied()
689 .collect();
690 new_codes_accumulated.push(codes);
691 new_doclens_accumulated.push(len);
692 code_offset += len_usize;
693 }
694
695 if i == 0 && append_to_last {
697 use ndarray_npy::ReadNpyExt;
698
699 let old_doclens_path = index_dir.join(format!("doclens.{}.json", global_chunk_idx));
700
701 if old_doclens_path.exists() {
702 let old_doclens: Vec<i64> =
703 serde_json::from_reader(BufReader::new(File::open(&old_doclens_path)?))?;
704
705 let old_codes_path = index_dir.join(format!("{}.codes.npy", global_chunk_idx));
706 let old_residuals_path =
707 index_dir.join(format!("{}.residuals.npy", global_chunk_idx));
708
709 let old_codes: Array1<i64> = Array1::read_npy(File::open(&old_codes_path)?)?;
710 let old_residuals: Array2<u8> = Array2::read_npy(File::open(&old_residuals_path)?)?;
711
712 let mut combined_codes: Vec<usize> =
714 old_codes.iter().map(|&x| x as usize).collect();
715 combined_codes.extend(chk_codes_list);
716 chk_codes_list = combined_codes;
717
718 let mut combined_residuals: Vec<u8> = old_residuals.iter().copied().collect();
719 combined_residuals.extend(chk_residuals_list);
720 chk_residuals_list = combined_residuals;
721
722 let mut combined_doclens = old_doclens;
723 combined_doclens.extend(chk_doclens);
724 chk_doclens = combined_doclens;
725 }
726 }
727
728 {
730 use ndarray_npy::WriteNpyExt;
731
732 let codes_arr: Array1<i64> = chk_codes_list.iter().map(|&x| x as i64).collect();
733 let codes_path = index_dir.join(format!("{}.codes.npy", global_chunk_idx));
734 codes_arr.write_npy(File::create(&codes_path)?)?;
735
736 let num_tokens = chk_codes_list.len();
737 let residuals_arr =
738 Array2::from_shape_vec((num_tokens, packed_dim), chk_residuals_list)
739 .map_err(|e| Error::Shape(format!("Failed to reshape residuals: {}", e)))?;
740 let residuals_path = index_dir.join(format!("{}.residuals.npy", global_chunk_idx));
741 residuals_arr.write_npy(File::create(&residuals_path)?)?;
742 }
743
744 let doclens_path = index_dir.join(format!("doclens.{}.json", global_chunk_idx));
746 serde_json::to_writer(BufWriter::new(File::create(&doclens_path)?), &chk_doclens)?;
747
748 let chk_meta = serde_json::json!({
750 "num_documents": chk_doclens.len(),
751 "num_embeddings": chk_codes_list.len(),
752 "embedding_offset": current_emb_offset,
753 });
754 current_emb_offset += chk_codes_list.len();
755
756 let meta_path = index_dir.join(format!("{}.metadata.json", global_chunk_idx));
757 serde_json::to_writer_pretty(BufWriter::new(File::create(&meta_path)?), &chk_meta)?;
758 }
759
760 if update_threshold && !all_residual_norms.is_empty() {
762 let norms = Array1::from_vec(all_residual_norms);
763 update_cluster_threshold(index_dir, &norms, old_total_embeddings)?;
764 }
765
766 let mut partition_pids_map: HashMap<usize, Vec<i64>> = HashMap::new();
768 let mut pid_counter = old_num_documents as i64;
769
770 for doc_codes in &new_codes_accumulated {
771 for &code in doc_codes {
772 partition_pids_map
773 .entry(code)
774 .or_default()
775 .push(pid_counter);
776 }
777 pid_counter += 1;
778 }
779
780 {
782 use ndarray_npy::{ReadNpyExt, WriteNpyExt};
783
784 let ivf_path = index_dir.join("ivf.npy");
785 let ivf_lengths_path = index_dir.join("ivf_lengths.npy");
786
787 let old_ivf: Array1<i64> = if ivf_path.exists() {
788 Array1::read_npy(File::open(&ivf_path)?)?
789 } else {
790 Array1::zeros(0)
791 };
792
793 let old_ivf_lengths: Array1<i32> = if ivf_lengths_path.exists() {
794 Array1::read_npy(File::open(&ivf_lengths_path)?)?
795 } else {
796 Array1::zeros(num_centroids)
797 };
798
799 let mut old_offsets = vec![0i64];
801 for &len in old_ivf_lengths.iter() {
802 old_offsets.push(old_offsets.last().unwrap() + len as i64);
803 }
804
805 let mut new_ivf_data: Vec<i64> = Vec::new();
807 let mut new_ivf_lengths: Vec<i32> = Vec::with_capacity(num_centroids);
808
809 for centroid_id in 0..num_centroids {
810 let old_start = old_offsets[centroid_id] as usize;
812 let old_len = if centroid_id < old_ivf_lengths.len() {
813 old_ivf_lengths[centroid_id] as usize
814 } else {
815 0
816 };
817
818 let mut pids: Vec<i64> = if old_len > 0 && old_start + old_len <= old_ivf.len() {
819 old_ivf.slice(s![old_start..old_start + old_len]).to_vec()
820 } else {
821 Vec::new()
822 };
823
824 if let Some(new_pids) = partition_pids_map.get(¢roid_id) {
826 pids.extend(new_pids);
827 }
828
829 pids.sort_unstable();
831 pids.dedup();
832
833 new_ivf_lengths.push(pids.len() as i32);
834 new_ivf_data.extend(pids);
835 }
836
837 let new_ivf = Array1::from_vec(new_ivf_data);
839 new_ivf.write_npy(File::create(&ivf_path)?)?;
840
841 let new_lengths = Array1::from_vec(new_ivf_lengths);
842 new_lengths.write_npy(File::create(&ivf_lengths_path)?)?;
843 }
844
845 let new_total_chunks = start_chunk_idx + n_new_chunks;
847 let new_tokens_count: i64 = new_doclens_accumulated.iter().sum();
848 let num_embeddings = old_total_embeddings + new_tokens_count as usize;
849 let total_num_documents = old_num_documents + num_new_documents;
850
851 let new_avg_doclen = if total_num_documents > 0 {
852 let old_sum = metadata.avg_doclen * old_num_documents as f64;
853 (old_sum + new_tokens_count as f64) / total_num_documents as f64
854 } else {
855 0.0
856 };
857
858 let new_metadata = Metadata {
859 num_chunks: new_total_chunks,
860 nbits,
861 num_partitions: num_centroids,
862 num_embeddings,
863 avg_doclen: new_avg_doclen,
864 num_documents: total_num_documents,
865 next_plaid_compatible: true,
866 };
867
868 serde_json::to_writer_pretty(BufWriter::new(File::create(&metadata_path)?), &new_metadata)?;
869
870 crate::mmap::clear_merged_files(index_dir)?;
873
874 Ok(num_new_documents)
875}
876
877#[cfg(test)]
878mod tests {
879 use super::*;
880
881 #[test]
882 fn test_update_config_default() {
883 let config = UpdateConfig::default();
884 assert_eq!(config.batch_size, 50_000);
885 assert_eq!(config.buffer_size, 100);
886 assert_eq!(config.start_from_scratch, 999);
887 }
888
889 #[test]
890 fn test_find_outliers() {
891 let centroids = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 1.0, 1.0]).unwrap();
893
894 let embeddings =
896 Array2::from_shape_vec((3, 2), vec![0.1, 0.1, 0.9, 0.9, 5.0, 5.0]).unwrap();
897
898 let outliers = find_outliers(&embeddings, ¢roids, 1.0);
900
901 assert_eq!(outliers.len(), 1);
903 assert_eq!(outliers[0], 2);
904 }
905
906 #[test]
907 fn test_buffer_roundtrip() {
908 use tempfile::TempDir;
909
910 let dir = TempDir::new().unwrap();
911
912 let embeddings = vec![
914 Array2::from_shape_vec((3, 4), (0..12).map(|x| x as f32).collect()).unwrap(),
915 Array2::from_shape_vec((2, 4), (12..20).map(|x| x as f32).collect()).unwrap(),
916 Array2::from_shape_vec((5, 4), (20..40).map(|x| x as f32).collect()).unwrap(),
917 ];
918
919 save_buffer(dir.path(), &embeddings).unwrap();
921
922 let loaded = load_buffer(dir.path()).unwrap();
924
925 assert_eq!(loaded.len(), 3, "Should have 3 documents, not 1");
926 assert_eq!(loaded[0].nrows(), 3, "First doc should have 3 rows");
927 assert_eq!(loaded[1].nrows(), 2, "Second doc should have 2 rows");
928 assert_eq!(loaded[2].nrows(), 5, "Third doc should have 5 rows");
929
930 assert_eq!(loaded[0], embeddings[0]);
932 assert_eq!(loaded[1], embeddings[1]);
933 assert_eq!(loaded[2], embeddings[2]);
934 }
935
936 #[test]
937 fn test_buffer_info_matches_buffer_len() {
938 use tempfile::TempDir;
939
940 let dir = TempDir::new().unwrap();
941
942 let embeddings: Vec<Array2<f32>> = (0..5)
944 .map(|i| {
945 let rows = i + 2; Array2::from_shape_fn((rows, 4), |(r, c)| (r * 4 + c) as f32)
947 })
948 .collect();
949
950 save_buffer(dir.path(), &embeddings).unwrap();
951
952 let info_count = load_buffer_info(dir.path()).unwrap();
954 let loaded = load_buffer(dir.path()).unwrap();
955
956 assert_eq!(info_count, 5, "buffer_info should report 5 docs");
957 assert_eq!(
958 loaded.len(),
959 5,
960 "load_buffer should return 5 docs to match buffer_info"
961 );
962 }
963}