Skip to main content

next_plaid/
index.rs

1//! Index creation and management for PLAID
2
3use std::collections::BTreeMap;
4use std::fs::{self, File};
5use std::io::{BufReader, BufWriter, Write};
6use std::path::Path;
7
8use ndarray::{s, Array1, Array2, Axis};
9use serde::{Deserialize, Serialize};
10
11use crate::codec::ResidualCodec;
12use crate::error::{Error, Result};
13use crate::kmeans::{compute_kmeans, ComputeKmeansConfig};
14use crate::utils::{quantile, quantiles};
15
16/// CPU implementation of fused compress_into_codes + residual computation.
17fn compress_and_residuals_cpu(
18    embeddings: &Array2<f32>,
19    codec: &ResidualCodec,
20) -> (Array1<usize>, Array2<f32>) {
21    use rayon::prelude::*;
22
23    // Use CPU-only version to ensure no CUDA is called
24    let codes = codec.compress_into_codes_cpu(embeddings);
25    let mut residuals = embeddings.clone();
26
27    let centroids = &codec.centroids;
28    residuals
29        .axis_iter_mut(Axis(0))
30        .into_par_iter()
31        .zip(codes.as_slice().unwrap().par_iter())
32        .for_each(|(mut row, &code)| {
33            let centroid = centroids.row(code);
34            row.iter_mut()
35                .zip(centroid.iter())
36                .for_each(|(r, c)| *r -= c);
37        });
38
39    (codes, residuals)
40}
41
42/// Configuration for index creation
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct IndexConfig {
45    /// Number of bits for quantization (typically 2 or 4)
46    pub nbits: usize,
47    /// Batch size for processing
48    pub batch_size: usize,
49    /// Random seed for reproducibility
50    pub seed: Option<u64>,
51    /// Number of K-means iterations (default: 4)
52    #[serde(default = "default_kmeans_niters")]
53    pub kmeans_niters: usize,
54    /// Maximum number of points per centroid for K-means (default: 256)
55    #[serde(default = "default_max_points_per_centroid")]
56    pub max_points_per_centroid: usize,
57    /// Number of samples for K-means training.
58    /// If None, uses heuristic: min(1 + 16 * sqrt(120 * num_documents), num_documents)
59    #[serde(default)]
60    pub n_samples_kmeans: Option<usize>,
61    /// Threshold for start-from-scratch mode (default: 999).
62    /// When the number of documents is <= this threshold, raw embeddings are saved
63    /// to embeddings.npy for potential rebuilds during updates.
64    #[serde(default = "default_start_from_scratch")]
65    pub start_from_scratch: usize,
66    /// Force CPU execution for K-means even when CUDA feature is enabled.
67    /// Useful for small batches where GPU initialization overhead exceeds benefits.
68    #[serde(default)]
69    pub force_cpu: bool,
70}
71
72fn default_start_from_scratch() -> usize {
73    999
74}
75
76fn default_kmeans_niters() -> usize {
77    4
78}
79
80fn default_max_points_per_centroid() -> usize {
81    256
82}
83
84impl Default for IndexConfig {
85    fn default() -> Self {
86        Self {
87            nbits: 4,
88            batch_size: 50_000,
89            seed: Some(42),
90            kmeans_niters: 4,
91            max_points_per_centroid: 256,
92            n_samples_kmeans: None,
93            start_from_scratch: 999,
94            force_cpu: false,
95        }
96    }
97}
98
99/// Metadata for the index
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct Metadata {
102    /// Number of chunks in the index
103    pub num_chunks: usize,
104    /// Number of bits for quantization
105    pub nbits: usize,
106    /// Number of partitions (centroids)
107    pub num_partitions: usize,
108    /// Total number of embeddings
109    pub num_embeddings: usize,
110    /// Average document length
111    pub avg_doclen: f64,
112    /// Total number of documents
113    #[serde(default)]
114    pub num_documents: usize,
115    /// Whether the index has been converted to next-plaid compatible format.
116    /// If false or missing, the index may need fast-plaid to next-plaid conversion.
117    #[serde(default)]
118    pub next_plaid_compatible: bool,
119}
120
121impl Metadata {
122    /// Load metadata from a JSON file, inferring num_documents from doclens if not present.
123    pub fn load_from_path(index_path: &Path) -> Result<Self> {
124        let metadata_path = index_path.join("metadata.json");
125        let mut metadata: Metadata = serde_json::from_reader(BufReader::new(
126            File::open(&metadata_path)
127                .map_err(|e| Error::IndexLoad(format!("Failed to open metadata: {}", e)))?,
128        ))?;
129
130        // If num_documents is 0 (default), infer from doclens files
131        if metadata.num_documents == 0 {
132            let mut total_docs = 0usize;
133            for chunk_idx in 0..metadata.num_chunks {
134                let doclens_path = index_path.join(format!("doclens.{}.json", chunk_idx));
135                if let Ok(file) = File::open(&doclens_path) {
136                    if let Ok(chunk_doclens) =
137                        serde_json::from_reader::<_, Vec<i64>>(BufReader::new(file))
138                    {
139                        total_docs += chunk_doclens.len();
140                    }
141                }
142            }
143            metadata.num_documents = total_docs;
144        }
145
146        Ok(metadata)
147    }
148}
149
150/// Chunk metadata
151#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct ChunkMetadata {
153    pub num_documents: usize,
154    pub num_embeddings: usize,
155    #[serde(default)]
156    pub embedding_offset: usize,
157}
158
159// ============================================================================
160// Standalone Index Creation Functions
161// ============================================================================
162
163/// Create index files on disk from embeddings and centroids.
164///
165/// This is a standalone function that creates all necessary index files
166/// without constructing an in-memory Index object. Both Index and MmapIndex
167/// can use this function to create their files, then load them in their
168/// preferred format.
169///
170/// # Arguments
171///
172/// * `embeddings` - List of document embeddings
173/// * `centroids` - Pre-computed centroids from K-means
174/// * `index_path` - Directory to save the index
175/// * `config` - Index configuration
176///
177/// # Returns
178///
179/// Metadata about the created index
180pub fn create_index_files(
181    embeddings: &[Array2<f32>],
182    centroids: Array2<f32>,
183    index_path: &str,
184    config: &IndexConfig,
185) -> Result<Metadata> {
186    let index_dir = Path::new(index_path);
187    fs::create_dir_all(index_dir)?;
188
189    let num_documents = embeddings.len();
190    let embedding_dim = centroids.ncols();
191    let num_centroids = centroids.nrows();
192
193    if num_documents == 0 {
194        return Err(Error::IndexCreation("No documents provided".into()));
195    }
196
197    // Calculate statistics
198    let total_embeddings: usize = embeddings.iter().map(|e| e.nrows()).sum();
199    let avg_doclen = total_embeddings as f64 / num_documents as f64;
200
201    // Sample documents for codec training
202    let sample_count = ((16.0 * (120.0 * num_documents as f64).sqrt()) as usize)
203        .min(num_documents)
204        .max(1);
205
206    let mut rng = if let Some(seed) = config.seed {
207        use rand::SeedableRng;
208        rand_chacha::ChaCha8Rng::seed_from_u64(seed)
209    } else {
210        use rand::SeedableRng;
211        rand_chacha::ChaCha8Rng::from_entropy()
212    };
213
214    use rand::seq::SliceRandom;
215    let mut indices: Vec<usize> = (0..num_documents).collect();
216    indices.shuffle(&mut rng);
217    let sample_indices: Vec<usize> = indices.into_iter().take(sample_count).collect();
218
219    // Collect sample embeddings for training
220    let heldout_size = (0.05 * total_embeddings as f64).min(50000.0) as usize;
221    let mut heldout_embeddings: Vec<f32> = Vec::with_capacity(heldout_size * embedding_dim);
222    let mut collected = 0;
223
224    for &idx in sample_indices.iter().rev() {
225        if collected >= heldout_size {
226            break;
227        }
228        let emb = &embeddings[idx];
229        let take = (heldout_size - collected).min(emb.nrows());
230        for row in emb.axis_iter(Axis(0)).take(take) {
231            heldout_embeddings.extend(row.iter());
232        }
233        collected += take;
234    }
235
236    let heldout = Array2::from_shape_vec((collected, embedding_dim), heldout_embeddings)
237        .map_err(|e| Error::IndexCreation(format!("Failed to create heldout array: {}", e)))?;
238
239    // Train codec: compute residuals and quantization parameters
240    let avg_residual = Array1::zeros(embedding_dim);
241    let initial_codec =
242        ResidualCodec::new(config.nbits, centroids.clone(), avg_residual, None, None)?;
243
244    // Compute codes for heldout samples
245    // Use CPU-only version when force_cpu is set to avoid CUDA initialization overhead
246    let heldout_codes = if config.force_cpu {
247        initial_codec.compress_into_codes_cpu(&heldout)
248    } else {
249        initial_codec.compress_into_codes(&heldout)
250    };
251
252    // Compute residuals
253    let mut residuals = heldout.clone();
254    for i in 0..heldout.nrows() {
255        let centroid = initial_codec.centroids.row(heldout_codes[i]);
256        for j in 0..embedding_dim {
257            residuals[[i, j]] -= centroid[j];
258        }
259    }
260
261    // Compute cluster threshold from residual distances
262    let distances: Array1<f32> = residuals
263        .axis_iter(Axis(0))
264        .map(|row| row.dot(&row).sqrt())
265        .collect();
266    #[allow(unused_variables)]
267    let cluster_threshold = quantile(&distances, 0.75);
268
269    // Compute average residual per dimension
270    let avg_res_per_dim: Array1<f32> = residuals
271        .axis_iter(Axis(1))
272        .map(|col| col.iter().map(|x| x.abs()).sum::<f32>() / col.len() as f32)
273        .collect();
274
275    // Compute quantization buckets
276    let n_options = 1 << config.nbits;
277    let quantile_values: Vec<f64> = (1..n_options)
278        .map(|i| i as f64 / n_options as f64)
279        .collect();
280    let weight_quantile_values: Vec<f64> = (0..n_options)
281        .map(|i| (i as f64 + 0.5) / n_options as f64)
282        .collect();
283
284    // Flatten residuals for quantile computation
285    let flat_residuals: Array1<f32> = residuals.iter().copied().collect();
286    let bucket_cutoffs = Array1::from_vec(quantiles(&flat_residuals, &quantile_values));
287    let bucket_weights = Array1::from_vec(quantiles(&flat_residuals, &weight_quantile_values));
288
289    let codec = ResidualCodec::new(
290        config.nbits,
291        centroids.clone(),
292        avg_res_per_dim.clone(),
293        Some(bucket_cutoffs.clone()),
294        Some(bucket_weights.clone()),
295    )?;
296
297    // Save codec components
298    use ndarray_npy::WriteNpyExt;
299
300    let centroids_path = index_dir.join("centroids.npy");
301    codec
302        .centroids_view()
303        .to_owned()
304        .write_npy(File::create(&centroids_path)?)?;
305
306    let cutoffs_path = index_dir.join("bucket_cutoffs.npy");
307    bucket_cutoffs.write_npy(File::create(&cutoffs_path)?)?;
308
309    let weights_path = index_dir.join("bucket_weights.npy");
310    bucket_weights.write_npy(File::create(&weights_path)?)?;
311
312    let avg_res_path = index_dir.join("avg_residual.npy");
313    avg_res_per_dim.write_npy(File::create(&avg_res_path)?)?;
314
315    let threshold_path = index_dir.join("cluster_threshold.npy");
316    Array1::from_vec(vec![cluster_threshold]).write_npy(File::create(&threshold_path)?)?;
317
318    // Process documents in chunks
319    let n_chunks = (num_documents as f64 / config.batch_size as f64).ceil() as usize;
320
321    // Save plan
322    let plan_path = index_dir.join("plan.json");
323    let plan = serde_json::json!({
324        "nbits": config.nbits,
325        "num_chunks": n_chunks,
326    });
327    let mut plan_file = File::create(&plan_path)?;
328    writeln!(plan_file, "{}", serde_json::to_string_pretty(&plan)?)?;
329
330    let mut all_codes: Vec<usize> = Vec::with_capacity(total_embeddings);
331    let mut doc_lengths: Vec<i64> = Vec::with_capacity(num_documents);
332
333    for chunk_idx in 0..n_chunks {
334        let start = chunk_idx * config.batch_size;
335        let end = (start + config.batch_size).min(num_documents);
336        let chunk_docs = &embeddings[start..end];
337
338        // Collect document lengths
339        let chunk_doclens: Vec<i64> = chunk_docs.iter().map(|d| d.nrows() as i64).collect();
340        let total_tokens: usize = chunk_doclens.iter().sum::<i64>() as usize;
341
342        // Concatenate all embeddings in the chunk for batch processing
343        let mut batch_embeddings = Array2::<f32>::zeros((total_tokens, embedding_dim));
344        let mut offset = 0;
345        for doc in chunk_docs {
346            let n = doc.nrows();
347            batch_embeddings
348                .slice_mut(s![offset..offset + n, ..])
349                .assign(doc);
350            offset += n;
351        }
352
353        // BATCH: Compress embeddings and compute residuals
354        // Try CUDA fused operation first, fall back to CPU (skip CUDA if force_cpu is set)
355        let (batch_codes, batch_residuals) = {
356            #[cfg(feature = "cuda")]
357            {
358                if !config.force_cpu {
359                    if let Some(ctx) = crate::cuda::get_global_context() {
360                        match crate::cuda::compress_and_residuals_cuda_batched(
361                            ctx,
362                            &batch_embeddings.view(),
363                            &codec.centroids_view(),
364                            None,
365                        ) {
366                            Ok(result) => result,
367                            Err(e) => {
368                                eprintln!(
369                                    "[next-plaid] CUDA compress_and_residuals failed: {}, falling back to CPU",
370                                    e
371                                );
372                                compress_and_residuals_cpu(&batch_embeddings, &codec)
373                            }
374                        }
375                    } else {
376                        compress_and_residuals_cpu(&batch_embeddings, &codec)
377                    }
378                } else {
379                    compress_and_residuals_cpu(&batch_embeddings, &codec)
380                }
381            }
382            #[cfg(not(feature = "cuda"))]
383            {
384                compress_and_residuals_cpu(&batch_embeddings, &codec)
385            }
386        };
387
388        // BATCH: Quantize all residuals at once
389        let batch_packed = codec.quantize_residuals(&batch_residuals)?;
390
391        // Track codes for IVF building
392        for &len in &chunk_doclens {
393            doc_lengths.push(len);
394        }
395        all_codes.extend(batch_codes.iter().copied());
396
397        // Save chunk metadata
398        let chunk_meta = ChunkMetadata {
399            num_documents: end - start,
400            num_embeddings: batch_codes.len(),
401            embedding_offset: 0, // Will be updated later
402        };
403
404        let chunk_meta_path = index_dir.join(format!("{}.metadata.json", chunk_idx));
405        serde_json::to_writer_pretty(BufWriter::new(File::create(&chunk_meta_path)?), &chunk_meta)?;
406
407        // Save chunk doclens
408        let doclens_path = index_dir.join(format!("doclens.{}.json", chunk_idx));
409        serde_json::to_writer(BufWriter::new(File::create(&doclens_path)?), &chunk_doclens)?;
410
411        // Save chunk codes
412        let chunk_codes_arr: Array1<i64> = batch_codes.iter().map(|&x| x as i64).collect();
413        let codes_path = index_dir.join(format!("{}.codes.npy", chunk_idx));
414        chunk_codes_arr.write_npy(File::create(&codes_path)?)?;
415
416        // Save chunk residuals
417        let residuals_path = index_dir.join(format!("{}.residuals.npy", chunk_idx));
418        batch_packed.write_npy(File::create(&residuals_path)?)?;
419    }
420
421    // Update chunk metadata with global offsets
422    let mut current_offset = 0usize;
423    for chunk_idx in 0..n_chunks {
424        let chunk_meta_path = index_dir.join(format!("{}.metadata.json", chunk_idx));
425        let mut meta: serde_json::Value =
426            serde_json::from_reader(BufReader::new(File::open(&chunk_meta_path)?))?;
427
428        if let Some(obj) = meta.as_object_mut() {
429            obj.insert("embedding_offset".to_string(), current_offset.into());
430            let num_emb = obj["num_embeddings"].as_u64().unwrap_or(0) as usize;
431            current_offset += num_emb;
432        }
433
434        serde_json::to_writer_pretty(BufWriter::new(File::create(&chunk_meta_path)?), &meta)?;
435    }
436
437    // Build IVF (Inverted File)
438    let mut code_to_docs: BTreeMap<usize, Vec<i64>> = BTreeMap::new();
439    let mut emb_idx = 0;
440
441    for (doc_id, &len) in doc_lengths.iter().enumerate() {
442        for _ in 0..len {
443            let code = all_codes[emb_idx];
444            code_to_docs.entry(code).or_default().push(doc_id as i64);
445            emb_idx += 1;
446        }
447    }
448
449    // Deduplicate document IDs per centroid
450    let mut ivf_data: Vec<i64> = Vec::new();
451    let mut ivf_lengths: Vec<i32> = vec![0; num_centroids];
452
453    for (centroid_id, ivf_len) in ivf_lengths.iter_mut().enumerate() {
454        if let Some(docs) = code_to_docs.get(&centroid_id) {
455            let mut unique_docs: Vec<i64> = docs.clone();
456            unique_docs.sort_unstable();
457            unique_docs.dedup();
458            *ivf_len = unique_docs.len() as i32;
459            ivf_data.extend(unique_docs);
460        }
461    }
462
463    let ivf = Array1::from_vec(ivf_data);
464    let ivf_lengths = Array1::from_vec(ivf_lengths);
465
466    let ivf_path = index_dir.join("ivf.npy");
467    ivf.write_npy(File::create(&ivf_path)?)?;
468
469    let ivf_lengths_path = index_dir.join("ivf_lengths.npy");
470    ivf_lengths.write_npy(File::create(&ivf_lengths_path)?)?;
471
472    // Save global metadata
473    let metadata = Metadata {
474        num_chunks: n_chunks,
475        nbits: config.nbits,
476        num_partitions: num_centroids,
477        num_embeddings: total_embeddings,
478        avg_doclen,
479        num_documents,
480        next_plaid_compatible: true, // Created by next-plaid, always compatible
481    };
482
483    let metadata_path = index_dir.join("metadata.json");
484    serde_json::to_writer_pretty(BufWriter::new(File::create(&metadata_path)?), &metadata)?;
485
486    Ok(metadata)
487}
488
489/// Create index files with automatic K-means centroid computation.
490///
491/// This is a standalone function that runs K-means to compute centroids,
492/// then creates all index files on disk.
493///
494/// # Arguments
495///
496/// * `embeddings` - List of document embeddings
497/// * `index_path` - Directory to save the index
498/// * `config` - Index configuration
499///
500/// # Returns
501///
502/// Metadata about the created index
503pub fn create_index_with_kmeans_files(
504    embeddings: &[Array2<f32>],
505    index_path: &str,
506    config: &IndexConfig,
507) -> Result<Metadata> {
508    if embeddings.is_empty() {
509        return Err(Error::IndexCreation("No documents provided".into()));
510    }
511
512    // Pre-initialize CUDA if available (first init can take 10-20s due to driver initialization)
513    // Skip if force_cpu is set to avoid unnecessary initialization overhead
514    #[cfg(feature = "cuda")]
515    if !config.force_cpu {
516        let _ = crate::cuda::get_global_context();
517    }
518
519    // Build K-means configuration from IndexConfig
520    let kmeans_config = ComputeKmeansConfig {
521        kmeans_niters: config.kmeans_niters,
522        max_points_per_centroid: config.max_points_per_centroid,
523        seed: config.seed.unwrap_or(42),
524        n_samples_kmeans: config.n_samples_kmeans,
525        num_partitions: None, // Let the heuristic decide
526        force_cpu: config.force_cpu,
527    };
528
529    // Compute centroids using fast-plaid's approach
530    let centroids = compute_kmeans(embeddings, &kmeans_config)?;
531
532    // Create the index files
533    let metadata = create_index_files(embeddings, centroids, index_path, config)?;
534
535    // If below start_from_scratch threshold, save raw embeddings for potential rebuilds
536    if embeddings.len() <= config.start_from_scratch {
537        let index_dir = std::path::Path::new(index_path);
538        crate::update::save_embeddings_npy(index_dir, embeddings)?;
539    }
540
541    Ok(metadata)
542}
543// ============================================================================
544// Memory-Mapped Index for Low Memory Usage
545// ============================================================================
546
547/// A memory-mapped PLAID index for multi-vector search.
548///
549/// This struct uses memory-mapped files for the large arrays (codes and residuals)
550/// instead of loading them entirely into RAM. Only small tensors (centroids,
551/// bucket weights, IVF) are loaded into memory.
552///
553/// # Memory Usage
554///
555/// Only small tensors (~50 MB for SciFact 5K docs) are loaded into RAM,
556/// with code and residual data accessed via OS-managed memory mapping.
557///
558/// # Usage
559///
560/// ```ignore
561/// use next_plaid::MmapIndex;
562///
563/// let index = MmapIndex::load("/path/to/index")?;
564/// let results = index.search(&query, &params, None)?;
565/// ```
566pub struct MmapIndex {
567    /// Path to the index directory
568    pub path: String,
569    /// Index metadata
570    pub metadata: Metadata,
571    /// Residual codec for quantization/decompression
572    pub codec: ResidualCodec,
573    /// IVF data (concatenated passage IDs per centroid)
574    pub ivf: Array1<i64>,
575    /// IVF lengths (number of passages per centroid)
576    pub ivf_lengths: Array1<i32>,
577    /// IVF offsets (cumulative offsets into ivf array)
578    pub ivf_offsets: Array1<i64>,
579    /// Document lengths (number of tokens per document)
580    pub doc_lengths: Array1<i64>,
581    /// Cumulative document offsets for indexing into codes/residuals
582    pub doc_offsets: Array1<usize>,
583    /// Memory-mapped codes array (public for search access)
584    pub mmap_codes: crate::mmap::MmapNpyArray1I64,
585    /// Memory-mapped residuals array (public for search access)
586    pub mmap_residuals: crate::mmap::MmapNpyArray2U8,
587}
588
589impl MmapIndex {
590    /// Load a memory-mapped index from disk.
591    ///
592    /// This creates merged files for codes and residuals if they don't exist,
593    /// then memory-maps them for efficient access.
594    ///
595    /// If the index was created by fast-plaid, it will be automatically converted
596    /// to next-plaid compatible format on first load.
597    pub fn load(index_path: &str) -> Result<Self> {
598        use ndarray_npy::ReadNpyExt;
599
600        let index_dir = Path::new(index_path);
601
602        // Load metadata (infers num_documents from doclens if not present)
603        let mut metadata = Metadata::load_from_path(index_dir)?;
604
605        // Check if conversion from fast-plaid format is needed
606        if !metadata.next_plaid_compatible {
607            eprintln!("Checking index format compatibility...");
608            let converted = crate::mmap::convert_fastplaid_to_nextplaid(index_dir)?;
609            if converted {
610                eprintln!("Index converted to next-plaid compatible format.");
611                // Delete any existing merged files since the source files changed
612                let merged_codes = index_dir.join("merged_codes.npy");
613                let merged_residuals = index_dir.join("merged_residuals.npy");
614                let codes_manifest = index_dir.join("merged_codes.manifest.json");
615                let residuals_manifest = index_dir.join("merged_residuals.manifest.json");
616                for path in [
617                    &merged_codes,
618                    &merged_residuals,
619                    &codes_manifest,
620                    &residuals_manifest,
621                ] {
622                    if path.exists() {
623                        let _ = fs::remove_file(path);
624                    }
625                }
626            }
627
628            // Mark as compatible and save metadata
629            metadata.next_plaid_compatible = true;
630            let metadata_path = index_dir.join("metadata.json");
631            let file = File::create(&metadata_path)
632                .map_err(|e| Error::IndexLoad(format!("Failed to update metadata: {}", e)))?;
633            serde_json::to_writer_pretty(BufWriter::new(file), &metadata)?;
634            eprintln!("Metadata updated with next_plaid_compatible: true");
635        }
636
637        // Load codec with memory-mapped centroids for reduced RAM usage.
638        // Other small tensors (bucket weights, etc.) are still loaded into memory.
639        let codec = ResidualCodec::load_mmap_from_dir(index_dir)?;
640
641        // Load IVF (small tensor)
642        let ivf_path = index_dir.join("ivf.npy");
643        let ivf: Array1<i64> = Array1::read_npy(
644            File::open(&ivf_path)
645                .map_err(|e| Error::IndexLoad(format!("Failed to open ivf.npy: {}", e)))?,
646        )
647        .map_err(|e| Error::IndexLoad(format!("Failed to read ivf.npy: {}", e)))?;
648
649        let ivf_lengths_path = index_dir.join("ivf_lengths.npy");
650        let ivf_lengths: Array1<i32> = Array1::read_npy(
651            File::open(&ivf_lengths_path)
652                .map_err(|e| Error::IndexLoad(format!("Failed to open ivf_lengths.npy: {}", e)))?,
653        )
654        .map_err(|e| Error::IndexLoad(format!("Failed to read ivf_lengths.npy: {}", e)))?;
655
656        // Compute IVF offsets
657        let num_centroids = ivf_lengths.len();
658        let mut ivf_offsets = Array1::<i64>::zeros(num_centroids + 1);
659        for i in 0..num_centroids {
660            ivf_offsets[i + 1] = ivf_offsets[i] + ivf_lengths[i] as i64;
661        }
662
663        // Load document lengths from all chunks
664        let mut doc_lengths_vec: Vec<i64> = Vec::with_capacity(metadata.num_documents);
665        for chunk_idx in 0..metadata.num_chunks {
666            let doclens_path = index_dir.join(format!("doclens.{}.json", chunk_idx));
667            let chunk_doclens: Vec<i64> =
668                serde_json::from_reader(BufReader::new(File::open(&doclens_path)?))?;
669            doc_lengths_vec.extend(chunk_doclens);
670        }
671        let doc_lengths = Array1::from_vec(doc_lengths_vec);
672
673        // Compute document offsets for indexing
674        let mut doc_offsets = Array1::<usize>::zeros(doc_lengths.len() + 1);
675        for i in 0..doc_lengths.len() {
676            doc_offsets[i + 1] = doc_offsets[i] + doc_lengths[i] as usize;
677        }
678
679        // Compute padding needed for StridedTensor compatibility
680        let max_len = doc_lengths.iter().cloned().max().unwrap_or(0) as usize;
681        let last_len = *doc_lengths.last().unwrap_or(&0) as usize;
682        let padding_needed = max_len.saturating_sub(last_len);
683
684        // Create merged files if needed
685        let merged_codes_path =
686            crate::mmap::merge_codes_chunks(index_dir, metadata.num_chunks, padding_needed)?;
687        let merged_residuals_path =
688            crate::mmap::merge_residuals_chunks(index_dir, metadata.num_chunks, padding_needed)?;
689
690        // Memory-map the merged files
691        let mmap_codes = crate::mmap::MmapNpyArray1I64::from_npy_file(&merged_codes_path)?;
692        let mmap_residuals = crate::mmap::MmapNpyArray2U8::from_npy_file(&merged_residuals_path)?;
693
694        Ok(Self {
695            path: index_path.to_string(),
696            metadata,
697            codec,
698            ivf,
699            ivf_lengths,
700            ivf_offsets,
701            doc_lengths,
702            doc_offsets,
703            mmap_codes,
704            mmap_residuals,
705        })
706    }
707
708    /// Get candidate documents from IVF for given centroid indices.
709    pub fn get_candidates(&self, centroid_indices: &[usize]) -> Vec<i64> {
710        let mut candidates: Vec<i64> = Vec::new();
711
712        for &idx in centroid_indices {
713            if idx < self.ivf_lengths.len() {
714                let start = self.ivf_offsets[idx] as usize;
715                let len = self.ivf_lengths[idx] as usize;
716                candidates.extend(self.ivf.slice(s![start..start + len]).iter());
717            }
718        }
719
720        candidates.sort_unstable();
721        candidates.dedup();
722        candidates
723    }
724
725    /// Get document embeddings by decompressing codes and residuals.
726    pub fn get_document_embeddings(&self, doc_id: usize) -> Result<Array2<f32>> {
727        if doc_id >= self.doc_lengths.len() {
728            return Err(Error::Search(format!("Invalid document ID: {}", doc_id)));
729        }
730
731        let start = self.doc_offsets[doc_id];
732        let end = self.doc_offsets[doc_id + 1];
733
734        // Get codes and residuals from mmap
735        let codes_slice = self.mmap_codes.slice(start, end);
736        let residuals_view = self.mmap_residuals.slice_rows(start, end);
737
738        // Convert codes to Array1<usize>
739        let codes: Array1<usize> = Array1::from_iter(codes_slice.iter().map(|&c| c as usize));
740
741        // Convert residuals to owned Array2
742        let residuals = residuals_view.to_owned();
743
744        // Decompress
745        self.codec.decompress(&residuals, &codes.view())
746    }
747
748    /// Get codes for a batch of document IDs (for approximate scoring).
749    pub fn get_document_codes(&self, doc_ids: &[usize]) -> Vec<Vec<i64>> {
750        doc_ids
751            .iter()
752            .map(|&doc_id| {
753                if doc_id >= self.doc_lengths.len() {
754                    return vec![];
755                }
756                let start = self.doc_offsets[doc_id];
757                let end = self.doc_offsets[doc_id + 1];
758                self.mmap_codes.slice(start, end).to_vec()
759            })
760            .collect()
761    }
762
763    /// Decompress embeddings for a batch of document IDs.
764    pub fn decompress_documents(&self, doc_ids: &[usize]) -> Result<(Array2<f32>, Vec<usize>)> {
765        // Compute total tokens
766        let mut total_tokens = 0usize;
767        let mut lengths = Vec::with_capacity(doc_ids.len());
768        for &doc_id in doc_ids {
769            if doc_id >= self.doc_lengths.len() {
770                lengths.push(0);
771            } else {
772                let len = self.doc_offsets[doc_id + 1] - self.doc_offsets[doc_id];
773                lengths.push(len);
774                total_tokens += len;
775            }
776        }
777
778        if total_tokens == 0 {
779            return Ok((Array2::zeros((0, self.codec.embedding_dim())), lengths));
780        }
781
782        // Gather all codes and residuals
783        let packed_dim = self.mmap_residuals.ncols();
784        let mut all_codes = Vec::with_capacity(total_tokens);
785        let mut all_residuals = Array2::<u8>::zeros((total_tokens, packed_dim));
786        let mut offset = 0;
787
788        for &doc_id in doc_ids {
789            if doc_id >= self.doc_lengths.len() {
790                continue;
791            }
792            let start = self.doc_offsets[doc_id];
793            let end = self.doc_offsets[doc_id + 1];
794            let len = end - start;
795
796            // Append codes
797            let codes_slice = self.mmap_codes.slice(start, end);
798            all_codes.extend(codes_slice.iter().map(|&c| c as usize));
799
800            // Copy residuals
801            let residuals_view = self.mmap_residuals.slice_rows(start, end);
802            all_residuals
803                .slice_mut(s![offset..offset + len, ..])
804                .assign(&residuals_view);
805            offset += len;
806        }
807
808        let codes_arr = Array1::from_vec(all_codes);
809        let embeddings = self.codec.decompress(&all_residuals, &codes_arr.view())?;
810
811        Ok((embeddings, lengths))
812    }
813
814    /// Search for similar documents.
815    ///
816    /// # Arguments
817    ///
818    /// * `query` - Query embedding matrix [num_tokens, dim]
819    /// * `params` - Search parameters
820    /// * `subset` - Optional subset of document IDs to search within
821    ///
822    /// # Returns
823    ///
824    /// Search result containing top-k document IDs and scores.
825    pub fn search(
826        &self,
827        query: &Array2<f32>,
828        params: &crate::search::SearchParameters,
829        subset: Option<&[i64]>,
830    ) -> Result<crate::search::SearchResult> {
831        crate::search::search_one_mmap(self, query, params, subset)
832    }
833
834    /// Search for multiple queries in batch.
835    ///
836    /// # Arguments
837    ///
838    /// * `queries` - Slice of query embedding matrices
839    /// * `params` - Search parameters
840    /// * `parallel` - If true, process queries in parallel using rayon
841    /// * `subset` - Optional subset of document IDs to search within
842    ///
843    /// # Returns
844    ///
845    /// Vector of search results, one per query.
846    pub fn search_batch(
847        &self,
848        queries: &[Array2<f32>],
849        params: &crate::search::SearchParameters,
850        parallel: bool,
851        subset: Option<&[i64]>,
852    ) -> Result<Vec<crate::search::SearchResult>> {
853        crate::search::search_many_mmap(self, queries, params, parallel, subset)
854    }
855
856    /// Get the number of documents in the index.
857    pub fn num_documents(&self) -> usize {
858        self.doc_lengths.len()
859    }
860
861    /// Get the total number of embeddings in the index.
862    pub fn num_embeddings(&self) -> usize {
863        self.metadata.num_embeddings
864    }
865
866    /// Get the number of partitions (centroids).
867    pub fn num_partitions(&self) -> usize {
868        self.metadata.num_partitions
869    }
870
871    /// Get the average document length.
872    pub fn avg_doclen(&self) -> f64 {
873        self.metadata.avg_doclen
874    }
875
876    /// Get the embedding dimension.
877    pub fn embedding_dim(&self) -> usize {
878        self.codec.embedding_dim()
879    }
880
881    /// Reconstruct embeddings for specific documents.
882    ///
883    /// This method retrieves the compressed codes and residuals for each document
884    /// from memory-mapped files and decompresses them to recover the original embeddings.
885    ///
886    /// # Arguments
887    ///
888    /// * `doc_ids` - Slice of document IDs to reconstruct (0-indexed)
889    ///
890    /// # Returns
891    ///
892    /// A vector of 2D arrays, one per document. Each array has shape `[num_tokens, dim]`.
893    ///
894    /// # Example
895    ///
896    /// ```ignore
897    /// use next_plaid::MmapIndex;
898    ///
899    /// let index = MmapIndex::load("/path/to/index")?;
900    /// let embeddings = index.reconstruct(&[0, 1, 2])?;
901    ///
902    /// for (i, emb) in embeddings.iter().enumerate() {
903    ///     println!("Document {}: {} tokens x {} dim", i, emb.nrows(), emb.ncols());
904    /// }
905    /// ```
906    pub fn reconstruct(&self, doc_ids: &[i64]) -> Result<Vec<Array2<f32>>> {
907        crate::embeddings::reconstruct_embeddings(self, doc_ids)
908    }
909
910    /// Reconstruct a single document's embeddings.
911    ///
912    /// Convenience method for reconstructing a single document.
913    ///
914    /// # Arguments
915    ///
916    /// * `doc_id` - Document ID to reconstruct (0-indexed)
917    ///
918    /// # Returns
919    ///
920    /// A 2D array with shape `[num_tokens, dim]`.
921    pub fn reconstruct_single(&self, doc_id: i64) -> Result<Array2<f32>> {
922        crate::embeddings::reconstruct_single(self, doc_id)
923    }
924
925    /// Create a new index from document embeddings with automatic centroid computation.
926    ///
927    /// This method:
928    /// 1. Computes centroids using K-means
929    /// 2. Creates index files on disk
930    /// 3. Loads the index using memory-mapped I/O
931    ///
932    /// Note: During creation, data is temporarily held in RAM for processing,
933    /// then written to disk and loaded as mmap.
934    ///
935    /// # Arguments
936    ///
937    /// * `embeddings` - List of document embeddings, each of shape `[num_tokens, dim]`
938    /// * `index_path` - Directory to save the index
939    /// * `config` - Index configuration
940    ///
941    /// # Returns
942    ///
943    /// The created MmapIndex
944    pub fn create_with_kmeans(
945        embeddings: &[Array2<f32>],
946        index_path: &str,
947        config: &IndexConfig,
948    ) -> Result<Self> {
949        // Use standalone function to create files
950        create_index_with_kmeans_files(embeddings, index_path, config)?;
951
952        // Load as memory-mapped index
953        Self::load(index_path)
954    }
955
956    /// Update the index with new documents, matching fast-plaid behavior.
957    ///
958    /// This method adds new documents to an existing index with three possible paths:
959    ///
960    /// 1. **Start-from-scratch mode** (num_documents <= start_from_scratch):
961    ///    - Loads existing embeddings from `embeddings.npy` if available
962    ///    - Combines with new embeddings
963    ///    - Rebuilds the entire index from scratch with fresh K-means
964    ///    - Clears `embeddings.npy` if total exceeds threshold
965    ///
966    /// 2. **Buffer mode** (total_new < buffer_size):
967    ///    - Adds new documents to the index without centroid expansion
968    ///    - Saves embeddings to buffer for later centroid expansion
969    ///
970    /// 3. **Centroid expansion mode** (total_new >= buffer_size):
971    ///    - Deletes previously buffered documents
972    ///    - Expands centroids with outliers from combined buffer + new embeddings
973    ///    - Re-indexes all combined embeddings with expanded centroids
974    ///
975    /// # Arguments
976    ///
977    /// * `embeddings` - New document embeddings to add
978    /// * `config` - Update configuration
979    ///
980    /// # Returns
981    ///
982    /// Vector of document IDs assigned to the new embeddings
983    pub fn update(
984        &mut self,
985        embeddings: &[Array2<f32>],
986        config: &crate::update::UpdateConfig,
987    ) -> Result<Vec<i64>> {
988        use crate::codec::ResidualCodec;
989        use crate::update::{
990            clear_buffer, clear_embeddings_npy, embeddings_npy_exists, load_buffer,
991            load_buffer_info, load_cluster_threshold, load_embeddings_npy, save_buffer,
992            update_centroids, update_index,
993        };
994
995        let path_str = self.path.clone();
996        let index_path = std::path::Path::new(&path_str);
997        let num_new_docs = embeddings.len();
998
999        // ==================================================================
1000        // Start-from-scratch mode (fast-plaid update.py:312-346)
1001        // ==================================================================
1002        if self.metadata.num_documents <= config.start_from_scratch {
1003            // Load existing embeddings if available
1004            let existing_embeddings = load_embeddings_npy(index_path)?;
1005
1006            // Check if embeddings.npy is in sync with the index.
1007            // If not (e.g., after delete when index was above threshold), we can't do
1008            // start-from-scratch mode because we don't have all the old embeddings.
1009            // Fall through to buffer mode instead.
1010            if existing_embeddings.len() == self.metadata.num_documents {
1011                // New documents start after existing documents
1012                let start_doc_id = existing_embeddings.len() as i64;
1013
1014                // Combine existing + new embeddings
1015                let combined_embeddings: Vec<Array2<f32>> = existing_embeddings
1016                    .into_iter()
1017                    .chain(embeddings.iter().cloned())
1018                    .collect();
1019
1020                // Build IndexConfig from UpdateConfig for create_with_kmeans
1021                let index_config = IndexConfig {
1022                    nbits: self.metadata.nbits,
1023                    batch_size: config.batch_size,
1024                    seed: Some(config.seed),
1025                    kmeans_niters: config.kmeans_niters,
1026                    max_points_per_centroid: config.max_points_per_centroid,
1027                    n_samples_kmeans: config.n_samples_kmeans,
1028                    start_from_scratch: config.start_from_scratch,
1029                    force_cpu: config.force_cpu,
1030                };
1031
1032                // Rebuild index from scratch with fresh K-means
1033                *self = Self::create_with_kmeans(&combined_embeddings, &path_str, &index_config)?;
1034
1035                // If we've crossed the threshold, clear embeddings.npy
1036                if combined_embeddings.len() > config.start_from_scratch
1037                    && embeddings_npy_exists(index_path)
1038                {
1039                    clear_embeddings_npy(index_path)?;
1040                }
1041
1042                // Return the document IDs assigned to the new embeddings
1043                return Ok((start_doc_id..start_doc_id + num_new_docs as i64).collect());
1044            }
1045            // else: embeddings.npy is out of sync, fall through to buffer mode
1046        }
1047
1048        // Load buffer
1049        let buffer = load_buffer(index_path)?;
1050        let buffer_len = buffer.len();
1051        let total_new = embeddings.len() + buffer_len;
1052
1053        // Track the starting document ID for the new embeddings
1054        let start_doc_id: i64;
1055
1056        // Load codec for update operations
1057        let mut codec = ResidualCodec::load_from_dir(index_path)?;
1058
1059        // Check buffer threshold
1060        if total_new >= config.buffer_size {
1061            // Centroid expansion path (matches fast-plaid update.py:376-422)
1062
1063            // 1. Get number of buffered docs that were previously indexed
1064            let num_buffered = load_buffer_info(index_path)?;
1065
1066            // 2. Delete buffered docs from index (they were indexed without centroid expansion)
1067            if num_buffered > 0 && self.metadata.num_documents >= num_buffered {
1068                let start_del_idx = self.metadata.num_documents - num_buffered;
1069                let docs_to_delete: Vec<i64> = (start_del_idx..self.metadata.num_documents)
1070                    .map(|i| i as i64)
1071                    .collect();
1072                crate::delete::delete_from_index_keep_buffer(&docs_to_delete, &path_str)?;
1073                // Reload metadata after delete
1074                self.metadata = Metadata::load_from_path(index_path)?;
1075            }
1076
1077            // New embeddings start after buffer is re-indexed
1078            start_doc_id = (self.metadata.num_documents + buffer_len) as i64;
1079
1080            // 3. Combine buffer + new embeddings
1081            let combined: Vec<Array2<f32>> = buffer
1082                .into_iter()
1083                .chain(embeddings.iter().cloned())
1084                .collect();
1085
1086            // 4. Expand centroids with outliers from combined embeddings
1087            if let Ok(cluster_threshold) = load_cluster_threshold(index_path) {
1088                let new_centroids =
1089                    update_centroids(index_path, &combined, cluster_threshold, config)?;
1090                if new_centroids > 0 {
1091                    // Reload codec with new centroids
1092                    codec = ResidualCodec::load_from_dir(index_path)?;
1093                }
1094            }
1095
1096            // 5. Clear buffer
1097            clear_buffer(index_path)?;
1098
1099            // 6. Update index with ALL combined embeddings (buffer + new)
1100            update_index(
1101                &combined,
1102                &path_str,
1103                &codec,
1104                Some(config.batch_size),
1105                true,
1106                config.force_cpu,
1107            )?;
1108        } else {
1109            // Small update: add to buffer and index without centroid expansion
1110            // New documents start at current num_documents
1111            start_doc_id = self.metadata.num_documents as i64;
1112
1113            // Accumulate buffer: combine existing buffer with new embeddings
1114            let combined_buffer: Vec<Array2<f32>> = buffer
1115                .into_iter()
1116                .chain(embeddings.iter().cloned())
1117                .collect();
1118            save_buffer(index_path, &combined_buffer)?;
1119
1120            // Update index without threshold update
1121            update_index(
1122                embeddings,
1123                &path_str,
1124                &codec,
1125                Some(config.batch_size),
1126                false,
1127                config.force_cpu,
1128            )?;
1129        }
1130
1131        // Reload self as mmap
1132        *self = Self::load(&path_str)?;
1133
1134        // Return the document IDs assigned to the new embeddings
1135        Ok((start_doc_id..start_doc_id + num_new_docs as i64).collect())
1136    }
1137
1138    /// Update the index with new documents and optional metadata.
1139    ///
1140    /// # Arguments
1141    ///
1142    /// * `embeddings` - New document embeddings to add
1143    /// * `config` - Update configuration
1144    /// * `metadata` - Optional metadata for new documents
1145    ///
1146    /// # Returns
1147    ///
1148    /// Vector of document IDs assigned to the new embeddings
1149    pub fn update_with_metadata(
1150        &mut self,
1151        embeddings: &[Array2<f32>],
1152        config: &crate::update::UpdateConfig,
1153        metadata: Option<&[serde_json::Value]>,
1154    ) -> Result<Vec<i64>> {
1155        // Validate metadata length if provided
1156        if let Some(meta) = metadata {
1157            if meta.len() != embeddings.len() {
1158                return Err(Error::Config(format!(
1159                    "Metadata length ({}) must match embeddings length ({})",
1160                    meta.len(),
1161                    embeddings.len()
1162                )));
1163            }
1164        }
1165
1166        // Perform the update and get document IDs
1167        let doc_ids = self.update(embeddings, config)?;
1168
1169        // Add metadata if provided, using the assigned document IDs
1170        if let Some(meta) = metadata {
1171            crate::filtering::update(&self.path, meta, &doc_ids)?;
1172        }
1173
1174        Ok(doc_ids)
1175    }
1176
1177    /// Update an existing index or create a new one if it doesn't exist.
1178    ///
1179    /// # Arguments
1180    ///
1181    /// * `embeddings` - Document embeddings to add
1182    /// * `index_path` - Directory for the index
1183    /// * `index_config` - Configuration for index creation
1184    /// * `update_config` - Configuration for updates
1185    ///
1186    /// # Returns
1187    ///
1188    /// A tuple of (MmapIndex, `Vec<i64>`) containing the index and document IDs
1189    pub fn update_or_create(
1190        embeddings: &[Array2<f32>],
1191        index_path: &str,
1192        index_config: &IndexConfig,
1193        update_config: &crate::update::UpdateConfig,
1194    ) -> Result<(Self, Vec<i64>)> {
1195        let index_dir = std::path::Path::new(index_path);
1196        let metadata_path = index_dir.join("metadata.json");
1197
1198        if metadata_path.exists() {
1199            // Index exists, load and update
1200            let mut index = Self::load(index_path)?;
1201            let doc_ids = index.update(embeddings, update_config)?;
1202            Ok((index, doc_ids))
1203        } else {
1204            // Index doesn't exist, create new
1205            let num_docs = embeddings.len();
1206            let index = Self::create_with_kmeans(embeddings, index_path, index_config)?;
1207            let doc_ids: Vec<i64> = (0..num_docs as i64).collect();
1208            Ok((index, doc_ids))
1209        }
1210    }
1211
1212    /// Reload the index from disk.
1213    ///
1214    /// This should be called after delete operations to refresh the in-memory
1215    /// representation with the updated on-disk state.
1216    pub fn reload(&mut self) -> Result<()> {
1217        *self = Self::load(&self.path)?;
1218        Ok(())
1219    }
1220
1221    /// Delete documents from the index.
1222    ///
1223    /// This performs the deletion on disk but does NOT reload the in-memory index.
1224    /// Call `reload()` after all delete operations are complete to refresh the index.
1225    ///
1226    /// # Arguments
1227    ///
1228    /// * `doc_ids` - Slice of document IDs to delete (0-indexed)
1229    ///
1230    /// # Returns
1231    ///
1232    /// The number of documents actually deleted
1233    pub fn delete(&mut self, doc_ids: &[i64]) -> Result<usize> {
1234        self.delete_with_options(doc_ids, true)
1235    }
1236
1237    /// Delete documents from the index with control over metadata deletion.
1238    ///
1239    /// This performs the deletion on disk but does NOT reload the in-memory index.
1240    /// Call `reload()` after all delete operations are complete to refresh the index.
1241    ///
1242    /// # Arguments
1243    ///
1244    /// * `doc_ids` - Slice of document IDs to delete
1245    /// * `delete_metadata` - If true, also delete from metadata.db if it exists
1246    ///
1247    /// # Returns
1248    ///
1249    /// The number of documents actually deleted
1250    pub fn delete_with_options(&mut self, doc_ids: &[i64], delete_metadata: bool) -> Result<usize> {
1251        let path = self.path.clone();
1252
1253        // Perform the deletion using standalone function
1254        let deleted = crate::delete::delete_from_index(doc_ids, &path)?;
1255
1256        // Also delete from metadata.db if requested
1257        if delete_metadata && deleted > 0 {
1258            let index_path = std::path::Path::new(&path);
1259            let db_path = index_path.join("metadata.db");
1260            if db_path.exists() {
1261                crate::filtering::delete(&path, doc_ids)?;
1262            }
1263        }
1264
1265        Ok(deleted)
1266    }
1267}
1268
1269#[cfg(test)]
1270mod tests {
1271    use super::*;
1272
1273    #[test]
1274    fn test_index_config_default() {
1275        let config = IndexConfig::default();
1276        assert_eq!(config.nbits, 4);
1277        assert_eq!(config.batch_size, 50_000);
1278        assert_eq!(config.seed, Some(42));
1279    }
1280
1281    #[test]
1282    fn test_update_or_create_new_index() {
1283        use ndarray::Array2;
1284        use tempfile::tempdir;
1285
1286        let temp_dir = tempdir().unwrap();
1287        let index_path = temp_dir.path().to_str().unwrap();
1288
1289        // Create test embeddings (5 documents)
1290        let mut embeddings: Vec<Array2<f32>> = Vec::new();
1291        for i in 0..5 {
1292            let mut doc = Array2::<f32>::zeros((5, 32));
1293            for j in 0..5 {
1294                for k in 0..32 {
1295                    doc[[j, k]] = (i as f32 * 0.1) + (j as f32 * 0.01) + (k as f32 * 0.001);
1296                }
1297            }
1298            // Normalize rows
1299            for mut row in doc.rows_mut() {
1300                let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
1301                if norm > 0.0 {
1302                    row.iter_mut().for_each(|x| *x /= norm);
1303                }
1304            }
1305            embeddings.push(doc);
1306        }
1307
1308        let index_config = IndexConfig {
1309            nbits: 2,
1310            batch_size: 50,
1311            seed: Some(42),
1312            kmeans_niters: 2,
1313            ..Default::default()
1314        };
1315        let update_config = crate::update::UpdateConfig::default();
1316
1317        // Index doesn't exist - should create new
1318        let (index, doc_ids) =
1319            MmapIndex::update_or_create(&embeddings, index_path, &index_config, &update_config)
1320                .expect("Failed to create index");
1321
1322        assert_eq!(index.metadata.num_documents, 5);
1323        assert_eq!(doc_ids, vec![0, 1, 2, 3, 4]);
1324
1325        // Verify index was created
1326        assert!(temp_dir.path().join("metadata.json").exists());
1327        assert!(temp_dir.path().join("centroids.npy").exists());
1328    }
1329
1330    #[test]
1331    fn test_update_or_create_existing_index() {
1332        use ndarray::Array2;
1333        use tempfile::tempdir;
1334
1335        let temp_dir = tempdir().unwrap();
1336        let index_path = temp_dir.path().to_str().unwrap();
1337
1338        // Helper to create embeddings
1339        let create_embeddings = |count: usize, offset: usize| -> Vec<Array2<f32>> {
1340            let mut embeddings = Vec::new();
1341            for i in 0..count {
1342                let mut doc = Array2::<f32>::zeros((5, 32));
1343                for j in 0..5 {
1344                    for k in 0..32 {
1345                        doc[[j, k]] =
1346                            ((i + offset) as f32 * 0.1) + (j as f32 * 0.01) + (k as f32 * 0.001);
1347                    }
1348                }
1349                for mut row in doc.rows_mut() {
1350                    let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
1351                    if norm > 0.0 {
1352                        row.iter_mut().for_each(|x| *x /= norm);
1353                    }
1354                }
1355                embeddings.push(doc);
1356            }
1357            embeddings
1358        };
1359
1360        let index_config = IndexConfig {
1361            nbits: 2,
1362            batch_size: 50,
1363            seed: Some(42),
1364            kmeans_niters: 2,
1365            ..Default::default()
1366        };
1367        let update_config = crate::update::UpdateConfig::default();
1368
1369        // First call - creates index with 5 documents
1370        let embeddings1 = create_embeddings(5, 0);
1371        let (index1, doc_ids1) =
1372            MmapIndex::update_or_create(&embeddings1, index_path, &index_config, &update_config)
1373                .expect("Failed to create index");
1374        assert_eq!(index1.metadata.num_documents, 5);
1375        assert_eq!(doc_ids1, vec![0, 1, 2, 3, 4]);
1376
1377        // Second call - updates existing index with 3 more documents
1378        let embeddings2 = create_embeddings(3, 5);
1379        let (index2, doc_ids2) =
1380            MmapIndex::update_or_create(&embeddings2, index_path, &index_config, &update_config)
1381                .expect("Failed to update index");
1382        assert_eq!(index2.metadata.num_documents, 8);
1383        assert_eq!(doc_ids2, vec![5, 6, 7]);
1384    }
1385}