Skip to main content

next_plaid/
update.rs

1//! Index update functionality for adding new documents.
2//!
3//! This module provides functions to incrementally update an existing PLAID index
4//! with new documents, matching fast-plaid's behavior:
5//! - Buffer mechanism for small updates
6//! - Centroid expansion for outliers
7//! - Cluster threshold updates
8
9use std::collections::HashMap;
10use std::fs;
11use std::fs::File;
12use std::io::{BufReader, BufWriter};
13use std::path::Path;
14
15use serde::{Deserialize, Serialize};
16
17use ndarray::{s, Array1, Array2, Axis};
18use rayon::prelude::*;
19
20use crate::codec::ResidualCodec;
21use crate::error::Error;
22use crate::error::Result;
23use crate::index::Metadata;
24use crate::kmeans::compute_kmeans;
25use crate::kmeans::ComputeKmeansConfig;
26use crate::utils::quantile;
27
28/// Default batch size for processing documents (matches fast-plaid).
29const DEFAULT_BATCH_SIZE: usize = 50_000;
30
31/// Configuration for index updates.
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct UpdateConfig {
34    /// Batch size for processing documents (default: 50,000)
35    pub batch_size: usize,
36    /// Number of K-means iterations for centroid expansion (default: 4)
37    pub kmeans_niters: usize,
38    /// Max points per centroid for K-means (default: 256)
39    pub max_points_per_centroid: usize,
40    /// Number of samples for K-means (default: auto-calculated)
41    pub n_samples_kmeans: Option<usize>,
42    /// Random seed (default: 42)
43    pub seed: u64,
44    /// If index has fewer docs than this, rebuild from scratch (default: 999)
45    pub start_from_scratch: usize,
46    /// Buffer size before triggering centroid expansion (default: 100)
47    pub buffer_size: usize,
48}
49
50impl Default for UpdateConfig {
51    fn default() -> Self {
52        Self {
53            batch_size: DEFAULT_BATCH_SIZE,
54            kmeans_niters: 4,
55            max_points_per_centroid: 256,
56            n_samples_kmeans: None,
57            seed: 42,
58            start_from_scratch: 999,
59            buffer_size: 100,
60        }
61    }
62}
63
64impl UpdateConfig {
65    /// Convert to ComputeKmeansConfig for centroid expansion.
66    pub fn to_kmeans_config(&self) -> ComputeKmeansConfig {
67        ComputeKmeansConfig {
68            kmeans_niters: self.kmeans_niters,
69            max_points_per_centroid: self.max_points_per_centroid,
70            seed: self.seed,
71            n_samples_kmeans: self.n_samples_kmeans,
72            num_partitions: None,
73        }
74    }
75}
76
77// ============================================================================
78// Buffer Management
79// ============================================================================
80
81/// Load buffered embeddings from buffer.npy.
82///
83/// Returns an empty vector if buffer.npy doesn't exist.
84/// Uses buffer_lengths.json to split the flattened array back into per-document arrays.
85pub fn load_buffer(index_path: &Path) -> Result<Vec<Array2<f32>>> {
86    use ndarray_npy::ReadNpyExt;
87
88    let buffer_path = index_path.join("buffer.npy");
89    let lengths_path = index_path.join("buffer_lengths.json");
90
91    if !buffer_path.exists() {
92        return Ok(Vec::new());
93    }
94
95    // Load the flattened embeddings array
96    let flat: Array2<f32> = match Array2::read_npy(File::open(&buffer_path)?) {
97        Ok(arr) => arr,
98        Err(_) => return Ok(Vec::new()),
99    };
100
101    // Load lengths to split back into per-document arrays
102    if lengths_path.exists() {
103        let lengths: Vec<i64> =
104            serde_json::from_reader(BufReader::new(File::open(&lengths_path)?))?;
105
106        let mut result = Vec::with_capacity(lengths.len());
107        let mut offset = 0;
108
109        for &len in &lengths {
110            let len_usize = len as usize;
111            if offset + len_usize > flat.nrows() {
112                break;
113            }
114            let doc_emb = flat.slice(s![offset..offset + len_usize, ..]).to_owned();
115            result.push(doc_emb);
116            offset += len_usize;
117        }
118
119        return Ok(result);
120    }
121
122    // Fallback: if no lengths file, return as single document (legacy behavior)
123    Ok(vec![flat])
124}
125
126/// Save embeddings to buffer.npy.
127///
128/// Also saves buffer_info.json with the number of documents for deletion tracking.
129pub fn save_buffer(index_path: &Path, embeddings: &[Array2<f32>]) -> Result<()> {
130    use ndarray_npy::WriteNpyExt;
131
132    let buffer_path = index_path.join("buffer.npy");
133
134    // For simplicity, concatenate all embeddings into one array
135    // and store the lengths separately
136    if embeddings.is_empty() {
137        return Ok(());
138    }
139
140    let dim = embeddings[0].ncols();
141    let total_rows: usize = embeddings.iter().map(|e| e.nrows()).sum();
142
143    let mut flat = Array2::<f32>::zeros((total_rows, dim));
144    let mut offset = 0;
145    let mut lengths = Vec::new();
146
147    for emb in embeddings {
148        let n = emb.nrows();
149        flat.slice_mut(s![offset..offset + n, ..]).assign(emb);
150        lengths.push(n as i64);
151        offset += n;
152    }
153
154    flat.write_npy(File::create(&buffer_path)?)?;
155
156    // Save lengths
157    let lengths_path = index_path.join("buffer_lengths.json");
158    serde_json::to_writer(BufWriter::new(File::create(&lengths_path)?), &lengths)?;
159
160    // Save buffer info for deletion tracking (number of documents)
161    let info_path = index_path.join("buffer_info.json");
162    let buffer_info = serde_json::json!({ "num_docs": embeddings.len() });
163    serde_json::to_writer(BufWriter::new(File::create(&info_path)?), &buffer_info)?;
164
165    Ok(())
166}
167
168/// Load buffer info (number of buffered documents).
169///
170/// Returns 0 if buffer_info.json doesn't exist.
171pub fn load_buffer_info(index_path: &Path) -> Result<usize> {
172    let info_path = index_path.join("buffer_info.json");
173    if !info_path.exists() {
174        return Ok(0);
175    }
176
177    let info: serde_json::Value = serde_json::from_reader(BufReader::new(File::open(&info_path)?))?;
178
179    Ok(info.get("num_docs").and_then(|v| v.as_u64()).unwrap_or(0) as usize)
180}
181
182/// Clear buffer files.
183pub fn clear_buffer(index_path: &Path) -> Result<()> {
184    let buffer_path = index_path.join("buffer.npy");
185    let lengths_path = index_path.join("buffer_lengths.json");
186    let info_path = index_path.join("buffer_info.json");
187
188    if buffer_path.exists() {
189        fs::remove_file(&buffer_path)?;
190    }
191    if lengths_path.exists() {
192        fs::remove_file(&lengths_path)?;
193    }
194    if info_path.exists() {
195        fs::remove_file(&info_path)?;
196    }
197
198    Ok(())
199}
200
201/// Load embeddings stored for rebuild (embeddings.npy + embeddings_lengths.json).
202///
203/// This function loads raw embeddings that were saved for start-from-scratch rebuilds.
204/// The embeddings are stored in a flat 2D array with a separate lengths file.
205pub fn load_embeddings_npy(index_path: &Path) -> Result<Vec<Array2<f32>>> {
206    use ndarray_npy::ReadNpyExt;
207
208    let emb_path = index_path.join("embeddings.npy");
209    let lengths_path = index_path.join("embeddings_lengths.json");
210
211    if !emb_path.exists() {
212        return Ok(Vec::new());
213    }
214
215    // Load flat embeddings array
216    let flat: Array2<f32> = Array2::read_npy(File::open(&emb_path)?)?;
217
218    // Load lengths to split back into per-document arrays
219    if lengths_path.exists() {
220        let lengths: Vec<i64> =
221            serde_json::from_reader(BufReader::new(File::open(&lengths_path)?))?;
222
223        let mut result = Vec::with_capacity(lengths.len());
224        let mut offset = 0;
225
226        for &len in &lengths {
227            let len_usize = len as usize;
228            if offset + len_usize > flat.nrows() {
229                break;
230            }
231            let doc_emb = flat.slice(s![offset..offset + len_usize, ..]).to_owned();
232            result.push(doc_emb);
233            offset += len_usize;
234        }
235
236        return Ok(result);
237    }
238
239    // Fallback: if no lengths file, return as single document
240    Ok(vec![flat])
241}
242
243/// Save embeddings for potential rebuild (start-from-scratch mode).
244///
245/// Stores embeddings in embeddings.npy (flat array) + embeddings_lengths.json.
246/// This matches fast-plaid's behavior of storing raw embeddings when the index
247/// is below the start_from_scratch threshold.
248pub fn save_embeddings_npy(index_path: &Path, embeddings: &[Array2<f32>]) -> Result<()> {
249    use ndarray_npy::WriteNpyExt;
250
251    if embeddings.is_empty() {
252        return Ok(());
253    }
254
255    let dim = embeddings[0].ncols();
256    let total_rows: usize = embeddings.iter().map(|e| e.nrows()).sum();
257
258    let mut flat = Array2::<f32>::zeros((total_rows, dim));
259    let mut offset = 0;
260    let mut lengths = Vec::with_capacity(embeddings.len());
261
262    for emb in embeddings {
263        let n = emb.nrows();
264        flat.slice_mut(s![offset..offset + n, ..]).assign(emb);
265        lengths.push(n as i64);
266        offset += n;
267    }
268
269    // Save flat embeddings
270    let emb_path = index_path.join("embeddings.npy");
271    flat.write_npy(File::create(&emb_path)?)?;
272
273    // Save lengths for reconstruction
274    let lengths_path = index_path.join("embeddings_lengths.json");
275    serde_json::to_writer(BufWriter::new(File::create(&lengths_path)?), &lengths)?;
276
277    Ok(())
278}
279
280/// Clear embeddings.npy and embeddings_lengths.json.
281pub fn clear_embeddings_npy(index_path: &Path) -> Result<()> {
282    let emb_path = index_path.join("embeddings.npy");
283    let lengths_path = index_path.join("embeddings_lengths.json");
284
285    if emb_path.exists() {
286        fs::remove_file(&emb_path)?;
287    }
288    if lengths_path.exists() {
289        fs::remove_file(&lengths_path)?;
290    }
291    Ok(())
292}
293
294/// Check if embeddings.npy exists for start-from-scratch mode.
295pub fn embeddings_npy_exists(index_path: &Path) -> bool {
296    index_path.join("embeddings.npy").exists()
297}
298
299// ============================================================================
300// Cluster Threshold Management
301// ============================================================================
302
303/// Load cluster threshold from cluster_threshold.npy.
304pub fn load_cluster_threshold(index_path: &Path) -> Result<f32> {
305    use ndarray_npy::ReadNpyExt;
306
307    let thresh_path = index_path.join("cluster_threshold.npy");
308    if !thresh_path.exists() {
309        return Err(Error::Update("cluster_threshold.npy not found".into()));
310    }
311
312    let arr: Array1<f32> = Array1::read_npy(File::open(&thresh_path)?)?;
313    Ok(arr[0])
314}
315
316/// Update cluster_threshold.npy with weighted average.
317pub fn update_cluster_threshold(
318    index_path: &Path,
319    new_residual_norms: &Array1<f32>,
320    old_total_embeddings: usize,
321) -> Result<()> {
322    use ndarray_npy::{ReadNpyExt, WriteNpyExt};
323
324    let new_count = new_residual_norms.len();
325    if new_count == 0 {
326        return Ok(());
327    }
328
329    let new_threshold = quantile(new_residual_norms, 0.75);
330
331    let thresh_path = index_path.join("cluster_threshold.npy");
332    let final_threshold = if thresh_path.exists() {
333        let old_arr: Array1<f32> = Array1::read_npy(File::open(&thresh_path)?)?;
334        let old_threshold = old_arr[0];
335        let total = old_total_embeddings + new_count;
336        (old_threshold * old_total_embeddings as f32 + new_threshold * new_count as f32)
337            / total as f32
338    } else {
339        new_threshold
340    };
341
342    Array1::from_vec(vec![final_threshold]).write_npy(File::create(&thresh_path)?)?;
343
344    Ok(())
345}
346
347// ============================================================================
348// Centroid Expansion
349// ============================================================================
350
351/// Find outlier embeddings that are far from all existing centroids.
352///
353/// Returns indices of embeddings where min L2² distance > threshold².
354///
355/// Uses batch matrix multiplication for efficiency:
356/// ||a - b||² = ||a||² + ||b||² - 2*a·b
357fn find_outliers(
358    flat_embeddings: &Array2<f32>,
359    centroids: &Array2<f32>,
360    threshold_sq: f32,
361) -> Vec<usize> {
362    let n = flat_embeddings.nrows();
363    let k = centroids.nrows();
364
365    if n == 0 || k == 0 {
366        return Vec::new();
367    }
368
369    // Pre-compute squared norms for embeddings and centroids
370    let emb_norms_sq: Vec<f32> = flat_embeddings
371        .axis_iter(Axis(0))
372        .into_par_iter()
373        .map(|row| row.dot(&row))
374        .collect();
375
376    let centroid_norms_sq: Vec<f32> = centroids
377        .axis_iter(Axis(0))
378        .into_par_iter()
379        .map(|row| row.dot(&row))
380        .collect();
381
382    // Batch matrix multiplication: [n, d] @ [d, k] -> [n, k]
383    // This computes dot products: similarities[i, j] = embeddings[i] · centroids[j]
384    // Process in batches to limit memory usage
385    let batch_size = (2 * 1024 * 1024 * 1024 / (k * std::mem::size_of::<f32>())).clamp(1, 4096);
386
387    let mut outlier_indices = Vec::new();
388
389    for batch_start in (0..n).step_by(batch_size) {
390        let batch_end = (batch_start + batch_size).min(n);
391        let batch = flat_embeddings.slice(s![batch_start..batch_end, ..]);
392
393        // Compute dot products: [batch, k]
394        let dot_products = batch.dot(&centroids.t());
395
396        // Find min L2² distance for each embedding in batch
397        for (batch_idx, row) in dot_products.axis_iter(Axis(0)).enumerate() {
398            let global_idx = batch_start + batch_idx;
399            let emb_norm_sq = emb_norms_sq[global_idx];
400
401            // L2² = ||a||² + ||b||² - 2*a·b
402            // Find minimum over all centroids
403            let min_dist_sq = row
404                .iter()
405                .zip(centroid_norms_sq.iter())
406                .map(|(&dot, &c_norm_sq)| emb_norm_sq + c_norm_sq - 2.0 * dot)
407                .fold(f32::INFINITY, f32::min);
408
409            if min_dist_sq > threshold_sq {
410                outlier_indices.push(global_idx);
411            }
412        }
413    }
414
415    outlier_indices
416}
417
418/// Expand centroids by clustering embeddings far from existing centroids.
419///
420/// This implements fast-plaid's update_centroids() function:
421/// 1. Flatten all new embeddings
422/// 2. Find outliers (distance > cluster_threshold²)
423/// 3. Cluster outliers to get new centroids
424/// 4. Append new centroids to centroids.npy
425/// 5. Extend ivf_lengths.npy with zeros
426/// 6. Update metadata.json num_partitions
427///
428/// Returns the number of new centroids added.
429pub fn update_centroids(
430    index_path: &Path,
431    new_embeddings: &[Array2<f32>],
432    cluster_threshold: f32,
433    config: &UpdateConfig,
434) -> Result<usize> {
435    use ndarray_npy::{ReadNpyExt, WriteNpyExt};
436
437    let centroids_path = index_path.join("centroids.npy");
438    if !centroids_path.exists() {
439        return Ok(0);
440    }
441
442    // Load existing centroids
443    let existing_centroids: Array2<f32> = Array2::read_npy(File::open(&centroids_path)?)?;
444
445    // Flatten all new embeddings
446    let dim = existing_centroids.ncols();
447    let total_tokens: usize = new_embeddings.iter().map(|e| e.nrows()).sum();
448
449    if total_tokens == 0 {
450        return Ok(0);
451    }
452
453    let mut flat_embeddings = Array2::<f32>::zeros((total_tokens, dim));
454    let mut offset = 0;
455
456    for emb in new_embeddings {
457        let n = emb.nrows();
458        flat_embeddings
459            .slice_mut(s![offset..offset + n, ..])
460            .assign(emb);
461        offset += n;
462    }
463
464    // Find outliers
465    let threshold_sq = cluster_threshold * cluster_threshold;
466    let outlier_indices = find_outliers(&flat_embeddings, &existing_centroids, threshold_sq);
467
468    let num_outliers = outlier_indices.len();
469    if num_outliers == 0 {
470        return Ok(0);
471    }
472
473    // Extract outlier embeddings
474    let mut outliers = Array2::<f32>::zeros((num_outliers, dim));
475    for (i, &idx) in outlier_indices.iter().enumerate() {
476        outliers.row_mut(i).assign(&flat_embeddings.row(idx));
477    }
478
479    // Compute number of new centroids
480    // k_update = max(1, ceil(num_outliers / max_points_per_centroid) * 4)
481    let target_k =
482        ((num_outliers as f64 / config.max_points_per_centroid as f64).ceil() as usize).max(1) * 4;
483    let k_update = target_k.min(num_outliers); // Can't have more centroids than points
484
485    // Cluster outliers to get new centroids
486    let kmeans_config = ComputeKmeansConfig {
487        kmeans_niters: config.kmeans_niters,
488        max_points_per_centroid: config.max_points_per_centroid,
489        seed: config.seed,
490        n_samples_kmeans: config.n_samples_kmeans,
491        num_partitions: Some(k_update),
492    };
493
494    // Convert outliers to vector of single-token "documents" for compute_kmeans
495    let outlier_docs: Vec<Array2<f32>> = outlier_indices
496        .iter()
497        .map(|&idx| flat_embeddings.slice(s![idx..idx + 1, ..]).to_owned())
498        .collect();
499
500    let new_centroids = compute_kmeans(&outlier_docs, &kmeans_config)?;
501    let k_new = new_centroids.nrows();
502
503    // Concatenate centroids
504    let new_num_centroids = existing_centroids.nrows() + k_new;
505    let mut final_centroids = Array2::<f32>::zeros((new_num_centroids, dim));
506    final_centroids
507        .slice_mut(s![..existing_centroids.nrows(), ..])
508        .assign(&existing_centroids);
509    final_centroids
510        .slice_mut(s![existing_centroids.nrows().., ..])
511        .assign(&new_centroids);
512
513    // Save updated centroids
514    final_centroids.write_npy(File::create(&centroids_path)?)?;
515
516    // Extend ivf_lengths.npy with zeros for new centroids
517    let ivf_lengths_path = index_path.join("ivf_lengths.npy");
518    if ivf_lengths_path.exists() {
519        let old_lengths: Array1<i32> = Array1::read_npy(File::open(&ivf_lengths_path)?)?;
520        let mut new_lengths = Array1::<i32>::zeros(new_num_centroids);
521        new_lengths
522            .slice_mut(s![..old_lengths.len()])
523            .assign(&old_lengths);
524        new_lengths.write_npy(File::create(&ivf_lengths_path)?)?;
525    }
526
527    // Update metadata.json num_partitions
528    let meta_path = index_path.join("metadata.json");
529    if meta_path.exists() {
530        let mut meta: serde_json::Value =
531            serde_json::from_reader(BufReader::new(File::open(&meta_path)?))?;
532
533        if let Some(obj) = meta.as_object_mut() {
534            obj.insert("num_partitions".to_string(), new_num_centroids.into());
535        }
536
537        serde_json::to_writer_pretty(BufWriter::new(File::create(&meta_path)?), &meta)?;
538    }
539
540    Ok(k_new)
541}
542
543// ============================================================================
544// Low-Level Index Update
545// ============================================================================
546
547/// Update an existing index with new documents.
548///
549/// # Arguments
550///
551/// * `embeddings` - List of new document embeddings, each of shape `[num_tokens, dim]`
552/// * `index_path` - Path to the existing index directory
553/// * `codec` - The loaded ResidualCodec for compression
554/// * `batch_size` - Optional batch size for processing (default: 50,000)
555/// * `update_threshold` - Whether to update the cluster threshold
556///
557/// # Returns
558///
559/// The number of new documents added
560pub fn update_index(
561    embeddings: &[Array2<f32>],
562    index_path: &str,
563    codec: &ResidualCodec,
564    batch_size: Option<usize>,
565    update_threshold: bool,
566) -> Result<usize> {
567    let batch_size = batch_size.unwrap_or(DEFAULT_BATCH_SIZE);
568    let index_dir = Path::new(index_path);
569
570    // Load existing metadata (infers num_documents from doclens if not present)
571    let metadata_path = index_dir.join("metadata.json");
572    let metadata = Metadata::load_from_path(index_dir)?;
573
574    let num_existing_chunks = metadata.num_chunks;
575    let old_num_documents = metadata.num_documents;
576    let old_total_embeddings = metadata.num_embeddings;
577    let num_centroids = codec.num_centroids();
578    let embedding_dim = codec.embedding_dim();
579    let nbits = metadata.nbits;
580
581    // Determine starting chunk index
582    let mut start_chunk_idx = num_existing_chunks;
583    let mut append_to_last = false;
584    let mut current_emb_offset = old_total_embeddings;
585
586    // Check if we should append to the last chunk (if it has < 2000 documents)
587    if start_chunk_idx > 0 {
588        let last_idx = start_chunk_idx - 1;
589        let last_meta_path = index_dir.join(format!("{}.metadata.json", last_idx));
590
591        if last_meta_path.exists() {
592            let last_meta: serde_json::Value =
593                serde_json::from_reader(BufReader::new(File::open(&last_meta_path).map_err(
594                    |e| Error::IndexLoad(format!("Failed to open chunk metadata: {}", e)),
595                )?))?;
596
597            if let Some(nd) = last_meta.get("num_documents").and_then(|x| x.as_u64()) {
598                if nd < 2000 {
599                    start_chunk_idx = last_idx;
600                    append_to_last = true;
601
602                    if let Some(off) = last_meta.get("embedding_offset").and_then(|x| x.as_u64()) {
603                        current_emb_offset = off as usize;
604                    } else {
605                        let embs_in_last = last_meta
606                            .get("num_embeddings")
607                            .and_then(|x| x.as_u64())
608                            .unwrap_or(0) as usize;
609                        current_emb_offset = old_total_embeddings - embs_in_last;
610                    }
611                }
612            }
613        }
614    }
615
616    // Process new documents
617    let num_new_documents = embeddings.len();
618    let n_new_chunks = (num_new_documents as f64 / batch_size as f64).ceil() as usize;
619
620    let mut new_codes_accumulated: Vec<Vec<usize>> = Vec::new();
621    let mut new_doclens_accumulated: Vec<i64> = Vec::new();
622    let mut all_residual_norms: Vec<f32> = Vec::new();
623
624    let packed_dim = embedding_dim * nbits / 8;
625
626    for i in 0..n_new_chunks {
627        let global_chunk_idx = start_chunk_idx + i;
628        let chk_offset = i * batch_size;
629        let chk_end = (chk_offset + batch_size).min(num_new_documents);
630        let chunk_docs = &embeddings[chk_offset..chk_end];
631
632        // Collect document lengths
633        let mut chk_doclens: Vec<i64> = chunk_docs.iter().map(|d| d.nrows() as i64).collect();
634        let total_tokens: usize = chk_doclens.iter().sum::<i64>() as usize;
635
636        // Concatenate all embeddings in the chunk for batch processing
637        let mut batch_embeddings = ndarray::Array2::<f32>::zeros((total_tokens, embedding_dim));
638        let mut offset = 0;
639        for doc in chunk_docs {
640            let n = doc.nrows();
641            batch_embeddings
642                .slice_mut(s![offset..offset + n, ..])
643                .assign(doc);
644            offset += n;
645        }
646
647        // BATCH: Compress all embeddings at once
648        let batch_codes = codec.compress_into_codes(&batch_embeddings);
649
650        // BATCH: Compute residuals using parallel subtraction
651        let mut batch_residuals = batch_embeddings;
652        {
653            let centroids = &codec.centroids;
654            batch_residuals
655                .axis_iter_mut(Axis(0))
656                .into_par_iter()
657                .zip(batch_codes.as_slice().unwrap().par_iter())
658                .for_each(|(mut row, &code)| {
659                    let centroid = centroids.row(code);
660                    row.iter_mut()
661                        .zip(centroid.iter())
662                        .for_each(|(r, c)| *r -= c);
663                });
664        }
665
666        // Collect residual norms if updating threshold
667        if update_threshold {
668            for row in batch_residuals.axis_iter(Axis(0)) {
669                let norm = row.dot(&row).sqrt();
670                all_residual_norms.push(norm);
671            }
672        }
673
674        // BATCH: Quantize all residuals at once
675        let batch_packed = codec.quantize_residuals(&batch_residuals)?;
676
677        // Convert to lists for chunk saving
678        let mut chk_codes_list: Vec<usize> = batch_codes.iter().copied().collect();
679        let mut chk_residuals_list: Vec<u8> = batch_packed.iter().copied().collect();
680
681        // Split codes back into per-document arrays for IVF building
682        let mut code_offset = 0;
683        for &len in &chk_doclens {
684            let len_usize = len as usize;
685            let codes: Vec<usize> = batch_codes
686                .slice(s![code_offset..code_offset + len_usize])
687                .iter()
688                .copied()
689                .collect();
690            new_codes_accumulated.push(codes);
691            new_doclens_accumulated.push(len);
692            code_offset += len_usize;
693        }
694
695        // Handle appending to last chunk
696        if i == 0 && append_to_last {
697            use ndarray_npy::ReadNpyExt;
698
699            let old_doclens_path = index_dir.join(format!("doclens.{}.json", global_chunk_idx));
700
701            if old_doclens_path.exists() {
702                let old_doclens: Vec<i64> =
703                    serde_json::from_reader(BufReader::new(File::open(&old_doclens_path)?))?;
704
705                let old_codes_path = index_dir.join(format!("{}.codes.npy", global_chunk_idx));
706                let old_residuals_path =
707                    index_dir.join(format!("{}.residuals.npy", global_chunk_idx));
708
709                let old_codes: Array1<i64> = Array1::read_npy(File::open(&old_codes_path)?)?;
710                let old_residuals: Array2<u8> = Array2::read_npy(File::open(&old_residuals_path)?)?;
711
712                // Prepend old data
713                let mut combined_codes: Vec<usize> =
714                    old_codes.iter().map(|&x| x as usize).collect();
715                combined_codes.extend(chk_codes_list);
716                chk_codes_list = combined_codes;
717
718                let mut combined_residuals: Vec<u8> = old_residuals.iter().copied().collect();
719                combined_residuals.extend(chk_residuals_list);
720                chk_residuals_list = combined_residuals;
721
722                let mut combined_doclens = old_doclens;
723                combined_doclens.extend(chk_doclens);
724                chk_doclens = combined_doclens;
725            }
726        }
727
728        // Save chunk data
729        {
730            use ndarray_npy::WriteNpyExt;
731
732            let codes_arr: Array1<i64> = chk_codes_list.iter().map(|&x| x as i64).collect();
733            let codes_path = index_dir.join(format!("{}.codes.npy", global_chunk_idx));
734            codes_arr.write_npy(File::create(&codes_path)?)?;
735
736            let num_tokens = chk_codes_list.len();
737            let residuals_arr =
738                Array2::from_shape_vec((num_tokens, packed_dim), chk_residuals_list)
739                    .map_err(|e| Error::Shape(format!("Failed to reshape residuals: {}", e)))?;
740            let residuals_path = index_dir.join(format!("{}.residuals.npy", global_chunk_idx));
741            residuals_arr.write_npy(File::create(&residuals_path)?)?;
742        }
743
744        // Save doclens
745        let doclens_path = index_dir.join(format!("doclens.{}.json", global_chunk_idx));
746        serde_json::to_writer(BufWriter::new(File::create(&doclens_path)?), &chk_doclens)?;
747
748        // Save chunk metadata
749        let chk_meta = serde_json::json!({
750            "num_documents": chk_doclens.len(),
751            "num_embeddings": chk_codes_list.len(),
752            "embedding_offset": current_emb_offset,
753        });
754        current_emb_offset += chk_codes_list.len();
755
756        let meta_path = index_dir.join(format!("{}.metadata.json", global_chunk_idx));
757        serde_json::to_writer_pretty(BufWriter::new(File::create(&meta_path)?), &chk_meta)?;
758    }
759
760    // Update cluster threshold if requested
761    if update_threshold && !all_residual_norms.is_empty() {
762        let norms = Array1::from_vec(all_residual_norms);
763        update_cluster_threshold(index_dir, &norms, old_total_embeddings)?;
764    }
765
766    // Build new partial IVF
767    let mut partition_pids_map: HashMap<usize, Vec<i64>> = HashMap::new();
768    let mut pid_counter = old_num_documents as i64;
769
770    for doc_codes in &new_codes_accumulated {
771        for &code in doc_codes {
772            partition_pids_map
773                .entry(code)
774                .or_default()
775                .push(pid_counter);
776        }
777        pid_counter += 1;
778    }
779
780    // Load old IVF and merge
781    {
782        use ndarray_npy::{ReadNpyExt, WriteNpyExt};
783
784        let ivf_path = index_dir.join("ivf.npy");
785        let ivf_lengths_path = index_dir.join("ivf_lengths.npy");
786
787        let old_ivf: Array1<i64> = if ivf_path.exists() {
788            Array1::read_npy(File::open(&ivf_path)?)?
789        } else {
790            Array1::zeros(0)
791        };
792
793        let old_ivf_lengths: Array1<i32> = if ivf_lengths_path.exists() {
794            Array1::read_npy(File::open(&ivf_lengths_path)?)?
795        } else {
796            Array1::zeros(num_centroids)
797        };
798
799        // Compute old offsets
800        let mut old_offsets = vec![0i64];
801        for &len in old_ivf_lengths.iter() {
802            old_offsets.push(old_offsets.last().unwrap() + len as i64);
803        }
804
805        // Merge IVF
806        let mut new_ivf_data: Vec<i64> = Vec::new();
807        let mut new_ivf_lengths: Vec<i32> = Vec::with_capacity(num_centroids);
808
809        for centroid_id in 0..num_centroids {
810            // Get old PIDs for this centroid
811            let old_start = old_offsets[centroid_id] as usize;
812            let old_len = if centroid_id < old_ivf_lengths.len() {
813                old_ivf_lengths[centroid_id] as usize
814            } else {
815                0
816            };
817
818            let mut pids: Vec<i64> = if old_len > 0 && old_start + old_len <= old_ivf.len() {
819                old_ivf.slice(s![old_start..old_start + old_len]).to_vec()
820            } else {
821                Vec::new()
822            };
823
824            // Add new PIDs
825            if let Some(new_pids) = partition_pids_map.get(&centroid_id) {
826                pids.extend(new_pids);
827            }
828
829            // Deduplicate and sort
830            pids.sort_unstable();
831            pids.dedup();
832
833            new_ivf_lengths.push(pids.len() as i32);
834            new_ivf_data.extend(pids);
835        }
836
837        // Save updated IVF
838        let new_ivf = Array1::from_vec(new_ivf_data);
839        new_ivf.write_npy(File::create(&ivf_path)?)?;
840
841        let new_lengths = Array1::from_vec(new_ivf_lengths);
842        new_lengths.write_npy(File::create(&ivf_lengths_path)?)?;
843    }
844
845    // Update global metadata
846    let new_total_chunks = start_chunk_idx + n_new_chunks;
847    let new_tokens_count: i64 = new_doclens_accumulated.iter().sum();
848    let num_embeddings = old_total_embeddings + new_tokens_count as usize;
849    let total_num_documents = old_num_documents + num_new_documents;
850
851    let new_avg_doclen = if total_num_documents > 0 {
852        let old_sum = metadata.avg_doclen * old_num_documents as f64;
853        (old_sum + new_tokens_count as f64) / total_num_documents as f64
854    } else {
855        0.0
856    };
857
858    let new_metadata = Metadata {
859        num_chunks: new_total_chunks,
860        nbits,
861        num_partitions: num_centroids,
862        num_embeddings,
863        avg_doclen: new_avg_doclen,
864        num_documents: total_num_documents,
865        next_plaid_compatible: true,
866    };
867
868    serde_json::to_writer_pretty(BufWriter::new(File::create(&metadata_path)?), &new_metadata)?;
869
870    // Clear merged files to force regeneration on next load.
871    // This ensures the merged files are consistent with the new chunk data.
872    crate::mmap::clear_merged_files(index_dir)?;
873
874    Ok(num_new_documents)
875}
876
877#[cfg(test)]
878mod tests {
879    use super::*;
880
881    #[test]
882    fn test_update_config_default() {
883        let config = UpdateConfig::default();
884        assert_eq!(config.batch_size, 50_000);
885        assert_eq!(config.buffer_size, 100);
886        assert_eq!(config.start_from_scratch, 999);
887    }
888
889    #[test]
890    fn test_find_outliers() {
891        // Create centroids at (0,0), (1,1)
892        let centroids = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 1.0, 1.0]).unwrap();
893
894        // Create embeddings: one close to (0,0), one close to (1,1), one far away at (5,5)
895        let embeddings =
896            Array2::from_shape_vec((3, 2), vec![0.1, 0.1, 0.9, 0.9, 5.0, 5.0]).unwrap();
897
898        // Threshold of 1.0 squared = 1.0
899        let outliers = find_outliers(&embeddings, &centroids, 1.0);
900
901        // Only the point at (5,5) should be an outlier
902        assert_eq!(outliers.len(), 1);
903        assert_eq!(outliers[0], 2);
904    }
905
906    #[test]
907    fn test_buffer_roundtrip() {
908        use tempfile::TempDir;
909
910        let dir = TempDir::new().unwrap();
911
912        // Create 3 documents with different numbers of embeddings
913        let embeddings = vec![
914            Array2::from_shape_vec((3, 4), (0..12).map(|x| x as f32).collect()).unwrap(),
915            Array2::from_shape_vec((2, 4), (12..20).map(|x| x as f32).collect()).unwrap(),
916            Array2::from_shape_vec((5, 4), (20..40).map(|x| x as f32).collect()).unwrap(),
917        ];
918
919        // Save buffer
920        save_buffer(dir.path(), &embeddings).unwrap();
921
922        // Load buffer and verify we get 3 documents, not 1
923        let loaded = load_buffer(dir.path()).unwrap();
924
925        assert_eq!(loaded.len(), 3, "Should have 3 documents, not 1");
926        assert_eq!(loaded[0].nrows(), 3, "First doc should have 3 rows");
927        assert_eq!(loaded[1].nrows(), 2, "Second doc should have 2 rows");
928        assert_eq!(loaded[2].nrows(), 5, "Third doc should have 5 rows");
929
930        // Verify content matches
931        assert_eq!(loaded[0], embeddings[0]);
932        assert_eq!(loaded[1], embeddings[1]);
933        assert_eq!(loaded[2], embeddings[2]);
934    }
935
936    #[test]
937    fn test_buffer_info_matches_buffer_len() {
938        use tempfile::TempDir;
939
940        let dir = TempDir::new().unwrap();
941
942        // Create 5 documents
943        let embeddings: Vec<Array2<f32>> = (0..5)
944            .map(|i| {
945                let rows = i + 2; // 2, 3, 4, 5, 6 rows
946                Array2::from_shape_fn((rows, 4), |(r, c)| (r * 4 + c) as f32)
947            })
948            .collect();
949
950        save_buffer(dir.path(), &embeddings).unwrap();
951
952        // Verify buffer_info.json matches actual document count
953        let info_count = load_buffer_info(dir.path()).unwrap();
954        let loaded = load_buffer(dir.path()).unwrap();
955
956        assert_eq!(info_count, 5, "buffer_info should report 5 docs");
957        assert_eq!(
958            loaded.len(),
959            5,
960            "load_buffer should return 5 docs to match buffer_info"
961        );
962    }
963}