next_plaid/
index.rs

1//! Index creation and management for PLAID
2
3use std::collections::BTreeMap;
4use std::fs::{self, File};
5use std::io::{BufReader, BufWriter, Write};
6use std::path::Path;
7
8use ndarray::{s, Array1, Array2, Axis};
9use serde::{Deserialize, Serialize};
10
11use crate::codec::ResidualCodec;
12use crate::error::{Error, Result};
13use crate::kmeans::{compute_kmeans, ComputeKmeansConfig};
14use crate::utils::{quantile, quantiles};
15
16/// Configuration for index creation
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct IndexConfig {
19    /// Number of bits for quantization (typically 2 or 4)
20    pub nbits: usize,
21    /// Batch size for processing
22    pub batch_size: usize,
23    /// Random seed for reproducibility
24    pub seed: Option<u64>,
25    /// Number of K-means iterations (default: 4)
26    #[serde(default = "default_kmeans_niters")]
27    pub kmeans_niters: usize,
28    /// Maximum number of points per centroid for K-means (default: 256)
29    #[serde(default = "default_max_points_per_centroid")]
30    pub max_points_per_centroid: usize,
31    /// Number of samples for K-means training.
32    /// If None, uses heuristic: min(1 + 16 * sqrt(120 * num_documents), num_documents)
33    #[serde(default)]
34    pub n_samples_kmeans: Option<usize>,
35    /// Threshold for start-from-scratch mode (default: 999).
36    /// When the number of documents is <= this threshold, raw embeddings are saved
37    /// to embeddings.npy for potential rebuilds during updates.
38    #[serde(default = "default_start_from_scratch")]
39    pub start_from_scratch: usize,
40}
41
42fn default_start_from_scratch() -> usize {
43    999
44}
45
46fn default_kmeans_niters() -> usize {
47    4
48}
49
50fn default_max_points_per_centroid() -> usize {
51    256
52}
53
54impl Default for IndexConfig {
55    fn default() -> Self {
56        Self {
57            nbits: 4,
58            batch_size: 50_000,
59            seed: Some(42),
60            kmeans_niters: 4,
61            max_points_per_centroid: 256,
62            n_samples_kmeans: None,
63            start_from_scratch: 999,
64        }
65    }
66}
67
68/// Metadata for the index
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct Metadata {
71    /// Number of chunks in the index
72    pub num_chunks: usize,
73    /// Number of bits for quantization
74    pub nbits: usize,
75    /// Number of partitions (centroids)
76    pub num_partitions: usize,
77    /// Total number of embeddings
78    pub num_embeddings: usize,
79    /// Average document length
80    pub avg_doclen: f64,
81    /// Total number of documents
82    #[serde(default)]
83    pub num_documents: usize,
84    /// Whether the index has been converted to next-plaid compatible format.
85    /// If false or missing, the index may need fast-plaid to next-plaid conversion.
86    #[serde(default)]
87    pub next_plaid_compatible: bool,
88}
89
90impl Metadata {
91    /// Load metadata from a JSON file, inferring num_documents from doclens if not present.
92    pub fn load_from_path(index_path: &Path) -> Result<Self> {
93        let metadata_path = index_path.join("metadata.json");
94        let mut metadata: Metadata = serde_json::from_reader(BufReader::new(
95            File::open(&metadata_path)
96                .map_err(|e| Error::IndexLoad(format!("Failed to open metadata: {}", e)))?,
97        ))?;
98
99        // If num_documents is 0 (default), infer from doclens files
100        if metadata.num_documents == 0 {
101            let mut total_docs = 0usize;
102            for chunk_idx in 0..metadata.num_chunks {
103                let doclens_path = index_path.join(format!("doclens.{}.json", chunk_idx));
104                if let Ok(file) = File::open(&doclens_path) {
105                    if let Ok(chunk_doclens) =
106                        serde_json::from_reader::<_, Vec<i64>>(BufReader::new(file))
107                    {
108                        total_docs += chunk_doclens.len();
109                    }
110                }
111            }
112            metadata.num_documents = total_docs;
113        }
114
115        Ok(metadata)
116    }
117}
118
119/// Chunk metadata
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct ChunkMetadata {
122    pub num_documents: usize,
123    pub num_embeddings: usize,
124    #[serde(default)]
125    pub embedding_offset: usize,
126}
127
128// ============================================================================
129// Standalone Index Creation Functions
130// ============================================================================
131
132/// Create index files on disk from embeddings and centroids.
133///
134/// This is a standalone function that creates all necessary index files
135/// without constructing an in-memory Index object. Both Index and MmapIndex
136/// can use this function to create their files, then load them in their
137/// preferred format.
138///
139/// # Arguments
140///
141/// * `embeddings` - List of document embeddings
142/// * `centroids` - Pre-computed centroids from K-means
143/// * `index_path` - Directory to save the index
144/// * `config` - Index configuration
145///
146/// # Returns
147///
148/// Metadata about the created index
149pub fn create_index_files(
150    embeddings: &[Array2<f32>],
151    centroids: Array2<f32>,
152    index_path: &str,
153    config: &IndexConfig,
154) -> Result<Metadata> {
155    let index_dir = Path::new(index_path);
156    fs::create_dir_all(index_dir)?;
157
158    let num_documents = embeddings.len();
159    let embedding_dim = centroids.ncols();
160    let num_centroids = centroids.nrows();
161
162    if num_documents == 0 {
163        return Err(Error::IndexCreation("No documents provided".into()));
164    }
165
166    // Calculate statistics
167    let total_embeddings: usize = embeddings.iter().map(|e| e.nrows()).sum();
168    let avg_doclen = total_embeddings as f64 / num_documents as f64;
169
170    // Sample documents for codec training
171    let sample_count = ((16.0 * (120.0 * num_documents as f64).sqrt()) as usize)
172        .min(num_documents)
173        .max(1);
174
175    let mut rng = if let Some(seed) = config.seed {
176        use rand::SeedableRng;
177        rand_chacha::ChaCha8Rng::seed_from_u64(seed)
178    } else {
179        use rand::SeedableRng;
180        rand_chacha::ChaCha8Rng::from_entropy()
181    };
182
183    use rand::seq::SliceRandom;
184    let mut indices: Vec<usize> = (0..num_documents).collect();
185    indices.shuffle(&mut rng);
186    let sample_indices: Vec<usize> = indices.into_iter().take(sample_count).collect();
187
188    // Collect sample embeddings for training
189    let heldout_size = (0.05 * total_embeddings as f64).min(50000.0) as usize;
190    let mut heldout_embeddings: Vec<f32> = Vec::with_capacity(heldout_size * embedding_dim);
191    let mut collected = 0;
192
193    for &idx in sample_indices.iter().rev() {
194        if collected >= heldout_size {
195            break;
196        }
197        let emb = &embeddings[idx];
198        let take = (heldout_size - collected).min(emb.nrows());
199        for row in emb.axis_iter(Axis(0)).take(take) {
200            heldout_embeddings.extend(row.iter());
201        }
202        collected += take;
203    }
204
205    let heldout = Array2::from_shape_vec((collected, embedding_dim), heldout_embeddings)
206        .map_err(|e| Error::IndexCreation(format!("Failed to create heldout array: {}", e)))?;
207
208    // Train codec: compute residuals and quantization parameters
209    let avg_residual = Array1::zeros(embedding_dim);
210    let initial_codec =
211        ResidualCodec::new(config.nbits, centroids.clone(), avg_residual, None, None)?;
212
213    // Compute codes for heldout samples
214    let heldout_codes = initial_codec.compress_into_codes(&heldout);
215
216    // Compute residuals
217    let mut residuals = heldout.clone();
218    for i in 0..heldout.nrows() {
219        let centroid = initial_codec.centroids.row(heldout_codes[i]);
220        for j in 0..embedding_dim {
221            residuals[[i, j]] -= centroid[j];
222        }
223    }
224
225    // Compute cluster threshold from residual distances
226    let distances: Array1<f32> = residuals
227        .axis_iter(Axis(0))
228        .map(|row| row.dot(&row).sqrt())
229        .collect();
230    #[allow(unused_variables)]
231    let cluster_threshold = quantile(&distances, 0.75);
232
233    // Compute average residual per dimension
234    let avg_res_per_dim: Array1<f32> = residuals
235        .axis_iter(Axis(1))
236        .map(|col| col.iter().map(|x| x.abs()).sum::<f32>() / col.len() as f32)
237        .collect();
238
239    // Compute quantization buckets
240    let n_options = 1 << config.nbits;
241    let quantile_values: Vec<f64> = (1..n_options)
242        .map(|i| i as f64 / n_options as f64)
243        .collect();
244    let weight_quantile_values: Vec<f64> = (0..n_options)
245        .map(|i| (i as f64 + 0.5) / n_options as f64)
246        .collect();
247
248    // Flatten residuals for quantile computation
249    let flat_residuals: Array1<f32> = residuals.iter().copied().collect();
250    let bucket_cutoffs = Array1::from_vec(quantiles(&flat_residuals, &quantile_values));
251    let bucket_weights = Array1::from_vec(quantiles(&flat_residuals, &weight_quantile_values));
252
253    let codec = ResidualCodec::new(
254        config.nbits,
255        centroids.clone(),
256        avg_res_per_dim.clone(),
257        Some(bucket_cutoffs.clone()),
258        Some(bucket_weights.clone()),
259    )?;
260
261    // Save codec components
262    use ndarray_npy::WriteNpyExt;
263
264    let centroids_path = index_dir.join("centroids.npy");
265    codec
266        .centroids_view()
267        .to_owned()
268        .write_npy(File::create(&centroids_path)?)?;
269
270    let cutoffs_path = index_dir.join("bucket_cutoffs.npy");
271    bucket_cutoffs.write_npy(File::create(&cutoffs_path)?)?;
272
273    let weights_path = index_dir.join("bucket_weights.npy");
274    bucket_weights.write_npy(File::create(&weights_path)?)?;
275
276    let avg_res_path = index_dir.join("avg_residual.npy");
277    avg_res_per_dim.write_npy(File::create(&avg_res_path)?)?;
278
279    let threshold_path = index_dir.join("cluster_threshold.npy");
280    Array1::from_vec(vec![cluster_threshold]).write_npy(File::create(&threshold_path)?)?;
281
282    // Process documents in chunks
283    let n_chunks = (num_documents as f64 / config.batch_size as f64).ceil() as usize;
284
285    // Save plan
286    let plan_path = index_dir.join("plan.json");
287    let plan = serde_json::json!({
288        "nbits": config.nbits,
289        "num_chunks": n_chunks,
290    });
291    let mut plan_file = File::create(&plan_path)?;
292    writeln!(plan_file, "{}", serde_json::to_string_pretty(&plan)?)?;
293
294    let mut all_codes: Vec<usize> = Vec::with_capacity(total_embeddings);
295    let mut doc_lengths: Vec<i64> = Vec::with_capacity(num_documents);
296
297    for chunk_idx in 0..n_chunks {
298        let start = chunk_idx * config.batch_size;
299        let end = (start + config.batch_size).min(num_documents);
300        let chunk_docs = &embeddings[start..end];
301
302        // Collect document lengths
303        let chunk_doclens: Vec<i64> = chunk_docs.iter().map(|d| d.nrows() as i64).collect();
304        let total_tokens: usize = chunk_doclens.iter().sum::<i64>() as usize;
305
306        // Concatenate all embeddings in the chunk for batch processing
307        let mut batch_embeddings = Array2::<f32>::zeros((total_tokens, embedding_dim));
308        let mut offset = 0;
309        for doc in chunk_docs {
310            let n = doc.nrows();
311            batch_embeddings
312                .slice_mut(s![offset..offset + n, ..])
313                .assign(doc);
314            offset += n;
315        }
316
317        // BATCH: Compress all embeddings at once
318        let batch_codes = codec.compress_into_codes(&batch_embeddings);
319
320        // BATCH: Compute residuals using parallel subtraction
321        let mut batch_residuals = batch_embeddings;
322        {
323            use rayon::prelude::*;
324            let centroids = &codec.centroids;
325            batch_residuals
326                .axis_iter_mut(Axis(0))
327                .into_par_iter()
328                .zip(batch_codes.as_slice().unwrap().par_iter())
329                .for_each(|(mut row, &code)| {
330                    let centroid = centroids.row(code);
331                    row.iter_mut()
332                        .zip(centroid.iter())
333                        .for_each(|(r, c)| *r -= c);
334                });
335        }
336
337        // BATCH: Quantize all residuals at once
338        let batch_packed = codec.quantize_residuals(&batch_residuals)?;
339
340        // Track codes for IVF building
341        for &len in &chunk_doclens {
342            doc_lengths.push(len);
343        }
344        all_codes.extend(batch_codes.iter().copied());
345
346        // Save chunk metadata
347        let chunk_meta = ChunkMetadata {
348            num_documents: end - start,
349            num_embeddings: batch_codes.len(),
350            embedding_offset: 0, // Will be updated later
351        };
352
353        let chunk_meta_path = index_dir.join(format!("{}.metadata.json", chunk_idx));
354        serde_json::to_writer_pretty(BufWriter::new(File::create(&chunk_meta_path)?), &chunk_meta)?;
355
356        // Save chunk doclens
357        let doclens_path = index_dir.join(format!("doclens.{}.json", chunk_idx));
358        serde_json::to_writer(BufWriter::new(File::create(&doclens_path)?), &chunk_doclens)?;
359
360        // Save chunk codes
361        let chunk_codes_arr: Array1<i64> = batch_codes.iter().map(|&x| x as i64).collect();
362        let codes_path = index_dir.join(format!("{}.codes.npy", chunk_idx));
363        chunk_codes_arr.write_npy(File::create(&codes_path)?)?;
364
365        // Save chunk residuals
366        let residuals_path = index_dir.join(format!("{}.residuals.npy", chunk_idx));
367        batch_packed.write_npy(File::create(&residuals_path)?)?;
368    }
369
370    // Update chunk metadata with global offsets
371    let mut current_offset = 0usize;
372    for chunk_idx in 0..n_chunks {
373        let chunk_meta_path = index_dir.join(format!("{}.metadata.json", chunk_idx));
374        let mut meta: serde_json::Value =
375            serde_json::from_reader(BufReader::new(File::open(&chunk_meta_path)?))?;
376
377        if let Some(obj) = meta.as_object_mut() {
378            obj.insert("embedding_offset".to_string(), current_offset.into());
379            let num_emb = obj["num_embeddings"].as_u64().unwrap_or(0) as usize;
380            current_offset += num_emb;
381        }
382
383        serde_json::to_writer_pretty(BufWriter::new(File::create(&chunk_meta_path)?), &meta)?;
384    }
385
386    // Build IVF (Inverted File)
387    let mut code_to_docs: BTreeMap<usize, Vec<i64>> = BTreeMap::new();
388    let mut emb_idx = 0;
389
390    for (doc_id, &len) in doc_lengths.iter().enumerate() {
391        for _ in 0..len {
392            let code = all_codes[emb_idx];
393            code_to_docs.entry(code).or_default().push(doc_id as i64);
394            emb_idx += 1;
395        }
396    }
397
398    // Deduplicate document IDs per centroid
399    let mut ivf_data: Vec<i64> = Vec::new();
400    let mut ivf_lengths: Vec<i32> = vec![0; num_centroids];
401
402    for (centroid_id, ivf_len) in ivf_lengths.iter_mut().enumerate() {
403        if let Some(docs) = code_to_docs.get(&centroid_id) {
404            let mut unique_docs: Vec<i64> = docs.clone();
405            unique_docs.sort_unstable();
406            unique_docs.dedup();
407            *ivf_len = unique_docs.len() as i32;
408            ivf_data.extend(unique_docs);
409        }
410    }
411
412    let ivf = Array1::from_vec(ivf_data);
413    let ivf_lengths = Array1::from_vec(ivf_lengths);
414
415    let ivf_path = index_dir.join("ivf.npy");
416    ivf.write_npy(File::create(&ivf_path)?)?;
417
418    let ivf_lengths_path = index_dir.join("ivf_lengths.npy");
419    ivf_lengths.write_npy(File::create(&ivf_lengths_path)?)?;
420
421    // Save global metadata
422    let metadata = Metadata {
423        num_chunks: n_chunks,
424        nbits: config.nbits,
425        num_partitions: num_centroids,
426        num_embeddings: total_embeddings,
427        avg_doclen,
428        num_documents,
429        next_plaid_compatible: true, // Created by next-plaid, always compatible
430    };
431
432    let metadata_path = index_dir.join("metadata.json");
433    serde_json::to_writer_pretty(BufWriter::new(File::create(&metadata_path)?), &metadata)?;
434
435    Ok(metadata)
436}
437
438/// Create index files with automatic K-means centroid computation.
439///
440/// This is a standalone function that runs K-means to compute centroids,
441/// then creates all index files on disk.
442///
443/// # Arguments
444///
445/// * `embeddings` - List of document embeddings
446/// * `index_path` - Directory to save the index
447/// * `config` - Index configuration
448///
449/// # Returns
450///
451/// Metadata about the created index
452pub fn create_index_with_kmeans_files(
453    embeddings: &[Array2<f32>],
454    index_path: &str,
455    config: &IndexConfig,
456) -> Result<Metadata> {
457    if embeddings.is_empty() {
458        return Err(Error::IndexCreation("No documents provided".into()));
459    }
460
461    // Build K-means configuration from IndexConfig
462    let kmeans_config = ComputeKmeansConfig {
463        kmeans_niters: config.kmeans_niters,
464        max_points_per_centroid: config.max_points_per_centroid,
465        seed: config.seed.unwrap_or(42),
466        n_samples_kmeans: config.n_samples_kmeans,
467        num_partitions: None, // Let the heuristic decide
468    };
469
470    // Compute centroids using fast-plaid's approach
471    let centroids = compute_kmeans(embeddings, &kmeans_config)?;
472
473    // Create the index files
474    let metadata = create_index_files(embeddings, centroids, index_path, config)?;
475
476    // If below start_from_scratch threshold, save raw embeddings for potential rebuilds
477    if embeddings.len() <= config.start_from_scratch {
478        let index_dir = std::path::Path::new(index_path);
479        crate::update::save_embeddings_npy(index_dir, embeddings)?;
480    }
481
482    Ok(metadata)
483}
484// ============================================================================
485// Memory-Mapped Index for Low Memory Usage
486// ============================================================================
487
488/// A memory-mapped PLAID index for multi-vector search.
489///
490/// This struct uses memory-mapped files for the large arrays (codes and residuals)
491/// instead of loading them entirely into RAM. Only small tensors (centroids,
492/// bucket weights, IVF) are loaded into memory.
493///
494/// # Memory Usage
495///
496/// Only small tensors (~50 MB for SciFact 5K docs) are loaded into RAM,
497/// with code and residual data accessed via OS-managed memory mapping.
498///
499/// # Usage
500///
501/// ```ignore
502/// use next_plaid::MmapIndex;
503///
504/// let index = MmapIndex::load("/path/to/index")?;
505/// let results = index.search(&query, &params, None)?;
506/// ```
507pub struct MmapIndex {
508    /// Path to the index directory
509    pub path: String,
510    /// Index metadata
511    pub metadata: Metadata,
512    /// Residual codec for quantization/decompression
513    pub codec: ResidualCodec,
514    /// IVF data (concatenated passage IDs per centroid)
515    pub ivf: Array1<i64>,
516    /// IVF lengths (number of passages per centroid)
517    pub ivf_lengths: Array1<i32>,
518    /// IVF offsets (cumulative offsets into ivf array)
519    pub ivf_offsets: Array1<i64>,
520    /// Document lengths (number of tokens per document)
521    pub doc_lengths: Array1<i64>,
522    /// Cumulative document offsets for indexing into codes/residuals
523    pub doc_offsets: Array1<usize>,
524    /// Memory-mapped codes array (public for search access)
525    pub mmap_codes: crate::mmap::MmapNpyArray1I64,
526    /// Memory-mapped residuals array (public for search access)
527    pub mmap_residuals: crate::mmap::MmapNpyArray2U8,
528}
529
530impl MmapIndex {
531    /// Load a memory-mapped index from disk.
532    ///
533    /// This creates merged files for codes and residuals if they don't exist,
534    /// then memory-maps them for efficient access.
535    ///
536    /// If the index was created by fast-plaid, it will be automatically converted
537    /// to next-plaid compatible format on first load.
538    pub fn load(index_path: &str) -> Result<Self> {
539        use ndarray_npy::ReadNpyExt;
540
541        let index_dir = Path::new(index_path);
542
543        // Load metadata (infers num_documents from doclens if not present)
544        let mut metadata = Metadata::load_from_path(index_dir)?;
545
546        // Check if conversion from fast-plaid format is needed
547        if !metadata.next_plaid_compatible {
548            eprintln!("Checking index format compatibility...");
549            let converted = crate::mmap::convert_fastplaid_to_nextplaid(index_dir)?;
550            if converted {
551                eprintln!("Index converted to next-plaid compatible format.");
552                // Delete any existing merged files since the source files changed
553                let merged_codes = index_dir.join("merged_codes.npy");
554                let merged_residuals = index_dir.join("merged_residuals.npy");
555                let codes_manifest = index_dir.join("merged_codes.manifest.json");
556                let residuals_manifest = index_dir.join("merged_residuals.manifest.json");
557                for path in [
558                    &merged_codes,
559                    &merged_residuals,
560                    &codes_manifest,
561                    &residuals_manifest,
562                ] {
563                    if path.exists() {
564                        let _ = fs::remove_file(path);
565                    }
566                }
567            }
568
569            // Mark as compatible and save metadata
570            metadata.next_plaid_compatible = true;
571            let metadata_path = index_dir.join("metadata.json");
572            let file = File::create(&metadata_path)
573                .map_err(|e| Error::IndexLoad(format!("Failed to update metadata: {}", e)))?;
574            serde_json::to_writer_pretty(BufWriter::new(file), &metadata)?;
575            eprintln!("Metadata updated with next_plaid_compatible: true");
576        }
577
578        // Load codec with memory-mapped centroids for reduced RAM usage.
579        // Other small tensors (bucket weights, etc.) are still loaded into memory.
580        let codec = ResidualCodec::load_mmap_from_dir(index_dir)?;
581
582        // Load IVF (small tensor)
583        let ivf_path = index_dir.join("ivf.npy");
584        let ivf: Array1<i64> = Array1::read_npy(
585            File::open(&ivf_path)
586                .map_err(|e| Error::IndexLoad(format!("Failed to open ivf.npy: {}", e)))?,
587        )
588        .map_err(|e| Error::IndexLoad(format!("Failed to read ivf.npy: {}", e)))?;
589
590        let ivf_lengths_path = index_dir.join("ivf_lengths.npy");
591        let ivf_lengths: Array1<i32> = Array1::read_npy(
592            File::open(&ivf_lengths_path)
593                .map_err(|e| Error::IndexLoad(format!("Failed to open ivf_lengths.npy: {}", e)))?,
594        )
595        .map_err(|e| Error::IndexLoad(format!("Failed to read ivf_lengths.npy: {}", e)))?;
596
597        // Compute IVF offsets
598        let num_centroids = ivf_lengths.len();
599        let mut ivf_offsets = Array1::<i64>::zeros(num_centroids + 1);
600        for i in 0..num_centroids {
601            ivf_offsets[i + 1] = ivf_offsets[i] + ivf_lengths[i] as i64;
602        }
603
604        // Load document lengths from all chunks
605        let mut doc_lengths_vec: Vec<i64> = Vec::with_capacity(metadata.num_documents);
606        for chunk_idx in 0..metadata.num_chunks {
607            let doclens_path = index_dir.join(format!("doclens.{}.json", chunk_idx));
608            let chunk_doclens: Vec<i64> =
609                serde_json::from_reader(BufReader::new(File::open(&doclens_path)?))?;
610            doc_lengths_vec.extend(chunk_doclens);
611        }
612        let doc_lengths = Array1::from_vec(doc_lengths_vec);
613
614        // Compute document offsets for indexing
615        let mut doc_offsets = Array1::<usize>::zeros(doc_lengths.len() + 1);
616        for i in 0..doc_lengths.len() {
617            doc_offsets[i + 1] = doc_offsets[i] + doc_lengths[i] as usize;
618        }
619
620        // Compute padding needed for StridedTensor compatibility
621        let max_len = doc_lengths.iter().cloned().max().unwrap_or(0) as usize;
622        let last_len = *doc_lengths.last().unwrap_or(&0) as usize;
623        let padding_needed = max_len.saturating_sub(last_len);
624
625        // Create merged files if needed
626        let merged_codes_path =
627            crate::mmap::merge_codes_chunks(index_dir, metadata.num_chunks, padding_needed)?;
628        let merged_residuals_path =
629            crate::mmap::merge_residuals_chunks(index_dir, metadata.num_chunks, padding_needed)?;
630
631        // Memory-map the merged files
632        let mmap_codes = crate::mmap::MmapNpyArray1I64::from_npy_file(&merged_codes_path)?;
633        let mmap_residuals = crate::mmap::MmapNpyArray2U8::from_npy_file(&merged_residuals_path)?;
634
635        Ok(Self {
636            path: index_path.to_string(),
637            metadata,
638            codec,
639            ivf,
640            ivf_lengths,
641            ivf_offsets,
642            doc_lengths,
643            doc_offsets,
644            mmap_codes,
645            mmap_residuals,
646        })
647    }
648
649    /// Get candidate documents from IVF for given centroid indices.
650    pub fn get_candidates(&self, centroid_indices: &[usize]) -> Vec<i64> {
651        let mut candidates: Vec<i64> = Vec::new();
652
653        for &idx in centroid_indices {
654            if idx < self.ivf_lengths.len() {
655                let start = self.ivf_offsets[idx] as usize;
656                let len = self.ivf_lengths[idx] as usize;
657                candidates.extend(self.ivf.slice(s![start..start + len]).iter());
658            }
659        }
660
661        candidates.sort_unstable();
662        candidates.dedup();
663        candidates
664    }
665
666    /// Get document embeddings by decompressing codes and residuals.
667    pub fn get_document_embeddings(&self, doc_id: usize) -> Result<Array2<f32>> {
668        if doc_id >= self.doc_lengths.len() {
669            return Err(Error::Search(format!("Invalid document ID: {}", doc_id)));
670        }
671
672        let start = self.doc_offsets[doc_id];
673        let end = self.doc_offsets[doc_id + 1];
674
675        // Get codes and residuals from mmap
676        let codes_slice = self.mmap_codes.slice(start, end);
677        let residuals_view = self.mmap_residuals.slice_rows(start, end);
678
679        // Convert codes to Array1<usize>
680        let codes: Array1<usize> = Array1::from_iter(codes_slice.iter().map(|&c| c as usize));
681
682        // Convert residuals to owned Array2
683        let residuals = residuals_view.to_owned();
684
685        // Decompress
686        self.codec.decompress(&residuals, &codes.view())
687    }
688
689    /// Get codes for a batch of document IDs (for approximate scoring).
690    pub fn get_document_codes(&self, doc_ids: &[usize]) -> Vec<Vec<i64>> {
691        doc_ids
692            .iter()
693            .map(|&doc_id| {
694                if doc_id >= self.doc_lengths.len() {
695                    return vec![];
696                }
697                let start = self.doc_offsets[doc_id];
698                let end = self.doc_offsets[doc_id + 1];
699                self.mmap_codes.slice(start, end).to_vec()
700            })
701            .collect()
702    }
703
704    /// Decompress embeddings for a batch of document IDs.
705    pub fn decompress_documents(&self, doc_ids: &[usize]) -> Result<(Array2<f32>, Vec<usize>)> {
706        // Compute total tokens
707        let mut total_tokens = 0usize;
708        let mut lengths = Vec::with_capacity(doc_ids.len());
709        for &doc_id in doc_ids {
710            if doc_id >= self.doc_lengths.len() {
711                lengths.push(0);
712            } else {
713                let len = self.doc_offsets[doc_id + 1] - self.doc_offsets[doc_id];
714                lengths.push(len);
715                total_tokens += len;
716            }
717        }
718
719        if total_tokens == 0 {
720            return Ok((Array2::zeros((0, self.codec.embedding_dim())), lengths));
721        }
722
723        // Gather all codes and residuals
724        let packed_dim = self.mmap_residuals.ncols();
725        let mut all_codes = Vec::with_capacity(total_tokens);
726        let mut all_residuals = Array2::<u8>::zeros((total_tokens, packed_dim));
727        let mut offset = 0;
728
729        for &doc_id in doc_ids {
730            if doc_id >= self.doc_lengths.len() {
731                continue;
732            }
733            let start = self.doc_offsets[doc_id];
734            let end = self.doc_offsets[doc_id + 1];
735            let len = end - start;
736
737            // Append codes
738            let codes_slice = self.mmap_codes.slice(start, end);
739            all_codes.extend(codes_slice.iter().map(|&c| c as usize));
740
741            // Copy residuals
742            let residuals_view = self.mmap_residuals.slice_rows(start, end);
743            all_residuals
744                .slice_mut(s![offset..offset + len, ..])
745                .assign(&residuals_view);
746            offset += len;
747        }
748
749        let codes_arr = Array1::from_vec(all_codes);
750        let embeddings = self.codec.decompress(&all_residuals, &codes_arr.view())?;
751
752        Ok((embeddings, lengths))
753    }
754
755    /// Search for similar documents.
756    ///
757    /// # Arguments
758    ///
759    /// * `query` - Query embedding matrix [num_tokens, dim]
760    /// * `params` - Search parameters
761    /// * `subset` - Optional subset of document IDs to search within
762    ///
763    /// # Returns
764    ///
765    /// Search result containing top-k document IDs and scores.
766    pub fn search(
767        &self,
768        query: &Array2<f32>,
769        params: &crate::search::SearchParameters,
770        subset: Option<&[i64]>,
771    ) -> Result<crate::search::SearchResult> {
772        crate::search::search_one_mmap(self, query, params, subset)
773    }
774
775    /// Search for multiple queries in batch.
776    ///
777    /// # Arguments
778    ///
779    /// * `queries` - Slice of query embedding matrices
780    /// * `params` - Search parameters
781    /// * `parallel` - If true, process queries in parallel using rayon
782    /// * `subset` - Optional subset of document IDs to search within
783    ///
784    /// # Returns
785    ///
786    /// Vector of search results, one per query.
787    pub fn search_batch(
788        &self,
789        queries: &[Array2<f32>],
790        params: &crate::search::SearchParameters,
791        parallel: bool,
792        subset: Option<&[i64]>,
793    ) -> Result<Vec<crate::search::SearchResult>> {
794        crate::search::search_many_mmap(self, queries, params, parallel, subset)
795    }
796
797    /// Get the number of documents in the index.
798    pub fn num_documents(&self) -> usize {
799        self.doc_lengths.len()
800    }
801
802    /// Get the total number of embeddings in the index.
803    pub fn num_embeddings(&self) -> usize {
804        self.metadata.num_embeddings
805    }
806
807    /// Get the number of partitions (centroids).
808    pub fn num_partitions(&self) -> usize {
809        self.metadata.num_partitions
810    }
811
812    /// Get the average document length.
813    pub fn avg_doclen(&self) -> f64 {
814        self.metadata.avg_doclen
815    }
816
817    /// Get the embedding dimension.
818    pub fn embedding_dim(&self) -> usize {
819        self.codec.embedding_dim()
820    }
821
822    /// Reconstruct embeddings for specific documents.
823    ///
824    /// This method retrieves the compressed codes and residuals for each document
825    /// from memory-mapped files and decompresses them to recover the original embeddings.
826    ///
827    /// # Arguments
828    ///
829    /// * `doc_ids` - Slice of document IDs to reconstruct (0-indexed)
830    ///
831    /// # Returns
832    ///
833    /// A vector of 2D arrays, one per document. Each array has shape `[num_tokens, dim]`.
834    ///
835    /// # Example
836    ///
837    /// ```ignore
838    /// use next_plaid::MmapIndex;
839    ///
840    /// let index = MmapIndex::load("/path/to/index")?;
841    /// let embeddings = index.reconstruct(&[0, 1, 2])?;
842    ///
843    /// for (i, emb) in embeddings.iter().enumerate() {
844    ///     println!("Document {}: {} tokens x {} dim", i, emb.nrows(), emb.ncols());
845    /// }
846    /// ```
847    pub fn reconstruct(&self, doc_ids: &[i64]) -> Result<Vec<Array2<f32>>> {
848        crate::embeddings::reconstruct_embeddings(self, doc_ids)
849    }
850
851    /// Reconstruct a single document's embeddings.
852    ///
853    /// Convenience method for reconstructing a single document.
854    ///
855    /// # Arguments
856    ///
857    /// * `doc_id` - Document ID to reconstruct (0-indexed)
858    ///
859    /// # Returns
860    ///
861    /// A 2D array with shape `[num_tokens, dim]`.
862    pub fn reconstruct_single(&self, doc_id: i64) -> Result<Array2<f32>> {
863        crate::embeddings::reconstruct_single(self, doc_id)
864    }
865
866    /// Create a new index from document embeddings with automatic centroid computation.
867    ///
868    /// This method:
869    /// 1. Computes centroids using K-means
870    /// 2. Creates index files on disk
871    /// 3. Loads the index using memory-mapped I/O
872    ///
873    /// Note: During creation, data is temporarily held in RAM for processing,
874    /// then written to disk and loaded as mmap.
875    ///
876    /// # Arguments
877    ///
878    /// * `embeddings` - List of document embeddings, each of shape `[num_tokens, dim]`
879    /// * `index_path` - Directory to save the index
880    /// * `config` - Index configuration
881    ///
882    /// # Returns
883    ///
884    /// The created MmapIndex
885    pub fn create_with_kmeans(
886        embeddings: &[Array2<f32>],
887        index_path: &str,
888        config: &IndexConfig,
889    ) -> Result<Self> {
890        // Use standalone function to create files
891        create_index_with_kmeans_files(embeddings, index_path, config)?;
892
893        // Load as memory-mapped index
894        Self::load(index_path)
895    }
896
897    /// Update the index with new documents, matching fast-plaid behavior.
898    ///
899    /// This method adds new documents to an existing index with three possible paths:
900    ///
901    /// 1. **Start-from-scratch mode** (num_documents <= start_from_scratch):
902    ///    - Loads existing embeddings from `embeddings.npy` if available
903    ///    - Combines with new embeddings
904    ///    - Rebuilds the entire index from scratch with fresh K-means
905    ///    - Clears `embeddings.npy` if total exceeds threshold
906    ///
907    /// 2. **Buffer mode** (total_new < buffer_size):
908    ///    - Adds new documents to the index without centroid expansion
909    ///    - Saves embeddings to buffer for later centroid expansion
910    ///
911    /// 3. **Centroid expansion mode** (total_new >= buffer_size):
912    ///    - Deletes previously buffered documents
913    ///    - Expands centroids with outliers from combined buffer + new embeddings
914    ///    - Re-indexes all combined embeddings with expanded centroids
915    ///
916    /// # Arguments
917    ///
918    /// * `embeddings` - New document embeddings to add
919    /// * `config` - Update configuration
920    ///
921    /// # Returns
922    ///
923    /// Vector of document IDs assigned to the new embeddings
924    pub fn update(
925        &mut self,
926        embeddings: &[Array2<f32>],
927        config: &crate::update::UpdateConfig,
928    ) -> Result<Vec<i64>> {
929        use crate::codec::ResidualCodec;
930        use crate::update::{
931            clear_buffer, clear_embeddings_npy, embeddings_npy_exists, load_buffer,
932            load_buffer_info, load_cluster_threshold, load_embeddings_npy, save_buffer,
933            update_centroids, update_index,
934        };
935
936        let path_str = self.path.clone();
937        let index_path = std::path::Path::new(&path_str);
938        let num_new_docs = embeddings.len();
939
940        // ==================================================================
941        // Start-from-scratch mode (fast-plaid update.py:312-346)
942        // ==================================================================
943        if self.metadata.num_documents <= config.start_from_scratch {
944            // Load existing embeddings if available
945            let existing_embeddings = load_embeddings_npy(index_path)?;
946            // New documents start after existing documents
947            let start_doc_id = existing_embeddings.len() as i64;
948
949            // Combine existing + new embeddings
950            let combined_embeddings: Vec<Array2<f32>> = existing_embeddings
951                .into_iter()
952                .chain(embeddings.iter().cloned())
953                .collect();
954
955            // Build IndexConfig from UpdateConfig for create_with_kmeans
956            let index_config = IndexConfig {
957                nbits: self.metadata.nbits,
958                batch_size: config.batch_size,
959                seed: Some(config.seed),
960                kmeans_niters: config.kmeans_niters,
961                max_points_per_centroid: config.max_points_per_centroid,
962                n_samples_kmeans: config.n_samples_kmeans,
963                start_from_scratch: config.start_from_scratch,
964            };
965
966            // Rebuild index from scratch with fresh K-means
967            *self = Self::create_with_kmeans(&combined_embeddings, &path_str, &index_config)?;
968
969            // If we've crossed the threshold, clear embeddings.npy
970            if combined_embeddings.len() > config.start_from_scratch
971                && embeddings_npy_exists(index_path)
972            {
973                clear_embeddings_npy(index_path)?;
974            }
975
976            // Return the document IDs assigned to the new embeddings
977            return Ok((start_doc_id..start_doc_id + num_new_docs as i64).collect());
978        }
979
980        // Load buffer
981        let buffer = load_buffer(index_path)?;
982        let buffer_len = buffer.len();
983        let total_new = embeddings.len() + buffer_len;
984
985        // Track the starting document ID for the new embeddings
986        let start_doc_id: i64;
987
988        // Load codec for update operations
989        let mut codec = ResidualCodec::load_from_dir(index_path)?;
990
991        // Check buffer threshold
992        if total_new >= config.buffer_size {
993            // Centroid expansion path (matches fast-plaid update.py:376-422)
994
995            // 1. Get number of buffered docs that were previously indexed
996            let num_buffered = load_buffer_info(index_path)?;
997
998            // 2. Delete buffered docs from index (they were indexed without centroid expansion)
999            if num_buffered > 0 && self.metadata.num_documents >= num_buffered {
1000                let start_del_idx = self.metadata.num_documents - num_buffered;
1001                let docs_to_delete: Vec<i64> = (start_del_idx..self.metadata.num_documents)
1002                    .map(|i| i as i64)
1003                    .collect();
1004                crate::delete::delete_from_index_keep_buffer(&docs_to_delete, &path_str)?;
1005                // Reload metadata after delete
1006                self.metadata = Metadata::load_from_path(index_path)?;
1007            }
1008
1009            // New embeddings start after buffer is re-indexed
1010            start_doc_id = (self.metadata.num_documents + buffer_len) as i64;
1011
1012            // 3. Combine buffer + new embeddings
1013            let combined: Vec<Array2<f32>> = buffer
1014                .into_iter()
1015                .chain(embeddings.iter().cloned())
1016                .collect();
1017
1018            // 4. Expand centroids with outliers from combined embeddings
1019            if let Ok(cluster_threshold) = load_cluster_threshold(index_path) {
1020                let new_centroids =
1021                    update_centroids(index_path, &combined, cluster_threshold, config)?;
1022                if new_centroids > 0 {
1023                    // Reload codec with new centroids
1024                    codec = ResidualCodec::load_from_dir(index_path)?;
1025                }
1026            }
1027
1028            // 5. Clear buffer
1029            clear_buffer(index_path)?;
1030
1031            // 6. Update index with ALL combined embeddings (buffer + new)
1032            update_index(&combined, &path_str, &codec, Some(config.batch_size), true)?;
1033        } else {
1034            // Small update: add to buffer and index without centroid expansion
1035            // New documents start at current num_documents
1036            start_doc_id = self.metadata.num_documents as i64;
1037
1038            // Accumulate buffer: combine existing buffer with new embeddings
1039            let combined_buffer: Vec<Array2<f32>> = buffer
1040                .into_iter()
1041                .chain(embeddings.iter().cloned())
1042                .collect();
1043            save_buffer(index_path, &combined_buffer)?;
1044
1045            // Update index without threshold update
1046            update_index(
1047                embeddings,
1048                &path_str,
1049                &codec,
1050                Some(config.batch_size),
1051                false,
1052            )?;
1053        }
1054
1055        // Reload self as mmap
1056        *self = Self::load(&path_str)?;
1057
1058        // Return the document IDs assigned to the new embeddings
1059        Ok((start_doc_id..start_doc_id + num_new_docs as i64).collect())
1060    }
1061
1062    /// Update the index with new documents and optional metadata.
1063    ///
1064    /// # Arguments
1065    ///
1066    /// * `embeddings` - New document embeddings to add
1067    /// * `config` - Update configuration
1068    /// * `metadata` - Optional metadata for new documents
1069    ///
1070    /// # Returns
1071    ///
1072    /// Vector of document IDs assigned to the new embeddings
1073    pub fn update_with_metadata(
1074        &mut self,
1075        embeddings: &[Array2<f32>],
1076        config: &crate::update::UpdateConfig,
1077        metadata: Option<&[serde_json::Value]>,
1078    ) -> Result<Vec<i64>> {
1079        // Validate metadata length if provided
1080        if let Some(meta) = metadata {
1081            if meta.len() != embeddings.len() {
1082                return Err(Error::Config(format!(
1083                    "Metadata length ({}) must match embeddings length ({})",
1084                    meta.len(),
1085                    embeddings.len()
1086                )));
1087            }
1088        }
1089
1090        // Perform the update and get document IDs
1091        let doc_ids = self.update(embeddings, config)?;
1092
1093        // Add metadata if provided, using the assigned document IDs
1094        if let Some(meta) = metadata {
1095            crate::filtering::update(&self.path, meta, &doc_ids)?;
1096        }
1097
1098        Ok(doc_ids)
1099    }
1100
1101    /// Update an existing index or create a new one if it doesn't exist.
1102    ///
1103    /// # Arguments
1104    ///
1105    /// * `embeddings` - Document embeddings to add
1106    /// * `index_path` - Directory for the index
1107    /// * `index_config` - Configuration for index creation
1108    /// * `update_config` - Configuration for updates
1109    ///
1110    /// # Returns
1111    ///
1112    /// A tuple of (MmapIndex, `Vec<i64>`) containing the index and document IDs
1113    pub fn update_or_create(
1114        embeddings: &[Array2<f32>],
1115        index_path: &str,
1116        index_config: &IndexConfig,
1117        update_config: &crate::update::UpdateConfig,
1118    ) -> Result<(Self, Vec<i64>)> {
1119        let index_dir = std::path::Path::new(index_path);
1120        let metadata_path = index_dir.join("metadata.json");
1121
1122        if metadata_path.exists() {
1123            // Index exists, load and update
1124            let mut index = Self::load(index_path)?;
1125            let doc_ids = index.update(embeddings, update_config)?;
1126            Ok((index, doc_ids))
1127        } else {
1128            // Index doesn't exist, create new
1129            let num_docs = embeddings.len();
1130            let index = Self::create_with_kmeans(embeddings, index_path, index_config)?;
1131            let doc_ids: Vec<i64> = (0..num_docs as i64).collect();
1132            Ok((index, doc_ids))
1133        }
1134    }
1135
1136    /// Delete documents from the index.
1137    ///
1138    /// # Arguments
1139    ///
1140    /// * `doc_ids` - Slice of document IDs to delete (0-indexed)
1141    ///
1142    /// # Returns
1143    ///
1144    /// The number of documents actually deleted
1145    pub fn delete(&mut self, doc_ids: &[i64]) -> Result<usize> {
1146        self.delete_with_options(doc_ids, true)
1147    }
1148
1149    /// Delete documents from the index with control over metadata deletion.
1150    ///
1151    /// # Arguments
1152    ///
1153    /// * `doc_ids` - Slice of document IDs to delete
1154    /// * `delete_metadata` - If true, also delete from metadata.db if it exists
1155    ///
1156    /// # Returns
1157    ///
1158    /// The number of documents actually deleted
1159    pub fn delete_with_options(&mut self, doc_ids: &[i64], delete_metadata: bool) -> Result<usize> {
1160        let path = self.path.clone();
1161
1162        // Perform the deletion using standalone function
1163        let deleted = crate::delete::delete_from_index(doc_ids, &path)?;
1164
1165        // Also delete from metadata.db if requested
1166        if delete_metadata && deleted > 0 {
1167            let index_path = std::path::Path::new(&path);
1168            let db_path = index_path.join("metadata.db");
1169            if db_path.exists() {
1170                crate::filtering::delete(&path, doc_ids)?;
1171            }
1172        }
1173
1174        // Reload self as mmap
1175        *self = Self::load(&path)?;
1176
1177        Ok(deleted)
1178    }
1179}
1180
1181#[cfg(test)]
1182mod tests {
1183    use super::*;
1184
1185    #[test]
1186    fn test_index_config_default() {
1187        let config = IndexConfig::default();
1188        assert_eq!(config.nbits, 4);
1189        assert_eq!(config.batch_size, 50_000);
1190        assert_eq!(config.seed, Some(42));
1191    }
1192
1193    #[test]
1194    fn test_update_or_create_new_index() {
1195        use ndarray::Array2;
1196        use tempfile::tempdir;
1197
1198        let temp_dir = tempdir().unwrap();
1199        let index_path = temp_dir.path().to_str().unwrap();
1200
1201        // Create test embeddings (5 documents)
1202        let mut embeddings: Vec<Array2<f32>> = Vec::new();
1203        for i in 0..5 {
1204            let mut doc = Array2::<f32>::zeros((5, 32));
1205            for j in 0..5 {
1206                for k in 0..32 {
1207                    doc[[j, k]] = (i as f32 * 0.1) + (j as f32 * 0.01) + (k as f32 * 0.001);
1208                }
1209            }
1210            // Normalize rows
1211            for mut row in doc.rows_mut() {
1212                let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
1213                if norm > 0.0 {
1214                    row.iter_mut().for_each(|x| *x /= norm);
1215                }
1216            }
1217            embeddings.push(doc);
1218        }
1219
1220        let index_config = IndexConfig {
1221            nbits: 2,
1222            batch_size: 50,
1223            seed: Some(42),
1224            kmeans_niters: 2,
1225            ..Default::default()
1226        };
1227        let update_config = crate::update::UpdateConfig::default();
1228
1229        // Index doesn't exist - should create new
1230        let (index, doc_ids) =
1231            MmapIndex::update_or_create(&embeddings, index_path, &index_config, &update_config)
1232                .expect("Failed to create index");
1233
1234        assert_eq!(index.metadata.num_documents, 5);
1235        assert_eq!(doc_ids, vec![0, 1, 2, 3, 4]);
1236
1237        // Verify index was created
1238        assert!(temp_dir.path().join("metadata.json").exists());
1239        assert!(temp_dir.path().join("centroids.npy").exists());
1240    }
1241
1242    #[test]
1243    fn test_update_or_create_existing_index() {
1244        use ndarray::Array2;
1245        use tempfile::tempdir;
1246
1247        let temp_dir = tempdir().unwrap();
1248        let index_path = temp_dir.path().to_str().unwrap();
1249
1250        // Helper to create embeddings
1251        let create_embeddings = |count: usize, offset: usize| -> Vec<Array2<f32>> {
1252            let mut embeddings = Vec::new();
1253            for i in 0..count {
1254                let mut doc = Array2::<f32>::zeros((5, 32));
1255                for j in 0..5 {
1256                    for k in 0..32 {
1257                        doc[[j, k]] =
1258                            ((i + offset) as f32 * 0.1) + (j as f32 * 0.01) + (k as f32 * 0.001);
1259                    }
1260                }
1261                for mut row in doc.rows_mut() {
1262                    let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
1263                    if norm > 0.0 {
1264                        row.iter_mut().for_each(|x| *x /= norm);
1265                    }
1266                }
1267                embeddings.push(doc);
1268            }
1269            embeddings
1270        };
1271
1272        let index_config = IndexConfig {
1273            nbits: 2,
1274            batch_size: 50,
1275            seed: Some(42),
1276            kmeans_niters: 2,
1277            ..Default::default()
1278        };
1279        let update_config = crate::update::UpdateConfig::default();
1280
1281        // First call - creates index with 5 documents
1282        let embeddings1 = create_embeddings(5, 0);
1283        let (index1, doc_ids1) =
1284            MmapIndex::update_or_create(&embeddings1, index_path, &index_config, &update_config)
1285                .expect("Failed to create index");
1286        assert_eq!(index1.metadata.num_documents, 5);
1287        assert_eq!(doc_ids1, vec![0, 1, 2, 3, 4]);
1288
1289        // Second call - updates existing index with 3 more documents
1290        let embeddings2 = create_embeddings(3, 5);
1291        let (index2, doc_ids2) =
1292            MmapIndex::update_or_create(&embeddings2, index_path, &index_config, &update_config)
1293                .expect("Failed to update index");
1294        assert_eq!(index2.metadata.num_documents, 8);
1295        assert_eq!(doc_ids2, vec![5, 6, 7]);
1296    }
1297}