1use std::collections::HashMap;
10use std::fs;
11use std::fs::File;
12use std::io::{BufReader, BufWriter};
13use std::path::Path;
14
15use serde::{Deserialize, Serialize};
16
17use ndarray::{s, Array1, Array2, Axis};
18use rayon::prelude::*;
19
20use crate::codec::ResidualCodec;
21use crate::error::Error;
22use crate::error::Result;
23use crate::index::Metadata;
24use crate::kmeans::compute_kmeans;
25use crate::kmeans::ComputeKmeansConfig;
26use crate::utils::quantile;
27
28const DEFAULT_BATCH_SIZE: usize = 50_000;
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct UpdateConfig {
34 pub batch_size: usize,
36 pub kmeans_niters: usize,
38 pub max_points_per_centroid: usize,
40 pub n_samples_kmeans: Option<usize>,
42 pub seed: u64,
44 pub start_from_scratch: usize,
46 pub buffer_size: usize,
48 #[serde(default)]
50 pub force_cpu: bool,
51}
52
53impl Default for UpdateConfig {
54 fn default() -> Self {
55 Self {
56 batch_size: DEFAULT_BATCH_SIZE,
57 kmeans_niters: 4,
58 max_points_per_centroid: 256,
59 n_samples_kmeans: None,
60 seed: 42,
61 start_from_scratch: 999,
62 buffer_size: 100,
63 force_cpu: false,
64 }
65 }
66}
67
68impl UpdateConfig {
69 pub fn to_kmeans_config(&self) -> ComputeKmeansConfig {
71 ComputeKmeansConfig {
72 kmeans_niters: self.kmeans_niters,
73 max_points_per_centroid: self.max_points_per_centroid,
74 seed: self.seed,
75 n_samples_kmeans: self.n_samples_kmeans,
76 num_partitions: None,
77 force_cpu: self.force_cpu,
78 }
79 }
80}
81
82pub fn load_buffer(index_path: &Path) -> Result<Vec<Array2<f32>>> {
91 use ndarray_npy::ReadNpyExt;
92
93 let buffer_path = index_path.join("buffer.npy");
94 let lengths_path = index_path.join("buffer_lengths.json");
95
96 if !buffer_path.exists() {
97 return Ok(Vec::new());
98 }
99
100 let flat: Array2<f32> = match Array2::read_npy(File::open(&buffer_path)?) {
102 Ok(arr) => arr,
103 Err(_) => return Ok(Vec::new()),
104 };
105
106 if lengths_path.exists() {
108 let lengths: Vec<i64> =
109 serde_json::from_reader(BufReader::new(File::open(&lengths_path)?))?;
110
111 let mut result = Vec::with_capacity(lengths.len());
112 let mut offset = 0;
113
114 for &len in &lengths {
115 let len_usize = len as usize;
116 if offset + len_usize > flat.nrows() {
117 break;
118 }
119 let doc_emb = flat.slice(s![offset..offset + len_usize, ..]).to_owned();
120 result.push(doc_emb);
121 offset += len_usize;
122 }
123
124 return Ok(result);
125 }
126
127 Ok(vec![flat])
129}
130
131pub fn save_buffer(index_path: &Path, embeddings: &[Array2<f32>]) -> Result<()> {
135 use ndarray_npy::WriteNpyExt;
136
137 let buffer_path = index_path.join("buffer.npy");
138
139 if embeddings.is_empty() {
142 return Ok(());
143 }
144
145 let dim = embeddings[0].ncols();
146 let total_rows: usize = embeddings.iter().map(|e| e.nrows()).sum();
147
148 let mut flat = Array2::<f32>::zeros((total_rows, dim));
149 let mut offset = 0;
150 let mut lengths = Vec::new();
151
152 for emb in embeddings {
153 let n = emb.nrows();
154 flat.slice_mut(s![offset..offset + n, ..]).assign(emb);
155 lengths.push(n as i64);
156 offset += n;
157 }
158
159 flat.write_npy(File::create(&buffer_path)?)?;
160
161 let lengths_path = index_path.join("buffer_lengths.json");
163 serde_json::to_writer(BufWriter::new(File::create(&lengths_path)?), &lengths)?;
164
165 let info_path = index_path.join("buffer_info.json");
167 let buffer_info = serde_json::json!({ "num_docs": embeddings.len() });
168 serde_json::to_writer(BufWriter::new(File::create(&info_path)?), &buffer_info)?;
169
170 Ok(())
171}
172
173pub fn load_buffer_info(index_path: &Path) -> Result<usize> {
177 let info_path = index_path.join("buffer_info.json");
178 if !info_path.exists() {
179 return Ok(0);
180 }
181
182 let info: serde_json::Value = serde_json::from_reader(BufReader::new(File::open(&info_path)?))?;
183
184 Ok(info.get("num_docs").and_then(|v| v.as_u64()).unwrap_or(0) as usize)
185}
186
187pub fn clear_buffer(index_path: &Path) -> Result<()> {
189 let buffer_path = index_path.join("buffer.npy");
190 let lengths_path = index_path.join("buffer_lengths.json");
191 let info_path = index_path.join("buffer_info.json");
192
193 if buffer_path.exists() {
194 fs::remove_file(&buffer_path)?;
195 }
196 if lengths_path.exists() {
197 fs::remove_file(&lengths_path)?;
198 }
199 if info_path.exists() {
200 fs::remove_file(&info_path)?;
201 }
202
203 Ok(())
204}
205
206pub fn load_embeddings_npy(index_path: &Path) -> Result<Vec<Array2<f32>>> {
211 use ndarray_npy::ReadNpyExt;
212
213 let emb_path = index_path.join("embeddings.npy");
214 let lengths_path = index_path.join("embeddings_lengths.json");
215
216 if !emb_path.exists() {
217 return Ok(Vec::new());
218 }
219
220 let flat: Array2<f32> = Array2::read_npy(File::open(&emb_path)?)?;
222
223 if lengths_path.exists() {
225 let lengths: Vec<i64> =
226 serde_json::from_reader(BufReader::new(File::open(&lengths_path)?))?;
227
228 let mut result = Vec::with_capacity(lengths.len());
229 let mut offset = 0;
230
231 for &len in &lengths {
232 let len_usize = len as usize;
233 if offset + len_usize > flat.nrows() {
234 break;
235 }
236 let doc_emb = flat.slice(s![offset..offset + len_usize, ..]).to_owned();
237 result.push(doc_emb);
238 offset += len_usize;
239 }
240
241 return Ok(result);
242 }
243
244 Ok(vec![flat])
246}
247
248pub fn save_embeddings_npy(index_path: &Path, embeddings: &[Array2<f32>]) -> Result<()> {
254 use ndarray_npy::WriteNpyExt;
255
256 if embeddings.is_empty() {
257 return Ok(());
258 }
259
260 let dim = embeddings[0].ncols();
261 let total_rows: usize = embeddings.iter().map(|e| e.nrows()).sum();
262
263 let mut flat = Array2::<f32>::zeros((total_rows, dim));
264 let mut offset = 0;
265 let mut lengths = Vec::with_capacity(embeddings.len());
266
267 for emb in embeddings {
268 let n = emb.nrows();
269 flat.slice_mut(s![offset..offset + n, ..]).assign(emb);
270 lengths.push(n as i64);
271 offset += n;
272 }
273
274 let emb_path = index_path.join("embeddings.npy");
276 flat.write_npy(File::create(&emb_path)?)?;
277
278 let lengths_path = index_path.join("embeddings_lengths.json");
280 serde_json::to_writer(BufWriter::new(File::create(&lengths_path)?), &lengths)?;
281
282 Ok(())
283}
284
285pub fn clear_embeddings_npy(index_path: &Path) -> Result<()> {
287 let emb_path = index_path.join("embeddings.npy");
288 let lengths_path = index_path.join("embeddings_lengths.json");
289
290 if emb_path.exists() {
291 fs::remove_file(&emb_path)?;
292 }
293 if lengths_path.exists() {
294 fs::remove_file(&lengths_path)?;
295 }
296 Ok(())
297}
298
299pub fn embeddings_npy_exists(index_path: &Path) -> bool {
301 index_path.join("embeddings.npy").exists()
302}
303
304pub fn load_cluster_threshold(index_path: &Path) -> Result<f32> {
310 use ndarray_npy::ReadNpyExt;
311
312 let thresh_path = index_path.join("cluster_threshold.npy");
313 if !thresh_path.exists() {
314 return Err(Error::Update("cluster_threshold.npy not found".into()));
315 }
316
317 let arr: Array1<f32> = Array1::read_npy(File::open(&thresh_path)?)?;
318 Ok(arr[0])
319}
320
321pub fn update_cluster_threshold(
323 index_path: &Path,
324 new_residual_norms: &Array1<f32>,
325 old_total_embeddings: usize,
326) -> Result<()> {
327 use ndarray_npy::{ReadNpyExt, WriteNpyExt};
328
329 let new_count = new_residual_norms.len();
330 if new_count == 0 {
331 return Ok(());
332 }
333
334 let new_threshold = quantile(new_residual_norms, 0.75);
335
336 let thresh_path = index_path.join("cluster_threshold.npy");
337 let final_threshold = if thresh_path.exists() {
338 let old_arr: Array1<f32> = Array1::read_npy(File::open(&thresh_path)?)?;
339 let old_threshold = old_arr[0];
340 let total = old_total_embeddings + new_count;
341 (old_threshold * old_total_embeddings as f32 + new_threshold * new_count as f32)
342 / total as f32
343 } else {
344 new_threshold
345 };
346
347 Array1::from_vec(vec![final_threshold]).write_npy(File::create(&thresh_path)?)?;
348
349 Ok(())
350}
351
352fn find_outliers(
363 flat_embeddings: &Array2<f32>,
364 centroids: &Array2<f32>,
365 threshold_sq: f32,
366) -> Vec<usize> {
367 let n = flat_embeddings.nrows();
368 let k = centroids.nrows();
369
370 if n == 0 || k == 0 {
371 return Vec::new();
372 }
373
374 let emb_norms_sq: Vec<f32> = flat_embeddings
376 .axis_iter(Axis(0))
377 .into_par_iter()
378 .map(|row| row.dot(&row))
379 .collect();
380
381 let centroid_norms_sq: Vec<f32> = centroids
382 .axis_iter(Axis(0))
383 .into_par_iter()
384 .map(|row| row.dot(&row))
385 .collect();
386
387 let batch_size = (2 * 1024 * 1024 * 1024 / (k * std::mem::size_of::<f32>())).clamp(1, 4096);
391
392 let mut outlier_indices = Vec::new();
393
394 for batch_start in (0..n).step_by(batch_size) {
395 let batch_end = (batch_start + batch_size).min(n);
396 let batch = flat_embeddings.slice(s![batch_start..batch_end, ..]);
397
398 let dot_products = batch.dot(¢roids.t());
400
401 for (batch_idx, row) in dot_products.axis_iter(Axis(0)).enumerate() {
403 let global_idx = batch_start + batch_idx;
404 let emb_norm_sq = emb_norms_sq[global_idx];
405
406 let min_dist_sq = row
409 .iter()
410 .zip(centroid_norms_sq.iter())
411 .map(|(&dot, &c_norm_sq)| emb_norm_sq + c_norm_sq - 2.0 * dot)
412 .fold(f32::INFINITY, f32::min);
413
414 if min_dist_sq > threshold_sq {
415 outlier_indices.push(global_idx);
416 }
417 }
418 }
419
420 outlier_indices
421}
422
423pub fn update_centroids(
435 index_path: &Path,
436 new_embeddings: &[Array2<f32>],
437 cluster_threshold: f32,
438 config: &UpdateConfig,
439) -> Result<usize> {
440 use ndarray_npy::{ReadNpyExt, WriteNpyExt};
441
442 let centroids_path = index_path.join("centroids.npy");
443 if !centroids_path.exists() {
444 return Ok(0);
445 }
446
447 let existing_centroids: Array2<f32> = Array2::read_npy(File::open(¢roids_path)?)?;
449
450 let dim = existing_centroids.ncols();
452 let total_tokens: usize = new_embeddings.iter().map(|e| e.nrows()).sum();
453
454 if total_tokens == 0 {
455 return Ok(0);
456 }
457
458 let mut flat_embeddings = Array2::<f32>::zeros((total_tokens, dim));
459 let mut offset = 0;
460
461 for emb in new_embeddings {
462 let n = emb.nrows();
463 flat_embeddings
464 .slice_mut(s![offset..offset + n, ..])
465 .assign(emb);
466 offset += n;
467 }
468
469 let threshold_sq = cluster_threshold * cluster_threshold;
471 let outlier_indices = find_outliers(&flat_embeddings, &existing_centroids, threshold_sq);
472
473 let num_outliers = outlier_indices.len();
474 if num_outliers == 0 {
475 return Ok(0);
476 }
477
478 let mut outliers = Array2::<f32>::zeros((num_outliers, dim));
480 for (i, &idx) in outlier_indices.iter().enumerate() {
481 outliers.row_mut(i).assign(&flat_embeddings.row(idx));
482 }
483
484 let target_k =
487 ((num_outliers as f64 / config.max_points_per_centroid as f64).ceil() as usize).max(1) * 4;
488 let k_update = target_k.min(num_outliers); let kmeans_config = ComputeKmeansConfig {
492 kmeans_niters: config.kmeans_niters,
493 max_points_per_centroid: config.max_points_per_centroid,
494 seed: config.seed,
495 n_samples_kmeans: config.n_samples_kmeans,
496 num_partitions: Some(k_update),
497 force_cpu: config.force_cpu,
498 };
499
500 let outlier_docs: Vec<Array2<f32>> = outlier_indices
502 .iter()
503 .map(|&idx| flat_embeddings.slice(s![idx..idx + 1, ..]).to_owned())
504 .collect();
505
506 let new_centroids = compute_kmeans(&outlier_docs, &kmeans_config)?;
507 let k_new = new_centroids.nrows();
508
509 let new_num_centroids = existing_centroids.nrows() + k_new;
511 let mut final_centroids = Array2::<f32>::zeros((new_num_centroids, dim));
512 final_centroids
513 .slice_mut(s![..existing_centroids.nrows(), ..])
514 .assign(&existing_centroids);
515 final_centroids
516 .slice_mut(s![existing_centroids.nrows().., ..])
517 .assign(&new_centroids);
518
519 final_centroids.write_npy(File::create(¢roids_path)?)?;
521
522 let ivf_lengths_path = index_path.join("ivf_lengths.npy");
524 if ivf_lengths_path.exists() {
525 let old_lengths: Array1<i32> = Array1::read_npy(File::open(&ivf_lengths_path)?)?;
526 let mut new_lengths = Array1::<i32>::zeros(new_num_centroids);
527 new_lengths
528 .slice_mut(s![..old_lengths.len()])
529 .assign(&old_lengths);
530 new_lengths.write_npy(File::create(&ivf_lengths_path)?)?;
531 }
532
533 let meta_path = index_path.join("metadata.json");
535 if meta_path.exists() {
536 let mut meta: serde_json::Value =
537 serde_json::from_reader(BufReader::new(File::open(&meta_path)?))?;
538
539 if let Some(obj) = meta.as_object_mut() {
540 obj.insert("num_partitions".to_string(), new_num_centroids.into());
541 }
542
543 serde_json::to_writer_pretty(BufWriter::new(File::create(&meta_path)?), &meta)?;
544 }
545
546 Ok(k_new)
547}
548
549pub fn update_index(
568 embeddings: &[Array2<f32>],
569 index_path: &str,
570 codec: &ResidualCodec,
571 batch_size: Option<usize>,
572 update_threshold: bool,
573 force_cpu: bool,
574) -> Result<usize> {
575 let batch_size = batch_size.unwrap_or(DEFAULT_BATCH_SIZE);
576 let index_dir = Path::new(index_path);
577
578 let metadata_path = index_dir.join("metadata.json");
580 let metadata = Metadata::load_from_path(index_dir)?;
581
582 let num_existing_chunks = metadata.num_chunks;
583 let old_num_documents = metadata.num_documents;
584 let old_total_embeddings = metadata.num_embeddings;
585 let num_centroids = codec.num_centroids();
586 let embedding_dim = codec.embedding_dim();
587 let nbits = metadata.nbits;
588
589 let mut start_chunk_idx = num_existing_chunks;
591 let mut append_to_last = false;
592 let mut current_emb_offset = old_total_embeddings;
593
594 if start_chunk_idx > 0 {
596 let last_idx = start_chunk_idx - 1;
597 let last_meta_path = index_dir.join(format!("{}.metadata.json", last_idx));
598
599 if last_meta_path.exists() {
600 let last_meta: serde_json::Value =
601 serde_json::from_reader(BufReader::new(File::open(&last_meta_path).map_err(
602 |e| Error::IndexLoad(format!("Failed to open chunk metadata: {}", e)),
603 )?))?;
604
605 if let Some(nd) = last_meta.get("num_documents").and_then(|x| x.as_u64()) {
606 if nd < 2000 {
607 start_chunk_idx = last_idx;
608 append_to_last = true;
609
610 if let Some(off) = last_meta.get("embedding_offset").and_then(|x| x.as_u64()) {
611 current_emb_offset = off as usize;
612 } else {
613 let embs_in_last = last_meta
614 .get("num_embeddings")
615 .and_then(|x| x.as_u64())
616 .unwrap_or(0) as usize;
617 current_emb_offset = old_total_embeddings - embs_in_last;
618 }
619 }
620 }
621 }
622 }
623
624 let num_new_documents = embeddings.len();
626 let n_new_chunks = (num_new_documents as f64 / batch_size as f64).ceil() as usize;
627
628 let mut new_codes_accumulated: Vec<Vec<usize>> = Vec::new();
629 let mut new_doclens_accumulated: Vec<i64> = Vec::new();
630 let mut all_residual_norms: Vec<f32> = Vec::new();
631
632 let packed_dim = embedding_dim * nbits / 8;
633
634 for i in 0..n_new_chunks {
635 let global_chunk_idx = start_chunk_idx + i;
636 let chk_offset = i * batch_size;
637 let chk_end = (chk_offset + batch_size).min(num_new_documents);
638 let chunk_docs = &embeddings[chk_offset..chk_end];
639
640 let mut chk_doclens: Vec<i64> = chunk_docs.iter().map(|d| d.nrows() as i64).collect();
642 let total_tokens: usize = chk_doclens.iter().sum::<i64>() as usize;
643
644 let mut batch_embeddings = ndarray::Array2::<f32>::zeros((total_tokens, embedding_dim));
646 let mut offset = 0;
647 for doc in chunk_docs {
648 let n = doc.nrows();
649 batch_embeddings
650 .slice_mut(s![offset..offset + n, ..])
651 .assign(doc);
652 offset += n;
653 }
654
655 let batch_codes = if force_cpu {
658 codec.compress_into_codes_cpu(&batch_embeddings)
659 } else {
660 codec.compress_into_codes(&batch_embeddings)
661 };
662
663 let mut batch_residuals = batch_embeddings;
665 {
666 let centroids = &codec.centroids;
667 batch_residuals
668 .axis_iter_mut(Axis(0))
669 .into_par_iter()
670 .zip(batch_codes.as_slice().unwrap().par_iter())
671 .for_each(|(mut row, &code)| {
672 let centroid = centroids.row(code);
673 row.iter_mut()
674 .zip(centroid.iter())
675 .for_each(|(r, c)| *r -= c);
676 });
677 }
678
679 if update_threshold {
681 for row in batch_residuals.axis_iter(Axis(0)) {
682 let norm = row.dot(&row).sqrt();
683 all_residual_norms.push(norm);
684 }
685 }
686
687 let batch_packed = codec.quantize_residuals(&batch_residuals)?;
689
690 let mut chk_codes_list: Vec<usize> = batch_codes.iter().copied().collect();
692 let mut chk_residuals_list: Vec<u8> = batch_packed.iter().copied().collect();
693
694 let mut code_offset = 0;
696 for &len in &chk_doclens {
697 let len_usize = len as usize;
698 let codes: Vec<usize> = batch_codes
699 .slice(s![code_offset..code_offset + len_usize])
700 .iter()
701 .copied()
702 .collect();
703 new_codes_accumulated.push(codes);
704 new_doclens_accumulated.push(len);
705 code_offset += len_usize;
706 }
707
708 if i == 0 && append_to_last {
710 use ndarray_npy::ReadNpyExt;
711
712 let old_doclens_path = index_dir.join(format!("doclens.{}.json", global_chunk_idx));
713
714 if old_doclens_path.exists() {
715 let old_doclens: Vec<i64> =
716 serde_json::from_reader(BufReader::new(File::open(&old_doclens_path)?))?;
717
718 let old_codes_path = index_dir.join(format!("{}.codes.npy", global_chunk_idx));
719 let old_residuals_path =
720 index_dir.join(format!("{}.residuals.npy", global_chunk_idx));
721
722 let old_codes: Array1<i64> = Array1::read_npy(File::open(&old_codes_path)?)?;
723 let old_residuals: Array2<u8> = Array2::read_npy(File::open(&old_residuals_path)?)?;
724
725 let mut combined_codes: Vec<usize> =
727 old_codes.iter().map(|&x| x as usize).collect();
728 combined_codes.extend(chk_codes_list);
729 chk_codes_list = combined_codes;
730
731 let mut combined_residuals: Vec<u8> = old_residuals.iter().copied().collect();
732 combined_residuals.extend(chk_residuals_list);
733 chk_residuals_list = combined_residuals;
734
735 let mut combined_doclens = old_doclens;
736 combined_doclens.extend(chk_doclens);
737 chk_doclens = combined_doclens;
738 }
739 }
740
741 {
743 use ndarray_npy::WriteNpyExt;
744
745 let codes_arr: Array1<i64> = chk_codes_list.iter().map(|&x| x as i64).collect();
746 let codes_path = index_dir.join(format!("{}.codes.npy", global_chunk_idx));
747 codes_arr.write_npy(File::create(&codes_path)?)?;
748
749 let num_tokens = chk_codes_list.len();
750 let residuals_arr =
751 Array2::from_shape_vec((num_tokens, packed_dim), chk_residuals_list)
752 .map_err(|e| Error::Shape(format!("Failed to reshape residuals: {}", e)))?;
753 let residuals_path = index_dir.join(format!("{}.residuals.npy", global_chunk_idx));
754 residuals_arr.write_npy(File::create(&residuals_path)?)?;
755 }
756
757 let doclens_path = index_dir.join(format!("doclens.{}.json", global_chunk_idx));
759 serde_json::to_writer(BufWriter::new(File::create(&doclens_path)?), &chk_doclens)?;
760
761 let chk_meta = serde_json::json!({
763 "num_documents": chk_doclens.len(),
764 "num_embeddings": chk_codes_list.len(),
765 "embedding_offset": current_emb_offset,
766 });
767 current_emb_offset += chk_codes_list.len();
768
769 let meta_path = index_dir.join(format!("{}.metadata.json", global_chunk_idx));
770 serde_json::to_writer_pretty(BufWriter::new(File::create(&meta_path)?), &chk_meta)?;
771 }
772
773 if update_threshold && !all_residual_norms.is_empty() {
775 let norms = Array1::from_vec(all_residual_norms);
776 update_cluster_threshold(index_dir, &norms, old_total_embeddings)?;
777 }
778
779 let mut partition_pids_map: HashMap<usize, Vec<i64>> = HashMap::new();
781 let mut pid_counter = old_num_documents as i64;
782
783 for doc_codes in &new_codes_accumulated {
784 for &code in doc_codes {
785 partition_pids_map
786 .entry(code)
787 .or_default()
788 .push(pid_counter);
789 }
790 pid_counter += 1;
791 }
792
793 {
795 use ndarray_npy::{ReadNpyExt, WriteNpyExt};
796
797 let ivf_path = index_dir.join("ivf.npy");
798 let ivf_lengths_path = index_dir.join("ivf_lengths.npy");
799
800 let old_ivf: Array1<i64> = if ivf_path.exists() {
801 Array1::read_npy(File::open(&ivf_path)?)?
802 } else {
803 Array1::zeros(0)
804 };
805
806 let old_ivf_lengths: Array1<i32> = if ivf_lengths_path.exists() {
807 Array1::read_npy(File::open(&ivf_lengths_path)?)?
808 } else {
809 Array1::zeros(num_centroids)
810 };
811
812 let mut old_offsets = vec![0i64];
814 for &len in old_ivf_lengths.iter() {
815 old_offsets.push(old_offsets.last().unwrap() + len as i64);
816 }
817
818 let mut new_ivf_data: Vec<i64> = Vec::new();
820 let mut new_ivf_lengths: Vec<i32> = Vec::with_capacity(num_centroids);
821
822 for centroid_id in 0..num_centroids {
823 let old_start = old_offsets[centroid_id] as usize;
825 let old_len = if centroid_id < old_ivf_lengths.len() {
826 old_ivf_lengths[centroid_id] as usize
827 } else {
828 0
829 };
830
831 let mut pids: Vec<i64> = if old_len > 0 && old_start + old_len <= old_ivf.len() {
832 old_ivf.slice(s![old_start..old_start + old_len]).to_vec()
833 } else {
834 Vec::new()
835 };
836
837 if let Some(new_pids) = partition_pids_map.get(¢roid_id) {
839 pids.extend(new_pids);
840 }
841
842 pids.sort_unstable();
844 pids.dedup();
845
846 new_ivf_lengths.push(pids.len() as i32);
847 new_ivf_data.extend(pids);
848 }
849
850 let new_ivf = Array1::from_vec(new_ivf_data);
852 new_ivf.write_npy(File::create(&ivf_path)?)?;
853
854 let new_lengths = Array1::from_vec(new_ivf_lengths);
855 new_lengths.write_npy(File::create(&ivf_lengths_path)?)?;
856 }
857
858 let new_total_chunks = start_chunk_idx + n_new_chunks;
860 let new_tokens_count: i64 = new_doclens_accumulated.iter().sum();
861 let num_embeddings = old_total_embeddings + new_tokens_count as usize;
862 let total_num_documents = old_num_documents + num_new_documents;
863
864 let new_avg_doclen = if total_num_documents > 0 {
865 let old_sum = metadata.avg_doclen * old_num_documents as f64;
866 (old_sum + new_tokens_count as f64) / total_num_documents as f64
867 } else {
868 0.0
869 };
870
871 let new_metadata = Metadata {
872 num_chunks: new_total_chunks,
873 nbits,
874 num_partitions: num_centroids,
875 num_embeddings,
876 avg_doclen: new_avg_doclen,
877 num_documents: total_num_documents,
878 next_plaid_compatible: true,
879 };
880
881 serde_json::to_writer_pretty(BufWriter::new(File::create(&metadata_path)?), &new_metadata)?;
882
883 crate::mmap::clear_merged_files(index_dir)?;
886
887 Ok(num_new_documents)
888}
889
890#[cfg(test)]
891mod tests {
892 use super::*;
893
894 #[test]
895 fn test_update_config_default() {
896 let config = UpdateConfig::default();
897 assert_eq!(config.batch_size, 50_000);
898 assert_eq!(config.buffer_size, 100);
899 assert_eq!(config.start_from_scratch, 999);
900 }
901
902 #[test]
903 fn test_find_outliers() {
904 let centroids = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 1.0, 1.0]).unwrap();
906
907 let embeddings =
909 Array2::from_shape_vec((3, 2), vec![0.1, 0.1, 0.9, 0.9, 5.0, 5.0]).unwrap();
910
911 let outliers = find_outliers(&embeddings, ¢roids, 1.0);
913
914 assert_eq!(outliers.len(), 1);
916 assert_eq!(outliers[0], 2);
917 }
918
919 #[test]
920 fn test_buffer_roundtrip() {
921 use tempfile::TempDir;
922
923 let dir = TempDir::new().unwrap();
924
925 let embeddings = vec![
927 Array2::from_shape_vec((3, 4), (0..12).map(|x| x as f32).collect()).unwrap(),
928 Array2::from_shape_vec((2, 4), (12..20).map(|x| x as f32).collect()).unwrap(),
929 Array2::from_shape_vec((5, 4), (20..40).map(|x| x as f32).collect()).unwrap(),
930 ];
931
932 save_buffer(dir.path(), &embeddings).unwrap();
934
935 let loaded = load_buffer(dir.path()).unwrap();
937
938 assert_eq!(loaded.len(), 3, "Should have 3 documents, not 1");
939 assert_eq!(loaded[0].nrows(), 3, "First doc should have 3 rows");
940 assert_eq!(loaded[1].nrows(), 2, "Second doc should have 2 rows");
941 assert_eq!(loaded[2].nrows(), 5, "Third doc should have 5 rows");
942
943 assert_eq!(loaded[0], embeddings[0]);
945 assert_eq!(loaded[1], embeddings[1]);
946 assert_eq!(loaded[2], embeddings[2]);
947 }
948
949 #[test]
950 fn test_buffer_info_matches_buffer_len() {
951 use tempfile::TempDir;
952
953 let dir = TempDir::new().unwrap();
954
955 let embeddings: Vec<Array2<f32>> = (0..5)
957 .map(|i| {
958 let rows = i + 2; Array2::from_shape_fn((rows, 4), |(r, c)| (r * 4 + c) as f32)
960 })
961 .collect();
962
963 save_buffer(dir.path(), &embeddings).unwrap();
964
965 let info_count = load_buffer_info(dir.path()).unwrap();
967 let loaded = load_buffer(dir.path()).unwrap();
968
969 assert_eq!(info_count, 5, "buffer_info should report 5 docs");
970 assert_eq!(
971 loaded.len(),
972 5,
973 "load_buffer should return 5 docs to match buffer_info"
974 );
975 }
976}