next_plaid/
update.rs

1//! Index update functionality for adding new documents.
2//!
3//! This module provides functions to incrementally update an existing PLAID index
4//! with new documents, matching fast-plaid's behavior:
5//! - Buffer mechanism for small updates
6//! - Centroid expansion for outliers
7//! - Cluster threshold updates
8
9#[cfg(feature = "npy")]
10use std::collections::HashMap;
11use std::fs;
12#[cfg(feature = "npy")]
13use std::fs::File;
14#[cfg(feature = "npy")]
15use std::io::{BufReader, BufWriter};
16use std::path::Path;
17
18use serde::{Deserialize, Serialize};
19
20#[cfg(feature = "npy")]
21use ndarray::{s, Array1, Array2, Axis};
22#[cfg(feature = "npy")]
23use rayon::prelude::*;
24
25#[cfg(feature = "npy")]
26use crate::codec::ResidualCodec;
27#[cfg(feature = "npy")]
28use crate::error::Error;
29use crate::error::Result;
30#[cfg(feature = "npy")]
31use crate::index::Metadata;
32#[cfg(feature = "npy")]
33use crate::kmeans::compute_kmeans;
34use crate::kmeans::ComputeKmeansConfig;
35#[cfg(feature = "npy")]
36use crate::utils::quantile;
37
38/// Default batch size for processing documents (matches fast-plaid).
39const DEFAULT_BATCH_SIZE: usize = 50_000;
40
41/// Configuration for index updates.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct UpdateConfig {
44    /// Batch size for processing documents (default: 50,000)
45    pub batch_size: usize,
46    /// Number of K-means iterations for centroid expansion (default: 4)
47    pub kmeans_niters: usize,
48    /// Max points per centroid for K-means (default: 256)
49    pub max_points_per_centroid: usize,
50    /// Number of samples for K-means (default: auto-calculated)
51    pub n_samples_kmeans: Option<usize>,
52    /// Random seed (default: 42)
53    pub seed: u64,
54    /// If index has fewer docs than this, rebuild from scratch (default: 999)
55    pub start_from_scratch: usize,
56    /// Buffer size before triggering centroid expansion (default: 100)
57    pub buffer_size: usize,
58}
59
60impl Default for UpdateConfig {
61    fn default() -> Self {
62        Self {
63            batch_size: DEFAULT_BATCH_SIZE,
64            kmeans_niters: 4,
65            max_points_per_centroid: 256,
66            n_samples_kmeans: None,
67            seed: 42,
68            start_from_scratch: 999,
69            buffer_size: 100,
70        }
71    }
72}
73
74impl UpdateConfig {
75    /// Convert to ComputeKmeansConfig for centroid expansion.
76    pub fn to_kmeans_config(&self) -> ComputeKmeansConfig {
77        ComputeKmeansConfig {
78            kmeans_niters: self.kmeans_niters,
79            max_points_per_centroid: self.max_points_per_centroid,
80            seed: self.seed,
81            n_samples_kmeans: self.n_samples_kmeans,
82            num_partitions: None,
83        }
84    }
85}
86
87// ============================================================================
88// Buffer Management
89// ============================================================================
90
91/// Load buffered embeddings from buffer.npy.
92///
93/// Returns an empty vector if buffer.npy doesn't exist.
94/// Uses buffer_lengths.json to split the flattened array back into per-document arrays.
95#[cfg(feature = "npy")]
96pub fn load_buffer(index_path: &Path) -> Result<Vec<Array2<f32>>> {
97    use ndarray_npy::ReadNpyExt;
98
99    let buffer_path = index_path.join("buffer.npy");
100    let lengths_path = index_path.join("buffer_lengths.json");
101
102    if !buffer_path.exists() {
103        return Ok(Vec::new());
104    }
105
106    // Load the flattened embeddings array
107    let flat: Array2<f32> = match Array2::read_npy(File::open(&buffer_path)?) {
108        Ok(arr) => arr,
109        Err(_) => return Ok(Vec::new()),
110    };
111
112    // Load lengths to split back into per-document arrays
113    if lengths_path.exists() {
114        let lengths: Vec<i64> =
115            serde_json::from_reader(BufReader::new(File::open(&lengths_path)?))?;
116
117        let mut result = Vec::with_capacity(lengths.len());
118        let mut offset = 0;
119
120        for &len in &lengths {
121            let len_usize = len as usize;
122            if offset + len_usize > flat.nrows() {
123                break;
124            }
125            let doc_emb = flat.slice(s![offset..offset + len_usize, ..]).to_owned();
126            result.push(doc_emb);
127            offset += len_usize;
128        }
129
130        return Ok(result);
131    }
132
133    // Fallback: if no lengths file, return as single document (legacy behavior)
134    Ok(vec![flat])
135}
136
137/// Save embeddings to buffer.npy.
138///
139/// Also saves buffer_info.json with the number of documents for deletion tracking.
140#[cfg(feature = "npy")]
141pub fn save_buffer(index_path: &Path, embeddings: &[Array2<f32>]) -> Result<()> {
142    use ndarray_npy::WriteNpyExt;
143
144    let buffer_path = index_path.join("buffer.npy");
145
146    // For simplicity, concatenate all embeddings into one array
147    // and store the lengths separately
148    if embeddings.is_empty() {
149        return Ok(());
150    }
151
152    let dim = embeddings[0].ncols();
153    let total_rows: usize = embeddings.iter().map(|e| e.nrows()).sum();
154
155    let mut flat = Array2::<f32>::zeros((total_rows, dim));
156    let mut offset = 0;
157    let mut lengths = Vec::new();
158
159    for emb in embeddings {
160        let n = emb.nrows();
161        flat.slice_mut(s![offset..offset + n, ..]).assign(emb);
162        lengths.push(n as i64);
163        offset += n;
164    }
165
166    flat.write_npy(File::create(&buffer_path)?)?;
167
168    // Save lengths
169    let lengths_path = index_path.join("buffer_lengths.json");
170    serde_json::to_writer(BufWriter::new(File::create(&lengths_path)?), &lengths)?;
171
172    // Save buffer info for deletion tracking (number of documents)
173    let info_path = index_path.join("buffer_info.json");
174    let buffer_info = serde_json::json!({ "num_docs": embeddings.len() });
175    serde_json::to_writer(BufWriter::new(File::create(&info_path)?), &buffer_info)?;
176
177    Ok(())
178}
179
180/// Load buffer info (number of buffered documents).
181///
182/// Returns 0 if buffer_info.json doesn't exist.
183#[cfg(feature = "npy")]
184pub fn load_buffer_info(index_path: &Path) -> Result<usize> {
185    let info_path = index_path.join("buffer_info.json");
186    if !info_path.exists() {
187        return Ok(0);
188    }
189
190    let info: serde_json::Value = serde_json::from_reader(BufReader::new(File::open(&info_path)?))?;
191
192    Ok(info.get("num_docs").and_then(|v| v.as_u64()).unwrap_or(0) as usize)
193}
194
195/// Clear buffer files.
196pub fn clear_buffer(index_path: &Path) -> Result<()> {
197    let buffer_path = index_path.join("buffer.npy");
198    let lengths_path = index_path.join("buffer_lengths.json");
199    let info_path = index_path.join("buffer_info.json");
200
201    if buffer_path.exists() {
202        fs::remove_file(&buffer_path)?;
203    }
204    if lengths_path.exists() {
205        fs::remove_file(&lengths_path)?;
206    }
207    if info_path.exists() {
208        fs::remove_file(&info_path)?;
209    }
210
211    Ok(())
212}
213
214/// Load embeddings stored for rebuild (embeddings.npy + embeddings_lengths.json).
215///
216/// This function loads raw embeddings that were saved for start-from-scratch rebuilds.
217/// The embeddings are stored in a flat 2D array with a separate lengths file.
218#[cfg(feature = "npy")]
219pub fn load_embeddings_npy(index_path: &Path) -> Result<Vec<Array2<f32>>> {
220    use ndarray_npy::ReadNpyExt;
221
222    let emb_path = index_path.join("embeddings.npy");
223    let lengths_path = index_path.join("embeddings_lengths.json");
224
225    if !emb_path.exists() {
226        return Ok(Vec::new());
227    }
228
229    // Load flat embeddings array
230    let flat: Array2<f32> = Array2::read_npy(File::open(&emb_path)?)?;
231
232    // Load lengths to split back into per-document arrays
233    if lengths_path.exists() {
234        let lengths: Vec<i64> =
235            serde_json::from_reader(BufReader::new(File::open(&lengths_path)?))?;
236
237        let mut result = Vec::with_capacity(lengths.len());
238        let mut offset = 0;
239
240        for &len in &lengths {
241            let len_usize = len as usize;
242            if offset + len_usize > flat.nrows() {
243                break;
244            }
245            let doc_emb = flat.slice(s![offset..offset + len_usize, ..]).to_owned();
246            result.push(doc_emb);
247            offset += len_usize;
248        }
249
250        return Ok(result);
251    }
252
253    // Fallback: if no lengths file, return as single document
254    Ok(vec![flat])
255}
256
257/// Save embeddings for potential rebuild (start-from-scratch mode).
258///
259/// Stores embeddings in embeddings.npy (flat array) + embeddings_lengths.json.
260/// This matches fast-plaid's behavior of storing raw embeddings when the index
261/// is below the start_from_scratch threshold.
262#[cfg(feature = "npy")]
263pub fn save_embeddings_npy(index_path: &Path, embeddings: &[Array2<f32>]) -> Result<()> {
264    use ndarray_npy::WriteNpyExt;
265
266    if embeddings.is_empty() {
267        return Ok(());
268    }
269
270    let dim = embeddings[0].ncols();
271    let total_rows: usize = embeddings.iter().map(|e| e.nrows()).sum();
272
273    let mut flat = Array2::<f32>::zeros((total_rows, dim));
274    let mut offset = 0;
275    let mut lengths = Vec::with_capacity(embeddings.len());
276
277    for emb in embeddings {
278        let n = emb.nrows();
279        flat.slice_mut(s![offset..offset + n, ..]).assign(emb);
280        lengths.push(n as i64);
281        offset += n;
282    }
283
284    // Save flat embeddings
285    let emb_path = index_path.join("embeddings.npy");
286    flat.write_npy(File::create(&emb_path)?)?;
287
288    // Save lengths for reconstruction
289    let lengths_path = index_path.join("embeddings_lengths.json");
290    serde_json::to_writer(BufWriter::new(File::create(&lengths_path)?), &lengths)?;
291
292    Ok(())
293}
294
295/// Clear embeddings.npy and embeddings_lengths.json.
296pub fn clear_embeddings_npy(index_path: &Path) -> Result<()> {
297    let emb_path = index_path.join("embeddings.npy");
298    let lengths_path = index_path.join("embeddings_lengths.json");
299
300    if emb_path.exists() {
301        fs::remove_file(&emb_path)?;
302    }
303    if lengths_path.exists() {
304        fs::remove_file(&lengths_path)?;
305    }
306    Ok(())
307}
308
309/// Check if embeddings.npy exists for start-from-scratch mode.
310pub fn embeddings_npy_exists(index_path: &Path) -> bool {
311    index_path.join("embeddings.npy").exists()
312}
313
314// ============================================================================
315// Cluster Threshold Management
316// ============================================================================
317
318/// Load cluster threshold from cluster_threshold.npy.
319#[cfg(feature = "npy")]
320pub fn load_cluster_threshold(index_path: &Path) -> Result<f32> {
321    use ndarray_npy::ReadNpyExt;
322
323    let thresh_path = index_path.join("cluster_threshold.npy");
324    if !thresh_path.exists() {
325        return Err(Error::Update("cluster_threshold.npy not found".into()));
326    }
327
328    let arr: Array1<f32> = Array1::read_npy(File::open(&thresh_path)?)?;
329    Ok(arr[0])
330}
331
332/// Update cluster_threshold.npy with weighted average.
333#[cfg(feature = "npy")]
334pub fn update_cluster_threshold(
335    index_path: &Path,
336    new_residual_norms: &Array1<f32>,
337    old_total_embeddings: usize,
338) -> Result<()> {
339    use ndarray_npy::{ReadNpyExt, WriteNpyExt};
340
341    let new_count = new_residual_norms.len();
342    if new_count == 0 {
343        return Ok(());
344    }
345
346    let new_threshold = quantile(new_residual_norms, 0.75);
347
348    let thresh_path = index_path.join("cluster_threshold.npy");
349    let final_threshold = if thresh_path.exists() {
350        let old_arr: Array1<f32> = Array1::read_npy(File::open(&thresh_path)?)?;
351        let old_threshold = old_arr[0];
352        let total = old_total_embeddings + new_count;
353        (old_threshold * old_total_embeddings as f32 + new_threshold * new_count as f32)
354            / total as f32
355    } else {
356        new_threshold
357    };
358
359    Array1::from_vec(vec![final_threshold]).write_npy(File::create(&thresh_path)?)?;
360
361    Ok(())
362}
363
364// ============================================================================
365// Centroid Expansion
366// ============================================================================
367
368/// Find outlier embeddings that are far from all existing centroids.
369///
370/// Returns indices of embeddings where min L2² distance > threshold².
371#[cfg(feature = "npy")]
372fn find_outliers(
373    flat_embeddings: &Array2<f32>,
374    centroids: &Array2<f32>,
375    threshold_sq: f32,
376) -> Vec<usize> {
377    flat_embeddings
378        .axis_iter(Axis(0))
379        .into_par_iter()
380        .enumerate()
381        .filter_map(|(i, emb)| {
382            // Find minimum squared distance to any centroid
383            let min_dist_sq = centroids
384                .axis_iter(Axis(0))
385                .map(|c| {
386                    // L2 squared distance
387                    emb.iter()
388                        .zip(c.iter())
389                        .map(|(a, b)| (a - b).powi(2))
390                        .sum::<f32>()
391                })
392                .fold(f32::INFINITY, f32::min);
393
394            if min_dist_sq > threshold_sq {
395                Some(i)
396            } else {
397                None
398            }
399        })
400        .collect()
401}
402
403/// Expand centroids by clustering embeddings far from existing centroids.
404///
405/// This implements fast-plaid's update_centroids() function:
406/// 1. Flatten all new embeddings
407/// 2. Find outliers (distance > cluster_threshold²)
408/// 3. Cluster outliers to get new centroids
409/// 4. Append new centroids to centroids.npy
410/// 5. Extend ivf_lengths.npy with zeros
411/// 6. Update metadata.json num_partitions
412///
413/// Returns the number of new centroids added.
414#[cfg(feature = "npy")]
415pub fn update_centroids(
416    index_path: &Path,
417    new_embeddings: &[Array2<f32>],
418    cluster_threshold: f32,
419    config: &UpdateConfig,
420) -> Result<usize> {
421    use ndarray_npy::{ReadNpyExt, WriteNpyExt};
422
423    let centroids_path = index_path.join("centroids.npy");
424    if !centroids_path.exists() {
425        return Ok(0);
426    }
427
428    // Load existing centroids
429    let existing_centroids: Array2<f32> = Array2::read_npy(File::open(&centroids_path)?)?;
430
431    // Flatten all new embeddings
432    let dim = existing_centroids.ncols();
433    let total_tokens: usize = new_embeddings.iter().map(|e| e.nrows()).sum();
434
435    if total_tokens == 0 {
436        return Ok(0);
437    }
438
439    let mut flat_embeddings = Array2::<f32>::zeros((total_tokens, dim));
440    let mut offset = 0;
441
442    for emb in new_embeddings {
443        let n = emb.nrows();
444        flat_embeddings
445            .slice_mut(s![offset..offset + n, ..])
446            .assign(emb);
447        offset += n;
448    }
449
450    // Find outliers
451    let threshold_sq = cluster_threshold * cluster_threshold;
452    let outlier_indices = find_outliers(&flat_embeddings, &existing_centroids, threshold_sq);
453
454    let num_outliers = outlier_indices.len();
455    if num_outliers == 0 {
456        return Ok(0);
457    }
458
459    // Extract outlier embeddings
460    let mut outliers = Array2::<f32>::zeros((num_outliers, dim));
461    for (i, &idx) in outlier_indices.iter().enumerate() {
462        outliers.row_mut(i).assign(&flat_embeddings.row(idx));
463    }
464
465    // Compute number of new centroids
466    // k_update = max(1, ceil(num_outliers / max_points_per_centroid) * 4)
467    let target_k =
468        ((num_outliers as f64 / config.max_points_per_centroid as f64).ceil() as usize).max(1) * 4;
469    let k_update = target_k.min(num_outliers); // Can't have more centroids than points
470
471    // Cluster outliers to get new centroids
472    let kmeans_config = ComputeKmeansConfig {
473        kmeans_niters: config.kmeans_niters,
474        max_points_per_centroid: config.max_points_per_centroid,
475        seed: config.seed,
476        n_samples_kmeans: config.n_samples_kmeans,
477        num_partitions: Some(k_update),
478    };
479
480    // Convert outliers to vector of single-token "documents" for compute_kmeans
481    let outlier_docs: Vec<Array2<f32>> = outlier_indices
482        .iter()
483        .map(|&idx| flat_embeddings.slice(s![idx..idx + 1, ..]).to_owned())
484        .collect();
485
486    let new_centroids = compute_kmeans(&outlier_docs, &kmeans_config)?;
487    let k_new = new_centroids.nrows();
488
489    // Concatenate centroids
490    let new_num_centroids = existing_centroids.nrows() + k_new;
491    let mut final_centroids = Array2::<f32>::zeros((new_num_centroids, dim));
492    final_centroids
493        .slice_mut(s![..existing_centroids.nrows(), ..])
494        .assign(&existing_centroids);
495    final_centroids
496        .slice_mut(s![existing_centroids.nrows().., ..])
497        .assign(&new_centroids);
498
499    // Save updated centroids
500    final_centroids.write_npy(File::create(&centroids_path)?)?;
501
502    // Extend ivf_lengths.npy with zeros for new centroids
503    let ivf_lengths_path = index_path.join("ivf_lengths.npy");
504    if ivf_lengths_path.exists() {
505        let old_lengths: Array1<i32> = Array1::read_npy(File::open(&ivf_lengths_path)?)?;
506        let mut new_lengths = Array1::<i32>::zeros(new_num_centroids);
507        new_lengths
508            .slice_mut(s![..old_lengths.len()])
509            .assign(&old_lengths);
510        new_lengths.write_npy(File::create(&ivf_lengths_path)?)?;
511    }
512
513    // Update metadata.json num_partitions
514    let meta_path = index_path.join("metadata.json");
515    if meta_path.exists() {
516        let mut meta: serde_json::Value =
517            serde_json::from_reader(BufReader::new(File::open(&meta_path)?))?;
518
519        if let Some(obj) = meta.as_object_mut() {
520            obj.insert("num_partitions".to_string(), new_num_centroids.into());
521        }
522
523        serde_json::to_writer_pretty(BufWriter::new(File::create(&meta_path)?), &meta)?;
524    }
525
526    Ok(k_new)
527}
528
529// ============================================================================
530// Low-Level Index Update
531// ============================================================================
532
533/// Update an existing index with new documents.
534///
535/// # Arguments
536///
537/// * `embeddings` - List of new document embeddings, each of shape `[num_tokens, dim]`
538/// * `index_path` - Path to the existing index directory
539/// * `codec` - The loaded ResidualCodec for compression
540/// * `batch_size` - Optional batch size for processing (default: 50,000)
541/// * `update_threshold` - Whether to update the cluster threshold
542///
543/// # Returns
544///
545/// The number of new documents added
546#[cfg(feature = "npy")]
547pub fn update_index(
548    embeddings: &[Array2<f32>],
549    index_path: &str,
550    codec: &ResidualCodec,
551    batch_size: Option<usize>,
552    update_threshold: bool,
553) -> Result<usize> {
554    let batch_size = batch_size.unwrap_or(DEFAULT_BATCH_SIZE);
555    let index_dir = Path::new(index_path);
556
557    // Load existing metadata
558    let metadata_path = index_dir.join("metadata.json");
559    let metadata: Metadata = serde_json::from_reader(BufReader::new(
560        File::open(&metadata_path)
561            .map_err(|e| Error::IndexLoad(format!("Failed to open metadata: {}", e)))?,
562    ))?;
563
564    let num_existing_chunks = metadata.num_chunks;
565    let old_num_documents = metadata.num_documents;
566    let old_total_embeddings = metadata.num_embeddings;
567    let num_centroids = codec.num_centroids();
568    let embedding_dim = codec.embedding_dim();
569    let nbits = metadata.nbits;
570
571    // Determine starting chunk index
572    let mut start_chunk_idx = num_existing_chunks;
573    let mut append_to_last = false;
574    let mut current_emb_offset = old_total_embeddings;
575
576    // Check if we should append to the last chunk (if it has < 2000 documents)
577    if start_chunk_idx > 0 {
578        let last_idx = start_chunk_idx - 1;
579        let last_meta_path = index_dir.join(format!("{}.metadata.json", last_idx));
580
581        if last_meta_path.exists() {
582            let last_meta: serde_json::Value =
583                serde_json::from_reader(BufReader::new(File::open(&last_meta_path).map_err(
584                    |e| Error::IndexLoad(format!("Failed to open chunk metadata: {}", e)),
585                )?))?;
586
587            if let Some(nd) = last_meta.get("num_documents").and_then(|x| x.as_u64()) {
588                if nd < 2000 {
589                    start_chunk_idx = last_idx;
590                    append_to_last = true;
591
592                    if let Some(off) = last_meta.get("embedding_offset").and_then(|x| x.as_u64()) {
593                        current_emb_offset = off as usize;
594                    } else {
595                        let embs_in_last = last_meta
596                            .get("num_embeddings")
597                            .and_then(|x| x.as_u64())
598                            .unwrap_or(0) as usize;
599                        current_emb_offset = old_total_embeddings - embs_in_last;
600                    }
601                }
602            }
603        }
604    }
605
606    // Process new documents
607    let num_new_documents = embeddings.len();
608    let n_new_chunks = (num_new_documents as f64 / batch_size as f64).ceil() as usize;
609
610    let mut new_codes_accumulated: Vec<Vec<usize>> = Vec::new();
611    let mut new_doclens_accumulated: Vec<i64> = Vec::new();
612    let mut all_residual_norms: Vec<f32> = Vec::new();
613
614    let progress = indicatif::ProgressBar::new(n_new_chunks as u64);
615    progress.set_message("Updating index...");
616
617    let packed_dim = embedding_dim * nbits / 8;
618
619    for i in 0..n_new_chunks {
620        let global_chunk_idx = start_chunk_idx + i;
621        let chk_offset = i * batch_size;
622        let chk_end = (chk_offset + batch_size).min(num_new_documents);
623        let chunk_docs = &embeddings[chk_offset..chk_end];
624
625        // Collect document lengths
626        let mut chk_doclens: Vec<i64> = chunk_docs.iter().map(|d| d.nrows() as i64).collect();
627        let total_tokens: usize = chk_doclens.iter().sum::<i64>() as usize;
628
629        // Concatenate all embeddings in the chunk for batch processing
630        let mut batch_embeddings = ndarray::Array2::<f32>::zeros((total_tokens, embedding_dim));
631        let mut offset = 0;
632        for doc in chunk_docs {
633            let n = doc.nrows();
634            batch_embeddings
635                .slice_mut(s![offset..offset + n, ..])
636                .assign(doc);
637            offset += n;
638        }
639
640        // BATCH: Compress all embeddings at once
641        let batch_codes = codec.compress_into_codes(&batch_embeddings);
642
643        // BATCH: Compute residuals using parallel subtraction
644        let mut batch_residuals = batch_embeddings;
645        {
646            let centroids = &codec.centroids;
647            batch_residuals
648                .axis_iter_mut(Axis(0))
649                .into_par_iter()
650                .zip(batch_codes.as_slice().unwrap().par_iter())
651                .for_each(|(mut row, &code)| {
652                    let centroid = centroids.row(code);
653                    row.iter_mut()
654                        .zip(centroid.iter())
655                        .for_each(|(r, c)| *r -= c);
656                });
657        }
658
659        // Collect residual norms if updating threshold
660        if update_threshold {
661            for row in batch_residuals.axis_iter(Axis(0)) {
662                let norm = row.dot(&row).sqrt();
663                all_residual_norms.push(norm);
664            }
665        }
666
667        // BATCH: Quantize all residuals at once
668        let batch_packed = codec.quantize_residuals(&batch_residuals)?;
669
670        // Convert to lists for chunk saving
671        let mut chk_codes_list: Vec<usize> = batch_codes.iter().copied().collect();
672        let mut chk_residuals_list: Vec<u8> = batch_packed.iter().copied().collect();
673
674        // Split codes back into per-document arrays for IVF building
675        let mut code_offset = 0;
676        for &len in &chk_doclens {
677            let len_usize = len as usize;
678            let codes: Vec<usize> = batch_codes
679                .slice(s![code_offset..code_offset + len_usize])
680                .iter()
681                .copied()
682                .collect();
683            new_codes_accumulated.push(codes);
684            new_doclens_accumulated.push(len);
685            code_offset += len_usize;
686        }
687
688        // Handle appending to last chunk
689        if i == 0 && append_to_last {
690            use ndarray_npy::ReadNpyExt;
691
692            let old_doclens_path = index_dir.join(format!("doclens.{}.json", global_chunk_idx));
693
694            if old_doclens_path.exists() {
695                let old_doclens: Vec<i64> =
696                    serde_json::from_reader(BufReader::new(File::open(&old_doclens_path)?))?;
697
698                let old_codes_path = index_dir.join(format!("{}.codes.npy", global_chunk_idx));
699                let old_residuals_path =
700                    index_dir.join(format!("{}.residuals.npy", global_chunk_idx));
701
702                let old_codes: Array1<i64> = Array1::read_npy(File::open(&old_codes_path)?)?;
703                let old_residuals: Array2<u8> = Array2::read_npy(File::open(&old_residuals_path)?)?;
704
705                // Prepend old data
706                let mut combined_codes: Vec<usize> =
707                    old_codes.iter().map(|&x| x as usize).collect();
708                combined_codes.extend(chk_codes_list);
709                chk_codes_list = combined_codes;
710
711                let mut combined_residuals: Vec<u8> = old_residuals.iter().copied().collect();
712                combined_residuals.extend(chk_residuals_list);
713                chk_residuals_list = combined_residuals;
714
715                let mut combined_doclens = old_doclens;
716                combined_doclens.extend(chk_doclens);
717                chk_doclens = combined_doclens;
718            }
719        }
720
721        // Save chunk data
722        {
723            use ndarray_npy::WriteNpyExt;
724
725            let codes_arr: Array1<i64> = chk_codes_list.iter().map(|&x| x as i64).collect();
726            let codes_path = index_dir.join(format!("{}.codes.npy", global_chunk_idx));
727            codes_arr.write_npy(File::create(&codes_path)?)?;
728
729            let num_tokens = chk_codes_list.len();
730            let residuals_arr =
731                Array2::from_shape_vec((num_tokens, packed_dim), chk_residuals_list)
732                    .map_err(|e| Error::Shape(format!("Failed to reshape residuals: {}", e)))?;
733            let residuals_path = index_dir.join(format!("{}.residuals.npy", global_chunk_idx));
734            residuals_arr.write_npy(File::create(&residuals_path)?)?;
735        }
736
737        // Save doclens
738        let doclens_path = index_dir.join(format!("doclens.{}.json", global_chunk_idx));
739        serde_json::to_writer(BufWriter::new(File::create(&doclens_path)?), &chk_doclens)?;
740
741        // Save chunk metadata
742        let chk_meta = serde_json::json!({
743            "num_documents": chk_doclens.len(),
744            "num_embeddings": chk_codes_list.len(),
745            "embedding_offset": current_emb_offset,
746        });
747        current_emb_offset += chk_codes_list.len();
748
749        let meta_path = index_dir.join(format!("{}.metadata.json", global_chunk_idx));
750        serde_json::to_writer_pretty(BufWriter::new(File::create(&meta_path)?), &chk_meta)?;
751
752        progress.inc(1);
753    }
754    progress.finish();
755
756    // Update cluster threshold if requested
757    if update_threshold && !all_residual_norms.is_empty() {
758        let norms = Array1::from_vec(all_residual_norms);
759        update_cluster_threshold(index_dir, &norms, old_total_embeddings)?;
760    }
761
762    // Build new partial IVF
763    let mut partition_pids_map: HashMap<usize, Vec<i64>> = HashMap::new();
764    let mut pid_counter = old_num_documents as i64;
765
766    for doc_codes in &new_codes_accumulated {
767        for &code in doc_codes {
768            partition_pids_map
769                .entry(code)
770                .or_default()
771                .push(pid_counter);
772        }
773        pid_counter += 1;
774    }
775
776    // Load old IVF and merge
777    {
778        use ndarray_npy::{ReadNpyExt, WriteNpyExt};
779
780        let ivf_path = index_dir.join("ivf.npy");
781        let ivf_lengths_path = index_dir.join("ivf_lengths.npy");
782
783        let old_ivf: Array1<i64> = if ivf_path.exists() {
784            Array1::read_npy(File::open(&ivf_path)?)?
785        } else {
786            Array1::zeros(0)
787        };
788
789        let old_ivf_lengths: Array1<i32> = if ivf_lengths_path.exists() {
790            Array1::read_npy(File::open(&ivf_lengths_path)?)?
791        } else {
792            Array1::zeros(num_centroids)
793        };
794
795        // Compute old offsets
796        let mut old_offsets = vec![0i64];
797        for &len in old_ivf_lengths.iter() {
798            old_offsets.push(old_offsets.last().unwrap() + len as i64);
799        }
800
801        // Merge IVF
802        let mut new_ivf_data: Vec<i64> = Vec::new();
803        let mut new_ivf_lengths: Vec<i32> = Vec::with_capacity(num_centroids);
804
805        for centroid_id in 0..num_centroids {
806            // Get old PIDs for this centroid
807            let old_start = old_offsets[centroid_id] as usize;
808            let old_len = if centroid_id < old_ivf_lengths.len() {
809                old_ivf_lengths[centroid_id] as usize
810            } else {
811                0
812            };
813
814            let mut pids: Vec<i64> = if old_len > 0 && old_start + old_len <= old_ivf.len() {
815                old_ivf.slice(s![old_start..old_start + old_len]).to_vec()
816            } else {
817                Vec::new()
818            };
819
820            // Add new PIDs
821            if let Some(new_pids) = partition_pids_map.get(&centroid_id) {
822                pids.extend(new_pids);
823            }
824
825            // Deduplicate and sort
826            pids.sort_unstable();
827            pids.dedup();
828
829            new_ivf_lengths.push(pids.len() as i32);
830            new_ivf_data.extend(pids);
831        }
832
833        // Save updated IVF
834        let new_ivf = Array1::from_vec(new_ivf_data);
835        new_ivf.write_npy(File::create(&ivf_path)?)?;
836
837        let new_lengths = Array1::from_vec(new_ivf_lengths);
838        new_lengths.write_npy(File::create(&ivf_lengths_path)?)?;
839    }
840
841    // Update global metadata
842    let new_total_chunks = start_chunk_idx + n_new_chunks;
843    let new_tokens_count: i64 = new_doclens_accumulated.iter().sum();
844    let num_embeddings = old_total_embeddings + new_tokens_count as usize;
845    let total_num_documents = old_num_documents + num_new_documents;
846
847    let new_avg_doclen = if total_num_documents > 0 {
848        let old_sum = metadata.avg_doclen * old_num_documents as f64;
849        (old_sum + new_tokens_count as f64) / total_num_documents as f64
850    } else {
851        0.0
852    };
853
854    let new_metadata = Metadata {
855        num_chunks: new_total_chunks,
856        nbits,
857        num_partitions: num_centroids,
858        num_embeddings,
859        avg_doclen: new_avg_doclen,
860        num_documents: total_num_documents,
861    };
862
863    serde_json::to_writer_pretty(BufWriter::new(File::create(&metadata_path)?), &new_metadata)?;
864
865    Ok(num_new_documents)
866}
867
868#[cfg(test)]
869mod tests {
870    use super::*;
871
872    #[test]
873    fn test_update_config_default() {
874        let config = UpdateConfig::default();
875        assert_eq!(config.batch_size, 50_000);
876        assert_eq!(config.buffer_size, 100);
877        assert_eq!(config.start_from_scratch, 999);
878    }
879
880    #[test]
881    #[cfg(feature = "npy")]
882    fn test_find_outliers() {
883        // Create centroids at (0,0), (1,1)
884        let centroids = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 1.0, 1.0]).unwrap();
885
886        // Create embeddings: one close to (0,0), one close to (1,1), one far away at (5,5)
887        let embeddings =
888            Array2::from_shape_vec((3, 2), vec![0.1, 0.1, 0.9, 0.9, 5.0, 5.0]).unwrap();
889
890        // Threshold of 1.0 squared = 1.0
891        let outliers = find_outliers(&embeddings, &centroids, 1.0);
892
893        // Only the point at (5,5) should be an outlier
894        assert_eq!(outliers.len(), 1);
895        assert_eq!(outliers[0], 2);
896    }
897
898    #[test]
899    #[cfg(feature = "npy")]
900    fn test_buffer_roundtrip() {
901        use tempfile::TempDir;
902
903        let dir = TempDir::new().unwrap();
904
905        // Create 3 documents with different numbers of embeddings
906        let embeddings = vec![
907            Array2::from_shape_vec((3, 4), (0..12).map(|x| x as f32).collect()).unwrap(),
908            Array2::from_shape_vec((2, 4), (12..20).map(|x| x as f32).collect()).unwrap(),
909            Array2::from_shape_vec((5, 4), (20..40).map(|x| x as f32).collect()).unwrap(),
910        ];
911
912        // Save buffer
913        save_buffer(dir.path(), &embeddings).unwrap();
914
915        // Load buffer and verify we get 3 documents, not 1
916        let loaded = load_buffer(dir.path()).unwrap();
917
918        assert_eq!(loaded.len(), 3, "Should have 3 documents, not 1");
919        assert_eq!(loaded[0].nrows(), 3, "First doc should have 3 rows");
920        assert_eq!(loaded[1].nrows(), 2, "Second doc should have 2 rows");
921        assert_eq!(loaded[2].nrows(), 5, "Third doc should have 5 rows");
922
923        // Verify content matches
924        assert_eq!(loaded[0], embeddings[0]);
925        assert_eq!(loaded[1], embeddings[1]);
926        assert_eq!(loaded[2], embeddings[2]);
927    }
928
929    #[test]
930    #[cfg(feature = "npy")]
931    fn test_buffer_info_matches_buffer_len() {
932        use tempfile::TempDir;
933
934        let dir = TempDir::new().unwrap();
935
936        // Create 5 documents
937        let embeddings: Vec<Array2<f32>> = (0..5)
938            .map(|i| {
939                let rows = i + 2; // 2, 3, 4, 5, 6 rows
940                Array2::from_shape_fn((rows, 4), |(r, c)| (r * 4 + c) as f32)
941            })
942            .collect();
943
944        save_buffer(dir.path(), &embeddings).unwrap();
945
946        // Verify buffer_info.json matches actual document count
947        let info_count = load_buffer_info(dir.path()).unwrap();
948        let loaded = load_buffer(dir.path()).unwrap();
949
950        assert_eq!(info_count, 5, "buffer_info should report 5 docs");
951        assert_eq!(
952            loaded.len(),
953            5,
954            "load_buffer should return 5 docs to match buffer_info"
955        );
956    }
957}