Skip to main content

next_plaid/
update.rs

1//! Index update functionality for adding new documents.
2//!
3//! This module provides functions to incrementally update an existing PLAID index
4//! with new documents, matching fast-plaid's behavior:
5//! - Buffer mechanism for small updates
6//! - Centroid expansion for outliers
7//! - Cluster threshold updates
8
9use std::collections::HashMap;
10use std::fs;
11use std::fs::File;
12use std::io::{BufReader, BufWriter};
13use std::path::Path;
14
15use serde::{Deserialize, Serialize};
16
17use ndarray::{s, Array1, Array2, Axis};
18use rayon::prelude::*;
19
20use crate::codec::ResidualCodec;
21use crate::error::Error;
22use crate::error::Result;
23use crate::index::Metadata;
24use crate::kmeans::compute_kmeans;
25use crate::kmeans::ComputeKmeansConfig;
26use crate::utils::quantile;
27
28/// Default batch size for processing documents (matches fast-plaid).
29const DEFAULT_BATCH_SIZE: usize = 50_000;
30
31/// Configuration for index updates.
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct UpdateConfig {
34    /// Batch size for processing documents (default: 50,000)
35    pub batch_size: usize,
36    /// Number of K-means iterations for centroid expansion (default: 4)
37    pub kmeans_niters: usize,
38    /// Max points per centroid for K-means (default: 256)
39    pub max_points_per_centroid: usize,
40    /// Number of samples for K-means (default: auto-calculated)
41    pub n_samples_kmeans: Option<usize>,
42    /// Random seed (default: 42)
43    pub seed: u64,
44    /// If index has fewer docs than this, rebuild from scratch (default: 999)
45    pub start_from_scratch: usize,
46    /// Buffer size before triggering centroid expansion (default: 100)
47    pub buffer_size: usize,
48    /// Force CPU execution for K-means even when CUDA feature is enabled.
49    #[serde(default)]
50    pub force_cpu: bool,
51}
52
53impl Default for UpdateConfig {
54    fn default() -> Self {
55        Self {
56            batch_size: DEFAULT_BATCH_SIZE,
57            kmeans_niters: 4,
58            max_points_per_centroid: 256,
59            n_samples_kmeans: None,
60            seed: 42,
61            start_from_scratch: 999,
62            buffer_size: 100,
63            force_cpu: false,
64        }
65    }
66}
67
68impl UpdateConfig {
69    /// Convert to ComputeKmeansConfig for centroid expansion.
70    pub fn to_kmeans_config(&self) -> ComputeKmeansConfig {
71        ComputeKmeansConfig {
72            kmeans_niters: self.kmeans_niters,
73            max_points_per_centroid: self.max_points_per_centroid,
74            seed: self.seed,
75            n_samples_kmeans: self.n_samples_kmeans,
76            num_partitions: None,
77            force_cpu: self.force_cpu,
78        }
79    }
80}
81
82// ============================================================================
83// Buffer Management
84// ============================================================================
85
86/// Load buffered embeddings from buffer.npy.
87///
88/// Returns an empty vector if buffer.npy doesn't exist.
89/// Uses buffer_lengths.json to split the flattened array back into per-document arrays.
90pub fn load_buffer(index_path: &Path) -> Result<Vec<Array2<f32>>> {
91    use ndarray_npy::ReadNpyExt;
92
93    let buffer_path = index_path.join("buffer.npy");
94    let lengths_path = index_path.join("buffer_lengths.json");
95
96    if !buffer_path.exists() {
97        return Ok(Vec::new());
98    }
99
100    // Load the flattened embeddings array
101    let flat: Array2<f32> = match Array2::read_npy(File::open(&buffer_path)?) {
102        Ok(arr) => arr,
103        Err(_) => return Ok(Vec::new()),
104    };
105
106    // Load lengths to split back into per-document arrays
107    if lengths_path.exists() {
108        let lengths: Vec<i64> =
109            serde_json::from_reader(BufReader::new(File::open(&lengths_path)?))?;
110
111        let mut result = Vec::with_capacity(lengths.len());
112        let mut offset = 0;
113
114        for &len in &lengths {
115            let len_usize = len as usize;
116            if offset + len_usize > flat.nrows() {
117                break;
118            }
119            let doc_emb = flat.slice(s![offset..offset + len_usize, ..]).to_owned();
120            result.push(doc_emb);
121            offset += len_usize;
122        }
123
124        return Ok(result);
125    }
126
127    // Fallback: if no lengths file, return as single document (legacy behavior)
128    Ok(vec![flat])
129}
130
131/// Save embeddings to buffer.npy.
132///
133/// Also saves buffer_info.json with the number of documents for deletion tracking.
134pub fn save_buffer(index_path: &Path, embeddings: &[Array2<f32>]) -> Result<()> {
135    use ndarray_npy::WriteNpyExt;
136
137    let buffer_path = index_path.join("buffer.npy");
138
139    // For simplicity, concatenate all embeddings into one array
140    // and store the lengths separately
141    if embeddings.is_empty() {
142        return Ok(());
143    }
144
145    let dim = embeddings[0].ncols();
146    let total_rows: usize = embeddings.iter().map(|e| e.nrows()).sum();
147
148    let mut flat = Array2::<f32>::zeros((total_rows, dim));
149    let mut offset = 0;
150    let mut lengths = Vec::new();
151
152    for emb in embeddings {
153        let n = emb.nrows();
154        flat.slice_mut(s![offset..offset + n, ..]).assign(emb);
155        lengths.push(n as i64);
156        offset += n;
157    }
158
159    flat.write_npy(File::create(&buffer_path)?)?;
160
161    // Save lengths
162    let lengths_path = index_path.join("buffer_lengths.json");
163    serde_json::to_writer(BufWriter::new(File::create(&lengths_path)?), &lengths)?;
164
165    // Save buffer info for deletion tracking (number of documents)
166    let info_path = index_path.join("buffer_info.json");
167    let buffer_info = serde_json::json!({ "num_docs": embeddings.len() });
168    serde_json::to_writer(BufWriter::new(File::create(&info_path)?), &buffer_info)?;
169
170    Ok(())
171}
172
173/// Load buffer info (number of buffered documents).
174///
175/// Returns 0 if buffer_info.json doesn't exist.
176pub fn load_buffer_info(index_path: &Path) -> Result<usize> {
177    let info_path = index_path.join("buffer_info.json");
178    if !info_path.exists() {
179        return Ok(0);
180    }
181
182    let info: serde_json::Value = serde_json::from_reader(BufReader::new(File::open(&info_path)?))?;
183
184    Ok(info.get("num_docs").and_then(|v| v.as_u64()).unwrap_or(0) as usize)
185}
186
187/// Clear buffer files.
188pub fn clear_buffer(index_path: &Path) -> Result<()> {
189    let buffer_path = index_path.join("buffer.npy");
190    let lengths_path = index_path.join("buffer_lengths.json");
191    let info_path = index_path.join("buffer_info.json");
192
193    if buffer_path.exists() {
194        fs::remove_file(&buffer_path)?;
195    }
196    if lengths_path.exists() {
197        fs::remove_file(&lengths_path)?;
198    }
199    if info_path.exists() {
200        fs::remove_file(&info_path)?;
201    }
202
203    Ok(())
204}
205
206/// Load embeddings stored for rebuild (embeddings.npy + embeddings_lengths.json).
207///
208/// This function loads raw embeddings that were saved for start-from-scratch rebuilds.
209/// The embeddings are stored in a flat 2D array with a separate lengths file.
210pub fn load_embeddings_npy(index_path: &Path) -> Result<Vec<Array2<f32>>> {
211    use ndarray_npy::ReadNpyExt;
212
213    let emb_path = index_path.join("embeddings.npy");
214    let lengths_path = index_path.join("embeddings_lengths.json");
215
216    if !emb_path.exists() {
217        return Ok(Vec::new());
218    }
219
220    // Load flat embeddings array
221    let flat: Array2<f32> = Array2::read_npy(File::open(&emb_path)?)?;
222
223    // Load lengths to split back into per-document arrays
224    if lengths_path.exists() {
225        let lengths: Vec<i64> =
226            serde_json::from_reader(BufReader::new(File::open(&lengths_path)?))?;
227
228        let mut result = Vec::with_capacity(lengths.len());
229        let mut offset = 0;
230
231        for &len in &lengths {
232            let len_usize = len as usize;
233            if offset + len_usize > flat.nrows() {
234                break;
235            }
236            let doc_emb = flat.slice(s![offset..offset + len_usize, ..]).to_owned();
237            result.push(doc_emb);
238            offset += len_usize;
239        }
240
241        return Ok(result);
242    }
243
244    // Fallback: if no lengths file, return as single document
245    Ok(vec![flat])
246}
247
248/// Save embeddings for potential rebuild (start-from-scratch mode).
249///
250/// Stores embeddings in embeddings.npy (flat array) + embeddings_lengths.json.
251/// This matches fast-plaid's behavior of storing raw embeddings when the index
252/// is below the start_from_scratch threshold.
253pub fn save_embeddings_npy(index_path: &Path, embeddings: &[Array2<f32>]) -> Result<()> {
254    use ndarray_npy::WriteNpyExt;
255
256    if embeddings.is_empty() {
257        return Ok(());
258    }
259
260    let dim = embeddings[0].ncols();
261    let total_rows: usize = embeddings.iter().map(|e| e.nrows()).sum();
262
263    let mut flat = Array2::<f32>::zeros((total_rows, dim));
264    let mut offset = 0;
265    let mut lengths = Vec::with_capacity(embeddings.len());
266
267    for emb in embeddings {
268        let n = emb.nrows();
269        flat.slice_mut(s![offset..offset + n, ..]).assign(emb);
270        lengths.push(n as i64);
271        offset += n;
272    }
273
274    // Save flat embeddings
275    let emb_path = index_path.join("embeddings.npy");
276    flat.write_npy(File::create(&emb_path)?)?;
277
278    // Save lengths for reconstruction
279    let lengths_path = index_path.join("embeddings_lengths.json");
280    serde_json::to_writer(BufWriter::new(File::create(&lengths_path)?), &lengths)?;
281
282    Ok(())
283}
284
285/// Clear embeddings.npy and embeddings_lengths.json.
286pub fn clear_embeddings_npy(index_path: &Path) -> Result<()> {
287    let emb_path = index_path.join("embeddings.npy");
288    let lengths_path = index_path.join("embeddings_lengths.json");
289
290    if emb_path.exists() {
291        fs::remove_file(&emb_path)?;
292    }
293    if lengths_path.exists() {
294        fs::remove_file(&lengths_path)?;
295    }
296    Ok(())
297}
298
299/// Check if embeddings.npy exists for start-from-scratch mode.
300pub fn embeddings_npy_exists(index_path: &Path) -> bool {
301    index_path.join("embeddings.npy").exists()
302}
303
304// ============================================================================
305// Cluster Threshold Management
306// ============================================================================
307
308/// Load cluster threshold from cluster_threshold.npy.
309pub fn load_cluster_threshold(index_path: &Path) -> Result<f32> {
310    use ndarray_npy::ReadNpyExt;
311
312    let thresh_path = index_path.join("cluster_threshold.npy");
313    if !thresh_path.exists() {
314        return Err(Error::Update("cluster_threshold.npy not found".into()));
315    }
316
317    let arr: Array1<f32> = Array1::read_npy(File::open(&thresh_path)?)?;
318    Ok(arr[0])
319}
320
321/// Update cluster_threshold.npy with weighted average.
322pub fn update_cluster_threshold(
323    index_path: &Path,
324    new_residual_norms: &Array1<f32>,
325    old_total_embeddings: usize,
326) -> Result<()> {
327    use ndarray_npy::{ReadNpyExt, WriteNpyExt};
328
329    let new_count = new_residual_norms.len();
330    if new_count == 0 {
331        return Ok(());
332    }
333
334    let new_threshold = quantile(new_residual_norms, 0.75);
335
336    let thresh_path = index_path.join("cluster_threshold.npy");
337    let final_threshold = if thresh_path.exists() {
338        let old_arr: Array1<f32> = Array1::read_npy(File::open(&thresh_path)?)?;
339        let old_threshold = old_arr[0];
340        let total = old_total_embeddings + new_count;
341        (old_threshold * old_total_embeddings as f32 + new_threshold * new_count as f32)
342            / total as f32
343    } else {
344        new_threshold
345    };
346
347    Array1::from_vec(vec![final_threshold]).write_npy(File::create(&thresh_path)?)?;
348
349    Ok(())
350}
351
352// ============================================================================
353// Centroid Expansion
354// ============================================================================
355
356const OUTLIER_CENTROID_TILE: usize = 8;
357const OUTLIER_EMBEDDING_TILE: usize = 64;
358const OUTLIER_BLOCKS_MIN_LEN: usize = 4;
359const OUTLIER_THRESHOLD_RECHECK_REL_EPS: f32 = 1e-5;
360
361#[inline]
362fn squared_norm(row: &[f32]) -> f32 {
363    let mut sum0 = 0.0f32;
364    let mut sum1 = 0.0f32;
365    let mut sum2 = 0.0f32;
366    let mut sum3 = 0.0f32;
367
368    let mut i = 0;
369    while i + 4 <= row.len() {
370        sum0 += row[i] * row[i];
371        sum1 += row[i + 1] * row[i + 1];
372        sum2 += row[i + 2] * row[i + 2];
373        sum3 += row[i + 3] * row[i + 3];
374        i += 4;
375    }
376
377    let mut total = sum0 + sum1 + sum2 + sum3;
378    while i < row.len() {
379        total += row[i] * row[i];
380        i += 1;
381    }
382
383    total
384}
385
386/// Recompute min L2² distance in f64 for borderline candidates.
387/// The f32 tiled loop can produce rounding errors near the threshold;
388/// promoting to f64 here eliminates false positives without slowing
389/// down the fast path (only called for ~1% of embeddings).
390#[inline]
391fn min_distance_sq_precise(row: &[f32], centroids_flat: &[f32], dim: usize) -> f32 {
392    let mut min_dist_sq = f32::INFINITY;
393
394    for centroid in centroids_flat.chunks_exact(dim) {
395        let mut dist_sq = 0.0f64;
396        let mut d = 0;
397        while d < dim {
398            let diff = row[d] as f64 - centroid[d] as f64;
399            dist_sq += diff * diff;
400            d += 1;
401        }
402
403        min_dist_sq = min_dist_sq.min(dist_sq as f32);
404    }
405
406    min_dist_sq
407}
408
409/// Find outliers with a blocked raw-slice kernel over centroid tiles.
410///
411/// This avoids materializing the intermediate [batch, num_centroids] scores matrix,
412/// removes iterator overhead from the hot loop, and parallelizes across embedding blocks.
413/// Find embeddings whose minimum L2² distance to any centroid exceeds `threshold_sq`.
414///
415/// Uses a tiled blocked loop instead of a full [n, k] similarity matrix:
416///   - Outer: blocks of OUTLIER_EMBEDDING_TILE embeddings (parallelized via rayon)
417///   - Inner: tiles of OUTLIER_CENTROID_TILE centroids
418///
419/// This avoids allocating O(n*k) memory and keeps the working set in L1/L2 cache.
420/// The dot-product loop (`dim_idx`) is the innermost to maximize sequential memory
421/// access across the embedding dimension. Borderline results (within 1% of the
422/// threshold) are rechecked in f64 to avoid false positives from f32 rounding.
423#[allow(clippy::needless_range_loop)]
424fn find_outliers(
425    flat_embeddings: &Array2<f32>,
426    centroids: &Array2<f32>,
427    threshold_sq: f32,
428) -> Vec<usize> {
429    let n = flat_embeddings.nrows();
430    let k = centroids.nrows();
431    let dim = flat_embeddings.ncols();
432
433    if n == 0 || k == 0 {
434        return Vec::new();
435    }
436
437    let embeddings_owned;
438    let embeddings_flat = if let Some(slice) = flat_embeddings.as_slice_memory_order() {
439        slice
440    } else {
441        embeddings_owned = flat_embeddings.as_standard_layout().to_owned();
442        embeddings_owned
443            .as_slice_memory_order()
444            .expect("standard-layout embeddings should be contiguous")
445    };
446    let centroids_owned;
447    let centroids_flat = if let Some(slice) = centroids.as_slice_memory_order() {
448        slice
449    } else {
450        centroids_owned = centroids.as_standard_layout().to_owned();
451        centroids_owned
452            .as_slice_memory_order()
453            .expect("standard-layout centroids should be contiguous")
454    };
455
456    let centroid_norms_sq: Vec<f32> = centroids_flat
457        .par_chunks_exact(dim)
458        .map(squared_norm)
459        .collect();
460
461    let row_stride = dim * OUTLIER_EMBEDDING_TILE;
462    embeddings_flat
463        .par_chunks(row_stride)
464        .with_min_len(OUTLIER_BLOCKS_MIN_LEN)
465        .enumerate()
466        .flat_map_iter(|(block_idx, rows_block)| {
467            let row_count = rows_block.len() / dim;
468            let row_offset = block_idx * OUTLIER_EMBEDDING_TILE;
469
470            let mut min_dists = vec![f32::INFINITY; row_count];
471            let emb_norms: Vec<f32> = rows_block.chunks_exact(dim).map(squared_norm).collect();
472
473            let mut centroid_idx = 0;
474            while centroid_idx + OUTLIER_CENTROID_TILE <= k {
475                let centroid_bases: [usize; OUTLIER_CENTROID_TILE] =
476                    std::array::from_fn(|j| (centroid_idx + j) * dim);
477
478                let mut dots = [[0.0f32; OUTLIER_CENTROID_TILE]; OUTLIER_EMBEDDING_TILE];
479
480                let mut dim_idx = 0;
481                while dim_idx < dim {
482                    let centroid_vals: [f32; OUTLIER_CENTROID_TILE] =
483                        std::array::from_fn(|j| centroids_flat[centroid_bases[j] + dim_idx]);
484
485                    for row_idx in 0..row_count {
486                        let x = rows_block[row_idx * dim + dim_idx];
487                        for j in 0..OUTLIER_CENTROID_TILE {
488                            dots[row_idx][j] += x * centroid_vals[j];
489                        }
490                    }
491
492                    dim_idx += 1;
493                }
494
495                for row_idx in 0..row_count {
496                    let emb_norm_sq = emb_norms[row_idx];
497                    for j in 0..OUTLIER_CENTROID_TILE {
498                        let dist_sq = emb_norm_sq + centroid_norms_sq[centroid_idx + j]
499                            - 2.0 * dots[row_idx][j];
500                        min_dists[row_idx] = min_dists[row_idx].min(dist_sq);
501                    }
502                }
503
504                centroid_idx += OUTLIER_CENTROID_TILE;
505            }
506
507            while centroid_idx < k {
508                let centroid = &centroids_flat[centroid_idx * dim..(centroid_idx + 1) * dim];
509                for row_idx in 0..row_count {
510                    let row = &rows_block[row_idx * dim..(row_idx + 1) * dim];
511                    let mut dot = 0.0f32;
512                    let mut dim_idx = 0;
513                    while dim_idx < dim {
514                        dot += row[dim_idx] * centroid[dim_idx];
515                        dim_idx += 1;
516                    }
517
518                    let dist_sq = emb_norms[row_idx] + centroid_norms_sq[centroid_idx] - 2.0 * dot;
519                    min_dists[row_idx] = min_dists[row_idx].min(dist_sq);
520                }
521
522                centroid_idx += 1;
523            }
524
525            min_dists
526                .into_iter()
527                .enumerate()
528                .filter_map(move |(row_idx, min_dist_sq)| {
529                    let final_min_dist_sq = if (min_dist_sq - threshold_sq).abs()
530                        <= threshold_sq.abs().max(1.0) * OUTLIER_THRESHOLD_RECHECK_REL_EPS
531                    {
532                        let row = &rows_block[row_idx * dim..(row_idx + 1) * dim];
533                        min_distance_sq_precise(row, centroids_flat, dim)
534                    } else {
535                        min_dist_sq
536                    };
537
538                    (final_min_dist_sq > threshold_sq).then_some(row_offset + row_idx)
539                })
540        })
541        .collect()
542}
543
544/// Expand centroids by clustering embeddings far from existing centroids.
545///
546/// This implements fast-plaid's update_centroids() function:
547/// 1. Flatten all new embeddings
548/// 2. Find outliers (distance > cluster_threshold²)
549/// 3. Cluster outliers to get new centroids
550/// 4. Append new centroids to centroids.npy
551/// 5. Extend ivf_lengths.npy with zeros
552/// 6. Update metadata.json num_partitions
553///
554/// Returns the number of new centroids added.
555pub fn update_centroids(
556    index_path: &Path,
557    new_embeddings: &[Array2<f32>],
558    cluster_threshold: f32,
559    config: &UpdateConfig,
560) -> Result<usize> {
561    use ndarray_npy::{ReadNpyExt, WriteNpyExt};
562
563    let centroids_path = index_path.join("centroids.npy");
564    if !centroids_path.exists() {
565        return Ok(0);
566    }
567
568    // Load existing centroids
569    let existing_centroids: Array2<f32> = Array2::read_npy(File::open(&centroids_path)?)?;
570
571    // Flatten all new embeddings
572    let dim = existing_centroids.ncols();
573    let total_tokens: usize = new_embeddings.iter().map(|e| e.nrows()).sum();
574
575    if total_tokens == 0 {
576        return Ok(0);
577    }
578
579    let mut flat_embeddings = Array2::<f32>::zeros((total_tokens, dim));
580    let mut offset = 0;
581
582    for emb in new_embeddings {
583        let n = emb.nrows();
584        flat_embeddings
585            .slice_mut(s![offset..offset + n, ..])
586            .assign(emb);
587        offset += n;
588    }
589
590    // Find outliers
591    let threshold_sq = cluster_threshold * cluster_threshold;
592    let outlier_indices = find_outliers(&flat_embeddings, &existing_centroids, threshold_sq);
593
594    let num_outliers = outlier_indices.len();
595    if num_outliers == 0 {
596        return Ok(0);
597    }
598
599    // Extract outlier embeddings
600    let mut outliers = Array2::<f32>::zeros((num_outliers, dim));
601    for (i, &idx) in outlier_indices.iter().enumerate() {
602        outliers.row_mut(i).assign(&flat_embeddings.row(idx));
603    }
604
605    // Compute number of new centroids
606    // k_update = max(1, ceil(num_outliers / max_points_per_centroid) * 4)
607    let target_k =
608        ((num_outliers as f64 / config.max_points_per_centroid as f64).ceil() as usize).max(1) * 4;
609    let k_update = target_k.min(num_outliers); // Can't have more centroids than points
610
611    // Cluster outliers to get new centroids
612    let kmeans_config = ComputeKmeansConfig {
613        kmeans_niters: config.kmeans_niters,
614        max_points_per_centroid: config.max_points_per_centroid,
615        seed: config.seed,
616        n_samples_kmeans: config.n_samples_kmeans,
617        num_partitions: Some(k_update),
618        force_cpu: config.force_cpu,
619    };
620
621    // Convert outliers to vector of single-token "documents" for compute_kmeans
622    let outlier_docs: Vec<Array2<f32>> = outlier_indices
623        .iter()
624        .map(|&idx| flat_embeddings.slice(s![idx..idx + 1, ..]).to_owned())
625        .collect();
626
627    let new_centroids = compute_kmeans(&outlier_docs, &kmeans_config)?;
628    let k_new = new_centroids.nrows();
629
630    // Concatenate centroids
631    let new_num_centroids = existing_centroids.nrows() + k_new;
632    let mut final_centroids = Array2::<f32>::zeros((new_num_centroids, dim));
633    final_centroids
634        .slice_mut(s![..existing_centroids.nrows(), ..])
635        .assign(&existing_centroids);
636    final_centroids
637        .slice_mut(s![existing_centroids.nrows().., ..])
638        .assign(&new_centroids);
639
640    // Save updated centroids
641    final_centroids.write_npy(File::create(&centroids_path)?)?;
642
643    // Extend ivf_lengths.npy with zeros for new centroids
644    let ivf_lengths_path = index_path.join("ivf_lengths.npy");
645    if ivf_lengths_path.exists() {
646        let old_lengths: Array1<i32> = Array1::read_npy(File::open(&ivf_lengths_path)?)?;
647        let mut new_lengths = Array1::<i32>::zeros(new_num_centroids);
648        new_lengths
649            .slice_mut(s![..old_lengths.len()])
650            .assign(&old_lengths);
651        new_lengths.write_npy(File::create(&ivf_lengths_path)?)?;
652    }
653
654    // Update metadata.json num_partitions
655    let meta_path = index_path.join("metadata.json");
656    if meta_path.exists() {
657        let mut meta: serde_json::Value =
658            serde_json::from_reader(BufReader::new(File::open(&meta_path)?))?;
659
660        if let Some(obj) = meta.as_object_mut() {
661            obj.insert("num_partitions".to_string(), new_num_centroids.into());
662        }
663
664        serde_json::to_writer_pretty(BufWriter::new(File::create(&meta_path)?), &meta)?;
665    }
666
667    Ok(k_new)
668}
669
670// ============================================================================
671// Low-Level Index Update
672// ============================================================================
673
674/// Update an existing index with new documents.
675///
676/// # Arguments
677///
678/// * `embeddings` - List of new document embeddings, each of shape `[num_tokens, dim]`
679/// * `index_path` - Path to the existing index directory
680/// * `codec` - The loaded ResidualCodec for compression
681/// * `batch_size` - Optional batch size for processing (default: 50,000)
682/// * `update_threshold` - Whether to update the cluster threshold
683/// * `force_cpu` - Force CPU execution even when CUDA is available
684///
685/// # Returns
686///
687/// The number of new documents added
688pub fn update_index(
689    embeddings: &[Array2<f32>],
690    index_path: &str,
691    codec: &ResidualCodec,
692    batch_size: Option<usize>,
693    update_threshold: bool,
694    force_cpu: bool,
695) -> Result<usize> {
696    let batch_size = batch_size.unwrap_or(DEFAULT_BATCH_SIZE);
697    let index_dir = Path::new(index_path);
698
699    // Load existing metadata (infers num_documents from doclens if not present)
700    let metadata_path = index_dir.join("metadata.json");
701    let metadata = Metadata::load_from_path(index_dir)?;
702
703    let num_existing_chunks = metadata.num_chunks;
704    let old_num_documents = metadata.num_documents;
705    let old_total_embeddings = metadata.num_embeddings;
706    let num_centroids = codec.num_centroids();
707    let embedding_dim = codec.embedding_dim();
708    let nbits = metadata.nbits;
709
710    // Determine starting chunk index
711    let mut start_chunk_idx = num_existing_chunks;
712    let mut append_to_last = false;
713    let mut current_emb_offset = old_total_embeddings;
714
715    // Check if we should append to the last chunk (if it has < 2000 documents)
716    if start_chunk_idx > 0 {
717        let last_idx = start_chunk_idx - 1;
718        let last_meta_path = index_dir.join(format!("{}.metadata.json", last_idx));
719
720        if last_meta_path.exists() {
721            let last_meta: serde_json::Value =
722                serde_json::from_reader(BufReader::new(File::open(&last_meta_path).map_err(
723                    |e| Error::IndexLoad(format!("Failed to open chunk metadata: {}", e)),
724                )?))?;
725
726            if let Some(nd) = last_meta.get("num_documents").and_then(|x| x.as_u64()) {
727                if nd < 2000 {
728                    start_chunk_idx = last_idx;
729                    append_to_last = true;
730
731                    if let Some(off) = last_meta.get("embedding_offset").and_then(|x| x.as_u64()) {
732                        current_emb_offset = off as usize;
733                    } else {
734                        let embs_in_last = last_meta
735                            .get("num_embeddings")
736                            .and_then(|x| x.as_u64())
737                            .unwrap_or(0) as usize;
738                        current_emb_offset = old_total_embeddings - embs_in_last;
739                    }
740                }
741            }
742        }
743    }
744
745    // Process new documents
746    let num_new_documents = embeddings.len();
747    let n_new_chunks = (num_new_documents as f64 / batch_size as f64).ceil() as usize;
748
749    let mut new_codes_accumulated: Vec<Vec<usize>> = Vec::new();
750    let mut new_doclens_accumulated: Vec<i64> = Vec::new();
751    let mut all_residual_norms: Vec<f32> = Vec::new();
752
753    let packed_dim = embedding_dim * nbits / 8;
754
755    for i in 0..n_new_chunks {
756        let global_chunk_idx = start_chunk_idx + i;
757        let chk_offset = i * batch_size;
758        let chk_end = (chk_offset + batch_size).min(num_new_documents);
759        let chunk_docs = &embeddings[chk_offset..chk_end];
760
761        // Collect document lengths
762        let mut chk_doclens: Vec<i64> = chunk_docs.iter().map(|d| d.nrows() as i64).collect();
763        let total_tokens: usize = chk_doclens.iter().sum::<i64>() as usize;
764
765        // Concatenate all embeddings in the chunk for batch processing
766        let mut batch_embeddings = ndarray::Array2::<f32>::zeros((total_tokens, embedding_dim));
767        let mut offset = 0;
768        for doc in chunk_docs {
769            let n = doc.nrows();
770            batch_embeddings
771                .slice_mut(s![offset..offset + n, ..])
772                .assign(doc);
773            offset += n;
774        }
775
776        // BATCH: Compress all embeddings at once
777        // Use CPU-only version when force_cpu is set to avoid CUDA initialization overhead
778        let batch_codes = if force_cpu {
779            codec.compress_into_codes_cpu(&batch_embeddings)
780        } else {
781            codec.compress_into_codes(&batch_embeddings)
782        };
783
784        // BATCH: Compute residuals using parallel subtraction
785        let mut batch_residuals = batch_embeddings;
786        {
787            let centroids = &codec.centroids;
788            batch_residuals
789                .axis_iter_mut(Axis(0))
790                .into_par_iter()
791                .zip(batch_codes.as_slice().unwrap().par_iter())
792                .for_each(|(mut row, &code)| {
793                    let centroid = centroids.row(code);
794                    row.iter_mut()
795                        .zip(centroid.iter())
796                        .for_each(|(r, c)| *r -= c);
797                });
798        }
799
800        // Collect residual norms if updating threshold
801        if update_threshold {
802            for row in batch_residuals.axis_iter(Axis(0)) {
803                let norm = row.dot(&row).sqrt();
804                all_residual_norms.push(norm);
805            }
806        }
807
808        // BATCH: Quantize all residuals at once
809        let batch_packed = codec.quantize_residuals(&batch_residuals)?;
810
811        // Convert to lists for chunk saving
812        let mut chk_codes_list: Vec<usize> = batch_codes.iter().copied().collect();
813        let mut chk_residuals_list: Vec<u8> = batch_packed.iter().copied().collect();
814
815        // Split codes back into per-document arrays for IVF building
816        let mut code_offset = 0;
817        for &len in &chk_doclens {
818            let len_usize = len as usize;
819            let codes: Vec<usize> = batch_codes
820                .slice(s![code_offset..code_offset + len_usize])
821                .iter()
822                .copied()
823                .collect();
824            new_codes_accumulated.push(codes);
825            new_doclens_accumulated.push(len);
826            code_offset += len_usize;
827        }
828
829        // Handle appending to last chunk
830        if i == 0 && append_to_last {
831            use ndarray_npy::ReadNpyExt;
832
833            let old_doclens_path = index_dir.join(format!("doclens.{}.json", global_chunk_idx));
834
835            if old_doclens_path.exists() {
836                let old_doclens: Vec<i64> =
837                    serde_json::from_reader(BufReader::new(File::open(&old_doclens_path)?))?;
838
839                let old_codes_path = index_dir.join(format!("{}.codes.npy", global_chunk_idx));
840                let old_residuals_path =
841                    index_dir.join(format!("{}.residuals.npy", global_chunk_idx));
842
843                let old_codes: Array1<i64> = Array1::read_npy(File::open(&old_codes_path)?)?;
844                let old_residuals: Array2<u8> = Array2::read_npy(File::open(&old_residuals_path)?)?;
845
846                // Prepend old data
847                let mut combined_codes: Vec<usize> =
848                    old_codes.iter().map(|&x| x as usize).collect();
849                combined_codes.extend(chk_codes_list);
850                chk_codes_list = combined_codes;
851
852                let mut combined_residuals: Vec<u8> = old_residuals.iter().copied().collect();
853                combined_residuals.extend(chk_residuals_list);
854                chk_residuals_list = combined_residuals;
855
856                let mut combined_doclens = old_doclens;
857                combined_doclens.extend(chk_doclens);
858                chk_doclens = combined_doclens;
859            }
860        }
861
862        // Save chunk data
863        {
864            use ndarray_npy::WriteNpyExt;
865
866            let codes_arr: Array1<i64> = chk_codes_list.iter().map(|&x| x as i64).collect();
867            let codes_path = index_dir.join(format!("{}.codes.npy", global_chunk_idx));
868            codes_arr.write_npy(File::create(&codes_path)?)?;
869
870            let num_tokens = chk_codes_list.len();
871            let residuals_arr =
872                Array2::from_shape_vec((num_tokens, packed_dim), chk_residuals_list)
873                    .map_err(|e| Error::Shape(format!("Failed to reshape residuals: {}", e)))?;
874            let residuals_path = index_dir.join(format!("{}.residuals.npy", global_chunk_idx));
875            residuals_arr.write_npy(File::create(&residuals_path)?)?;
876        }
877
878        // Save doclens
879        let doclens_path = index_dir.join(format!("doclens.{}.json", global_chunk_idx));
880        serde_json::to_writer(BufWriter::new(File::create(&doclens_path)?), &chk_doclens)?;
881
882        // Save chunk metadata
883        let chk_meta = serde_json::json!({
884            "num_documents": chk_doclens.len(),
885            "num_embeddings": chk_codes_list.len(),
886            "embedding_offset": current_emb_offset,
887        });
888        current_emb_offset += chk_codes_list.len();
889
890        let meta_path = index_dir.join(format!("{}.metadata.json", global_chunk_idx));
891        serde_json::to_writer_pretty(BufWriter::new(File::create(&meta_path)?), &chk_meta)?;
892    }
893
894    // Update cluster threshold if requested
895    if update_threshold && !all_residual_norms.is_empty() {
896        let norms = Array1::from_vec(all_residual_norms);
897        update_cluster_threshold(index_dir, &norms, old_total_embeddings)?;
898    }
899
900    // Build new partial IVF
901    let mut partition_pids_map: HashMap<usize, Vec<i64>> = HashMap::new();
902
903    for (pid_counter, doc_codes) in (old_num_documents as i64..).zip(new_codes_accumulated.iter()) {
904        for &code in doc_codes {
905            partition_pids_map
906                .entry(code)
907                .or_default()
908                .push(pid_counter);
909        }
910    }
911
912    // Load old IVF and merge
913    {
914        use ndarray_npy::{ReadNpyExt, WriteNpyExt};
915
916        let ivf_path = index_dir.join("ivf.npy");
917        let ivf_lengths_path = index_dir.join("ivf_lengths.npy");
918
919        let old_ivf: Array1<i64> = if ivf_path.exists() {
920            Array1::read_npy(File::open(&ivf_path)?)?
921        } else {
922            Array1::zeros(0)
923        };
924
925        let old_ivf_lengths: Array1<i32> = if ivf_lengths_path.exists() {
926            Array1::read_npy(File::open(&ivf_lengths_path)?)?
927        } else {
928            Array1::zeros(num_centroids)
929        };
930
931        // Compute old offsets
932        let mut old_offsets = vec![0i64];
933        for &len in old_ivf_lengths.iter() {
934            old_offsets.push(old_offsets.last().unwrap() + len as i64);
935        }
936
937        // Merge IVF
938        let mut new_ivf_data: Vec<i64> = Vec::new();
939        let mut new_ivf_lengths: Vec<i32> = Vec::with_capacity(num_centroids);
940
941        for centroid_id in 0..num_centroids {
942            // Get old PIDs for this centroid
943            let old_start = old_offsets[centroid_id] as usize;
944            let old_len = if centroid_id < old_ivf_lengths.len() {
945                old_ivf_lengths[centroid_id] as usize
946            } else {
947                0
948            };
949
950            let mut pids: Vec<i64> = if old_len > 0 && old_start + old_len <= old_ivf.len() {
951                old_ivf.slice(s![old_start..old_start + old_len]).to_vec()
952            } else {
953                Vec::new()
954            };
955
956            // Add new PIDs
957            if let Some(new_pids) = partition_pids_map.get(&centroid_id) {
958                pids.extend(new_pids);
959            }
960
961            // Deduplicate and sort
962            pids.sort_unstable();
963            pids.dedup();
964
965            new_ivf_lengths.push(pids.len() as i32);
966            new_ivf_data.extend(pids);
967        }
968
969        // Save updated IVF
970        let new_ivf = Array1::from_vec(new_ivf_data);
971        new_ivf.write_npy(File::create(&ivf_path)?)?;
972
973        let new_lengths = Array1::from_vec(new_ivf_lengths);
974        new_lengths.write_npy(File::create(&ivf_lengths_path)?)?;
975    }
976
977    // Update global metadata
978    let new_total_chunks = start_chunk_idx + n_new_chunks;
979    let new_tokens_count: i64 = new_doclens_accumulated.iter().sum();
980    let num_embeddings = old_total_embeddings + new_tokens_count as usize;
981    let total_num_documents = old_num_documents + num_new_documents;
982
983    let new_avg_doclen = if total_num_documents > 0 {
984        let old_sum = metadata.avg_doclen * old_num_documents as f64;
985        (old_sum + new_tokens_count as f64) / total_num_documents as f64
986    } else {
987        0.0
988    };
989
990    let new_metadata = Metadata {
991        num_chunks: new_total_chunks,
992        nbits,
993        num_partitions: num_centroids,
994        num_embeddings,
995        avg_doclen: new_avg_doclen,
996        num_documents: total_num_documents,
997        embedding_dim,
998        next_plaid_compatible: true,
999    };
1000
1001    serde_json::to_writer_pretty(BufWriter::new(File::create(&metadata_path)?), &new_metadata)?;
1002
1003    // Clear merged files to force regeneration on next load.
1004    // This ensures the merged files are consistent with the new chunk data.
1005    crate::mmap::clear_merged_files(index_dir)?;
1006
1007    Ok(num_new_documents)
1008}
1009
1010#[cfg(test)]
1011mod tests {
1012    use super::*;
1013
1014    #[test]
1015    fn test_update_config_default() {
1016        let config = UpdateConfig::default();
1017        assert_eq!(config.batch_size, 50_000);
1018        assert_eq!(config.buffer_size, 100);
1019        assert_eq!(config.start_from_scratch, 999);
1020    }
1021
1022    #[test]
1023    fn test_find_outliers() {
1024        // Create centroids at (0,0), (1,1)
1025        let centroids = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 1.0, 1.0]).unwrap();
1026
1027        // Create embeddings: one close to (0,0), one close to (1,1), one far away at (5,5)
1028        let embeddings =
1029            Array2::from_shape_vec((3, 2), vec![0.1, 0.1, 0.9, 0.9, 5.0, 5.0]).unwrap();
1030
1031        // Threshold of 1.0 squared = 1.0
1032        let outliers = find_outliers(&embeddings, &centroids, 1.0);
1033
1034        // Only the point at (5,5) should be an outlier
1035        assert_eq!(outliers.len(), 1);
1036        assert_eq!(outliers[0], 2);
1037    }
1038
1039    #[test]
1040    fn test_buffer_roundtrip() {
1041        use tempfile::TempDir;
1042
1043        let dir = TempDir::new().unwrap();
1044
1045        // Create 3 documents with different numbers of embeddings
1046        let embeddings = vec![
1047            Array2::from_shape_vec((3, 4), (0..12).map(|x| x as f32).collect()).unwrap(),
1048            Array2::from_shape_vec((2, 4), (12..20).map(|x| x as f32).collect()).unwrap(),
1049            Array2::from_shape_vec((5, 4), (20..40).map(|x| x as f32).collect()).unwrap(),
1050        ];
1051
1052        // Save buffer
1053        save_buffer(dir.path(), &embeddings).unwrap();
1054
1055        // Load buffer and verify we get 3 documents, not 1
1056        let loaded = load_buffer(dir.path()).unwrap();
1057
1058        assert_eq!(loaded.len(), 3, "Should have 3 documents, not 1");
1059        assert_eq!(loaded[0].nrows(), 3, "First doc should have 3 rows");
1060        assert_eq!(loaded[1].nrows(), 2, "Second doc should have 2 rows");
1061        assert_eq!(loaded[2].nrows(), 5, "Third doc should have 5 rows");
1062
1063        // Verify content matches
1064        assert_eq!(loaded[0], embeddings[0]);
1065        assert_eq!(loaded[1], embeddings[1]);
1066        assert_eq!(loaded[2], embeddings[2]);
1067    }
1068
1069    #[test]
1070    fn test_buffer_info_matches_buffer_len() {
1071        use tempfile::TempDir;
1072
1073        let dir = TempDir::new().unwrap();
1074
1075        // Create 5 documents
1076        let embeddings: Vec<Array2<f32>> = (0..5)
1077            .map(|i| {
1078                let rows = i + 2; // 2, 3, 4, 5, 6 rows
1079                Array2::from_shape_fn((rows, 4), |(r, c)| (r * 4 + c) as f32)
1080            })
1081            .collect();
1082
1083        save_buffer(dir.path(), &embeddings).unwrap();
1084
1085        // Verify buffer_info.json matches actual document count
1086        let info_count = load_buffer_info(dir.path()).unwrap();
1087        let loaded = load_buffer(dir.path()).unwrap();
1088
1089        assert_eq!(info_count, 5, "buffer_info should report 5 docs");
1090        assert_eq!(
1091            loaded.len(),
1092            5,
1093            "load_buffer should return 5 docs to match buffer_info"
1094        );
1095    }
1096}