Skip to main content

next_plaid/
index.rs

1//! Index creation and management for PLAID
2
3use std::collections::BTreeMap;
4use std::fs::{self, File};
5use std::io::{BufReader, BufWriter, Write};
6use std::path::Path;
7
8use ndarray::{s, Array1, Array2, Axis};
9use serde::{Deserialize, Serialize};
10
11use crate::codec::ResidualCodec;
12use crate::error::{Error, Result};
13use crate::kmeans::{compute_kmeans, ComputeKmeansConfig};
14use crate::utils::{quantile, quantiles};
15
16/// CPU implementation of fused compress_into_codes + residual computation.
17fn compress_and_residuals_cpu(
18    embeddings: &Array2<f32>,
19    codec: &ResidualCodec,
20) -> (Array1<usize>, Array2<f32>) {
21    use rayon::prelude::*;
22
23    // Use CPU-only version to ensure no CUDA is called
24    let codes = codec.compress_into_codes_cpu(embeddings);
25    let mut residuals = embeddings.clone();
26
27    let centroids = &codec.centroids;
28    residuals
29        .axis_iter_mut(Axis(0))
30        .into_par_iter()
31        .zip(codes.as_slice().unwrap().par_iter())
32        .for_each(|(mut row, &code)| {
33            let centroid = centroids.row(code);
34            row.iter_mut()
35                .zip(centroid.iter())
36                .for_each(|(r, c)| *r -= c);
37        });
38
39    (codes, residuals)
40}
41
42/// Configuration for index creation
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct IndexConfig {
45    /// Number of bits for quantization (typically 2 or 4)
46    pub nbits: usize,
47    /// Batch size for processing
48    pub batch_size: usize,
49    /// Random seed for reproducibility
50    pub seed: Option<u64>,
51    /// Number of K-means iterations (default: 4)
52    #[serde(default = "default_kmeans_niters")]
53    pub kmeans_niters: usize,
54    /// Maximum number of points per centroid for K-means (default: 256)
55    #[serde(default = "default_max_points_per_centroid")]
56    pub max_points_per_centroid: usize,
57    /// Number of samples for K-means training.
58    /// If None, uses heuristic: min(1 + 16 * sqrt(120 * num_documents), num_documents)
59    #[serde(default)]
60    pub n_samples_kmeans: Option<usize>,
61    /// Threshold for start-from-scratch mode (default: 999).
62    /// When the number of documents is <= this threshold, raw embeddings are saved
63    /// to embeddings.npy for potential rebuilds during updates.
64    #[serde(default = "default_start_from_scratch")]
65    pub start_from_scratch: usize,
66    /// Force CPU execution for K-means even when CUDA feature is enabled.
67    /// Useful for small batches where GPU initialization overhead exceeds benefits.
68    #[serde(default)]
69    pub force_cpu: bool,
70}
71
72fn default_start_from_scratch() -> usize {
73    999
74}
75
76fn default_kmeans_niters() -> usize {
77    4
78}
79
80fn default_max_points_per_centroid() -> usize {
81    256
82}
83
84impl Default for IndexConfig {
85    fn default() -> Self {
86        Self {
87            nbits: 4,
88            batch_size: 50_000,
89            seed: Some(42),
90            kmeans_niters: 4,
91            max_points_per_centroid: 256,
92            n_samples_kmeans: None,
93            start_from_scratch: 999,
94            force_cpu: false,
95        }
96    }
97}
98
99/// Metadata for the index
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct Metadata {
102    /// Number of chunks in the index
103    pub num_chunks: usize,
104    /// Number of bits for quantization
105    pub nbits: usize,
106    /// Number of partitions (centroids)
107    pub num_partitions: usize,
108    /// Total number of embeddings
109    pub num_embeddings: usize,
110    /// Average document length
111    pub avg_doclen: f64,
112    /// Total number of documents
113    #[serde(default)]
114    pub num_documents: usize,
115    /// Embedding dimension (columns of centroids matrix)
116    #[serde(default)]
117    pub embedding_dim: usize,
118    /// Whether the index has been converted to next-plaid compatible format.
119    /// If false or missing, the index may need fast-plaid to next-plaid conversion.
120    #[serde(default)]
121    pub next_plaid_compatible: bool,
122}
123
124impl Metadata {
125    /// Load metadata from a JSON file, inferring num_documents from doclens if not present.
126    pub fn load_from_path(index_path: &Path) -> Result<Self> {
127        let metadata_path = index_path.join("metadata.json");
128        let mut metadata: Metadata = serde_json::from_reader(BufReader::new(
129            File::open(&metadata_path)
130                .map_err(|e| Error::IndexLoad(format!("Failed to open metadata: {}", e)))?,
131        ))?;
132
133        // If num_documents is 0 (default), infer from doclens files
134        if metadata.num_documents == 0 {
135            let mut total_docs = 0usize;
136            for chunk_idx in 0..metadata.num_chunks {
137                let doclens_path = index_path.join(format!("doclens.{}.json", chunk_idx));
138                if let Ok(file) = File::open(&doclens_path) {
139                    if let Ok(chunk_doclens) =
140                        serde_json::from_reader::<_, Vec<i64>>(BufReader::new(file))
141                    {
142                        total_docs += chunk_doclens.len();
143                    }
144                }
145            }
146            metadata.num_documents = total_docs;
147        }
148
149        Ok(metadata)
150    }
151}
152
153/// Chunk metadata
154#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct ChunkMetadata {
156    pub num_documents: usize,
157    pub num_embeddings: usize,
158    #[serde(default)]
159    pub embedding_offset: usize,
160}
161
162// ============================================================================
163// Standalone Index Creation Functions
164// ============================================================================
165
166/// Create index files on disk from embeddings and centroids.
167///
168/// This is a standalone function that creates all necessary index files
169/// without constructing an in-memory Index object. Both Index and MmapIndex
170/// can use this function to create their files, then load them in their
171/// preferred format.
172///
173/// # Arguments
174///
175/// * `embeddings` - List of document embeddings
176/// * `centroids` - Pre-computed centroids from K-means
177/// * `index_path` - Directory to save the index
178/// * `config` - Index configuration
179///
180/// # Returns
181///
182/// Metadata about the created index
183pub fn create_index_files(
184    embeddings: &[Array2<f32>],
185    centroids: Array2<f32>,
186    index_path: &str,
187    config: &IndexConfig,
188) -> Result<Metadata> {
189    let index_dir = Path::new(index_path);
190    fs::create_dir_all(index_dir)?;
191
192    let num_documents = embeddings.len();
193    let embedding_dim = centroids.ncols();
194    let num_centroids = centroids.nrows();
195
196    if num_documents == 0 {
197        return Err(Error::IndexCreation("No documents provided".into()));
198    }
199
200    // Calculate statistics
201    let total_embeddings: usize = embeddings.iter().map(|e| e.nrows()).sum();
202    let avg_doclen = total_embeddings as f64 / num_documents as f64;
203
204    // Sample documents for codec training
205    let sample_count = ((16.0 * (120.0 * num_documents as f64).sqrt()) as usize)
206        .min(num_documents)
207        .max(1);
208
209    let mut rng = if let Some(seed) = config.seed {
210        use rand::SeedableRng;
211        rand_chacha::ChaCha8Rng::seed_from_u64(seed)
212    } else {
213        use rand::SeedableRng;
214        rand_chacha::ChaCha8Rng::from_entropy()
215    };
216
217    use rand::seq::SliceRandom;
218    let mut indices: Vec<usize> = (0..num_documents).collect();
219    indices.shuffle(&mut rng);
220    let sample_indices: Vec<usize> = indices.into_iter().take(sample_count).collect();
221
222    // Collect sample embeddings for training
223    let heldout_size = (0.05 * total_embeddings as f64).min(50000.0) as usize;
224    let mut heldout_embeddings: Vec<f32> = Vec::with_capacity(heldout_size * embedding_dim);
225    let mut collected = 0;
226
227    for &idx in sample_indices.iter().rev() {
228        if collected >= heldout_size {
229            break;
230        }
231        let emb = &embeddings[idx];
232        let take = (heldout_size - collected).min(emb.nrows());
233        for row in emb.axis_iter(Axis(0)).take(take) {
234            heldout_embeddings.extend(row.iter());
235        }
236        collected += take;
237    }
238
239    let heldout = Array2::from_shape_vec((collected, embedding_dim), heldout_embeddings)
240        .map_err(|e| Error::IndexCreation(format!("Failed to create heldout array: {}", e)))?;
241
242    // Train codec: compute residuals and quantization parameters
243    let avg_residual = Array1::zeros(embedding_dim);
244    let initial_codec =
245        ResidualCodec::new(config.nbits, centroids.clone(), avg_residual, None, None)?;
246
247    // Compute codes for heldout samples
248    // Use CPU-only version when force_cpu is set to avoid CUDA initialization overhead
249    let heldout_codes = if config.force_cpu {
250        initial_codec.compress_into_codes_cpu(&heldout)
251    } else {
252        initial_codec.compress_into_codes(&heldout)
253    };
254
255    // Compute residuals
256    let mut residuals = heldout.clone();
257    for i in 0..heldout.nrows() {
258        let centroid = initial_codec.centroids.row(heldout_codes[i]);
259        for j in 0..embedding_dim {
260            residuals[[i, j]] -= centroid[j];
261        }
262    }
263
264    // Compute cluster threshold from residual distances
265    let distances: Array1<f32> = residuals
266        .axis_iter(Axis(0))
267        .map(|row| row.dot(&row).sqrt())
268        .collect();
269    #[allow(unused_variables)]
270    let cluster_threshold = quantile(&distances, 0.75);
271
272    // Compute average residual per dimension
273    let avg_res_per_dim: Array1<f32> = residuals
274        .axis_iter(Axis(1))
275        .map(|col| col.iter().map(|x| x.abs()).sum::<f32>() / col.len() as f32)
276        .collect();
277
278    // Compute quantization buckets
279    let n_options = 1 << config.nbits;
280    let quantile_values: Vec<f64> = (1..n_options)
281        .map(|i| i as f64 / n_options as f64)
282        .collect();
283    let weight_quantile_values: Vec<f64> = (0..n_options)
284        .map(|i| (i as f64 + 0.5) / n_options as f64)
285        .collect();
286
287    // Flatten residuals for quantile computation
288    let flat_residuals: Array1<f32> = residuals.iter().copied().collect();
289    let bucket_cutoffs = Array1::from_vec(quantiles(&flat_residuals, &quantile_values));
290    let bucket_weights = Array1::from_vec(quantiles(&flat_residuals, &weight_quantile_values));
291
292    let codec = ResidualCodec::new(
293        config.nbits,
294        centroids.clone(),
295        avg_res_per_dim.clone(),
296        Some(bucket_cutoffs.clone()),
297        Some(bucket_weights.clone()),
298    )?;
299
300    // Save codec components
301    use ndarray_npy::WriteNpyExt;
302
303    let centroids_path = index_dir.join("centroids.npy");
304    codec
305        .centroids_view()
306        .to_owned()
307        .write_npy(File::create(&centroids_path)?)?;
308
309    let cutoffs_path = index_dir.join("bucket_cutoffs.npy");
310    bucket_cutoffs.write_npy(File::create(&cutoffs_path)?)?;
311
312    let weights_path = index_dir.join("bucket_weights.npy");
313    bucket_weights.write_npy(File::create(&weights_path)?)?;
314
315    let avg_res_path = index_dir.join("avg_residual.npy");
316    avg_res_per_dim.write_npy(File::create(&avg_res_path)?)?;
317
318    let threshold_path = index_dir.join("cluster_threshold.npy");
319    Array1::from_vec(vec![cluster_threshold]).write_npy(File::create(&threshold_path)?)?;
320
321    // Process documents in chunks
322    let n_chunks = (num_documents as f64 / config.batch_size as f64).ceil() as usize;
323
324    // Save plan
325    let plan_path = index_dir.join("plan.json");
326    let plan = serde_json::json!({
327        "nbits": config.nbits,
328        "num_chunks": n_chunks,
329    });
330    let mut plan_file = File::create(&plan_path)?;
331    writeln!(plan_file, "{}", serde_json::to_string_pretty(&plan)?)?;
332
333    let mut all_codes: Vec<usize> = Vec::with_capacity(total_embeddings);
334    let mut doc_lengths: Vec<i64> = Vec::with_capacity(num_documents);
335
336    for chunk_idx in 0..n_chunks {
337        let start = chunk_idx * config.batch_size;
338        let end = (start + config.batch_size).min(num_documents);
339        let chunk_docs = &embeddings[start..end];
340
341        // Collect document lengths
342        let chunk_doclens: Vec<i64> = chunk_docs.iter().map(|d| d.nrows() as i64).collect();
343        let total_tokens: usize = chunk_doclens.iter().sum::<i64>() as usize;
344
345        // Concatenate all embeddings in the chunk for batch processing
346        let mut batch_embeddings = Array2::<f32>::zeros((total_tokens, embedding_dim));
347        let mut offset = 0;
348        for doc in chunk_docs {
349            let n = doc.nrows();
350            batch_embeddings
351                .slice_mut(s![offset..offset + n, ..])
352                .assign(doc);
353            offset += n;
354        }
355
356        // BATCH: Compress embeddings and compute residuals
357        // Try CUDA fused operation first, fall back to CPU (skip CUDA if force_cpu is set)
358        let (batch_codes, batch_residuals) = {
359            #[cfg(feature = "cuda")]
360            {
361                if !config.force_cpu {
362                    if let Some(ctx) = crate::cuda::get_global_context() {
363                        match crate::cuda::compress_and_residuals_cuda_batched(
364                            ctx,
365                            &batch_embeddings.view(),
366                            &codec.centroids_view(),
367                            None,
368                        ) {
369                            Ok(result) => result,
370                            Err(e) => {
371                                eprintln!(
372                                    "[next-plaid] CUDA compress_and_residuals failed: {}, falling back to CPU",
373                                    e
374                                );
375                                compress_and_residuals_cpu(&batch_embeddings, &codec)
376                            }
377                        }
378                    } else {
379                        compress_and_residuals_cpu(&batch_embeddings, &codec)
380                    }
381                } else {
382                    compress_and_residuals_cpu(&batch_embeddings, &codec)
383                }
384            }
385            #[cfg(not(feature = "cuda"))]
386            {
387                compress_and_residuals_cpu(&batch_embeddings, &codec)
388            }
389        };
390
391        // BATCH: Quantize all residuals at once
392        let batch_packed = codec.quantize_residuals(&batch_residuals)?;
393
394        // Track codes for IVF building
395        for &len in &chunk_doclens {
396            doc_lengths.push(len);
397        }
398        all_codes.extend(batch_codes.iter().copied());
399
400        // Save chunk metadata
401        let chunk_meta = ChunkMetadata {
402            num_documents: end - start,
403            num_embeddings: batch_codes.len(),
404            embedding_offset: 0, // Will be updated later
405        };
406
407        let chunk_meta_path = index_dir.join(format!("{}.metadata.json", chunk_idx));
408        serde_json::to_writer_pretty(BufWriter::new(File::create(&chunk_meta_path)?), &chunk_meta)?;
409
410        // Save chunk doclens
411        let doclens_path = index_dir.join(format!("doclens.{}.json", chunk_idx));
412        serde_json::to_writer(BufWriter::new(File::create(&doclens_path)?), &chunk_doclens)?;
413
414        // Save chunk codes
415        let chunk_codes_arr: Array1<i64> = batch_codes.iter().map(|&x| x as i64).collect();
416        let codes_path = index_dir.join(format!("{}.codes.npy", chunk_idx));
417        chunk_codes_arr.write_npy(File::create(&codes_path)?)?;
418
419        // Save chunk residuals
420        let residuals_path = index_dir.join(format!("{}.residuals.npy", chunk_idx));
421        batch_packed.write_npy(File::create(&residuals_path)?)?;
422    }
423
424    // Update chunk metadata with global offsets
425    let mut current_offset = 0usize;
426    for chunk_idx in 0..n_chunks {
427        let chunk_meta_path = index_dir.join(format!("{}.metadata.json", chunk_idx));
428        let mut meta: serde_json::Value =
429            serde_json::from_reader(BufReader::new(File::open(&chunk_meta_path)?))?;
430
431        if let Some(obj) = meta.as_object_mut() {
432            obj.insert("embedding_offset".to_string(), current_offset.into());
433            let num_emb = obj["num_embeddings"].as_u64().unwrap_or(0) as usize;
434            current_offset += num_emb;
435        }
436
437        serde_json::to_writer_pretty(BufWriter::new(File::create(&chunk_meta_path)?), &meta)?;
438    }
439
440    // Build IVF (Inverted File)
441    let mut code_to_docs: BTreeMap<usize, Vec<i64>> = BTreeMap::new();
442    let mut emb_idx = 0;
443
444    for (doc_id, &len) in doc_lengths.iter().enumerate() {
445        for _ in 0..len {
446            let code = all_codes[emb_idx];
447            code_to_docs.entry(code).or_default().push(doc_id as i64);
448            emb_idx += 1;
449        }
450    }
451
452    // Deduplicate document IDs per centroid
453    let mut ivf_data: Vec<i64> = Vec::new();
454    let mut ivf_lengths: Vec<i32> = vec![0; num_centroids];
455
456    for (centroid_id, ivf_len) in ivf_lengths.iter_mut().enumerate() {
457        if let Some(docs) = code_to_docs.get(&centroid_id) {
458            let mut unique_docs: Vec<i64> = docs.clone();
459            unique_docs.sort_unstable();
460            unique_docs.dedup();
461            *ivf_len = unique_docs.len() as i32;
462            ivf_data.extend(unique_docs);
463        }
464    }
465
466    let ivf = Array1::from_vec(ivf_data);
467    let ivf_lengths = Array1::from_vec(ivf_lengths);
468
469    let ivf_path = index_dir.join("ivf.npy");
470    ivf.write_npy(File::create(&ivf_path)?)?;
471
472    let ivf_lengths_path = index_dir.join("ivf_lengths.npy");
473    ivf_lengths.write_npy(File::create(&ivf_lengths_path)?)?;
474
475    // Save global metadata
476    let metadata = Metadata {
477        num_chunks: n_chunks,
478        nbits: config.nbits,
479        num_partitions: num_centroids,
480        num_embeddings: total_embeddings,
481        avg_doclen,
482        num_documents,
483        embedding_dim,
484        next_plaid_compatible: true, // Created by next-plaid, always compatible
485    };
486
487    let metadata_path = index_dir.join("metadata.json");
488    serde_json::to_writer_pretty(BufWriter::new(File::create(&metadata_path)?), &metadata)?;
489
490    Ok(metadata)
491}
492
493/// Create index files with automatic K-means centroid computation.
494///
495/// This is a standalone function that runs K-means to compute centroids,
496/// then creates all index files on disk.
497///
498/// # Arguments
499///
500/// * `embeddings` - List of document embeddings
501/// * `index_path` - Directory to save the index
502/// * `config` - Index configuration
503///
504/// # Returns
505///
506/// Metadata about the created index
507pub fn create_index_with_kmeans_files(
508    embeddings: &[Array2<f32>],
509    index_path: &str,
510    config: &IndexConfig,
511) -> Result<Metadata> {
512    if embeddings.is_empty() {
513        return Err(Error::IndexCreation("No documents provided".into()));
514    }
515
516    // Pre-initialize CUDA if available (first init can take 10-20s due to driver initialization)
517    // Skip if force_cpu is set to avoid unnecessary initialization overhead
518    #[cfg(feature = "cuda")]
519    if !config.force_cpu {
520        let _ = crate::cuda::get_global_context();
521    }
522
523    // Build K-means configuration from IndexConfig
524    let kmeans_config = ComputeKmeansConfig {
525        kmeans_niters: config.kmeans_niters,
526        max_points_per_centroid: config.max_points_per_centroid,
527        seed: config.seed.unwrap_or(42),
528        n_samples_kmeans: config.n_samples_kmeans,
529        num_partitions: None, // Let the heuristic decide
530        force_cpu: config.force_cpu,
531    };
532
533    // Compute centroids using fast-plaid's approach
534    let centroids = compute_kmeans(embeddings, &kmeans_config)?;
535
536    // Create the index files
537    let metadata = create_index_files(embeddings, centroids, index_path, config)?;
538
539    // If below start_from_scratch threshold, save raw embeddings for potential rebuilds
540    if embeddings.len() <= config.start_from_scratch {
541        let index_dir = std::path::Path::new(index_path);
542        crate::update::save_embeddings_npy(index_dir, embeddings)?;
543    }
544
545    Ok(metadata)
546}
547// ============================================================================
548// Memory-Mapped Index for Low Memory Usage
549// ============================================================================
550
551/// A memory-mapped PLAID index for multi-vector search.
552///
553/// This struct uses memory-mapped files for the large arrays (codes and residuals)
554/// instead of loading them entirely into RAM. Only small tensors (centroids,
555/// bucket weights, IVF) are loaded into memory.
556///
557/// # Memory Usage
558///
559/// Only small tensors (~50 MB for SciFact 5K docs) are loaded into RAM,
560/// with code and residual data accessed via OS-managed memory mapping.
561///
562/// # Usage
563///
564/// ```ignore
565/// use next_plaid::MmapIndex;
566///
567/// let index = MmapIndex::load("/path/to/index")?;
568/// let results = index.search(&query, &params, None)?;
569/// ```
570pub struct MmapIndex {
571    /// Path to the index directory
572    pub path: String,
573    /// Index metadata
574    pub metadata: Metadata,
575    /// Residual codec for quantization/decompression
576    pub codec: ResidualCodec,
577    /// IVF data (concatenated passage IDs per centroid)
578    pub ivf: Array1<i64>,
579    /// IVF lengths (number of passages per centroid)
580    pub ivf_lengths: Array1<i32>,
581    /// IVF offsets (cumulative offsets into ivf array)
582    pub ivf_offsets: Array1<i64>,
583    /// Document lengths (number of tokens per document)
584    pub doc_lengths: Array1<i64>,
585    /// Cumulative document offsets for indexing into codes/residuals
586    pub doc_offsets: Array1<usize>,
587    /// Memory-mapped codes array (public for search access)
588    pub mmap_codes: crate::mmap::MmapNpyArray1I64,
589    /// Memory-mapped residuals array (public for search access)
590    pub mmap_residuals: crate::mmap::MmapNpyArray2U8,
591}
592
593impl MmapIndex {
594    /// Load a memory-mapped index from disk.
595    ///
596    /// This creates merged files for codes and residuals if they don't exist,
597    /// then memory-maps them for efficient access.
598    ///
599    /// If the index was created by fast-plaid, it will be automatically converted
600    /// to next-plaid compatible format on first load.
601    pub fn load(index_path: &str) -> Result<Self> {
602        use ndarray_npy::ReadNpyExt;
603
604        let index_dir = Path::new(index_path);
605
606        // Load metadata (infers num_documents from doclens if not present)
607        let mut metadata = Metadata::load_from_path(index_dir)?;
608
609        // Check if conversion from fast-plaid format is needed
610        if !metadata.next_plaid_compatible {
611            eprintln!("Checking index format compatibility...");
612            let converted = crate::mmap::convert_fastplaid_to_nextplaid(index_dir)?;
613            if converted {
614                eprintln!("Index converted to next-plaid compatible format.");
615                // Delete any existing merged files since the source files changed
616                let merged_codes = index_dir.join("merged_codes.npy");
617                let merged_residuals = index_dir.join("merged_residuals.npy");
618                let codes_manifest = index_dir.join("merged_codes.manifest.json");
619                let residuals_manifest = index_dir.join("merged_residuals.manifest.json");
620                for path in [
621                    &merged_codes,
622                    &merged_residuals,
623                    &codes_manifest,
624                    &residuals_manifest,
625                ] {
626                    if path.exists() {
627                        let _ = fs::remove_file(path);
628                    }
629                }
630            }
631
632            // Mark as compatible and save metadata
633            metadata.next_plaid_compatible = true;
634            let metadata_path = index_dir.join("metadata.json");
635            let file = File::create(&metadata_path)
636                .map_err(|e| Error::IndexLoad(format!("Failed to update metadata: {}", e)))?;
637            serde_json::to_writer_pretty(BufWriter::new(file), &metadata)?;
638            eprintln!("Metadata updated with next_plaid_compatible: true");
639        }
640
641        // Load codec with memory-mapped centroids for reduced RAM usage.
642        // Other small tensors (bucket weights, etc.) are still loaded into memory.
643        let codec = ResidualCodec::load_mmap_from_dir(index_dir)?;
644
645        // Load IVF (small tensor)
646        let ivf_path = index_dir.join("ivf.npy");
647        let ivf: Array1<i64> = Array1::read_npy(
648            File::open(&ivf_path)
649                .map_err(|e| Error::IndexLoad(format!("Failed to open ivf.npy: {}", e)))?,
650        )
651        .map_err(|e| Error::IndexLoad(format!("Failed to read ivf.npy: {}", e)))?;
652
653        let ivf_lengths_path = index_dir.join("ivf_lengths.npy");
654        let ivf_lengths: Array1<i32> = Array1::read_npy(
655            File::open(&ivf_lengths_path)
656                .map_err(|e| Error::IndexLoad(format!("Failed to open ivf_lengths.npy: {}", e)))?,
657        )
658        .map_err(|e| Error::IndexLoad(format!("Failed to read ivf_lengths.npy: {}", e)))?;
659
660        // Compute IVF offsets
661        let num_centroids = ivf_lengths.len();
662        let mut ivf_offsets = Array1::<i64>::zeros(num_centroids + 1);
663        for i in 0..num_centroids {
664            ivf_offsets[i + 1] = ivf_offsets[i] + ivf_lengths[i] as i64;
665        }
666
667        // Load document lengths from all chunks
668        let mut doc_lengths_vec: Vec<i64> = Vec::with_capacity(metadata.num_documents);
669        for chunk_idx in 0..metadata.num_chunks {
670            let doclens_path = index_dir.join(format!("doclens.{}.json", chunk_idx));
671            let chunk_doclens: Vec<i64> =
672                serde_json::from_reader(BufReader::new(File::open(&doclens_path)?))?;
673            doc_lengths_vec.extend(chunk_doclens);
674        }
675        let doc_lengths = Array1::from_vec(doc_lengths_vec);
676
677        // Compute document offsets for indexing
678        let mut doc_offsets = Array1::<usize>::zeros(doc_lengths.len() + 1);
679        for i in 0..doc_lengths.len() {
680            doc_offsets[i + 1] = doc_offsets[i] + doc_lengths[i] as usize;
681        }
682
683        // Compute padding needed for StridedTensor compatibility
684        let max_len = doc_lengths.iter().cloned().max().unwrap_or(0) as usize;
685        let last_len = *doc_lengths.last().unwrap_or(&0) as usize;
686        let padding_needed = max_len.saturating_sub(last_len);
687
688        // Create merged files if needed
689        let merged_codes_path =
690            crate::mmap::merge_codes_chunks(index_dir, metadata.num_chunks, padding_needed)?;
691        let merged_residuals_path =
692            crate::mmap::merge_residuals_chunks(index_dir, metadata.num_chunks, padding_needed)?;
693
694        // Memory-map the merged files
695        let mmap_codes = crate::mmap::MmapNpyArray1I64::from_npy_file(&merged_codes_path)?;
696        let mmap_residuals = crate::mmap::MmapNpyArray2U8::from_npy_file(&merged_residuals_path)?;
697
698        Ok(Self {
699            path: index_path.to_string(),
700            metadata,
701            codec,
702            ivf,
703            ivf_lengths,
704            ivf_offsets,
705            doc_lengths,
706            doc_offsets,
707            mmap_codes,
708            mmap_residuals,
709        })
710    }
711
712    /// Get candidate documents from IVF for given centroid indices.
713    pub fn get_candidates(&self, centroid_indices: &[usize]) -> Vec<i64> {
714        let mut candidates: Vec<i64> = Vec::new();
715
716        for &idx in centroid_indices {
717            if idx < self.ivf_lengths.len() {
718                let start = self.ivf_offsets[idx] as usize;
719                let len = self.ivf_lengths[idx] as usize;
720                candidates.extend(self.ivf.slice(s![start..start + len]).iter());
721            }
722        }
723
724        candidates.sort_unstable();
725        candidates.dedup();
726        candidates
727    }
728
729    /// Get document embeddings by decompressing codes and residuals.
730    pub fn get_document_embeddings(&self, doc_id: usize) -> Result<Array2<f32>> {
731        if doc_id >= self.doc_lengths.len() {
732            return Err(Error::Search(format!("Invalid document ID: {}", doc_id)));
733        }
734
735        let start = self.doc_offsets[doc_id];
736        let end = self.doc_offsets[doc_id + 1];
737
738        // Get codes and residuals from mmap
739        let codes_slice = self.mmap_codes.slice(start, end);
740        let residuals_view = self.mmap_residuals.slice_rows(start, end);
741
742        // Convert codes to Array1<usize>
743        let codes: Array1<usize> = Array1::from_iter(codes_slice.iter().map(|&c| c as usize));
744
745        // Convert residuals to owned Array2
746        let residuals = residuals_view.to_owned();
747
748        // Decompress
749        self.codec.decompress(&residuals, &codes.view())
750    }
751
752    /// Get codes for a batch of document IDs (for approximate scoring).
753    pub fn get_document_codes(&self, doc_ids: &[usize]) -> Vec<Vec<i64>> {
754        doc_ids
755            .iter()
756            .map(|&doc_id| {
757                if doc_id >= self.doc_lengths.len() {
758                    return vec![];
759                }
760                let start = self.doc_offsets[doc_id];
761                let end = self.doc_offsets[doc_id + 1];
762                self.mmap_codes.slice(start, end).to_vec()
763            })
764            .collect()
765    }
766
767    /// Decompress embeddings for a batch of document IDs.
768    pub fn decompress_documents(&self, doc_ids: &[usize]) -> Result<(Array2<f32>, Vec<usize>)> {
769        // Compute total tokens
770        let mut total_tokens = 0usize;
771        let mut lengths = Vec::with_capacity(doc_ids.len());
772        for &doc_id in doc_ids {
773            if doc_id >= self.doc_lengths.len() {
774                lengths.push(0);
775            } else {
776                let len = self.doc_offsets[doc_id + 1] - self.doc_offsets[doc_id];
777                lengths.push(len);
778                total_tokens += len;
779            }
780        }
781
782        if total_tokens == 0 {
783            return Ok((Array2::zeros((0, self.codec.embedding_dim())), lengths));
784        }
785
786        // Gather all codes and residuals
787        let packed_dim = self.mmap_residuals.ncols();
788        let mut all_codes = Vec::with_capacity(total_tokens);
789        let mut all_residuals = Array2::<u8>::zeros((total_tokens, packed_dim));
790        let mut offset = 0;
791
792        for &doc_id in doc_ids {
793            if doc_id >= self.doc_lengths.len() {
794                continue;
795            }
796            let start = self.doc_offsets[doc_id];
797            let end = self.doc_offsets[doc_id + 1];
798            let len = end - start;
799
800            // Append codes
801            let codes_slice = self.mmap_codes.slice(start, end);
802            all_codes.extend(codes_slice.iter().map(|&c| c as usize));
803
804            // Copy residuals
805            let residuals_view = self.mmap_residuals.slice_rows(start, end);
806            all_residuals
807                .slice_mut(s![offset..offset + len, ..])
808                .assign(&residuals_view);
809            offset += len;
810        }
811
812        let codes_arr = Array1::from_vec(all_codes);
813        let embeddings = self.codec.decompress(&all_residuals, &codes_arr.view())?;
814
815        Ok((embeddings, lengths))
816    }
817
818    /// Search for similar documents.
819    ///
820    /// # Arguments
821    ///
822    /// * `query` - Query embedding matrix [num_tokens, dim]
823    /// * `params` - Search parameters
824    /// * `subset` - Optional subset of document IDs to search within
825    ///
826    /// # Returns
827    ///
828    /// Search result containing top-k document IDs and scores.
829    pub fn search(
830        &self,
831        query: &Array2<f32>,
832        params: &crate::search::SearchParameters,
833        subset: Option<&[i64]>,
834    ) -> Result<crate::search::SearchResult> {
835        crate::search::search_one_mmap(self, query, params, subset)
836    }
837
838    /// Search for multiple queries in batch.
839    ///
840    /// # Arguments
841    ///
842    /// * `queries` - Slice of query embedding matrices
843    /// * `params` - Search parameters
844    /// * `parallel` - If true, process queries in parallel using rayon
845    /// * `subset` - Optional subset of document IDs to search within
846    ///
847    /// # Returns
848    ///
849    /// Vector of search results, one per query.
850    pub fn search_batch(
851        &self,
852        queries: &[Array2<f32>],
853        params: &crate::search::SearchParameters,
854        parallel: bool,
855        subset: Option<&[i64]>,
856    ) -> Result<Vec<crate::search::SearchResult>> {
857        crate::search::search_many_mmap(self, queries, params, parallel, subset)
858    }
859
860    /// Get the number of documents in the index.
861    pub fn num_documents(&self) -> usize {
862        self.doc_lengths.len()
863    }
864
865    /// Get the total number of embeddings in the index.
866    pub fn num_embeddings(&self) -> usize {
867        self.metadata.num_embeddings
868    }
869
870    /// Get the number of partitions (centroids).
871    pub fn num_partitions(&self) -> usize {
872        self.metadata.num_partitions
873    }
874
875    /// Get the average document length.
876    pub fn avg_doclen(&self) -> f64 {
877        self.metadata.avg_doclen
878    }
879
880    /// Get the embedding dimension.
881    pub fn embedding_dim(&self) -> usize {
882        self.codec.embedding_dim()
883    }
884
885    /// Reconstruct embeddings for specific documents.
886    ///
887    /// This method retrieves the compressed codes and residuals for each document
888    /// from memory-mapped files and decompresses them to recover the original embeddings.
889    ///
890    /// # Arguments
891    ///
892    /// * `doc_ids` - Slice of document IDs to reconstruct (0-indexed)
893    ///
894    /// # Returns
895    ///
896    /// A vector of 2D arrays, one per document. Each array has shape `[num_tokens, dim]`.
897    ///
898    /// # Example
899    ///
900    /// ```ignore
901    /// use next_plaid::MmapIndex;
902    ///
903    /// let index = MmapIndex::load("/path/to/index")?;
904    /// let embeddings = index.reconstruct(&[0, 1, 2])?;
905    ///
906    /// for (i, emb) in embeddings.iter().enumerate() {
907    ///     println!("Document {}: {} tokens x {} dim", i, emb.nrows(), emb.ncols());
908    /// }
909    /// ```
910    pub fn reconstruct(&self, doc_ids: &[i64]) -> Result<Vec<Array2<f32>>> {
911        crate::embeddings::reconstruct_embeddings(self, doc_ids)
912    }
913
914    /// Reconstruct a single document's embeddings.
915    ///
916    /// Convenience method for reconstructing a single document.
917    ///
918    /// # Arguments
919    ///
920    /// * `doc_id` - Document ID to reconstruct (0-indexed)
921    ///
922    /// # Returns
923    ///
924    /// A 2D array with shape `[num_tokens, dim]`.
925    pub fn reconstruct_single(&self, doc_id: i64) -> Result<Array2<f32>> {
926        crate::embeddings::reconstruct_single(self, doc_id)
927    }
928
929    /// Create a new index from document embeddings with automatic centroid computation.
930    ///
931    /// This method:
932    /// 1. Computes centroids using K-means
933    /// 2. Creates index files on disk
934    /// 3. Loads the index using memory-mapped I/O
935    ///
936    /// Note: During creation, data is temporarily held in RAM for processing,
937    /// then written to disk and loaded as mmap.
938    ///
939    /// # Arguments
940    ///
941    /// * `embeddings` - List of document embeddings, each of shape `[num_tokens, dim]`
942    /// * `index_path` - Directory to save the index
943    /// * `config` - Index configuration
944    ///
945    /// # Returns
946    ///
947    /// The created MmapIndex
948    pub fn create_with_kmeans(
949        embeddings: &[Array2<f32>],
950        index_path: &str,
951        config: &IndexConfig,
952    ) -> Result<Self> {
953        // Use standalone function to create files
954        create_index_with_kmeans_files(embeddings, index_path, config)?;
955
956        // Load as memory-mapped index
957        Self::load(index_path)
958    }
959
960    /// Update the index with new documents, matching fast-plaid behavior.
961    ///
962    /// This method adds new documents to an existing index with three possible paths:
963    ///
964    /// 1. **Start-from-scratch mode** (num_documents <= start_from_scratch):
965    ///    - Loads existing embeddings from `embeddings.npy` if available
966    ///    - Combines with new embeddings
967    ///    - Rebuilds the entire index from scratch with fresh K-means
968    ///    - Clears `embeddings.npy` if total exceeds threshold
969    ///
970    /// 2. **Buffer mode** (total_new < buffer_size):
971    ///    - Adds new documents to the index without centroid expansion
972    ///    - Saves embeddings to buffer for later centroid expansion
973    ///
974    /// 3. **Centroid expansion mode** (total_new >= buffer_size):
975    ///    - Deletes previously buffered documents
976    ///    - Expands centroids with outliers from combined buffer + new embeddings
977    ///    - Re-indexes all combined embeddings with expanded centroids
978    ///
979    /// # Arguments
980    ///
981    /// * `embeddings` - New document embeddings to add
982    /// * `config` - Update configuration
983    ///
984    /// # Returns
985    ///
986    /// Vector of document IDs assigned to the new embeddings
987    pub fn update(
988        &mut self,
989        embeddings: &[Array2<f32>],
990        config: &crate::update::UpdateConfig,
991    ) -> Result<Vec<i64>> {
992        use crate::codec::ResidualCodec;
993        use crate::update::{
994            clear_buffer, clear_embeddings_npy, embeddings_npy_exists, load_buffer,
995            load_buffer_info, load_cluster_threshold, load_embeddings_npy, save_buffer,
996            update_centroids, update_index,
997        };
998
999        let path_str = self.path.clone();
1000        let index_path = std::path::Path::new(&path_str);
1001        let num_new_docs = embeddings.len();
1002
1003        // ==================================================================
1004        // Start-from-scratch mode (fast-plaid update.py:312-346)
1005        // ==================================================================
1006        if self.metadata.num_documents <= config.start_from_scratch {
1007            // Load existing embeddings if available
1008            let existing_embeddings = load_embeddings_npy(index_path)?;
1009
1010            // Check if embeddings.npy is in sync with the index.
1011            // If not (e.g., after delete when index was above threshold), we can't do
1012            // start-from-scratch mode because we don't have all the old embeddings.
1013            // Fall through to buffer mode instead.
1014            if existing_embeddings.len() == self.metadata.num_documents {
1015                // New documents start after existing documents
1016                let start_doc_id = existing_embeddings.len() as i64;
1017
1018                // Combine existing + new embeddings
1019                let combined_embeddings: Vec<Array2<f32>> = existing_embeddings
1020                    .into_iter()
1021                    .chain(embeddings.iter().cloned())
1022                    .collect();
1023
1024                // Build IndexConfig from UpdateConfig for create_with_kmeans
1025                let index_config = IndexConfig {
1026                    nbits: self.metadata.nbits,
1027                    batch_size: config.batch_size,
1028                    seed: Some(config.seed),
1029                    kmeans_niters: config.kmeans_niters,
1030                    max_points_per_centroid: config.max_points_per_centroid,
1031                    n_samples_kmeans: config.n_samples_kmeans,
1032                    start_from_scratch: config.start_from_scratch,
1033                    force_cpu: config.force_cpu,
1034                };
1035
1036                // Rebuild index from scratch with fresh K-means
1037                *self = Self::create_with_kmeans(&combined_embeddings, &path_str, &index_config)?;
1038
1039                // If we've crossed the threshold, clear embeddings.npy
1040                if combined_embeddings.len() > config.start_from_scratch
1041                    && embeddings_npy_exists(index_path)
1042                {
1043                    clear_embeddings_npy(index_path)?;
1044                }
1045
1046                // Return the document IDs assigned to the new embeddings
1047                return Ok((start_doc_id..start_doc_id + num_new_docs as i64).collect());
1048            }
1049            // else: embeddings.npy is out of sync, fall through to buffer mode
1050        }
1051
1052        // Load buffer
1053        let buffer = load_buffer(index_path)?;
1054        let buffer_len = buffer.len();
1055        let total_new = embeddings.len() + buffer_len;
1056
1057        // Track the starting document ID for the new embeddings
1058        let start_doc_id: i64;
1059
1060        // Load codec for update operations
1061        let mut codec = ResidualCodec::load_from_dir(index_path)?;
1062
1063        // Check buffer threshold
1064        if total_new >= config.buffer_size {
1065            // Centroid expansion path (matches fast-plaid update.py:376-422)
1066
1067            // 1. Get number of buffered docs that were previously indexed
1068            let num_buffered = load_buffer_info(index_path)?;
1069
1070            // 2. Delete buffered docs from index (they were indexed without centroid expansion)
1071            if num_buffered > 0 && self.metadata.num_documents >= num_buffered {
1072                let start_del_idx = self.metadata.num_documents - num_buffered;
1073                let docs_to_delete: Vec<i64> = (start_del_idx..self.metadata.num_documents)
1074                    .map(|i| i as i64)
1075                    .collect();
1076                crate::delete::delete_from_index_keep_buffer(&docs_to_delete, &path_str)?;
1077                // Reload metadata after delete
1078                self.metadata = Metadata::load_from_path(index_path)?;
1079            }
1080
1081            // New embeddings start after buffer is re-indexed
1082            start_doc_id = (self.metadata.num_documents + buffer_len) as i64;
1083
1084            // 3. Combine buffer + new embeddings
1085            let combined: Vec<Array2<f32>> = buffer
1086                .into_iter()
1087                .chain(embeddings.iter().cloned())
1088                .collect();
1089
1090            // 4. Expand centroids with outliers from combined embeddings
1091            if let Ok(cluster_threshold) = load_cluster_threshold(index_path) {
1092                let new_centroids =
1093                    update_centroids(index_path, &combined, cluster_threshold, config)?;
1094                if new_centroids > 0 {
1095                    // Reload codec with new centroids
1096                    codec = ResidualCodec::load_from_dir(index_path)?;
1097                }
1098            }
1099
1100            // 5. Clear buffer
1101            clear_buffer(index_path)?;
1102
1103            // 6. Update index with ALL combined embeddings (buffer + new)
1104            update_index(
1105                &combined,
1106                &path_str,
1107                &codec,
1108                Some(config.batch_size),
1109                true,
1110                config.force_cpu,
1111            )?;
1112        } else {
1113            // Small update: add to buffer and index without centroid expansion
1114            // New documents start at current num_documents
1115            start_doc_id = self.metadata.num_documents as i64;
1116
1117            // Accumulate buffer: combine existing buffer with new embeddings
1118            let combined_buffer: Vec<Array2<f32>> = buffer
1119                .into_iter()
1120                .chain(embeddings.iter().cloned())
1121                .collect();
1122            save_buffer(index_path, &combined_buffer)?;
1123
1124            // Update index without threshold update
1125            update_index(
1126                embeddings,
1127                &path_str,
1128                &codec,
1129                Some(config.batch_size),
1130                false,
1131                config.force_cpu,
1132            )?;
1133        }
1134
1135        // Reload self as mmap
1136        *self = Self::load(&path_str)?;
1137
1138        // Return the document IDs assigned to the new embeddings
1139        Ok((start_doc_id..start_doc_id + num_new_docs as i64).collect())
1140    }
1141
1142    /// Update the index with new documents and optional metadata.
1143    ///
1144    /// # Arguments
1145    ///
1146    /// * `embeddings` - New document embeddings to add
1147    /// * `config` - Update configuration
1148    /// * `metadata` - Optional metadata for new documents
1149    ///
1150    /// # Returns
1151    ///
1152    /// Vector of document IDs assigned to the new embeddings
1153    pub fn update_with_metadata(
1154        &mut self,
1155        embeddings: &[Array2<f32>],
1156        config: &crate::update::UpdateConfig,
1157        metadata: Option<&[serde_json::Value]>,
1158    ) -> Result<Vec<i64>> {
1159        // Validate metadata length if provided
1160        if let Some(meta) = metadata {
1161            if meta.len() != embeddings.len() {
1162                return Err(Error::Config(format!(
1163                    "Metadata length ({}) must match embeddings length ({})",
1164                    meta.len(),
1165                    embeddings.len()
1166                )));
1167            }
1168        }
1169
1170        // Perform the update and get document IDs
1171        let doc_ids = self.update(embeddings, config)?;
1172
1173        // Add metadata if provided, using the assigned document IDs
1174        if let Some(meta) = metadata {
1175            crate::filtering::update(&self.path, meta, &doc_ids)?;
1176        }
1177
1178        Ok(doc_ids)
1179    }
1180
1181    /// Update an existing index or create a new one if it doesn't exist.
1182    ///
1183    /// # Arguments
1184    ///
1185    /// * `embeddings` - Document embeddings to add
1186    /// * `index_path` - Directory for the index
1187    /// * `index_config` - Configuration for index creation
1188    /// * `update_config` - Configuration for updates
1189    ///
1190    /// # Returns
1191    ///
1192    /// A tuple of (MmapIndex, `Vec<i64>`) containing the index and document IDs
1193    pub fn update_or_create(
1194        embeddings: &[Array2<f32>],
1195        index_path: &str,
1196        index_config: &IndexConfig,
1197        update_config: &crate::update::UpdateConfig,
1198    ) -> Result<(Self, Vec<i64>)> {
1199        let index_dir = std::path::Path::new(index_path);
1200        let metadata_path = index_dir.join("metadata.json");
1201
1202        if metadata_path.exists() {
1203            // Index exists, load and update
1204            let mut index = Self::load(index_path)?;
1205            let doc_ids = index.update(embeddings, update_config)?;
1206            Ok((index, doc_ids))
1207        } else {
1208            // Index doesn't exist, create new
1209            let num_docs = embeddings.len();
1210            let index = Self::create_with_kmeans(embeddings, index_path, index_config)?;
1211            let doc_ids: Vec<i64> = (0..num_docs as i64).collect();
1212            Ok((index, doc_ids))
1213        }
1214    }
1215
1216    /// Reload the index from disk.
1217    ///
1218    /// This should be called after delete operations to refresh the in-memory
1219    /// representation with the updated on-disk state.
1220    pub fn reload(&mut self) -> Result<()> {
1221        *self = Self::load(&self.path)?;
1222        Ok(())
1223    }
1224
1225    /// Delete documents from the index.
1226    ///
1227    /// This performs the deletion on disk but does NOT reload the in-memory index.
1228    /// Call `reload()` after all delete operations are complete to refresh the index.
1229    ///
1230    /// # Arguments
1231    ///
1232    /// * `doc_ids` - Slice of document IDs to delete (0-indexed)
1233    ///
1234    /// # Returns
1235    ///
1236    /// The number of documents actually deleted
1237    pub fn delete(&mut self, doc_ids: &[i64]) -> Result<usize> {
1238        self.delete_with_options(doc_ids, true)
1239    }
1240
1241    /// Delete documents from the index with control over metadata deletion.
1242    ///
1243    /// This performs the deletion on disk but does NOT reload the in-memory index.
1244    /// Call `reload()` after all delete operations are complete to refresh the index.
1245    ///
1246    /// # Arguments
1247    ///
1248    /// * `doc_ids` - Slice of document IDs to delete
1249    /// * `delete_metadata` - If true, also delete from metadata.db if it exists
1250    ///
1251    /// # Returns
1252    ///
1253    /// The number of documents actually deleted
1254    pub fn delete_with_options(&mut self, doc_ids: &[i64], delete_metadata: bool) -> Result<usize> {
1255        let path = self.path.clone();
1256
1257        // Perform the deletion using standalone function
1258        let deleted = crate::delete::delete_from_index(doc_ids, &path)?;
1259
1260        // Also delete from metadata.db if requested
1261        if delete_metadata && deleted > 0 {
1262            let index_path = std::path::Path::new(&path);
1263            let db_path = index_path.join("metadata.db");
1264            if db_path.exists() {
1265                crate::filtering::delete(&path, doc_ids)?;
1266            }
1267        }
1268
1269        Ok(deleted)
1270    }
1271}
1272
1273#[cfg(test)]
1274mod tests {
1275    use super::*;
1276
1277    #[test]
1278    fn test_index_config_default() {
1279        let config = IndexConfig::default();
1280        assert_eq!(config.nbits, 4);
1281        assert_eq!(config.batch_size, 50_000);
1282        assert_eq!(config.seed, Some(42));
1283    }
1284
1285    #[test]
1286    fn test_update_or_create_new_index() {
1287        use ndarray::Array2;
1288        use tempfile::tempdir;
1289
1290        let temp_dir = tempdir().unwrap();
1291        let index_path = temp_dir.path().to_str().unwrap();
1292
1293        // Create test embeddings (5 documents)
1294        let mut embeddings: Vec<Array2<f32>> = Vec::new();
1295        for i in 0..5 {
1296            let mut doc = Array2::<f32>::zeros((5, 32));
1297            for j in 0..5 {
1298                for k in 0..32 {
1299                    doc[[j, k]] = (i as f32 * 0.1) + (j as f32 * 0.01) + (k as f32 * 0.001);
1300                }
1301            }
1302            // Normalize rows
1303            for mut row in doc.rows_mut() {
1304                let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
1305                if norm > 0.0 {
1306                    row.iter_mut().for_each(|x| *x /= norm);
1307                }
1308            }
1309            embeddings.push(doc);
1310        }
1311
1312        let index_config = IndexConfig {
1313            nbits: 2,
1314            batch_size: 50,
1315            seed: Some(42),
1316            kmeans_niters: 2,
1317            ..Default::default()
1318        };
1319        let update_config = crate::update::UpdateConfig::default();
1320
1321        // Index doesn't exist - should create new
1322        let (index, doc_ids) =
1323            MmapIndex::update_or_create(&embeddings, index_path, &index_config, &update_config)
1324                .expect("Failed to create index");
1325
1326        assert_eq!(index.metadata.num_documents, 5);
1327        assert_eq!(doc_ids, vec![0, 1, 2, 3, 4]);
1328
1329        // Verify index was created
1330        assert!(temp_dir.path().join("metadata.json").exists());
1331        assert!(temp_dir.path().join("centroids.npy").exists());
1332    }
1333
1334    #[test]
1335    fn test_update_or_create_existing_index() {
1336        use ndarray::Array2;
1337        use tempfile::tempdir;
1338
1339        let temp_dir = tempdir().unwrap();
1340        let index_path = temp_dir.path().to_str().unwrap();
1341
1342        // Helper to create embeddings
1343        let create_embeddings = |count: usize, offset: usize| -> Vec<Array2<f32>> {
1344            let mut embeddings = Vec::new();
1345            for i in 0..count {
1346                let mut doc = Array2::<f32>::zeros((5, 32));
1347                for j in 0..5 {
1348                    for k in 0..32 {
1349                        doc[[j, k]] =
1350                            ((i + offset) as f32 * 0.1) + (j as f32 * 0.01) + (k as f32 * 0.001);
1351                    }
1352                }
1353                for mut row in doc.rows_mut() {
1354                    let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
1355                    if norm > 0.0 {
1356                        row.iter_mut().for_each(|x| *x /= norm);
1357                    }
1358                }
1359                embeddings.push(doc);
1360            }
1361            embeddings
1362        };
1363
1364        let index_config = IndexConfig {
1365            nbits: 2,
1366            batch_size: 50,
1367            seed: Some(42),
1368            kmeans_niters: 2,
1369            ..Default::default()
1370        };
1371        let update_config = crate::update::UpdateConfig::default();
1372
1373        // First call - creates index with 5 documents
1374        let embeddings1 = create_embeddings(5, 0);
1375        let (index1, doc_ids1) =
1376            MmapIndex::update_or_create(&embeddings1, index_path, &index_config, &update_config)
1377                .expect("Failed to create index");
1378        assert_eq!(index1.metadata.num_documents, 5);
1379        assert_eq!(doc_ids1, vec![0, 1, 2, 3, 4]);
1380
1381        // Second call - updates existing index with 3 more documents
1382        let embeddings2 = create_embeddings(3, 5);
1383        let (index2, doc_ids2) =
1384            MmapIndex::update_or_create(&embeddings2, index_path, &index_config, &update_config)
1385                .expect("Failed to update index");
1386        assert_eq!(index2.metadata.num_documents, 8);
1387        assert_eq!(doc_ids2, vec![5, 6, 7]);
1388    }
1389}