1use std::collections::HashMap;
10use std::fs;
11use std::fs::File;
12use std::io::{BufReader, BufWriter};
13use std::path::Path;
14
15use serde::{Deserialize, Serialize};
16
17use ndarray::{s, Array1, Array2, Axis};
18use rayon::prelude::*;
19
20use crate::codec::ResidualCodec;
21use crate::error::Error;
22use crate::error::Result;
23use crate::index::Metadata;
24use crate::kmeans::compute_kmeans;
25use crate::kmeans::ComputeKmeansConfig;
26use crate::utils::quantile;
27
28const DEFAULT_BATCH_SIZE: usize = 50_000;
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct UpdateConfig {
34 pub batch_size: usize,
36 pub kmeans_niters: usize,
38 pub max_points_per_centroid: usize,
40 pub n_samples_kmeans: Option<usize>,
42 pub seed: u64,
44 pub start_from_scratch: usize,
46 pub buffer_size: usize,
48 #[serde(default)]
50 pub force_cpu: bool,
51}
52
53impl Default for UpdateConfig {
54 fn default() -> Self {
55 Self {
56 batch_size: DEFAULT_BATCH_SIZE,
57 kmeans_niters: 4,
58 max_points_per_centroid: 256,
59 n_samples_kmeans: None,
60 seed: 42,
61 start_from_scratch: 999,
62 buffer_size: 100,
63 force_cpu: false,
64 }
65 }
66}
67
68impl UpdateConfig {
69 pub fn to_kmeans_config(&self) -> ComputeKmeansConfig {
71 ComputeKmeansConfig {
72 kmeans_niters: self.kmeans_niters,
73 max_points_per_centroid: self.max_points_per_centroid,
74 seed: self.seed,
75 n_samples_kmeans: self.n_samples_kmeans,
76 num_partitions: None,
77 force_cpu: self.force_cpu,
78 }
79 }
80}
81
82pub fn load_buffer(index_path: &Path) -> Result<Vec<Array2<f32>>> {
91 use ndarray_npy::ReadNpyExt;
92
93 let buffer_path = index_path.join("buffer.npy");
94 let lengths_path = index_path.join("buffer_lengths.json");
95
96 if !buffer_path.exists() {
97 return Ok(Vec::new());
98 }
99
100 let flat: Array2<f32> = match Array2::read_npy(File::open(&buffer_path)?) {
102 Ok(arr) => arr,
103 Err(_) => return Ok(Vec::new()),
104 };
105
106 if lengths_path.exists() {
108 let lengths: Vec<i64> =
109 serde_json::from_reader(BufReader::new(File::open(&lengths_path)?))?;
110
111 let mut result = Vec::with_capacity(lengths.len());
112 let mut offset = 0;
113
114 for &len in &lengths {
115 let len_usize = len as usize;
116 if offset + len_usize > flat.nrows() {
117 break;
118 }
119 let doc_emb = flat.slice(s![offset..offset + len_usize, ..]).to_owned();
120 result.push(doc_emb);
121 offset += len_usize;
122 }
123
124 return Ok(result);
125 }
126
127 Ok(vec![flat])
129}
130
131pub fn save_buffer(index_path: &Path, embeddings: &[Array2<f32>]) -> Result<()> {
135 use ndarray_npy::WriteNpyExt;
136
137 let buffer_path = index_path.join("buffer.npy");
138
139 if embeddings.is_empty() {
142 return Ok(());
143 }
144
145 let dim = embeddings[0].ncols();
146 let total_rows: usize = embeddings.iter().map(|e| e.nrows()).sum();
147
148 let mut flat = Array2::<f32>::zeros((total_rows, dim));
149 let mut offset = 0;
150 let mut lengths = Vec::new();
151
152 for emb in embeddings {
153 let n = emb.nrows();
154 flat.slice_mut(s![offset..offset + n, ..]).assign(emb);
155 lengths.push(n as i64);
156 offset += n;
157 }
158
159 flat.write_npy(File::create(&buffer_path)?)?;
160
161 let lengths_path = index_path.join("buffer_lengths.json");
163 serde_json::to_writer(BufWriter::new(File::create(&lengths_path)?), &lengths)?;
164
165 let info_path = index_path.join("buffer_info.json");
167 let buffer_info = serde_json::json!({ "num_docs": embeddings.len() });
168 serde_json::to_writer(BufWriter::new(File::create(&info_path)?), &buffer_info)?;
169
170 Ok(())
171}
172
173pub fn load_buffer_info(index_path: &Path) -> Result<usize> {
177 let info_path = index_path.join("buffer_info.json");
178 if !info_path.exists() {
179 return Ok(0);
180 }
181
182 let info: serde_json::Value = serde_json::from_reader(BufReader::new(File::open(&info_path)?))?;
183
184 Ok(info.get("num_docs").and_then(|v| v.as_u64()).unwrap_or(0) as usize)
185}
186
187pub fn clear_buffer(index_path: &Path) -> Result<()> {
189 let buffer_path = index_path.join("buffer.npy");
190 let lengths_path = index_path.join("buffer_lengths.json");
191 let info_path = index_path.join("buffer_info.json");
192
193 if buffer_path.exists() {
194 fs::remove_file(&buffer_path)?;
195 }
196 if lengths_path.exists() {
197 fs::remove_file(&lengths_path)?;
198 }
199 if info_path.exists() {
200 fs::remove_file(&info_path)?;
201 }
202
203 Ok(())
204}
205
206pub fn load_embeddings_npy(index_path: &Path) -> Result<Vec<Array2<f32>>> {
211 use ndarray_npy::ReadNpyExt;
212
213 let emb_path = index_path.join("embeddings.npy");
214 let lengths_path = index_path.join("embeddings_lengths.json");
215
216 if !emb_path.exists() {
217 return Ok(Vec::new());
218 }
219
220 let flat: Array2<f32> = Array2::read_npy(File::open(&emb_path)?)?;
222
223 if lengths_path.exists() {
225 let lengths: Vec<i64> =
226 serde_json::from_reader(BufReader::new(File::open(&lengths_path)?))?;
227
228 let mut result = Vec::with_capacity(lengths.len());
229 let mut offset = 0;
230
231 for &len in &lengths {
232 let len_usize = len as usize;
233 if offset + len_usize > flat.nrows() {
234 break;
235 }
236 let doc_emb = flat.slice(s![offset..offset + len_usize, ..]).to_owned();
237 result.push(doc_emb);
238 offset += len_usize;
239 }
240
241 return Ok(result);
242 }
243
244 Ok(vec![flat])
246}
247
248pub fn save_embeddings_npy(index_path: &Path, embeddings: &[Array2<f32>]) -> Result<()> {
254 use ndarray_npy::WriteNpyExt;
255
256 if embeddings.is_empty() {
257 return Ok(());
258 }
259
260 let dim = embeddings[0].ncols();
261 let total_rows: usize = embeddings.iter().map(|e| e.nrows()).sum();
262
263 let mut flat = Array2::<f32>::zeros((total_rows, dim));
264 let mut offset = 0;
265 let mut lengths = Vec::with_capacity(embeddings.len());
266
267 for emb in embeddings {
268 let n = emb.nrows();
269 flat.slice_mut(s![offset..offset + n, ..]).assign(emb);
270 lengths.push(n as i64);
271 offset += n;
272 }
273
274 let emb_path = index_path.join("embeddings.npy");
276 flat.write_npy(File::create(&emb_path)?)?;
277
278 let lengths_path = index_path.join("embeddings_lengths.json");
280 serde_json::to_writer(BufWriter::new(File::create(&lengths_path)?), &lengths)?;
281
282 Ok(())
283}
284
285pub fn clear_embeddings_npy(index_path: &Path) -> Result<()> {
287 let emb_path = index_path.join("embeddings.npy");
288 let lengths_path = index_path.join("embeddings_lengths.json");
289
290 if emb_path.exists() {
291 fs::remove_file(&emb_path)?;
292 }
293 if lengths_path.exists() {
294 fs::remove_file(&lengths_path)?;
295 }
296 Ok(())
297}
298
299pub fn embeddings_npy_exists(index_path: &Path) -> bool {
301 index_path.join("embeddings.npy").exists()
302}
303
304pub fn load_cluster_threshold(index_path: &Path) -> Result<f32> {
310 use ndarray_npy::ReadNpyExt;
311
312 let thresh_path = index_path.join("cluster_threshold.npy");
313 if !thresh_path.exists() {
314 return Err(Error::Update("cluster_threshold.npy not found".into()));
315 }
316
317 let arr: Array1<f32> = Array1::read_npy(File::open(&thresh_path)?)?;
318 Ok(arr[0])
319}
320
321pub fn update_cluster_threshold(
323 index_path: &Path,
324 new_residual_norms: &Array1<f32>,
325 old_total_embeddings: usize,
326) -> Result<()> {
327 use ndarray_npy::{ReadNpyExt, WriteNpyExt};
328
329 let new_count = new_residual_norms.len();
330 if new_count == 0 {
331 return Ok(());
332 }
333
334 let new_threshold = quantile(new_residual_norms, 0.75);
335
336 let thresh_path = index_path.join("cluster_threshold.npy");
337 let final_threshold = if thresh_path.exists() {
338 let old_arr: Array1<f32> = Array1::read_npy(File::open(&thresh_path)?)?;
339 let old_threshold = old_arr[0];
340 let total = old_total_embeddings + new_count;
341 (old_threshold * old_total_embeddings as f32 + new_threshold * new_count as f32)
342 / total as f32
343 } else {
344 new_threshold
345 };
346
347 Array1::from_vec(vec![final_threshold]).write_npy(File::create(&thresh_path)?)?;
348
349 Ok(())
350}
351
352const OUTLIER_CENTROID_TILE: usize = 8;
357const OUTLIER_EMBEDDING_TILE: usize = 64;
358const OUTLIER_BLOCKS_MIN_LEN: usize = 4;
359const OUTLIER_THRESHOLD_RECHECK_REL_EPS: f32 = 1e-5;
360
361#[inline]
362fn squared_norm(row: &[f32]) -> f32 {
363 let mut sum0 = 0.0f32;
364 let mut sum1 = 0.0f32;
365 let mut sum2 = 0.0f32;
366 let mut sum3 = 0.0f32;
367
368 let mut i = 0;
369 while i + 4 <= row.len() {
370 sum0 += row[i] * row[i];
371 sum1 += row[i + 1] * row[i + 1];
372 sum2 += row[i + 2] * row[i + 2];
373 sum3 += row[i + 3] * row[i + 3];
374 i += 4;
375 }
376
377 let mut total = sum0 + sum1 + sum2 + sum3;
378 while i < row.len() {
379 total += row[i] * row[i];
380 i += 1;
381 }
382
383 total
384}
385
386#[inline]
391fn min_distance_sq_precise(row: &[f32], centroids_flat: &[f32], dim: usize) -> f32 {
392 let mut min_dist_sq = f32::INFINITY;
393
394 for centroid in centroids_flat.chunks_exact(dim) {
395 let mut dist_sq = 0.0f64;
396 let mut d = 0;
397 while d < dim {
398 let diff = row[d] as f64 - centroid[d] as f64;
399 dist_sq += diff * diff;
400 d += 1;
401 }
402
403 min_dist_sq = min_dist_sq.min(dist_sq as f32);
404 }
405
406 min_dist_sq
407}
408
409#[allow(clippy::needless_range_loop)]
424fn find_outliers(
425 flat_embeddings: &Array2<f32>,
426 centroids: &Array2<f32>,
427 threshold_sq: f32,
428) -> Vec<usize> {
429 let n = flat_embeddings.nrows();
430 let k = centroids.nrows();
431 let dim = flat_embeddings.ncols();
432
433 if n == 0 || k == 0 {
434 return Vec::new();
435 }
436
437 let embeddings_owned;
438 let embeddings_flat = if let Some(slice) = flat_embeddings.as_slice_memory_order() {
439 slice
440 } else {
441 embeddings_owned = flat_embeddings.as_standard_layout().to_owned();
442 embeddings_owned
443 .as_slice_memory_order()
444 .expect("standard-layout embeddings should be contiguous")
445 };
446 let centroids_owned;
447 let centroids_flat = if let Some(slice) = centroids.as_slice_memory_order() {
448 slice
449 } else {
450 centroids_owned = centroids.as_standard_layout().to_owned();
451 centroids_owned
452 .as_slice_memory_order()
453 .expect("standard-layout centroids should be contiguous")
454 };
455
456 let centroid_norms_sq: Vec<f32> = centroids_flat
457 .par_chunks_exact(dim)
458 .map(squared_norm)
459 .collect();
460
461 let row_stride = dim * OUTLIER_EMBEDDING_TILE;
462 embeddings_flat
463 .par_chunks(row_stride)
464 .with_min_len(OUTLIER_BLOCKS_MIN_LEN)
465 .enumerate()
466 .flat_map_iter(|(block_idx, rows_block)| {
467 let row_count = rows_block.len() / dim;
468 let row_offset = block_idx * OUTLIER_EMBEDDING_TILE;
469
470 let mut min_dists = vec![f32::INFINITY; row_count];
471 let emb_norms: Vec<f32> = rows_block.chunks_exact(dim).map(squared_norm).collect();
472
473 let mut centroid_idx = 0;
474 while centroid_idx + OUTLIER_CENTROID_TILE <= k {
475 let centroid_bases: [usize; OUTLIER_CENTROID_TILE] =
476 std::array::from_fn(|j| (centroid_idx + j) * dim);
477
478 let mut dots = [[0.0f32; OUTLIER_CENTROID_TILE]; OUTLIER_EMBEDDING_TILE];
479
480 let mut dim_idx = 0;
481 while dim_idx < dim {
482 let centroid_vals: [f32; OUTLIER_CENTROID_TILE] =
483 std::array::from_fn(|j| centroids_flat[centroid_bases[j] + dim_idx]);
484
485 for row_idx in 0..row_count {
486 let x = rows_block[row_idx * dim + dim_idx];
487 for j in 0..OUTLIER_CENTROID_TILE {
488 dots[row_idx][j] += x * centroid_vals[j];
489 }
490 }
491
492 dim_idx += 1;
493 }
494
495 for row_idx in 0..row_count {
496 let emb_norm_sq = emb_norms[row_idx];
497 for j in 0..OUTLIER_CENTROID_TILE {
498 let dist_sq = emb_norm_sq + centroid_norms_sq[centroid_idx + j]
499 - 2.0 * dots[row_idx][j];
500 min_dists[row_idx] = min_dists[row_idx].min(dist_sq);
501 }
502 }
503
504 centroid_idx += OUTLIER_CENTROID_TILE;
505 }
506
507 while centroid_idx < k {
508 let centroid = ¢roids_flat[centroid_idx * dim..(centroid_idx + 1) * dim];
509 for row_idx in 0..row_count {
510 let row = &rows_block[row_idx * dim..(row_idx + 1) * dim];
511 let mut dot = 0.0f32;
512 let mut dim_idx = 0;
513 while dim_idx < dim {
514 dot += row[dim_idx] * centroid[dim_idx];
515 dim_idx += 1;
516 }
517
518 let dist_sq = emb_norms[row_idx] + centroid_norms_sq[centroid_idx] - 2.0 * dot;
519 min_dists[row_idx] = min_dists[row_idx].min(dist_sq);
520 }
521
522 centroid_idx += 1;
523 }
524
525 min_dists
526 .into_iter()
527 .enumerate()
528 .filter_map(move |(row_idx, min_dist_sq)| {
529 let final_min_dist_sq = if (min_dist_sq - threshold_sq).abs()
530 <= threshold_sq.abs().max(1.0) * OUTLIER_THRESHOLD_RECHECK_REL_EPS
531 {
532 let row = &rows_block[row_idx * dim..(row_idx + 1) * dim];
533 min_distance_sq_precise(row, centroids_flat, dim)
534 } else {
535 min_dist_sq
536 };
537
538 (final_min_dist_sq > threshold_sq).then_some(row_offset + row_idx)
539 })
540 })
541 .collect()
542}
543
544pub fn update_centroids(
556 index_path: &Path,
557 new_embeddings: &[Array2<f32>],
558 cluster_threshold: f32,
559 config: &UpdateConfig,
560) -> Result<usize> {
561 use ndarray_npy::{ReadNpyExt, WriteNpyExt};
562
563 let centroids_path = index_path.join("centroids.npy");
564 if !centroids_path.exists() {
565 return Ok(0);
566 }
567
568 let existing_centroids: Array2<f32> = Array2::read_npy(File::open(¢roids_path)?)?;
570
571 let dim = existing_centroids.ncols();
573 let total_tokens: usize = new_embeddings.iter().map(|e| e.nrows()).sum();
574
575 if total_tokens == 0 {
576 return Ok(0);
577 }
578
579 let mut flat_embeddings = Array2::<f32>::zeros((total_tokens, dim));
580 let mut offset = 0;
581
582 for emb in new_embeddings {
583 let n = emb.nrows();
584 flat_embeddings
585 .slice_mut(s![offset..offset + n, ..])
586 .assign(emb);
587 offset += n;
588 }
589
590 let threshold_sq = cluster_threshold * cluster_threshold;
592 let outlier_indices = find_outliers(&flat_embeddings, &existing_centroids, threshold_sq);
593
594 let num_outliers = outlier_indices.len();
595 if num_outliers == 0 {
596 return Ok(0);
597 }
598
599 let mut outliers = Array2::<f32>::zeros((num_outliers, dim));
601 for (i, &idx) in outlier_indices.iter().enumerate() {
602 outliers.row_mut(i).assign(&flat_embeddings.row(idx));
603 }
604
605 let target_k =
608 ((num_outliers as f64 / config.max_points_per_centroid as f64).ceil() as usize).max(1) * 4;
609 let k_update = target_k.min(num_outliers); let kmeans_config = ComputeKmeansConfig {
613 kmeans_niters: config.kmeans_niters,
614 max_points_per_centroid: config.max_points_per_centroid,
615 seed: config.seed,
616 n_samples_kmeans: config.n_samples_kmeans,
617 num_partitions: Some(k_update),
618 force_cpu: config.force_cpu,
619 };
620
621 let outlier_docs: Vec<Array2<f32>> = outlier_indices
623 .iter()
624 .map(|&idx| flat_embeddings.slice(s![idx..idx + 1, ..]).to_owned())
625 .collect();
626
627 let new_centroids = compute_kmeans(&outlier_docs, &kmeans_config)?;
628 let k_new = new_centroids.nrows();
629
630 let new_num_centroids = existing_centroids.nrows() + k_new;
632 let mut final_centroids = Array2::<f32>::zeros((new_num_centroids, dim));
633 final_centroids
634 .slice_mut(s![..existing_centroids.nrows(), ..])
635 .assign(&existing_centroids);
636 final_centroids
637 .slice_mut(s![existing_centroids.nrows().., ..])
638 .assign(&new_centroids);
639
640 final_centroids.write_npy(File::create(¢roids_path)?)?;
642
643 let ivf_lengths_path = index_path.join("ivf_lengths.npy");
645 if ivf_lengths_path.exists() {
646 let old_lengths: Array1<i32> = Array1::read_npy(File::open(&ivf_lengths_path)?)?;
647 let mut new_lengths = Array1::<i32>::zeros(new_num_centroids);
648 new_lengths
649 .slice_mut(s![..old_lengths.len()])
650 .assign(&old_lengths);
651 new_lengths.write_npy(File::create(&ivf_lengths_path)?)?;
652 }
653
654 let meta_path = index_path.join("metadata.json");
656 if meta_path.exists() {
657 let mut meta: serde_json::Value =
658 serde_json::from_reader(BufReader::new(File::open(&meta_path)?))?;
659
660 if let Some(obj) = meta.as_object_mut() {
661 obj.insert("num_partitions".to_string(), new_num_centroids.into());
662 }
663
664 serde_json::to_writer_pretty(BufWriter::new(File::create(&meta_path)?), &meta)?;
665 }
666
667 Ok(k_new)
668}
669
670pub fn update_index(
689 embeddings: &[Array2<f32>],
690 index_path: &str,
691 codec: &ResidualCodec,
692 batch_size: Option<usize>,
693 update_threshold: bool,
694 force_cpu: bool,
695) -> Result<usize> {
696 let batch_size = batch_size.unwrap_or(DEFAULT_BATCH_SIZE);
697 let index_dir = Path::new(index_path);
698
699 let metadata_path = index_dir.join("metadata.json");
701 let metadata = Metadata::load_from_path(index_dir)?;
702
703 let num_existing_chunks = metadata.num_chunks;
704 let old_num_documents = metadata.num_documents;
705 let old_total_embeddings = metadata.num_embeddings;
706 let num_centroids = codec.num_centroids();
707 let embedding_dim = codec.embedding_dim();
708 let nbits = metadata.nbits;
709
710 let mut start_chunk_idx = num_existing_chunks;
712 let mut append_to_last = false;
713 let mut current_emb_offset = old_total_embeddings;
714
715 if start_chunk_idx > 0 {
717 let last_idx = start_chunk_idx - 1;
718 let last_meta_path = index_dir.join(format!("{}.metadata.json", last_idx));
719
720 if last_meta_path.exists() {
721 let last_meta: serde_json::Value =
722 serde_json::from_reader(BufReader::new(File::open(&last_meta_path).map_err(
723 |e| Error::IndexLoad(format!("Failed to open chunk metadata: {}", e)),
724 )?))?;
725
726 if let Some(nd) = last_meta.get("num_documents").and_then(|x| x.as_u64()) {
727 if nd < 2000 {
728 start_chunk_idx = last_idx;
729 append_to_last = true;
730
731 if let Some(off) = last_meta.get("embedding_offset").and_then(|x| x.as_u64()) {
732 current_emb_offset = off as usize;
733 } else {
734 let embs_in_last = last_meta
735 .get("num_embeddings")
736 .and_then(|x| x.as_u64())
737 .unwrap_or(0) as usize;
738 current_emb_offset = old_total_embeddings - embs_in_last;
739 }
740 }
741 }
742 }
743 }
744
745 let num_new_documents = embeddings.len();
747 let n_new_chunks = (num_new_documents as f64 / batch_size as f64).ceil() as usize;
748
749 let mut new_codes_accumulated: Vec<Vec<usize>> = Vec::new();
750 let mut new_doclens_accumulated: Vec<i64> = Vec::new();
751 let mut all_residual_norms: Vec<f32> = Vec::new();
752
753 let packed_dim = embedding_dim * nbits / 8;
754
755 for i in 0..n_new_chunks {
756 let global_chunk_idx = start_chunk_idx + i;
757 let chk_offset = i * batch_size;
758 let chk_end = (chk_offset + batch_size).min(num_new_documents);
759 let chunk_docs = &embeddings[chk_offset..chk_end];
760
761 let mut chk_doclens: Vec<i64> = chunk_docs.iter().map(|d| d.nrows() as i64).collect();
763 let total_tokens: usize = chk_doclens.iter().sum::<i64>() as usize;
764
765 let mut batch_embeddings = ndarray::Array2::<f32>::zeros((total_tokens, embedding_dim));
767 let mut offset = 0;
768 for doc in chunk_docs {
769 let n = doc.nrows();
770 batch_embeddings
771 .slice_mut(s![offset..offset + n, ..])
772 .assign(doc);
773 offset += n;
774 }
775
776 let batch_codes = if force_cpu {
779 codec.compress_into_codes_cpu(&batch_embeddings)
780 } else {
781 codec.compress_into_codes(&batch_embeddings)
782 };
783
784 let mut batch_residuals = batch_embeddings;
786 {
787 let centroids = &codec.centroids;
788 batch_residuals
789 .axis_iter_mut(Axis(0))
790 .into_par_iter()
791 .zip(batch_codes.as_slice().unwrap().par_iter())
792 .for_each(|(mut row, &code)| {
793 let centroid = centroids.row(code);
794 row.iter_mut()
795 .zip(centroid.iter())
796 .for_each(|(r, c)| *r -= c);
797 });
798 }
799
800 if update_threshold {
802 for row in batch_residuals.axis_iter(Axis(0)) {
803 let norm = row.dot(&row).sqrt();
804 all_residual_norms.push(norm);
805 }
806 }
807
808 let batch_packed = codec.quantize_residuals(&batch_residuals)?;
810
811 let mut chk_codes_list: Vec<usize> = batch_codes.iter().copied().collect();
813 let mut chk_residuals_list: Vec<u8> = batch_packed.iter().copied().collect();
814
815 let mut code_offset = 0;
817 for &len in &chk_doclens {
818 let len_usize = len as usize;
819 let codes: Vec<usize> = batch_codes
820 .slice(s![code_offset..code_offset + len_usize])
821 .iter()
822 .copied()
823 .collect();
824 new_codes_accumulated.push(codes);
825 new_doclens_accumulated.push(len);
826 code_offset += len_usize;
827 }
828
829 if i == 0 && append_to_last {
831 use ndarray_npy::ReadNpyExt;
832
833 let old_doclens_path = index_dir.join(format!("doclens.{}.json", global_chunk_idx));
834
835 if old_doclens_path.exists() {
836 let old_doclens: Vec<i64> =
837 serde_json::from_reader(BufReader::new(File::open(&old_doclens_path)?))?;
838
839 let old_codes_path = index_dir.join(format!("{}.codes.npy", global_chunk_idx));
840 let old_residuals_path =
841 index_dir.join(format!("{}.residuals.npy", global_chunk_idx));
842
843 let old_codes: Array1<i64> = Array1::read_npy(File::open(&old_codes_path)?)?;
844 let old_residuals: Array2<u8> = Array2::read_npy(File::open(&old_residuals_path)?)?;
845
846 let mut combined_codes: Vec<usize> =
848 old_codes.iter().map(|&x| x as usize).collect();
849 combined_codes.extend(chk_codes_list);
850 chk_codes_list = combined_codes;
851
852 let mut combined_residuals: Vec<u8> = old_residuals.iter().copied().collect();
853 combined_residuals.extend(chk_residuals_list);
854 chk_residuals_list = combined_residuals;
855
856 let mut combined_doclens = old_doclens;
857 combined_doclens.extend(chk_doclens);
858 chk_doclens = combined_doclens;
859 }
860 }
861
862 {
864 use ndarray_npy::WriteNpyExt;
865
866 let codes_arr: Array1<i64> = chk_codes_list.iter().map(|&x| x as i64).collect();
867 let codes_path = index_dir.join(format!("{}.codes.npy", global_chunk_idx));
868 codes_arr.write_npy(File::create(&codes_path)?)?;
869
870 let num_tokens = chk_codes_list.len();
871 let residuals_arr =
872 Array2::from_shape_vec((num_tokens, packed_dim), chk_residuals_list)
873 .map_err(|e| Error::Shape(format!("Failed to reshape residuals: {}", e)))?;
874 let residuals_path = index_dir.join(format!("{}.residuals.npy", global_chunk_idx));
875 residuals_arr.write_npy(File::create(&residuals_path)?)?;
876 }
877
878 let doclens_path = index_dir.join(format!("doclens.{}.json", global_chunk_idx));
880 serde_json::to_writer(BufWriter::new(File::create(&doclens_path)?), &chk_doclens)?;
881
882 let chk_meta = serde_json::json!({
884 "num_documents": chk_doclens.len(),
885 "num_embeddings": chk_codes_list.len(),
886 "embedding_offset": current_emb_offset,
887 });
888 current_emb_offset += chk_codes_list.len();
889
890 let meta_path = index_dir.join(format!("{}.metadata.json", global_chunk_idx));
891 serde_json::to_writer_pretty(BufWriter::new(File::create(&meta_path)?), &chk_meta)?;
892 }
893
894 if update_threshold && !all_residual_norms.is_empty() {
896 let norms = Array1::from_vec(all_residual_norms);
897 update_cluster_threshold(index_dir, &norms, old_total_embeddings)?;
898 }
899
900 let mut partition_pids_map: HashMap<usize, Vec<i64>> = HashMap::new();
902
903 for (pid_counter, doc_codes) in (old_num_documents as i64..).zip(new_codes_accumulated.iter()) {
904 for &code in doc_codes {
905 partition_pids_map
906 .entry(code)
907 .or_default()
908 .push(pid_counter);
909 }
910 }
911
912 {
914 use ndarray_npy::{ReadNpyExt, WriteNpyExt};
915
916 let ivf_path = index_dir.join("ivf.npy");
917 let ivf_lengths_path = index_dir.join("ivf_lengths.npy");
918
919 let old_ivf: Array1<i64> = if ivf_path.exists() {
920 Array1::read_npy(File::open(&ivf_path)?)?
921 } else {
922 Array1::zeros(0)
923 };
924
925 let old_ivf_lengths: Array1<i32> = if ivf_lengths_path.exists() {
926 Array1::read_npy(File::open(&ivf_lengths_path)?)?
927 } else {
928 Array1::zeros(num_centroids)
929 };
930
931 let mut old_offsets = vec![0i64];
933 for &len in old_ivf_lengths.iter() {
934 old_offsets.push(old_offsets.last().unwrap() + len as i64);
935 }
936
937 let mut new_ivf_data: Vec<i64> = Vec::new();
939 let mut new_ivf_lengths: Vec<i32> = Vec::with_capacity(num_centroids);
940
941 for centroid_id in 0..num_centroids {
942 let old_start = old_offsets[centroid_id] as usize;
944 let old_len = if centroid_id < old_ivf_lengths.len() {
945 old_ivf_lengths[centroid_id] as usize
946 } else {
947 0
948 };
949
950 let mut pids: Vec<i64> = if old_len > 0 && old_start + old_len <= old_ivf.len() {
951 old_ivf.slice(s![old_start..old_start + old_len]).to_vec()
952 } else {
953 Vec::new()
954 };
955
956 if let Some(new_pids) = partition_pids_map.get(¢roid_id) {
958 pids.extend(new_pids);
959 }
960
961 pids.sort_unstable();
963 pids.dedup();
964
965 new_ivf_lengths.push(pids.len() as i32);
966 new_ivf_data.extend(pids);
967 }
968
969 let new_ivf = Array1::from_vec(new_ivf_data);
971 new_ivf.write_npy(File::create(&ivf_path)?)?;
972
973 let new_lengths = Array1::from_vec(new_ivf_lengths);
974 new_lengths.write_npy(File::create(&ivf_lengths_path)?)?;
975 }
976
977 let new_total_chunks = start_chunk_idx + n_new_chunks;
979 let new_tokens_count: i64 = new_doclens_accumulated.iter().sum();
980 let num_embeddings = old_total_embeddings + new_tokens_count as usize;
981 let total_num_documents = old_num_documents + num_new_documents;
982
983 let new_avg_doclen = if total_num_documents > 0 {
984 let old_sum = metadata.avg_doclen * old_num_documents as f64;
985 (old_sum + new_tokens_count as f64) / total_num_documents as f64
986 } else {
987 0.0
988 };
989
990 let new_metadata = Metadata {
991 num_chunks: new_total_chunks,
992 nbits,
993 num_partitions: num_centroids,
994 num_embeddings,
995 avg_doclen: new_avg_doclen,
996 num_documents: total_num_documents,
997 embedding_dim,
998 next_plaid_compatible: true,
999 };
1000
1001 serde_json::to_writer_pretty(BufWriter::new(File::create(&metadata_path)?), &new_metadata)?;
1002
1003 crate::mmap::clear_merged_files(index_dir)?;
1006
1007 Ok(num_new_documents)
1008}
1009
1010#[cfg(test)]
1011mod tests {
1012 use super::*;
1013
1014 #[test]
1015 fn test_update_config_default() {
1016 let config = UpdateConfig::default();
1017 assert_eq!(config.batch_size, 50_000);
1018 assert_eq!(config.buffer_size, 100);
1019 assert_eq!(config.start_from_scratch, 999);
1020 }
1021
1022 #[test]
1023 fn test_find_outliers() {
1024 let centroids = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 1.0, 1.0]).unwrap();
1026
1027 let embeddings =
1029 Array2::from_shape_vec((3, 2), vec![0.1, 0.1, 0.9, 0.9, 5.0, 5.0]).unwrap();
1030
1031 let outliers = find_outliers(&embeddings, ¢roids, 1.0);
1033
1034 assert_eq!(outliers.len(), 1);
1036 assert_eq!(outliers[0], 2);
1037 }
1038
1039 #[test]
1040 fn test_buffer_roundtrip() {
1041 use tempfile::TempDir;
1042
1043 let dir = TempDir::new().unwrap();
1044
1045 let embeddings = vec![
1047 Array2::from_shape_vec((3, 4), (0..12).map(|x| x as f32).collect()).unwrap(),
1048 Array2::from_shape_vec((2, 4), (12..20).map(|x| x as f32).collect()).unwrap(),
1049 Array2::from_shape_vec((5, 4), (20..40).map(|x| x as f32).collect()).unwrap(),
1050 ];
1051
1052 save_buffer(dir.path(), &embeddings).unwrap();
1054
1055 let loaded = load_buffer(dir.path()).unwrap();
1057
1058 assert_eq!(loaded.len(), 3, "Should have 3 documents, not 1");
1059 assert_eq!(loaded[0].nrows(), 3, "First doc should have 3 rows");
1060 assert_eq!(loaded[1].nrows(), 2, "Second doc should have 2 rows");
1061 assert_eq!(loaded[2].nrows(), 5, "Third doc should have 5 rows");
1062
1063 assert_eq!(loaded[0], embeddings[0]);
1065 assert_eq!(loaded[1], embeddings[1]);
1066 assert_eq!(loaded[2], embeddings[2]);
1067 }
1068
1069 #[test]
1070 fn test_buffer_info_matches_buffer_len() {
1071 use tempfile::TempDir;
1072
1073 let dir = TempDir::new().unwrap();
1074
1075 let embeddings: Vec<Array2<f32>> = (0..5)
1077 .map(|i| {
1078 let rows = i + 2; Array2::from_shape_fn((rows, 4), |(r, c)| (r * 4 + c) as f32)
1080 })
1081 .collect();
1082
1083 save_buffer(dir.path(), &embeddings).unwrap();
1084
1085 let info_count = load_buffer_info(dir.path()).unwrap();
1087 let loaded = load_buffer(dir.path()).unwrap();
1088
1089 assert_eq!(info_count, 5, "buffer_info should report 5 docs");
1090 assert_eq!(
1091 loaded.len(),
1092 5,
1093 "load_buffer should return 5 docs to match buffer_info"
1094 );
1095 }
1096}