Skip to main content

next_plaid/
update.rs

1//! Index update functionality for adding new documents.
2//!
3//! This module provides functions to incrementally update an existing PLAID index
4//! with new documents, matching fast-plaid's behavior:
5//! - Buffer mechanism for small updates
6//! - Centroid expansion for outliers
7//! - Cluster threshold updates
8
9use std::collections::HashMap;
10use std::fs;
11use std::fs::File;
12use std::io::{BufReader, BufWriter};
13use std::path::Path;
14
15use serde::{Deserialize, Serialize};
16
17use ndarray::{s, Array1, Array2, Axis};
18use rayon::prelude::*;
19
20use crate::codec::ResidualCodec;
21use crate::error::Error;
22use crate::error::Result;
23use crate::index::Metadata;
24use crate::kmeans::compute_kmeans;
25use crate::kmeans::ComputeKmeansConfig;
26use crate::utils::quantile;
27
28/// Default batch size for processing documents (matches fast-plaid).
29const DEFAULT_BATCH_SIZE: usize = 50_000;
30
31/// Configuration for index updates.
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct UpdateConfig {
34    /// Batch size for processing documents (default: 50,000)
35    pub batch_size: usize,
36    /// Number of K-means iterations for centroid expansion (default: 4)
37    pub kmeans_niters: usize,
38    /// Max points per centroid for K-means (default: 256)
39    pub max_points_per_centroid: usize,
40    /// Number of samples for K-means (default: auto-calculated)
41    pub n_samples_kmeans: Option<usize>,
42    /// Random seed (default: 42)
43    pub seed: u64,
44    /// If index has fewer docs than this, rebuild from scratch (default: 999)
45    pub start_from_scratch: usize,
46    /// Buffer size before triggering centroid expansion (default: 100)
47    pub buffer_size: usize,
48    /// Force CPU execution for K-means even when CUDA feature is enabled.
49    #[serde(default)]
50    pub force_cpu: bool,
51}
52
53impl Default for UpdateConfig {
54    fn default() -> Self {
55        Self {
56            batch_size: DEFAULT_BATCH_SIZE,
57            kmeans_niters: 4,
58            max_points_per_centroid: 256,
59            n_samples_kmeans: None,
60            seed: 42,
61            start_from_scratch: 999,
62            buffer_size: 100,
63            force_cpu: false,
64        }
65    }
66}
67
68impl UpdateConfig {
69    /// Convert to ComputeKmeansConfig for centroid expansion.
70    pub fn to_kmeans_config(&self) -> ComputeKmeansConfig {
71        ComputeKmeansConfig {
72            kmeans_niters: self.kmeans_niters,
73            max_points_per_centroid: self.max_points_per_centroid,
74            seed: self.seed,
75            n_samples_kmeans: self.n_samples_kmeans,
76            num_partitions: None,
77            force_cpu: self.force_cpu,
78        }
79    }
80}
81
82// ============================================================================
83// Buffer Management
84// ============================================================================
85
86/// Load buffered embeddings from buffer.npy.
87///
88/// Returns an empty vector if buffer.npy doesn't exist.
89/// Uses buffer_lengths.json to split the flattened array back into per-document arrays.
90pub fn load_buffer(index_path: &Path) -> Result<Vec<Array2<f32>>> {
91    use ndarray_npy::ReadNpyExt;
92
93    let buffer_path = index_path.join("buffer.npy");
94    let lengths_path = index_path.join("buffer_lengths.json");
95
96    if !buffer_path.exists() {
97        return Ok(Vec::new());
98    }
99
100    // Load the flattened embeddings array
101    let flat: Array2<f32> = match Array2::read_npy(File::open(&buffer_path)?) {
102        Ok(arr) => arr,
103        Err(_) => return Ok(Vec::new()),
104    };
105
106    // Load lengths to split back into per-document arrays
107    if lengths_path.exists() {
108        let lengths: Vec<i64> =
109            serde_json::from_reader(BufReader::new(File::open(&lengths_path)?))?;
110
111        let mut result = Vec::with_capacity(lengths.len());
112        let mut offset = 0;
113
114        for &len in &lengths {
115            let len_usize = len as usize;
116            if offset + len_usize > flat.nrows() {
117                break;
118            }
119            let doc_emb = flat.slice(s![offset..offset + len_usize, ..]).to_owned();
120            result.push(doc_emb);
121            offset += len_usize;
122        }
123
124        return Ok(result);
125    }
126
127    // Fallback: if no lengths file, return as single document (legacy behavior)
128    Ok(vec![flat])
129}
130
131/// Save embeddings to buffer.npy.
132///
133/// Also saves buffer_info.json with the number of documents for deletion tracking.
134pub fn save_buffer(index_path: &Path, embeddings: &[Array2<f32>]) -> Result<()> {
135    use ndarray_npy::WriteNpyExt;
136
137    let buffer_path = index_path.join("buffer.npy");
138
139    // For simplicity, concatenate all embeddings into one array
140    // and store the lengths separately
141    if embeddings.is_empty() {
142        return Ok(());
143    }
144
145    let dim = embeddings[0].ncols();
146    let total_rows: usize = embeddings.iter().map(|e| e.nrows()).sum();
147
148    let mut flat = Array2::<f32>::zeros((total_rows, dim));
149    let mut offset = 0;
150    let mut lengths = Vec::new();
151
152    for emb in embeddings {
153        let n = emb.nrows();
154        flat.slice_mut(s![offset..offset + n, ..]).assign(emb);
155        lengths.push(n as i64);
156        offset += n;
157    }
158
159    flat.write_npy(File::create(&buffer_path)?)?;
160
161    // Save lengths
162    let lengths_path = index_path.join("buffer_lengths.json");
163    serde_json::to_writer(BufWriter::new(File::create(&lengths_path)?), &lengths)?;
164
165    // Save buffer info for deletion tracking (number of documents)
166    let info_path = index_path.join("buffer_info.json");
167    let buffer_info = serde_json::json!({ "num_docs": embeddings.len() });
168    serde_json::to_writer(BufWriter::new(File::create(&info_path)?), &buffer_info)?;
169
170    Ok(())
171}
172
173/// Load buffer info (number of buffered documents).
174///
175/// Returns 0 if buffer_info.json doesn't exist.
176pub fn load_buffer_info(index_path: &Path) -> Result<usize> {
177    let info_path = index_path.join("buffer_info.json");
178    if !info_path.exists() {
179        return Ok(0);
180    }
181
182    let info: serde_json::Value = serde_json::from_reader(BufReader::new(File::open(&info_path)?))?;
183
184    Ok(info.get("num_docs").and_then(|v| v.as_u64()).unwrap_or(0) as usize)
185}
186
187/// Clear buffer files.
188pub fn clear_buffer(index_path: &Path) -> Result<()> {
189    let buffer_path = index_path.join("buffer.npy");
190    let lengths_path = index_path.join("buffer_lengths.json");
191    let info_path = index_path.join("buffer_info.json");
192
193    if buffer_path.exists() {
194        fs::remove_file(&buffer_path)?;
195    }
196    if lengths_path.exists() {
197        fs::remove_file(&lengths_path)?;
198    }
199    if info_path.exists() {
200        fs::remove_file(&info_path)?;
201    }
202
203    Ok(())
204}
205
206/// Load embeddings stored for rebuild (embeddings.npy + embeddings_lengths.json).
207///
208/// This function loads raw embeddings that were saved for start-from-scratch rebuilds.
209/// The embeddings are stored in a flat 2D array with a separate lengths file.
210pub fn load_embeddings_npy(index_path: &Path) -> Result<Vec<Array2<f32>>> {
211    use ndarray_npy::ReadNpyExt;
212
213    let emb_path = index_path.join("embeddings.npy");
214    let lengths_path = index_path.join("embeddings_lengths.json");
215
216    if !emb_path.exists() {
217        return Ok(Vec::new());
218    }
219
220    // Load flat embeddings array
221    let flat: Array2<f32> = Array2::read_npy(File::open(&emb_path)?)?;
222
223    // Load lengths to split back into per-document arrays
224    if lengths_path.exists() {
225        let lengths: Vec<i64> =
226            serde_json::from_reader(BufReader::new(File::open(&lengths_path)?))?;
227
228        let mut result = Vec::with_capacity(lengths.len());
229        let mut offset = 0;
230
231        for &len in &lengths {
232            let len_usize = len as usize;
233            if offset + len_usize > flat.nrows() {
234                break;
235            }
236            let doc_emb = flat.slice(s![offset..offset + len_usize, ..]).to_owned();
237            result.push(doc_emb);
238            offset += len_usize;
239        }
240
241        return Ok(result);
242    }
243
244    // Fallback: if no lengths file, return as single document
245    Ok(vec![flat])
246}
247
248/// Save embeddings for potential rebuild (start-from-scratch mode).
249///
250/// Stores embeddings in embeddings.npy (flat array) + embeddings_lengths.json.
251/// This matches fast-plaid's behavior of storing raw embeddings when the index
252/// is below the start_from_scratch threshold.
253pub fn save_embeddings_npy(index_path: &Path, embeddings: &[Array2<f32>]) -> Result<()> {
254    use ndarray_npy::WriteNpyExt;
255
256    if embeddings.is_empty() {
257        return Ok(());
258    }
259
260    let dim = embeddings[0].ncols();
261    let total_rows: usize = embeddings.iter().map(|e| e.nrows()).sum();
262
263    let mut flat = Array2::<f32>::zeros((total_rows, dim));
264    let mut offset = 0;
265    let mut lengths = Vec::with_capacity(embeddings.len());
266
267    for emb in embeddings {
268        let n = emb.nrows();
269        flat.slice_mut(s![offset..offset + n, ..]).assign(emb);
270        lengths.push(n as i64);
271        offset += n;
272    }
273
274    // Save flat embeddings
275    let emb_path = index_path.join("embeddings.npy");
276    flat.write_npy(File::create(&emb_path)?)?;
277
278    // Save lengths for reconstruction
279    let lengths_path = index_path.join("embeddings_lengths.json");
280    serde_json::to_writer(BufWriter::new(File::create(&lengths_path)?), &lengths)?;
281
282    Ok(())
283}
284
285/// Clear embeddings.npy and embeddings_lengths.json.
286pub fn clear_embeddings_npy(index_path: &Path) -> Result<()> {
287    let emb_path = index_path.join("embeddings.npy");
288    let lengths_path = index_path.join("embeddings_lengths.json");
289
290    if emb_path.exists() {
291        fs::remove_file(&emb_path)?;
292    }
293    if lengths_path.exists() {
294        fs::remove_file(&lengths_path)?;
295    }
296    Ok(())
297}
298
299/// Check if embeddings.npy exists for start-from-scratch mode.
300pub fn embeddings_npy_exists(index_path: &Path) -> bool {
301    index_path.join("embeddings.npy").exists()
302}
303
304// ============================================================================
305// Cluster Threshold Management
306// ============================================================================
307
308/// Load cluster threshold from cluster_threshold.npy.
309pub fn load_cluster_threshold(index_path: &Path) -> Result<f32> {
310    use ndarray_npy::ReadNpyExt;
311
312    let thresh_path = index_path.join("cluster_threshold.npy");
313    if !thresh_path.exists() {
314        return Err(Error::Update("cluster_threshold.npy not found".into()));
315    }
316
317    let arr: Array1<f32> = Array1::read_npy(File::open(&thresh_path)?)?;
318    Ok(arr[0])
319}
320
321/// Update cluster_threshold.npy with weighted average.
322pub fn update_cluster_threshold(
323    index_path: &Path,
324    new_residual_norms: &Array1<f32>,
325    old_total_embeddings: usize,
326) -> Result<()> {
327    use ndarray_npy::{ReadNpyExt, WriteNpyExt};
328
329    let new_count = new_residual_norms.len();
330    if new_count == 0 {
331        return Ok(());
332    }
333
334    let new_threshold = quantile(new_residual_norms, 0.75);
335
336    let thresh_path = index_path.join("cluster_threshold.npy");
337    let final_threshold = if thresh_path.exists() {
338        let old_arr: Array1<f32> = Array1::read_npy(File::open(&thresh_path)?)?;
339        let old_threshold = old_arr[0];
340        let total = old_total_embeddings + new_count;
341        (old_threshold * old_total_embeddings as f32 + new_threshold * new_count as f32)
342            / total as f32
343    } else {
344        new_threshold
345    };
346
347    Array1::from_vec(vec![final_threshold]).write_npy(File::create(&thresh_path)?)?;
348
349    Ok(())
350}
351
352// ============================================================================
353// Centroid Expansion
354// ============================================================================
355
356/// Find outlier embeddings that are far from all existing centroids.
357///
358/// Returns indices of embeddings where min L2² distance > threshold².
359///
360/// Uses batch matrix multiplication for efficiency:
361/// ||a - b||² = ||a||² + ||b||² - 2*a·b
362fn find_outliers(
363    flat_embeddings: &Array2<f32>,
364    centroids: &Array2<f32>,
365    threshold_sq: f32,
366) -> Vec<usize> {
367    let n = flat_embeddings.nrows();
368    let k = centroids.nrows();
369
370    if n == 0 || k == 0 {
371        return Vec::new();
372    }
373
374    // Pre-compute squared norms for embeddings and centroids
375    let emb_norms_sq: Vec<f32> = flat_embeddings
376        .axis_iter(Axis(0))
377        .into_par_iter()
378        .map(|row| row.dot(&row))
379        .collect();
380
381    let centroid_norms_sq: Vec<f32> = centroids
382        .axis_iter(Axis(0))
383        .into_par_iter()
384        .map(|row| row.dot(&row))
385        .collect();
386
387    // Batch matrix multiplication: [n, d] @ [d, k] -> [n, k]
388    // This computes dot products: similarities[i, j] = embeddings[i] · centroids[j]
389    // Process in batches to limit memory usage
390    let batch_size = (2 * 1024 * 1024 * 1024 / (k * std::mem::size_of::<f32>())).clamp(1, 4096);
391
392    let mut outlier_indices = Vec::new();
393
394    for batch_start in (0..n).step_by(batch_size) {
395        let batch_end = (batch_start + batch_size).min(n);
396        let batch = flat_embeddings.slice(s![batch_start..batch_end, ..]);
397
398        // Compute dot products: [batch, k]
399        let dot_products = batch.dot(&centroids.t());
400
401        // Find min L2² distance for each embedding in batch
402        for (batch_idx, row) in dot_products.axis_iter(Axis(0)).enumerate() {
403            let global_idx = batch_start + batch_idx;
404            let emb_norm_sq = emb_norms_sq[global_idx];
405
406            // L2² = ||a||² + ||b||² - 2*a·b
407            // Find minimum over all centroids
408            let min_dist_sq = row
409                .iter()
410                .zip(centroid_norms_sq.iter())
411                .map(|(&dot, &c_norm_sq)| emb_norm_sq + c_norm_sq - 2.0 * dot)
412                .fold(f32::INFINITY, f32::min);
413
414            if min_dist_sq > threshold_sq {
415                outlier_indices.push(global_idx);
416            }
417        }
418    }
419
420    outlier_indices
421}
422
423/// Expand centroids by clustering embeddings far from existing centroids.
424///
425/// This implements fast-plaid's update_centroids() function:
426/// 1. Flatten all new embeddings
427/// 2. Find outliers (distance > cluster_threshold²)
428/// 3. Cluster outliers to get new centroids
429/// 4. Append new centroids to centroids.npy
430/// 5. Extend ivf_lengths.npy with zeros
431/// 6. Update metadata.json num_partitions
432///
433/// Returns the number of new centroids added.
434pub fn update_centroids(
435    index_path: &Path,
436    new_embeddings: &[Array2<f32>],
437    cluster_threshold: f32,
438    config: &UpdateConfig,
439) -> Result<usize> {
440    use ndarray_npy::{ReadNpyExt, WriteNpyExt};
441
442    let centroids_path = index_path.join("centroids.npy");
443    if !centroids_path.exists() {
444        return Ok(0);
445    }
446
447    // Load existing centroids
448    let existing_centroids: Array2<f32> = Array2::read_npy(File::open(&centroids_path)?)?;
449
450    // Flatten all new embeddings
451    let dim = existing_centroids.ncols();
452    let total_tokens: usize = new_embeddings.iter().map(|e| e.nrows()).sum();
453
454    if total_tokens == 0 {
455        return Ok(0);
456    }
457
458    let mut flat_embeddings = Array2::<f32>::zeros((total_tokens, dim));
459    let mut offset = 0;
460
461    for emb in new_embeddings {
462        let n = emb.nrows();
463        flat_embeddings
464            .slice_mut(s![offset..offset + n, ..])
465            .assign(emb);
466        offset += n;
467    }
468
469    // Find outliers
470    let threshold_sq = cluster_threshold * cluster_threshold;
471    let outlier_indices = find_outliers(&flat_embeddings, &existing_centroids, threshold_sq);
472
473    let num_outliers = outlier_indices.len();
474    if num_outliers == 0 {
475        return Ok(0);
476    }
477
478    // Extract outlier embeddings
479    let mut outliers = Array2::<f32>::zeros((num_outliers, dim));
480    for (i, &idx) in outlier_indices.iter().enumerate() {
481        outliers.row_mut(i).assign(&flat_embeddings.row(idx));
482    }
483
484    // Compute number of new centroids
485    // k_update = max(1, ceil(num_outliers / max_points_per_centroid) * 4)
486    let target_k =
487        ((num_outliers as f64 / config.max_points_per_centroid as f64).ceil() as usize).max(1) * 4;
488    let k_update = target_k.min(num_outliers); // Can't have more centroids than points
489
490    // Cluster outliers to get new centroids
491    let kmeans_config = ComputeKmeansConfig {
492        kmeans_niters: config.kmeans_niters,
493        max_points_per_centroid: config.max_points_per_centroid,
494        seed: config.seed,
495        n_samples_kmeans: config.n_samples_kmeans,
496        num_partitions: Some(k_update),
497        force_cpu: config.force_cpu,
498    };
499
500    // Convert outliers to vector of single-token "documents" for compute_kmeans
501    let outlier_docs: Vec<Array2<f32>> = outlier_indices
502        .iter()
503        .map(|&idx| flat_embeddings.slice(s![idx..idx + 1, ..]).to_owned())
504        .collect();
505
506    let new_centroids = compute_kmeans(&outlier_docs, &kmeans_config)?;
507    let k_new = new_centroids.nrows();
508
509    // Concatenate centroids
510    let new_num_centroids = existing_centroids.nrows() + k_new;
511    let mut final_centroids = Array2::<f32>::zeros((new_num_centroids, dim));
512    final_centroids
513        .slice_mut(s![..existing_centroids.nrows(), ..])
514        .assign(&existing_centroids);
515    final_centroids
516        .slice_mut(s![existing_centroids.nrows().., ..])
517        .assign(&new_centroids);
518
519    // Save updated centroids
520    final_centroids.write_npy(File::create(&centroids_path)?)?;
521
522    // Extend ivf_lengths.npy with zeros for new centroids
523    let ivf_lengths_path = index_path.join("ivf_lengths.npy");
524    if ivf_lengths_path.exists() {
525        let old_lengths: Array1<i32> = Array1::read_npy(File::open(&ivf_lengths_path)?)?;
526        let mut new_lengths = Array1::<i32>::zeros(new_num_centroids);
527        new_lengths
528            .slice_mut(s![..old_lengths.len()])
529            .assign(&old_lengths);
530        new_lengths.write_npy(File::create(&ivf_lengths_path)?)?;
531    }
532
533    // Update metadata.json num_partitions
534    let meta_path = index_path.join("metadata.json");
535    if meta_path.exists() {
536        let mut meta: serde_json::Value =
537            serde_json::from_reader(BufReader::new(File::open(&meta_path)?))?;
538
539        if let Some(obj) = meta.as_object_mut() {
540            obj.insert("num_partitions".to_string(), new_num_centroids.into());
541        }
542
543        serde_json::to_writer_pretty(BufWriter::new(File::create(&meta_path)?), &meta)?;
544    }
545
546    Ok(k_new)
547}
548
549// ============================================================================
550// Low-Level Index Update
551// ============================================================================
552
553/// Update an existing index with new documents.
554///
555/// # Arguments
556///
557/// * `embeddings` - List of new document embeddings, each of shape `[num_tokens, dim]`
558/// * `index_path` - Path to the existing index directory
559/// * `codec` - The loaded ResidualCodec for compression
560/// * `batch_size` - Optional batch size for processing (default: 50,000)
561/// * `update_threshold` - Whether to update the cluster threshold
562/// * `force_cpu` - Force CPU execution even when CUDA is available
563///
564/// # Returns
565///
566/// The number of new documents added
567pub fn update_index(
568    embeddings: &[Array2<f32>],
569    index_path: &str,
570    codec: &ResidualCodec,
571    batch_size: Option<usize>,
572    update_threshold: bool,
573    force_cpu: bool,
574) -> Result<usize> {
575    let batch_size = batch_size.unwrap_or(DEFAULT_BATCH_SIZE);
576    let index_dir = Path::new(index_path);
577
578    // Load existing metadata (infers num_documents from doclens if not present)
579    let metadata_path = index_dir.join("metadata.json");
580    let metadata = Metadata::load_from_path(index_dir)?;
581
582    let num_existing_chunks = metadata.num_chunks;
583    let old_num_documents = metadata.num_documents;
584    let old_total_embeddings = metadata.num_embeddings;
585    let num_centroids = codec.num_centroids();
586    let embedding_dim = codec.embedding_dim();
587    let nbits = metadata.nbits;
588
589    // Determine starting chunk index
590    let mut start_chunk_idx = num_existing_chunks;
591    let mut append_to_last = false;
592    let mut current_emb_offset = old_total_embeddings;
593
594    // Check if we should append to the last chunk (if it has < 2000 documents)
595    if start_chunk_idx > 0 {
596        let last_idx = start_chunk_idx - 1;
597        let last_meta_path = index_dir.join(format!("{}.metadata.json", last_idx));
598
599        if last_meta_path.exists() {
600            let last_meta: serde_json::Value =
601                serde_json::from_reader(BufReader::new(File::open(&last_meta_path).map_err(
602                    |e| Error::IndexLoad(format!("Failed to open chunk metadata: {}", e)),
603                )?))?;
604
605            if let Some(nd) = last_meta.get("num_documents").and_then(|x| x.as_u64()) {
606                if nd < 2000 {
607                    start_chunk_idx = last_idx;
608                    append_to_last = true;
609
610                    if let Some(off) = last_meta.get("embedding_offset").and_then(|x| x.as_u64()) {
611                        current_emb_offset = off as usize;
612                    } else {
613                        let embs_in_last = last_meta
614                            .get("num_embeddings")
615                            .and_then(|x| x.as_u64())
616                            .unwrap_or(0) as usize;
617                        current_emb_offset = old_total_embeddings - embs_in_last;
618                    }
619                }
620            }
621        }
622    }
623
624    // Process new documents
625    let num_new_documents = embeddings.len();
626    let n_new_chunks = (num_new_documents as f64 / batch_size as f64).ceil() as usize;
627
628    let mut new_codes_accumulated: Vec<Vec<usize>> = Vec::new();
629    let mut new_doclens_accumulated: Vec<i64> = Vec::new();
630    let mut all_residual_norms: Vec<f32> = Vec::new();
631
632    let packed_dim = embedding_dim * nbits / 8;
633
634    for i in 0..n_new_chunks {
635        let global_chunk_idx = start_chunk_idx + i;
636        let chk_offset = i * batch_size;
637        let chk_end = (chk_offset + batch_size).min(num_new_documents);
638        let chunk_docs = &embeddings[chk_offset..chk_end];
639
640        // Collect document lengths
641        let mut chk_doclens: Vec<i64> = chunk_docs.iter().map(|d| d.nrows() as i64).collect();
642        let total_tokens: usize = chk_doclens.iter().sum::<i64>() as usize;
643
644        // Concatenate all embeddings in the chunk for batch processing
645        let mut batch_embeddings = ndarray::Array2::<f32>::zeros((total_tokens, embedding_dim));
646        let mut offset = 0;
647        for doc in chunk_docs {
648            let n = doc.nrows();
649            batch_embeddings
650                .slice_mut(s![offset..offset + n, ..])
651                .assign(doc);
652            offset += n;
653        }
654
655        // BATCH: Compress all embeddings at once
656        // Use CPU-only version when force_cpu is set to avoid CUDA initialization overhead
657        let batch_codes = if force_cpu {
658            codec.compress_into_codes_cpu(&batch_embeddings)
659        } else {
660            codec.compress_into_codes(&batch_embeddings)
661        };
662
663        // BATCH: Compute residuals using parallel subtraction
664        let mut batch_residuals = batch_embeddings;
665        {
666            let centroids = &codec.centroids;
667            batch_residuals
668                .axis_iter_mut(Axis(0))
669                .into_par_iter()
670                .zip(batch_codes.as_slice().unwrap().par_iter())
671                .for_each(|(mut row, &code)| {
672                    let centroid = centroids.row(code);
673                    row.iter_mut()
674                        .zip(centroid.iter())
675                        .for_each(|(r, c)| *r -= c);
676                });
677        }
678
679        // Collect residual norms if updating threshold
680        if update_threshold {
681            for row in batch_residuals.axis_iter(Axis(0)) {
682                let norm = row.dot(&row).sqrt();
683                all_residual_norms.push(norm);
684            }
685        }
686
687        // BATCH: Quantize all residuals at once
688        let batch_packed = codec.quantize_residuals(&batch_residuals)?;
689
690        // Convert to lists for chunk saving
691        let mut chk_codes_list: Vec<usize> = batch_codes.iter().copied().collect();
692        let mut chk_residuals_list: Vec<u8> = batch_packed.iter().copied().collect();
693
694        // Split codes back into per-document arrays for IVF building
695        let mut code_offset = 0;
696        for &len in &chk_doclens {
697            let len_usize = len as usize;
698            let codes: Vec<usize> = batch_codes
699                .slice(s![code_offset..code_offset + len_usize])
700                .iter()
701                .copied()
702                .collect();
703            new_codes_accumulated.push(codes);
704            new_doclens_accumulated.push(len);
705            code_offset += len_usize;
706        }
707
708        // Handle appending to last chunk
709        if i == 0 && append_to_last {
710            use ndarray_npy::ReadNpyExt;
711
712            let old_doclens_path = index_dir.join(format!("doclens.{}.json", global_chunk_idx));
713
714            if old_doclens_path.exists() {
715                let old_doclens: Vec<i64> =
716                    serde_json::from_reader(BufReader::new(File::open(&old_doclens_path)?))?;
717
718                let old_codes_path = index_dir.join(format!("{}.codes.npy", global_chunk_idx));
719                let old_residuals_path =
720                    index_dir.join(format!("{}.residuals.npy", global_chunk_idx));
721
722                let old_codes: Array1<i64> = Array1::read_npy(File::open(&old_codes_path)?)?;
723                let old_residuals: Array2<u8> = Array2::read_npy(File::open(&old_residuals_path)?)?;
724
725                // Prepend old data
726                let mut combined_codes: Vec<usize> =
727                    old_codes.iter().map(|&x| x as usize).collect();
728                combined_codes.extend(chk_codes_list);
729                chk_codes_list = combined_codes;
730
731                let mut combined_residuals: Vec<u8> = old_residuals.iter().copied().collect();
732                combined_residuals.extend(chk_residuals_list);
733                chk_residuals_list = combined_residuals;
734
735                let mut combined_doclens = old_doclens;
736                combined_doclens.extend(chk_doclens);
737                chk_doclens = combined_doclens;
738            }
739        }
740
741        // Save chunk data
742        {
743            use ndarray_npy::WriteNpyExt;
744
745            let codes_arr: Array1<i64> = chk_codes_list.iter().map(|&x| x as i64).collect();
746            let codes_path = index_dir.join(format!("{}.codes.npy", global_chunk_idx));
747            codes_arr.write_npy(File::create(&codes_path)?)?;
748
749            let num_tokens = chk_codes_list.len();
750            let residuals_arr =
751                Array2::from_shape_vec((num_tokens, packed_dim), chk_residuals_list)
752                    .map_err(|e| Error::Shape(format!("Failed to reshape residuals: {}", e)))?;
753            let residuals_path = index_dir.join(format!("{}.residuals.npy", global_chunk_idx));
754            residuals_arr.write_npy(File::create(&residuals_path)?)?;
755        }
756
757        // Save doclens
758        let doclens_path = index_dir.join(format!("doclens.{}.json", global_chunk_idx));
759        serde_json::to_writer(BufWriter::new(File::create(&doclens_path)?), &chk_doclens)?;
760
761        // Save chunk metadata
762        let chk_meta = serde_json::json!({
763            "num_documents": chk_doclens.len(),
764            "num_embeddings": chk_codes_list.len(),
765            "embedding_offset": current_emb_offset,
766        });
767        current_emb_offset += chk_codes_list.len();
768
769        let meta_path = index_dir.join(format!("{}.metadata.json", global_chunk_idx));
770        serde_json::to_writer_pretty(BufWriter::new(File::create(&meta_path)?), &chk_meta)?;
771    }
772
773    // Update cluster threshold if requested
774    if update_threshold && !all_residual_norms.is_empty() {
775        let norms = Array1::from_vec(all_residual_norms);
776        update_cluster_threshold(index_dir, &norms, old_total_embeddings)?;
777    }
778
779    // Build new partial IVF
780    let mut partition_pids_map: HashMap<usize, Vec<i64>> = HashMap::new();
781    let mut pid_counter = old_num_documents as i64;
782
783    for doc_codes in &new_codes_accumulated {
784        for &code in doc_codes {
785            partition_pids_map
786                .entry(code)
787                .or_default()
788                .push(pid_counter);
789        }
790        pid_counter += 1;
791    }
792
793    // Load old IVF and merge
794    {
795        use ndarray_npy::{ReadNpyExt, WriteNpyExt};
796
797        let ivf_path = index_dir.join("ivf.npy");
798        let ivf_lengths_path = index_dir.join("ivf_lengths.npy");
799
800        let old_ivf: Array1<i64> = if ivf_path.exists() {
801            Array1::read_npy(File::open(&ivf_path)?)?
802        } else {
803            Array1::zeros(0)
804        };
805
806        let old_ivf_lengths: Array1<i32> = if ivf_lengths_path.exists() {
807            Array1::read_npy(File::open(&ivf_lengths_path)?)?
808        } else {
809            Array1::zeros(num_centroids)
810        };
811
812        // Compute old offsets
813        let mut old_offsets = vec![0i64];
814        for &len in old_ivf_lengths.iter() {
815            old_offsets.push(old_offsets.last().unwrap() + len as i64);
816        }
817
818        // Merge IVF
819        let mut new_ivf_data: Vec<i64> = Vec::new();
820        let mut new_ivf_lengths: Vec<i32> = Vec::with_capacity(num_centroids);
821
822        for centroid_id in 0..num_centroids {
823            // Get old PIDs for this centroid
824            let old_start = old_offsets[centroid_id] as usize;
825            let old_len = if centroid_id < old_ivf_lengths.len() {
826                old_ivf_lengths[centroid_id] as usize
827            } else {
828                0
829            };
830
831            let mut pids: Vec<i64> = if old_len > 0 && old_start + old_len <= old_ivf.len() {
832                old_ivf.slice(s![old_start..old_start + old_len]).to_vec()
833            } else {
834                Vec::new()
835            };
836
837            // Add new PIDs
838            if let Some(new_pids) = partition_pids_map.get(&centroid_id) {
839                pids.extend(new_pids);
840            }
841
842            // Deduplicate and sort
843            pids.sort_unstable();
844            pids.dedup();
845
846            new_ivf_lengths.push(pids.len() as i32);
847            new_ivf_data.extend(pids);
848        }
849
850        // Save updated IVF
851        let new_ivf = Array1::from_vec(new_ivf_data);
852        new_ivf.write_npy(File::create(&ivf_path)?)?;
853
854        let new_lengths = Array1::from_vec(new_ivf_lengths);
855        new_lengths.write_npy(File::create(&ivf_lengths_path)?)?;
856    }
857
858    // Update global metadata
859    let new_total_chunks = start_chunk_idx + n_new_chunks;
860    let new_tokens_count: i64 = new_doclens_accumulated.iter().sum();
861    let num_embeddings = old_total_embeddings + new_tokens_count as usize;
862    let total_num_documents = old_num_documents + num_new_documents;
863
864    let new_avg_doclen = if total_num_documents > 0 {
865        let old_sum = metadata.avg_doclen * old_num_documents as f64;
866        (old_sum + new_tokens_count as f64) / total_num_documents as f64
867    } else {
868        0.0
869    };
870
871    let new_metadata = Metadata {
872        num_chunks: new_total_chunks,
873        nbits,
874        num_partitions: num_centroids,
875        num_embeddings,
876        avg_doclen: new_avg_doclen,
877        num_documents: total_num_documents,
878        next_plaid_compatible: true,
879    };
880
881    serde_json::to_writer_pretty(BufWriter::new(File::create(&metadata_path)?), &new_metadata)?;
882
883    // Clear merged files to force regeneration on next load.
884    // This ensures the merged files are consistent with the new chunk data.
885    crate::mmap::clear_merged_files(index_dir)?;
886
887    Ok(num_new_documents)
888}
889
890#[cfg(test)]
891mod tests {
892    use super::*;
893
894    #[test]
895    fn test_update_config_default() {
896        let config = UpdateConfig::default();
897        assert_eq!(config.batch_size, 50_000);
898        assert_eq!(config.buffer_size, 100);
899        assert_eq!(config.start_from_scratch, 999);
900    }
901
902    #[test]
903    fn test_find_outliers() {
904        // Create centroids at (0,0), (1,1)
905        let centroids = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 1.0, 1.0]).unwrap();
906
907        // Create embeddings: one close to (0,0), one close to (1,1), one far away at (5,5)
908        let embeddings =
909            Array2::from_shape_vec((3, 2), vec![0.1, 0.1, 0.9, 0.9, 5.0, 5.0]).unwrap();
910
911        // Threshold of 1.0 squared = 1.0
912        let outliers = find_outliers(&embeddings, &centroids, 1.0);
913
914        // Only the point at (5,5) should be an outlier
915        assert_eq!(outliers.len(), 1);
916        assert_eq!(outliers[0], 2);
917    }
918
919    #[test]
920    fn test_buffer_roundtrip() {
921        use tempfile::TempDir;
922
923        let dir = TempDir::new().unwrap();
924
925        // Create 3 documents with different numbers of embeddings
926        let embeddings = vec![
927            Array2::from_shape_vec((3, 4), (0..12).map(|x| x as f32).collect()).unwrap(),
928            Array2::from_shape_vec((2, 4), (12..20).map(|x| x as f32).collect()).unwrap(),
929            Array2::from_shape_vec((5, 4), (20..40).map(|x| x as f32).collect()).unwrap(),
930        ];
931
932        // Save buffer
933        save_buffer(dir.path(), &embeddings).unwrap();
934
935        // Load buffer and verify we get 3 documents, not 1
936        let loaded = load_buffer(dir.path()).unwrap();
937
938        assert_eq!(loaded.len(), 3, "Should have 3 documents, not 1");
939        assert_eq!(loaded[0].nrows(), 3, "First doc should have 3 rows");
940        assert_eq!(loaded[1].nrows(), 2, "Second doc should have 2 rows");
941        assert_eq!(loaded[2].nrows(), 5, "Third doc should have 5 rows");
942
943        // Verify content matches
944        assert_eq!(loaded[0], embeddings[0]);
945        assert_eq!(loaded[1], embeddings[1]);
946        assert_eq!(loaded[2], embeddings[2]);
947    }
948
949    #[test]
950    fn test_buffer_info_matches_buffer_len() {
951        use tempfile::TempDir;
952
953        let dir = TempDir::new().unwrap();
954
955        // Create 5 documents
956        let embeddings: Vec<Array2<f32>> = (0..5)
957            .map(|i| {
958                let rows = i + 2; // 2, 3, 4, 5, 6 rows
959                Array2::from_shape_fn((rows, 4), |(r, c)| (r * 4 + c) as f32)
960            })
961            .collect();
962
963        save_buffer(dir.path(), &embeddings).unwrap();
964
965        // Verify buffer_info.json matches actual document count
966        let info_count = load_buffer_info(dir.path()).unwrap();
967        let loaded = load_buffer(dir.path()).unwrap();
968
969        assert_eq!(info_count, 5, "buffer_info should report 5 docs");
970        assert_eq!(
971            loaded.len(),
972            5,
973            "load_buffer should return 5 docs to match buffer_info"
974        );
975    }
976}