1#[cfg(feature = "npy")]
10use std::collections::HashMap;
11use std::fs;
12#[cfg(feature = "npy")]
13use std::fs::File;
14#[cfg(feature = "npy")]
15use std::io::{BufReader, BufWriter};
16use std::path::Path;
17
18use serde::{Deserialize, Serialize};
19
20#[cfg(feature = "npy")]
21use ndarray::{s, Array1, Array2, Axis};
22#[cfg(feature = "npy")]
23use rayon::prelude::*;
24
25#[cfg(feature = "npy")]
26use crate::codec::ResidualCodec;
27#[cfg(feature = "npy")]
28use crate::error::Error;
29use crate::error::Result;
30#[cfg(feature = "npy")]
31use crate::index::Metadata;
32#[cfg(feature = "npy")]
33use crate::kmeans::compute_kmeans;
34use crate::kmeans::ComputeKmeansConfig;
35#[cfg(feature = "npy")]
36use crate::utils::quantile;
37
38const DEFAULT_BATCH_SIZE: usize = 50_000;
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct UpdateConfig {
44 pub batch_size: usize,
46 pub kmeans_niters: usize,
48 pub max_points_per_centroid: usize,
50 pub n_samples_kmeans: Option<usize>,
52 pub seed: u64,
54 pub start_from_scratch: usize,
56 pub buffer_size: usize,
58}
59
60impl Default for UpdateConfig {
61 fn default() -> Self {
62 Self {
63 batch_size: DEFAULT_BATCH_SIZE,
64 kmeans_niters: 4,
65 max_points_per_centroid: 256,
66 n_samples_kmeans: None,
67 seed: 42,
68 start_from_scratch: 999,
69 buffer_size: 100,
70 }
71 }
72}
73
74impl UpdateConfig {
75 pub fn to_kmeans_config(&self) -> ComputeKmeansConfig {
77 ComputeKmeansConfig {
78 kmeans_niters: self.kmeans_niters,
79 max_points_per_centroid: self.max_points_per_centroid,
80 seed: self.seed,
81 n_samples_kmeans: self.n_samples_kmeans,
82 num_partitions: None,
83 }
84 }
85}
86
87#[cfg(feature = "npy")]
96pub fn load_buffer(index_path: &Path) -> Result<Vec<Array2<f32>>> {
97 use ndarray_npy::ReadNpyExt;
98
99 let buffer_path = index_path.join("buffer.npy");
100 let lengths_path = index_path.join("buffer_lengths.json");
101
102 if !buffer_path.exists() {
103 return Ok(Vec::new());
104 }
105
106 let flat: Array2<f32> = match Array2::read_npy(File::open(&buffer_path)?) {
108 Ok(arr) => arr,
109 Err(_) => return Ok(Vec::new()),
110 };
111
112 if lengths_path.exists() {
114 let lengths: Vec<i64> =
115 serde_json::from_reader(BufReader::new(File::open(&lengths_path)?))?;
116
117 let mut result = Vec::with_capacity(lengths.len());
118 let mut offset = 0;
119
120 for &len in &lengths {
121 let len_usize = len as usize;
122 if offset + len_usize > flat.nrows() {
123 break;
124 }
125 let doc_emb = flat.slice(s![offset..offset + len_usize, ..]).to_owned();
126 result.push(doc_emb);
127 offset += len_usize;
128 }
129
130 return Ok(result);
131 }
132
133 Ok(vec![flat])
135}
136
137#[cfg(feature = "npy")]
141pub fn save_buffer(index_path: &Path, embeddings: &[Array2<f32>]) -> Result<()> {
142 use ndarray_npy::WriteNpyExt;
143
144 let buffer_path = index_path.join("buffer.npy");
145
146 if embeddings.is_empty() {
149 return Ok(());
150 }
151
152 let dim = embeddings[0].ncols();
153 let total_rows: usize = embeddings.iter().map(|e| e.nrows()).sum();
154
155 let mut flat = Array2::<f32>::zeros((total_rows, dim));
156 let mut offset = 0;
157 let mut lengths = Vec::new();
158
159 for emb in embeddings {
160 let n = emb.nrows();
161 flat.slice_mut(s![offset..offset + n, ..]).assign(emb);
162 lengths.push(n as i64);
163 offset += n;
164 }
165
166 flat.write_npy(File::create(&buffer_path)?)?;
167
168 let lengths_path = index_path.join("buffer_lengths.json");
170 serde_json::to_writer(BufWriter::new(File::create(&lengths_path)?), &lengths)?;
171
172 let info_path = index_path.join("buffer_info.json");
174 let buffer_info = serde_json::json!({ "num_docs": embeddings.len() });
175 serde_json::to_writer(BufWriter::new(File::create(&info_path)?), &buffer_info)?;
176
177 Ok(())
178}
179
180#[cfg(feature = "npy")]
184pub fn load_buffer_info(index_path: &Path) -> Result<usize> {
185 let info_path = index_path.join("buffer_info.json");
186 if !info_path.exists() {
187 return Ok(0);
188 }
189
190 let info: serde_json::Value = serde_json::from_reader(BufReader::new(File::open(&info_path)?))?;
191
192 Ok(info.get("num_docs").and_then(|v| v.as_u64()).unwrap_or(0) as usize)
193}
194
195pub fn clear_buffer(index_path: &Path) -> Result<()> {
197 let buffer_path = index_path.join("buffer.npy");
198 let lengths_path = index_path.join("buffer_lengths.json");
199 let info_path = index_path.join("buffer_info.json");
200
201 if buffer_path.exists() {
202 fs::remove_file(&buffer_path)?;
203 }
204 if lengths_path.exists() {
205 fs::remove_file(&lengths_path)?;
206 }
207 if info_path.exists() {
208 fs::remove_file(&info_path)?;
209 }
210
211 Ok(())
212}
213
214#[cfg(feature = "npy")]
219pub fn load_embeddings_npy(index_path: &Path) -> Result<Vec<Array2<f32>>> {
220 use ndarray_npy::ReadNpyExt;
221
222 let emb_path = index_path.join("embeddings.npy");
223 let lengths_path = index_path.join("embeddings_lengths.json");
224
225 if !emb_path.exists() {
226 return Ok(Vec::new());
227 }
228
229 let flat: Array2<f32> = Array2::read_npy(File::open(&emb_path)?)?;
231
232 if lengths_path.exists() {
234 let lengths: Vec<i64> =
235 serde_json::from_reader(BufReader::new(File::open(&lengths_path)?))?;
236
237 let mut result = Vec::with_capacity(lengths.len());
238 let mut offset = 0;
239
240 for &len in &lengths {
241 let len_usize = len as usize;
242 if offset + len_usize > flat.nrows() {
243 break;
244 }
245 let doc_emb = flat.slice(s![offset..offset + len_usize, ..]).to_owned();
246 result.push(doc_emb);
247 offset += len_usize;
248 }
249
250 return Ok(result);
251 }
252
253 Ok(vec![flat])
255}
256
257#[cfg(feature = "npy")]
263pub fn save_embeddings_npy(index_path: &Path, embeddings: &[Array2<f32>]) -> Result<()> {
264 use ndarray_npy::WriteNpyExt;
265
266 if embeddings.is_empty() {
267 return Ok(());
268 }
269
270 let dim = embeddings[0].ncols();
271 let total_rows: usize = embeddings.iter().map(|e| e.nrows()).sum();
272
273 let mut flat = Array2::<f32>::zeros((total_rows, dim));
274 let mut offset = 0;
275 let mut lengths = Vec::with_capacity(embeddings.len());
276
277 for emb in embeddings {
278 let n = emb.nrows();
279 flat.slice_mut(s![offset..offset + n, ..]).assign(emb);
280 lengths.push(n as i64);
281 offset += n;
282 }
283
284 let emb_path = index_path.join("embeddings.npy");
286 flat.write_npy(File::create(&emb_path)?)?;
287
288 let lengths_path = index_path.join("embeddings_lengths.json");
290 serde_json::to_writer(BufWriter::new(File::create(&lengths_path)?), &lengths)?;
291
292 Ok(())
293}
294
295pub fn clear_embeddings_npy(index_path: &Path) -> Result<()> {
297 let emb_path = index_path.join("embeddings.npy");
298 let lengths_path = index_path.join("embeddings_lengths.json");
299
300 if emb_path.exists() {
301 fs::remove_file(&emb_path)?;
302 }
303 if lengths_path.exists() {
304 fs::remove_file(&lengths_path)?;
305 }
306 Ok(())
307}
308
309pub fn embeddings_npy_exists(index_path: &Path) -> bool {
311 index_path.join("embeddings.npy").exists()
312}
313
314#[cfg(feature = "npy")]
320pub fn load_cluster_threshold(index_path: &Path) -> Result<f32> {
321 use ndarray_npy::ReadNpyExt;
322
323 let thresh_path = index_path.join("cluster_threshold.npy");
324 if !thresh_path.exists() {
325 return Err(Error::Update("cluster_threshold.npy not found".into()));
326 }
327
328 let arr: Array1<f32> = Array1::read_npy(File::open(&thresh_path)?)?;
329 Ok(arr[0])
330}
331
332#[cfg(feature = "npy")]
334pub fn update_cluster_threshold(
335 index_path: &Path,
336 new_residual_norms: &Array1<f32>,
337 old_total_embeddings: usize,
338) -> Result<()> {
339 use ndarray_npy::{ReadNpyExt, WriteNpyExt};
340
341 let new_count = new_residual_norms.len();
342 if new_count == 0 {
343 return Ok(());
344 }
345
346 let new_threshold = quantile(new_residual_norms, 0.75);
347
348 let thresh_path = index_path.join("cluster_threshold.npy");
349 let final_threshold = if thresh_path.exists() {
350 let old_arr: Array1<f32> = Array1::read_npy(File::open(&thresh_path)?)?;
351 let old_threshold = old_arr[0];
352 let total = old_total_embeddings + new_count;
353 (old_threshold * old_total_embeddings as f32 + new_threshold * new_count as f32)
354 / total as f32
355 } else {
356 new_threshold
357 };
358
359 Array1::from_vec(vec![final_threshold]).write_npy(File::create(&thresh_path)?)?;
360
361 Ok(())
362}
363
364#[cfg(feature = "npy")]
372fn find_outliers(
373 flat_embeddings: &Array2<f32>,
374 centroids: &Array2<f32>,
375 threshold_sq: f32,
376) -> Vec<usize> {
377 flat_embeddings
378 .axis_iter(Axis(0))
379 .into_par_iter()
380 .enumerate()
381 .filter_map(|(i, emb)| {
382 let min_dist_sq = centroids
384 .axis_iter(Axis(0))
385 .map(|c| {
386 emb.iter()
388 .zip(c.iter())
389 .map(|(a, b)| (a - b).powi(2))
390 .sum::<f32>()
391 })
392 .fold(f32::INFINITY, f32::min);
393
394 if min_dist_sq > threshold_sq {
395 Some(i)
396 } else {
397 None
398 }
399 })
400 .collect()
401}
402
403#[cfg(feature = "npy")]
415pub fn update_centroids(
416 index_path: &Path,
417 new_embeddings: &[Array2<f32>],
418 cluster_threshold: f32,
419 config: &UpdateConfig,
420) -> Result<usize> {
421 use ndarray_npy::{ReadNpyExt, WriteNpyExt};
422
423 let centroids_path = index_path.join("centroids.npy");
424 if !centroids_path.exists() {
425 return Ok(0);
426 }
427
428 let existing_centroids: Array2<f32> = Array2::read_npy(File::open(¢roids_path)?)?;
430
431 let dim = existing_centroids.ncols();
433 let total_tokens: usize = new_embeddings.iter().map(|e| e.nrows()).sum();
434
435 if total_tokens == 0 {
436 return Ok(0);
437 }
438
439 let mut flat_embeddings = Array2::<f32>::zeros((total_tokens, dim));
440 let mut offset = 0;
441
442 for emb in new_embeddings {
443 let n = emb.nrows();
444 flat_embeddings
445 .slice_mut(s![offset..offset + n, ..])
446 .assign(emb);
447 offset += n;
448 }
449
450 let threshold_sq = cluster_threshold * cluster_threshold;
452 let outlier_indices = find_outliers(&flat_embeddings, &existing_centroids, threshold_sq);
453
454 let num_outliers = outlier_indices.len();
455 if num_outliers == 0 {
456 return Ok(0);
457 }
458
459 let mut outliers = Array2::<f32>::zeros((num_outliers, dim));
461 for (i, &idx) in outlier_indices.iter().enumerate() {
462 outliers.row_mut(i).assign(&flat_embeddings.row(idx));
463 }
464
465 let target_k =
468 ((num_outliers as f64 / config.max_points_per_centroid as f64).ceil() as usize).max(1) * 4;
469 let k_update = target_k.min(num_outliers); let kmeans_config = ComputeKmeansConfig {
473 kmeans_niters: config.kmeans_niters,
474 max_points_per_centroid: config.max_points_per_centroid,
475 seed: config.seed,
476 n_samples_kmeans: config.n_samples_kmeans,
477 num_partitions: Some(k_update),
478 };
479
480 let outlier_docs: Vec<Array2<f32>> = outlier_indices
482 .iter()
483 .map(|&idx| flat_embeddings.slice(s![idx..idx + 1, ..]).to_owned())
484 .collect();
485
486 let new_centroids = compute_kmeans(&outlier_docs, &kmeans_config)?;
487 let k_new = new_centroids.nrows();
488
489 let new_num_centroids = existing_centroids.nrows() + k_new;
491 let mut final_centroids = Array2::<f32>::zeros((new_num_centroids, dim));
492 final_centroids
493 .slice_mut(s![..existing_centroids.nrows(), ..])
494 .assign(&existing_centroids);
495 final_centroids
496 .slice_mut(s![existing_centroids.nrows().., ..])
497 .assign(&new_centroids);
498
499 final_centroids.write_npy(File::create(¢roids_path)?)?;
501
502 let ivf_lengths_path = index_path.join("ivf_lengths.npy");
504 if ivf_lengths_path.exists() {
505 let old_lengths: Array1<i32> = Array1::read_npy(File::open(&ivf_lengths_path)?)?;
506 let mut new_lengths = Array1::<i32>::zeros(new_num_centroids);
507 new_lengths
508 .slice_mut(s![..old_lengths.len()])
509 .assign(&old_lengths);
510 new_lengths.write_npy(File::create(&ivf_lengths_path)?)?;
511 }
512
513 let meta_path = index_path.join("metadata.json");
515 if meta_path.exists() {
516 let mut meta: serde_json::Value =
517 serde_json::from_reader(BufReader::new(File::open(&meta_path)?))?;
518
519 if let Some(obj) = meta.as_object_mut() {
520 obj.insert("num_partitions".to_string(), new_num_centroids.into());
521 }
522
523 serde_json::to_writer_pretty(BufWriter::new(File::create(&meta_path)?), &meta)?;
524 }
525
526 Ok(k_new)
527}
528
529#[cfg(feature = "npy")]
547pub fn update_index(
548 embeddings: &[Array2<f32>],
549 index_path: &str,
550 codec: &ResidualCodec,
551 batch_size: Option<usize>,
552 update_threshold: bool,
553) -> Result<usize> {
554 let batch_size = batch_size.unwrap_or(DEFAULT_BATCH_SIZE);
555 let index_dir = Path::new(index_path);
556
557 let metadata_path = index_dir.join("metadata.json");
559 let metadata: Metadata = serde_json::from_reader(BufReader::new(
560 File::open(&metadata_path)
561 .map_err(|e| Error::IndexLoad(format!("Failed to open metadata: {}", e)))?,
562 ))?;
563
564 let num_existing_chunks = metadata.num_chunks;
565 let old_num_documents = metadata.num_documents;
566 let old_total_embeddings = metadata.num_embeddings;
567 let num_centroids = codec.num_centroids();
568 let embedding_dim = codec.embedding_dim();
569 let nbits = metadata.nbits;
570
571 let mut start_chunk_idx = num_existing_chunks;
573 let mut append_to_last = false;
574 let mut current_emb_offset = old_total_embeddings;
575
576 if start_chunk_idx > 0 {
578 let last_idx = start_chunk_idx - 1;
579 let last_meta_path = index_dir.join(format!("{}.metadata.json", last_idx));
580
581 if last_meta_path.exists() {
582 let last_meta: serde_json::Value =
583 serde_json::from_reader(BufReader::new(File::open(&last_meta_path).map_err(
584 |e| Error::IndexLoad(format!("Failed to open chunk metadata: {}", e)),
585 )?))?;
586
587 if let Some(nd) = last_meta.get("num_documents").and_then(|x| x.as_u64()) {
588 if nd < 2000 {
589 start_chunk_idx = last_idx;
590 append_to_last = true;
591
592 if let Some(off) = last_meta.get("embedding_offset").and_then(|x| x.as_u64()) {
593 current_emb_offset = off as usize;
594 } else {
595 let embs_in_last = last_meta
596 .get("num_embeddings")
597 .and_then(|x| x.as_u64())
598 .unwrap_or(0) as usize;
599 current_emb_offset = old_total_embeddings - embs_in_last;
600 }
601 }
602 }
603 }
604 }
605
606 let num_new_documents = embeddings.len();
608 let n_new_chunks = (num_new_documents as f64 / batch_size as f64).ceil() as usize;
609
610 let mut new_codes_accumulated: Vec<Vec<usize>> = Vec::new();
611 let mut new_doclens_accumulated: Vec<i64> = Vec::new();
612 let mut all_residual_norms: Vec<f32> = Vec::new();
613
614 let progress = indicatif::ProgressBar::new(n_new_chunks as u64);
615 progress.set_message("Updating index...");
616
617 let packed_dim = embedding_dim * nbits / 8;
618
619 for i in 0..n_new_chunks {
620 let global_chunk_idx = start_chunk_idx + i;
621 let chk_offset = i * batch_size;
622 let chk_end = (chk_offset + batch_size).min(num_new_documents);
623 let chunk_docs = &embeddings[chk_offset..chk_end];
624
625 let mut chk_doclens: Vec<i64> = chunk_docs.iter().map(|d| d.nrows() as i64).collect();
627 let total_tokens: usize = chk_doclens.iter().sum::<i64>() as usize;
628
629 let mut batch_embeddings = ndarray::Array2::<f32>::zeros((total_tokens, embedding_dim));
631 let mut offset = 0;
632 for doc in chunk_docs {
633 let n = doc.nrows();
634 batch_embeddings
635 .slice_mut(s![offset..offset + n, ..])
636 .assign(doc);
637 offset += n;
638 }
639
640 let batch_codes = codec.compress_into_codes(&batch_embeddings);
642
643 let mut batch_residuals = batch_embeddings;
645 {
646 let centroids = &codec.centroids;
647 batch_residuals
648 .axis_iter_mut(Axis(0))
649 .into_par_iter()
650 .zip(batch_codes.as_slice().unwrap().par_iter())
651 .for_each(|(mut row, &code)| {
652 let centroid = centroids.row(code);
653 row.iter_mut()
654 .zip(centroid.iter())
655 .for_each(|(r, c)| *r -= c);
656 });
657 }
658
659 if update_threshold {
661 for row in batch_residuals.axis_iter(Axis(0)) {
662 let norm = row.dot(&row).sqrt();
663 all_residual_norms.push(norm);
664 }
665 }
666
667 let batch_packed = codec.quantize_residuals(&batch_residuals)?;
669
670 let mut chk_codes_list: Vec<usize> = batch_codes.iter().copied().collect();
672 let mut chk_residuals_list: Vec<u8> = batch_packed.iter().copied().collect();
673
674 let mut code_offset = 0;
676 for &len in &chk_doclens {
677 let len_usize = len as usize;
678 let codes: Vec<usize> = batch_codes
679 .slice(s![code_offset..code_offset + len_usize])
680 .iter()
681 .copied()
682 .collect();
683 new_codes_accumulated.push(codes);
684 new_doclens_accumulated.push(len);
685 code_offset += len_usize;
686 }
687
688 if i == 0 && append_to_last {
690 use ndarray_npy::ReadNpyExt;
691
692 let old_doclens_path = index_dir.join(format!("doclens.{}.json", global_chunk_idx));
693
694 if old_doclens_path.exists() {
695 let old_doclens: Vec<i64> =
696 serde_json::from_reader(BufReader::new(File::open(&old_doclens_path)?))?;
697
698 let old_codes_path = index_dir.join(format!("{}.codes.npy", global_chunk_idx));
699 let old_residuals_path =
700 index_dir.join(format!("{}.residuals.npy", global_chunk_idx));
701
702 let old_codes: Array1<i64> = Array1::read_npy(File::open(&old_codes_path)?)?;
703 let old_residuals: Array2<u8> = Array2::read_npy(File::open(&old_residuals_path)?)?;
704
705 let mut combined_codes: Vec<usize> =
707 old_codes.iter().map(|&x| x as usize).collect();
708 combined_codes.extend(chk_codes_list);
709 chk_codes_list = combined_codes;
710
711 let mut combined_residuals: Vec<u8> = old_residuals.iter().copied().collect();
712 combined_residuals.extend(chk_residuals_list);
713 chk_residuals_list = combined_residuals;
714
715 let mut combined_doclens = old_doclens;
716 combined_doclens.extend(chk_doclens);
717 chk_doclens = combined_doclens;
718 }
719 }
720
721 {
723 use ndarray_npy::WriteNpyExt;
724
725 let codes_arr: Array1<i64> = chk_codes_list.iter().map(|&x| x as i64).collect();
726 let codes_path = index_dir.join(format!("{}.codes.npy", global_chunk_idx));
727 codes_arr.write_npy(File::create(&codes_path)?)?;
728
729 let num_tokens = chk_codes_list.len();
730 let residuals_arr =
731 Array2::from_shape_vec((num_tokens, packed_dim), chk_residuals_list)
732 .map_err(|e| Error::Shape(format!("Failed to reshape residuals: {}", e)))?;
733 let residuals_path = index_dir.join(format!("{}.residuals.npy", global_chunk_idx));
734 residuals_arr.write_npy(File::create(&residuals_path)?)?;
735 }
736
737 let doclens_path = index_dir.join(format!("doclens.{}.json", global_chunk_idx));
739 serde_json::to_writer(BufWriter::new(File::create(&doclens_path)?), &chk_doclens)?;
740
741 let chk_meta = serde_json::json!({
743 "num_documents": chk_doclens.len(),
744 "num_embeddings": chk_codes_list.len(),
745 "embedding_offset": current_emb_offset,
746 });
747 current_emb_offset += chk_codes_list.len();
748
749 let meta_path = index_dir.join(format!("{}.metadata.json", global_chunk_idx));
750 serde_json::to_writer_pretty(BufWriter::new(File::create(&meta_path)?), &chk_meta)?;
751
752 progress.inc(1);
753 }
754 progress.finish();
755
756 if update_threshold && !all_residual_norms.is_empty() {
758 let norms = Array1::from_vec(all_residual_norms);
759 update_cluster_threshold(index_dir, &norms, old_total_embeddings)?;
760 }
761
762 let mut partition_pids_map: HashMap<usize, Vec<i64>> = HashMap::new();
764 let mut pid_counter = old_num_documents as i64;
765
766 for doc_codes in &new_codes_accumulated {
767 for &code in doc_codes {
768 partition_pids_map
769 .entry(code)
770 .or_default()
771 .push(pid_counter);
772 }
773 pid_counter += 1;
774 }
775
776 {
778 use ndarray_npy::{ReadNpyExt, WriteNpyExt};
779
780 let ivf_path = index_dir.join("ivf.npy");
781 let ivf_lengths_path = index_dir.join("ivf_lengths.npy");
782
783 let old_ivf: Array1<i64> = if ivf_path.exists() {
784 Array1::read_npy(File::open(&ivf_path)?)?
785 } else {
786 Array1::zeros(0)
787 };
788
789 let old_ivf_lengths: Array1<i32> = if ivf_lengths_path.exists() {
790 Array1::read_npy(File::open(&ivf_lengths_path)?)?
791 } else {
792 Array1::zeros(num_centroids)
793 };
794
795 let mut old_offsets = vec![0i64];
797 for &len in old_ivf_lengths.iter() {
798 old_offsets.push(old_offsets.last().unwrap() + len as i64);
799 }
800
801 let mut new_ivf_data: Vec<i64> = Vec::new();
803 let mut new_ivf_lengths: Vec<i32> = Vec::with_capacity(num_centroids);
804
805 for centroid_id in 0..num_centroids {
806 let old_start = old_offsets[centroid_id] as usize;
808 let old_len = if centroid_id < old_ivf_lengths.len() {
809 old_ivf_lengths[centroid_id] as usize
810 } else {
811 0
812 };
813
814 let mut pids: Vec<i64> = if old_len > 0 && old_start + old_len <= old_ivf.len() {
815 old_ivf.slice(s![old_start..old_start + old_len]).to_vec()
816 } else {
817 Vec::new()
818 };
819
820 if let Some(new_pids) = partition_pids_map.get(¢roid_id) {
822 pids.extend(new_pids);
823 }
824
825 pids.sort_unstable();
827 pids.dedup();
828
829 new_ivf_lengths.push(pids.len() as i32);
830 new_ivf_data.extend(pids);
831 }
832
833 let new_ivf = Array1::from_vec(new_ivf_data);
835 new_ivf.write_npy(File::create(&ivf_path)?)?;
836
837 let new_lengths = Array1::from_vec(new_ivf_lengths);
838 new_lengths.write_npy(File::create(&ivf_lengths_path)?)?;
839 }
840
841 let new_total_chunks = start_chunk_idx + n_new_chunks;
843 let new_tokens_count: i64 = new_doclens_accumulated.iter().sum();
844 let num_embeddings = old_total_embeddings + new_tokens_count as usize;
845 let total_num_documents = old_num_documents + num_new_documents;
846
847 let new_avg_doclen = if total_num_documents > 0 {
848 let old_sum = metadata.avg_doclen * old_num_documents as f64;
849 (old_sum + new_tokens_count as f64) / total_num_documents as f64
850 } else {
851 0.0
852 };
853
854 let new_metadata = Metadata {
855 num_chunks: new_total_chunks,
856 nbits,
857 num_partitions: num_centroids,
858 num_embeddings,
859 avg_doclen: new_avg_doclen,
860 num_documents: total_num_documents,
861 };
862
863 serde_json::to_writer_pretty(BufWriter::new(File::create(&metadata_path)?), &new_metadata)?;
864
865 Ok(num_new_documents)
866}
867
868#[cfg(test)]
869mod tests {
870 use super::*;
871
872 #[test]
873 fn test_update_config_default() {
874 let config = UpdateConfig::default();
875 assert_eq!(config.batch_size, 50_000);
876 assert_eq!(config.buffer_size, 100);
877 assert_eq!(config.start_from_scratch, 999);
878 }
879
880 #[test]
881 #[cfg(feature = "npy")]
882 fn test_find_outliers() {
883 let centroids = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 1.0, 1.0]).unwrap();
885
886 let embeddings =
888 Array2::from_shape_vec((3, 2), vec![0.1, 0.1, 0.9, 0.9, 5.0, 5.0]).unwrap();
889
890 let outliers = find_outliers(&embeddings, ¢roids, 1.0);
892
893 assert_eq!(outliers.len(), 1);
895 assert_eq!(outliers[0], 2);
896 }
897
898 #[test]
899 #[cfg(feature = "npy")]
900 fn test_buffer_roundtrip() {
901 use tempfile::TempDir;
902
903 let dir = TempDir::new().unwrap();
904
905 let embeddings = vec![
907 Array2::from_shape_vec((3, 4), (0..12).map(|x| x as f32).collect()).unwrap(),
908 Array2::from_shape_vec((2, 4), (12..20).map(|x| x as f32).collect()).unwrap(),
909 Array2::from_shape_vec((5, 4), (20..40).map(|x| x as f32).collect()).unwrap(),
910 ];
911
912 save_buffer(dir.path(), &embeddings).unwrap();
914
915 let loaded = load_buffer(dir.path()).unwrap();
917
918 assert_eq!(loaded.len(), 3, "Should have 3 documents, not 1");
919 assert_eq!(loaded[0].nrows(), 3, "First doc should have 3 rows");
920 assert_eq!(loaded[1].nrows(), 2, "Second doc should have 2 rows");
921 assert_eq!(loaded[2].nrows(), 5, "Third doc should have 5 rows");
922
923 assert_eq!(loaded[0], embeddings[0]);
925 assert_eq!(loaded[1], embeddings[1]);
926 assert_eq!(loaded[2], embeddings[2]);
927 }
928
929 #[test]
930 #[cfg(feature = "npy")]
931 fn test_buffer_info_matches_buffer_len() {
932 use tempfile::TempDir;
933
934 let dir = TempDir::new().unwrap();
935
936 let embeddings: Vec<Array2<f32>> = (0..5)
938 .map(|i| {
939 let rows = i + 2; Array2::from_shape_fn((rows, 4), |(r, c)| (r * 4 + c) as f32)
941 })
942 .collect();
943
944 save_buffer(dir.path(), &embeddings).unwrap();
945
946 let info_count = load_buffer_info(dir.path()).unwrap();
948 let loaded = load_buffer(dir.path()).unwrap();
949
950 assert_eq!(info_count, 5, "buffer_info should report 5 docs");
951 assert_eq!(
952 loaded.len(),
953 5,
954 "load_buffer should return 5 docs to match buffer_info"
955 );
956 }
957}