next_plaid/
index.rs

1//! Index creation and management for PLAID
2
3use std::collections::BTreeMap;
4use std::fs::{self, File};
5use std::io::{BufReader, BufWriter, Write};
6use std::path::Path;
7
8use ndarray::{s, Array1, Array2, Axis};
9use serde::{Deserialize, Serialize};
10
11use crate::codec::ResidualCodec;
12use crate::error::{Error, Result};
13use crate::kmeans::{compute_kmeans, ComputeKmeansConfig};
14use crate::utils::{quantile, quantiles};
15
16/// Configuration for index creation
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct IndexConfig {
19    /// Number of bits for quantization (typically 2 or 4)
20    pub nbits: usize,
21    /// Batch size for processing
22    pub batch_size: usize,
23    /// Random seed for reproducibility
24    pub seed: Option<u64>,
25    /// Number of K-means iterations (default: 4)
26    #[serde(default = "default_kmeans_niters")]
27    pub kmeans_niters: usize,
28    /// Maximum number of points per centroid for K-means (default: 256)
29    #[serde(default = "default_max_points_per_centroid")]
30    pub max_points_per_centroid: usize,
31    /// Number of samples for K-means training.
32    /// If None, uses heuristic: min(1 + 16 * sqrt(120 * num_documents), num_documents)
33    #[serde(default)]
34    pub n_samples_kmeans: Option<usize>,
35    /// Threshold for start-from-scratch mode (default: 999).
36    /// When the number of documents is <= this threshold, raw embeddings are saved
37    /// to embeddings.npy for potential rebuilds during updates.
38    #[serde(default = "default_start_from_scratch")]
39    pub start_from_scratch: usize,
40}
41
42fn default_start_from_scratch() -> usize {
43    999
44}
45
46fn default_kmeans_niters() -> usize {
47    4
48}
49
50fn default_max_points_per_centroid() -> usize {
51    256
52}
53
54impl Default for IndexConfig {
55    fn default() -> Self {
56        Self {
57            nbits: 4,
58            batch_size: 50_000,
59            seed: Some(42),
60            kmeans_niters: 4,
61            max_points_per_centroid: 256,
62            n_samples_kmeans: None,
63            start_from_scratch: 999,
64        }
65    }
66}
67
68/// Metadata for the index
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct Metadata {
71    /// Number of chunks in the index
72    pub num_chunks: usize,
73    /// Number of bits for quantization
74    pub nbits: usize,
75    /// Number of partitions (centroids)
76    pub num_partitions: usize,
77    /// Total number of embeddings
78    pub num_embeddings: usize,
79    /// Average document length
80    pub avg_doclen: f64,
81    /// Total number of documents
82    pub num_documents: usize,
83}
84
85/// Chunk metadata
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct ChunkMetadata {
88    pub num_documents: usize,
89    pub num_embeddings: usize,
90    #[serde(default)]
91    pub embedding_offset: usize,
92}
93
94// ============================================================================
95// Standalone Index Creation Functions
96// ============================================================================
97
98/// Create index files on disk from embeddings and centroids.
99///
100/// This is a standalone function that creates all necessary index files
101/// without constructing an in-memory Index object. Both Index and MmapIndex
102/// can use this function to create their files, then load them in their
103/// preferred format.
104///
105/// # Arguments
106///
107/// * `embeddings` - List of document embeddings
108/// * `centroids` - Pre-computed centroids from K-means
109/// * `index_path` - Directory to save the index
110/// * `config` - Index configuration
111///
112/// # Returns
113///
114/// Metadata about the created index
115pub fn create_index_files(
116    embeddings: &[Array2<f32>],
117    centroids: Array2<f32>,
118    index_path: &str,
119    config: &IndexConfig,
120) -> Result<Metadata> {
121    let index_dir = Path::new(index_path);
122    fs::create_dir_all(index_dir)?;
123
124    let num_documents = embeddings.len();
125    let embedding_dim = centroids.ncols();
126    let num_centroids = centroids.nrows();
127
128    if num_documents == 0 {
129        return Err(Error::IndexCreation("No documents provided".into()));
130    }
131
132    // Calculate statistics
133    let total_embeddings: usize = embeddings.iter().map(|e| e.nrows()).sum();
134    let avg_doclen = total_embeddings as f64 / num_documents as f64;
135
136    // Sample documents for codec training
137    let sample_count = ((16.0 * (120.0 * num_documents as f64).sqrt()) as usize)
138        .min(num_documents)
139        .max(1);
140
141    let mut rng = if let Some(seed) = config.seed {
142        use rand::SeedableRng;
143        rand_chacha::ChaCha8Rng::seed_from_u64(seed)
144    } else {
145        use rand::SeedableRng;
146        rand_chacha::ChaCha8Rng::from_entropy()
147    };
148
149    use rand::seq::SliceRandom;
150    let mut indices: Vec<usize> = (0..num_documents).collect();
151    indices.shuffle(&mut rng);
152    let sample_indices: Vec<usize> = indices.into_iter().take(sample_count).collect();
153
154    // Collect sample embeddings for training
155    let heldout_size = (0.05 * total_embeddings as f64).min(50000.0) as usize;
156    let mut heldout_embeddings: Vec<f32> = Vec::with_capacity(heldout_size * embedding_dim);
157    let mut collected = 0;
158
159    for &idx in sample_indices.iter().rev() {
160        if collected >= heldout_size {
161            break;
162        }
163        let emb = &embeddings[idx];
164        let take = (heldout_size - collected).min(emb.nrows());
165        for row in emb.axis_iter(Axis(0)).take(take) {
166            heldout_embeddings.extend(row.iter());
167        }
168        collected += take;
169    }
170
171    let heldout = Array2::from_shape_vec((collected, embedding_dim), heldout_embeddings)
172        .map_err(|e| Error::IndexCreation(format!("Failed to create heldout array: {}", e)))?;
173
174    // Train codec: compute residuals and quantization parameters
175    let avg_residual = Array1::zeros(embedding_dim);
176    let initial_codec =
177        ResidualCodec::new(config.nbits, centroids.clone(), avg_residual, None, None)?;
178
179    // Compute codes for heldout samples
180    let heldout_codes = initial_codec.compress_into_codes(&heldout);
181
182    // Compute residuals
183    let mut residuals = heldout.clone();
184    for i in 0..heldout.nrows() {
185        let centroid = initial_codec.centroids.row(heldout_codes[i]);
186        for j in 0..embedding_dim {
187            residuals[[i, j]] -= centroid[j];
188        }
189    }
190
191    // Compute cluster threshold from residual distances
192    let distances: Array1<f32> = residuals
193        .axis_iter(Axis(0))
194        .map(|row| row.dot(&row).sqrt())
195        .collect();
196    #[allow(unused_variables)]
197    let cluster_threshold = quantile(&distances, 0.75);
198
199    // Compute average residual per dimension
200    let avg_res_per_dim: Array1<f32> = residuals
201        .axis_iter(Axis(1))
202        .map(|col| col.iter().map(|x| x.abs()).sum::<f32>() / col.len() as f32)
203        .collect();
204
205    // Compute quantization buckets
206    let n_options = 1 << config.nbits;
207    let quantile_values: Vec<f64> = (1..n_options)
208        .map(|i| i as f64 / n_options as f64)
209        .collect();
210    let weight_quantile_values: Vec<f64> = (0..n_options)
211        .map(|i| (i as f64 + 0.5) / n_options as f64)
212        .collect();
213
214    // Flatten residuals for quantile computation
215    let flat_residuals: Array1<f32> = residuals.iter().copied().collect();
216    let bucket_cutoffs = Array1::from_vec(quantiles(&flat_residuals, &quantile_values));
217    let bucket_weights = Array1::from_vec(quantiles(&flat_residuals, &weight_quantile_values));
218
219    let codec = ResidualCodec::new(
220        config.nbits,
221        centroids.clone(),
222        avg_res_per_dim.clone(),
223        Some(bucket_cutoffs.clone()),
224        Some(bucket_weights.clone()),
225    )?;
226
227    // Save codec components
228    use ndarray_npy::WriteNpyExt;
229
230    let centroids_path = index_dir.join("centroids.npy");
231    codec
232        .centroids_view()
233        .to_owned()
234        .write_npy(File::create(&centroids_path)?)?;
235
236    let cutoffs_path = index_dir.join("bucket_cutoffs.npy");
237    bucket_cutoffs.write_npy(File::create(&cutoffs_path)?)?;
238
239    let weights_path = index_dir.join("bucket_weights.npy");
240    bucket_weights.write_npy(File::create(&weights_path)?)?;
241
242    let avg_res_path = index_dir.join("avg_residual.npy");
243    avg_res_per_dim.write_npy(File::create(&avg_res_path)?)?;
244
245    let threshold_path = index_dir.join("cluster_threshold.npy");
246    Array1::from_vec(vec![cluster_threshold]).write_npy(File::create(&threshold_path)?)?;
247
248    // Process documents in chunks
249    let n_chunks = (num_documents as f64 / config.batch_size as f64).ceil() as usize;
250
251    // Save plan
252    let plan_path = index_dir.join("plan.json");
253    let plan = serde_json::json!({
254        "nbits": config.nbits,
255        "num_chunks": n_chunks,
256    });
257    let mut plan_file = File::create(&plan_path)?;
258    writeln!(plan_file, "{}", serde_json::to_string_pretty(&plan)?)?;
259
260    let mut all_codes: Vec<usize> = Vec::with_capacity(total_embeddings);
261    let mut doc_lengths: Vec<i64> = Vec::with_capacity(num_documents);
262
263    let progress = indicatif::ProgressBar::new(n_chunks as u64);
264    progress.set_message("Creating index...");
265
266    for chunk_idx in 0..n_chunks {
267        let start = chunk_idx * config.batch_size;
268        let end = (start + config.batch_size).min(num_documents);
269        let chunk_docs = &embeddings[start..end];
270
271        // Collect document lengths
272        let chunk_doclens: Vec<i64> = chunk_docs.iter().map(|d| d.nrows() as i64).collect();
273        let total_tokens: usize = chunk_doclens.iter().sum::<i64>() as usize;
274
275        // Concatenate all embeddings in the chunk for batch processing
276        let mut batch_embeddings = Array2::<f32>::zeros((total_tokens, embedding_dim));
277        let mut offset = 0;
278        for doc in chunk_docs {
279            let n = doc.nrows();
280            batch_embeddings
281                .slice_mut(s![offset..offset + n, ..])
282                .assign(doc);
283            offset += n;
284        }
285
286        // BATCH: Compress all embeddings at once
287        let batch_codes = codec.compress_into_codes(&batch_embeddings);
288
289        // BATCH: Compute residuals using parallel subtraction
290        let mut batch_residuals = batch_embeddings;
291        {
292            use rayon::prelude::*;
293            let centroids = &codec.centroids;
294            batch_residuals
295                .axis_iter_mut(Axis(0))
296                .into_par_iter()
297                .zip(batch_codes.as_slice().unwrap().par_iter())
298                .for_each(|(mut row, &code)| {
299                    let centroid = centroids.row(code);
300                    row.iter_mut()
301                        .zip(centroid.iter())
302                        .for_each(|(r, c)| *r -= c);
303                });
304        }
305
306        // BATCH: Quantize all residuals at once
307        let batch_packed = codec.quantize_residuals(&batch_residuals)?;
308
309        // Track codes for IVF building
310        for &len in &chunk_doclens {
311            doc_lengths.push(len);
312        }
313        all_codes.extend(batch_codes.iter().copied());
314
315        // Save chunk metadata
316        let chunk_meta = ChunkMetadata {
317            num_documents: end - start,
318            num_embeddings: batch_codes.len(),
319            embedding_offset: 0, // Will be updated later
320        };
321
322        let chunk_meta_path = index_dir.join(format!("{}.metadata.json", chunk_idx));
323        serde_json::to_writer_pretty(BufWriter::new(File::create(&chunk_meta_path)?), &chunk_meta)?;
324
325        // Save chunk doclens
326        let doclens_path = index_dir.join(format!("doclens.{}.json", chunk_idx));
327        serde_json::to_writer(BufWriter::new(File::create(&doclens_path)?), &chunk_doclens)?;
328
329        // Save chunk codes
330        let chunk_codes_arr: Array1<i64> = batch_codes.iter().map(|&x| x as i64).collect();
331        let codes_path = index_dir.join(format!("{}.codes.npy", chunk_idx));
332        chunk_codes_arr.write_npy(File::create(&codes_path)?)?;
333
334        // Save chunk residuals
335        let residuals_path = index_dir.join(format!("{}.residuals.npy", chunk_idx));
336        batch_packed.write_npy(File::create(&residuals_path)?)?;
337
338        progress.inc(1);
339    }
340    progress.finish();
341
342    // Update chunk metadata with global offsets
343    let mut current_offset = 0usize;
344    for chunk_idx in 0..n_chunks {
345        let chunk_meta_path = index_dir.join(format!("{}.metadata.json", chunk_idx));
346        let mut meta: serde_json::Value =
347            serde_json::from_reader(BufReader::new(File::open(&chunk_meta_path)?))?;
348
349        if let Some(obj) = meta.as_object_mut() {
350            obj.insert("embedding_offset".to_string(), current_offset.into());
351            let num_emb = obj["num_embeddings"].as_u64().unwrap_or(0) as usize;
352            current_offset += num_emb;
353        }
354
355        serde_json::to_writer_pretty(BufWriter::new(File::create(&chunk_meta_path)?), &meta)?;
356    }
357
358    // Build IVF (Inverted File)
359    let mut code_to_docs: BTreeMap<usize, Vec<i64>> = BTreeMap::new();
360    let mut emb_idx = 0;
361
362    for (doc_id, &len) in doc_lengths.iter().enumerate() {
363        for _ in 0..len {
364            let code = all_codes[emb_idx];
365            code_to_docs.entry(code).or_default().push(doc_id as i64);
366            emb_idx += 1;
367        }
368    }
369
370    // Deduplicate document IDs per centroid
371    let mut ivf_data: Vec<i64> = Vec::new();
372    let mut ivf_lengths: Vec<i32> = vec![0; num_centroids];
373
374    for (centroid_id, ivf_len) in ivf_lengths.iter_mut().enumerate() {
375        if let Some(docs) = code_to_docs.get(&centroid_id) {
376            let mut unique_docs: Vec<i64> = docs.clone();
377            unique_docs.sort_unstable();
378            unique_docs.dedup();
379            *ivf_len = unique_docs.len() as i32;
380            ivf_data.extend(unique_docs);
381        }
382    }
383
384    let ivf = Array1::from_vec(ivf_data);
385    let ivf_lengths = Array1::from_vec(ivf_lengths);
386
387    let ivf_path = index_dir.join("ivf.npy");
388    ivf.write_npy(File::create(&ivf_path)?)?;
389
390    let ivf_lengths_path = index_dir.join("ivf_lengths.npy");
391    ivf_lengths.write_npy(File::create(&ivf_lengths_path)?)?;
392
393    // Save global metadata
394    let metadata = Metadata {
395        num_chunks: n_chunks,
396        nbits: config.nbits,
397        num_partitions: num_centroids,
398        num_embeddings: total_embeddings,
399        avg_doclen,
400        num_documents,
401    };
402
403    let metadata_path = index_dir.join("metadata.json");
404    serde_json::to_writer_pretty(BufWriter::new(File::create(&metadata_path)?), &metadata)?;
405
406    Ok(metadata)
407}
408
409/// Create index files with automatic K-means centroid computation.
410///
411/// This is a standalone function that runs K-means to compute centroids,
412/// then creates all index files on disk.
413///
414/// # Arguments
415///
416/// * `embeddings` - List of document embeddings
417/// * `index_path` - Directory to save the index
418/// * `config` - Index configuration
419///
420/// # Returns
421///
422/// Metadata about the created index
423pub fn create_index_with_kmeans_files(
424    embeddings: &[Array2<f32>],
425    index_path: &str,
426    config: &IndexConfig,
427) -> Result<Metadata> {
428    if embeddings.is_empty() {
429        return Err(Error::IndexCreation("No documents provided".into()));
430    }
431
432    // Build K-means configuration from IndexConfig
433    let kmeans_config = ComputeKmeansConfig {
434        kmeans_niters: config.kmeans_niters,
435        max_points_per_centroid: config.max_points_per_centroid,
436        seed: config.seed.unwrap_or(42),
437        n_samples_kmeans: config.n_samples_kmeans,
438        num_partitions: None, // Let the heuristic decide
439    };
440
441    // Compute centroids using fast-plaid's approach
442    let centroids = compute_kmeans(embeddings, &kmeans_config)?;
443
444    // Create the index files
445    let metadata = create_index_files(embeddings, centroids, index_path, config)?;
446
447    // If below start_from_scratch threshold, save raw embeddings for potential rebuilds
448    if embeddings.len() <= config.start_from_scratch {
449        let index_dir = std::path::Path::new(index_path);
450        crate::update::save_embeddings_npy(index_dir, embeddings)?;
451    }
452
453    Ok(metadata)
454}
455// ============================================================================
456// Memory-Mapped Index for Low Memory Usage
457// ============================================================================
458
459/// A memory-mapped PLAID index for multi-vector search.
460///
461/// This struct uses memory-mapped files for the large arrays (codes and residuals)
462/// instead of loading them entirely into RAM. Only small tensors (centroids,
463/// bucket weights, IVF) are loaded into memory.
464///
465/// # Memory Usage
466///
467/// Only small tensors (~50 MB for SciFact 5K docs) are loaded into RAM,
468/// with code and residual data accessed via OS-managed memory mapping.
469///
470/// # Usage
471///
472/// ```ignore
473/// use next_plaid::MmapIndex;
474///
475/// let index = MmapIndex::load("/path/to/index")?;
476/// let results = index.search(&query, &params, None)?;
477/// ```
478pub struct MmapIndex {
479    /// Path to the index directory
480    pub path: String,
481    /// Index metadata
482    pub metadata: Metadata,
483    /// Residual codec for quantization/decompression
484    pub codec: ResidualCodec,
485    /// IVF data (concatenated passage IDs per centroid)
486    pub ivf: Array1<i64>,
487    /// IVF lengths (number of passages per centroid)
488    pub ivf_lengths: Array1<i32>,
489    /// IVF offsets (cumulative offsets into ivf array)
490    pub ivf_offsets: Array1<i64>,
491    /// Document lengths (number of tokens per document)
492    pub doc_lengths: Array1<i64>,
493    /// Cumulative document offsets for indexing into codes/residuals
494    pub doc_offsets: Array1<usize>,
495    /// Memory-mapped codes array (public for search access)
496    pub mmap_codes: crate::mmap::MmapNpyArray1I64,
497    /// Memory-mapped residuals array (public for search access)
498    pub mmap_residuals: crate::mmap::MmapNpyArray2U8,
499}
500
501impl MmapIndex {
502    /// Load a memory-mapped index from disk.
503    ///
504    /// This creates merged files for codes and residuals if they don't exist,
505    /// then memory-maps them for efficient access.
506    pub fn load(index_path: &str) -> Result<Self> {
507        use ndarray_npy::ReadNpyExt;
508
509        let index_dir = Path::new(index_path);
510
511        // Load metadata
512        let metadata_path = index_dir.join("metadata.json");
513        let metadata: Metadata = serde_json::from_reader(BufReader::new(
514            File::open(&metadata_path)
515                .map_err(|e| Error::IndexLoad(format!("Failed to open metadata: {}", e)))?,
516        ))?;
517
518        // Load codec with memory-mapped centroids for reduced RAM usage.
519        // Other small tensors (bucket weights, etc.) are still loaded into memory.
520        let codec = ResidualCodec::load_mmap_from_dir(index_dir)?;
521
522        // Load IVF (small tensor)
523        let ivf_path = index_dir.join("ivf.npy");
524        let ivf: Array1<i64> = Array1::read_npy(
525            File::open(&ivf_path)
526                .map_err(|e| Error::IndexLoad(format!("Failed to open ivf.npy: {}", e)))?,
527        )
528        .map_err(|e| Error::IndexLoad(format!("Failed to read ivf.npy: {}", e)))?;
529
530        let ivf_lengths_path = index_dir.join("ivf_lengths.npy");
531        let ivf_lengths: Array1<i32> = Array1::read_npy(
532            File::open(&ivf_lengths_path)
533                .map_err(|e| Error::IndexLoad(format!("Failed to open ivf_lengths.npy: {}", e)))?,
534        )
535        .map_err(|e| Error::IndexLoad(format!("Failed to read ivf_lengths.npy: {}", e)))?;
536
537        // Compute IVF offsets
538        let num_centroids = ivf_lengths.len();
539        let mut ivf_offsets = Array1::<i64>::zeros(num_centroids + 1);
540        for i in 0..num_centroids {
541            ivf_offsets[i + 1] = ivf_offsets[i] + ivf_lengths[i] as i64;
542        }
543
544        // Load document lengths from all chunks
545        let mut doc_lengths_vec: Vec<i64> = Vec::with_capacity(metadata.num_documents);
546        for chunk_idx in 0..metadata.num_chunks {
547            let doclens_path = index_dir.join(format!("doclens.{}.json", chunk_idx));
548            let chunk_doclens: Vec<i64> =
549                serde_json::from_reader(BufReader::new(File::open(&doclens_path)?))?;
550            doc_lengths_vec.extend(chunk_doclens);
551        }
552        let doc_lengths = Array1::from_vec(doc_lengths_vec);
553
554        // Compute document offsets for indexing
555        let mut doc_offsets = Array1::<usize>::zeros(doc_lengths.len() + 1);
556        for i in 0..doc_lengths.len() {
557            doc_offsets[i + 1] = doc_offsets[i] + doc_lengths[i] as usize;
558        }
559
560        // Compute padding needed for StridedTensor compatibility
561        let max_len = doc_lengths.iter().cloned().max().unwrap_or(0) as usize;
562        let last_len = *doc_lengths.last().unwrap_or(&0) as usize;
563        let padding_needed = max_len.saturating_sub(last_len);
564
565        // Create merged files if needed
566        let merged_codes_path =
567            crate::mmap::merge_codes_chunks(index_dir, metadata.num_chunks, padding_needed)?;
568        let merged_residuals_path =
569            crate::mmap::merge_residuals_chunks(index_dir, metadata.num_chunks, padding_needed)?;
570
571        // Memory-map the merged files
572        let mmap_codes = crate::mmap::MmapNpyArray1I64::from_npy_file(&merged_codes_path)?;
573        let mmap_residuals = crate::mmap::MmapNpyArray2U8::from_npy_file(&merged_residuals_path)?;
574
575        Ok(Self {
576            path: index_path.to_string(),
577            metadata,
578            codec,
579            ivf,
580            ivf_lengths,
581            ivf_offsets,
582            doc_lengths,
583            doc_offsets,
584            mmap_codes,
585            mmap_residuals,
586        })
587    }
588
589    /// Get candidate documents from IVF for given centroid indices.
590    pub fn get_candidates(&self, centroid_indices: &[usize]) -> Vec<i64> {
591        let mut candidates: Vec<i64> = Vec::new();
592
593        for &idx in centroid_indices {
594            if idx < self.ivf_lengths.len() {
595                let start = self.ivf_offsets[idx] as usize;
596                let len = self.ivf_lengths[idx] as usize;
597                candidates.extend(self.ivf.slice(s![start..start + len]).iter());
598            }
599        }
600
601        candidates.sort_unstable();
602        candidates.dedup();
603        candidates
604    }
605
606    /// Get document embeddings by decompressing codes and residuals.
607    pub fn get_document_embeddings(&self, doc_id: usize) -> Result<Array2<f32>> {
608        if doc_id >= self.doc_lengths.len() {
609            return Err(Error::Search(format!("Invalid document ID: {}", doc_id)));
610        }
611
612        let start = self.doc_offsets[doc_id];
613        let end = self.doc_offsets[doc_id + 1];
614
615        // Get codes and residuals from mmap
616        let codes_slice = self.mmap_codes.slice(start, end);
617        let residuals_view = self.mmap_residuals.slice_rows(start, end);
618
619        // Convert codes to Array1<usize>
620        let codes: Array1<usize> = Array1::from_iter(codes_slice.iter().map(|&c| c as usize));
621
622        // Convert residuals to owned Array2
623        let residuals = residuals_view.to_owned();
624
625        // Decompress
626        self.codec.decompress(&residuals, &codes.view())
627    }
628
629    /// Get codes for a batch of document IDs (for approximate scoring).
630    pub fn get_document_codes(&self, doc_ids: &[usize]) -> Vec<Vec<i64>> {
631        doc_ids
632            .iter()
633            .map(|&doc_id| {
634                if doc_id >= self.doc_lengths.len() {
635                    return vec![];
636                }
637                let start = self.doc_offsets[doc_id];
638                let end = self.doc_offsets[doc_id + 1];
639                self.mmap_codes.slice(start, end).to_vec()
640            })
641            .collect()
642    }
643
644    /// Decompress embeddings for a batch of document IDs.
645    pub fn decompress_documents(&self, doc_ids: &[usize]) -> Result<(Array2<f32>, Vec<usize>)> {
646        // Compute total tokens
647        let mut total_tokens = 0usize;
648        let mut lengths = Vec::with_capacity(doc_ids.len());
649        for &doc_id in doc_ids {
650            if doc_id >= self.doc_lengths.len() {
651                lengths.push(0);
652            } else {
653                let len = self.doc_offsets[doc_id + 1] - self.doc_offsets[doc_id];
654                lengths.push(len);
655                total_tokens += len;
656            }
657        }
658
659        if total_tokens == 0 {
660            return Ok((Array2::zeros((0, self.codec.embedding_dim())), lengths));
661        }
662
663        // Gather all codes and residuals
664        let packed_dim = self.mmap_residuals.ncols();
665        let mut all_codes = Vec::with_capacity(total_tokens);
666        let mut all_residuals = Array2::<u8>::zeros((total_tokens, packed_dim));
667        let mut offset = 0;
668
669        for &doc_id in doc_ids {
670            if doc_id >= self.doc_lengths.len() {
671                continue;
672            }
673            let start = self.doc_offsets[doc_id];
674            let end = self.doc_offsets[doc_id + 1];
675            let len = end - start;
676
677            // Append codes
678            let codes_slice = self.mmap_codes.slice(start, end);
679            all_codes.extend(codes_slice.iter().map(|&c| c as usize));
680
681            // Copy residuals
682            let residuals_view = self.mmap_residuals.slice_rows(start, end);
683            all_residuals
684                .slice_mut(s![offset..offset + len, ..])
685                .assign(&residuals_view);
686            offset += len;
687        }
688
689        let codes_arr = Array1::from_vec(all_codes);
690        let embeddings = self.codec.decompress(&all_residuals, &codes_arr.view())?;
691
692        Ok((embeddings, lengths))
693    }
694
695    /// Search for similar documents.
696    ///
697    /// # Arguments
698    ///
699    /// * `query` - Query embedding matrix [num_tokens, dim]
700    /// * `params` - Search parameters
701    /// * `subset` - Optional subset of document IDs to search within
702    ///
703    /// # Returns
704    ///
705    /// Search result containing top-k document IDs and scores.
706    pub fn search(
707        &self,
708        query: &Array2<f32>,
709        params: &crate::search::SearchParameters,
710        subset: Option<&[i64]>,
711    ) -> Result<crate::search::SearchResult> {
712        crate::search::search_one_mmap(self, query, params, subset)
713    }
714
715    /// Search for multiple queries in batch.
716    ///
717    /// # Arguments
718    ///
719    /// * `queries` - Slice of query embedding matrices
720    /// * `params` - Search parameters
721    /// * `parallel` - If true, process queries in parallel using rayon
722    /// * `subset` - Optional subset of document IDs to search within
723    ///
724    /// # Returns
725    ///
726    /// Vector of search results, one per query.
727    pub fn search_batch(
728        &self,
729        queries: &[Array2<f32>],
730        params: &crate::search::SearchParameters,
731        parallel: bool,
732        subset: Option<&[i64]>,
733    ) -> Result<Vec<crate::search::SearchResult>> {
734        crate::search::search_many_mmap(self, queries, params, parallel, subset)
735    }
736
737    /// Get the number of documents in the index.
738    pub fn num_documents(&self) -> usize {
739        self.doc_lengths.len()
740    }
741
742    /// Get the total number of embeddings in the index.
743    pub fn num_embeddings(&self) -> usize {
744        self.metadata.num_embeddings
745    }
746
747    /// Get the number of partitions (centroids).
748    pub fn num_partitions(&self) -> usize {
749        self.metadata.num_partitions
750    }
751
752    /// Get the average document length.
753    pub fn avg_doclen(&self) -> f64 {
754        self.metadata.avg_doclen
755    }
756
757    /// Get the embedding dimension.
758    pub fn embedding_dim(&self) -> usize {
759        self.codec.embedding_dim()
760    }
761
762    /// Reconstruct embeddings for specific documents.
763    ///
764    /// This method retrieves the compressed codes and residuals for each document
765    /// from memory-mapped files and decompresses them to recover the original embeddings.
766    ///
767    /// # Arguments
768    ///
769    /// * `doc_ids` - Slice of document IDs to reconstruct (0-indexed)
770    ///
771    /// # Returns
772    ///
773    /// A vector of 2D arrays, one per document. Each array has shape `[num_tokens, dim]`.
774    ///
775    /// # Example
776    ///
777    /// ```ignore
778    /// use next_plaid::MmapIndex;
779    ///
780    /// let index = MmapIndex::load("/path/to/index")?;
781    /// let embeddings = index.reconstruct(&[0, 1, 2])?;
782    ///
783    /// for (i, emb) in embeddings.iter().enumerate() {
784    ///     println!("Document {}: {} tokens x {} dim", i, emb.nrows(), emb.ncols());
785    /// }
786    /// ```
787    pub fn reconstruct(&self, doc_ids: &[i64]) -> Result<Vec<Array2<f32>>> {
788        crate::embeddings::reconstruct_embeddings(self, doc_ids)
789    }
790
791    /// Reconstruct a single document's embeddings.
792    ///
793    /// Convenience method for reconstructing a single document.
794    ///
795    /// # Arguments
796    ///
797    /// * `doc_id` - Document ID to reconstruct (0-indexed)
798    ///
799    /// # Returns
800    ///
801    /// A 2D array with shape `[num_tokens, dim]`.
802    pub fn reconstruct_single(&self, doc_id: i64) -> Result<Array2<f32>> {
803        crate::embeddings::reconstruct_single(self, doc_id)
804    }
805
806    /// Create a new index from document embeddings with automatic centroid computation.
807    ///
808    /// This method:
809    /// 1. Computes centroids using K-means
810    /// 2. Creates index files on disk
811    /// 3. Loads the index using memory-mapped I/O
812    ///
813    /// Note: During creation, data is temporarily held in RAM for processing,
814    /// then written to disk and loaded as mmap.
815    ///
816    /// # Arguments
817    ///
818    /// * `embeddings` - List of document embeddings, each of shape `[num_tokens, dim]`
819    /// * `index_path` - Directory to save the index
820    /// * `config` - Index configuration
821    ///
822    /// # Returns
823    ///
824    /// The created MmapIndex
825    pub fn create_with_kmeans(
826        embeddings: &[Array2<f32>],
827        index_path: &str,
828        config: &IndexConfig,
829    ) -> Result<Self> {
830        // Use standalone function to create files
831        create_index_with_kmeans_files(embeddings, index_path, config)?;
832
833        // Load as memory-mapped index
834        Self::load(index_path)
835    }
836
837    /// Update the index with new documents, matching fast-plaid behavior.
838    ///
839    /// This method adds new documents to an existing index with three possible paths:
840    ///
841    /// 1. **Start-from-scratch mode** (num_documents <= start_from_scratch):
842    ///    - Loads existing embeddings from `embeddings.npy` if available
843    ///    - Combines with new embeddings
844    ///    - Rebuilds the entire index from scratch with fresh K-means
845    ///    - Clears `embeddings.npy` if total exceeds threshold
846    ///
847    /// 2. **Buffer mode** (total_new < buffer_size):
848    ///    - Adds new documents to the index without centroid expansion
849    ///    - Saves embeddings to buffer for later centroid expansion
850    ///
851    /// 3. **Centroid expansion mode** (total_new >= buffer_size):
852    ///    - Deletes previously buffered documents
853    ///    - Expands centroids with outliers from combined buffer + new embeddings
854    ///    - Re-indexes all combined embeddings with expanded centroids
855    ///
856    /// # Arguments
857    ///
858    /// * `embeddings` - New document embeddings to add
859    /// * `config` - Update configuration
860    ///
861    /// # Returns
862    ///
863    /// Vector of document IDs assigned to the new embeddings
864    pub fn update(
865        &mut self,
866        embeddings: &[Array2<f32>],
867        config: &crate::update::UpdateConfig,
868    ) -> Result<Vec<i64>> {
869        use crate::codec::ResidualCodec;
870        use crate::update::{
871            clear_buffer, clear_embeddings_npy, embeddings_npy_exists, load_buffer,
872            load_buffer_info, load_cluster_threshold, load_embeddings_npy, save_buffer,
873            update_centroids, update_index,
874        };
875
876        let path_str = self.path.clone();
877        let index_path = std::path::Path::new(&path_str);
878        let num_new_docs = embeddings.len();
879
880        // ==================================================================
881        // Start-from-scratch mode (fast-plaid update.py:312-346)
882        // ==================================================================
883        if self.metadata.num_documents <= config.start_from_scratch {
884            // Load existing embeddings if available
885            let existing_embeddings = load_embeddings_npy(index_path)?;
886            // New documents start after existing documents
887            let start_doc_id = existing_embeddings.len() as i64;
888
889            // Combine existing + new embeddings
890            let combined_embeddings: Vec<Array2<f32>> = existing_embeddings
891                .into_iter()
892                .chain(embeddings.iter().cloned())
893                .collect();
894
895            // Build IndexConfig from UpdateConfig for create_with_kmeans
896            let index_config = IndexConfig {
897                nbits: self.metadata.nbits,
898                batch_size: config.batch_size,
899                seed: Some(config.seed),
900                kmeans_niters: config.kmeans_niters,
901                max_points_per_centroid: config.max_points_per_centroid,
902                n_samples_kmeans: config.n_samples_kmeans,
903                start_from_scratch: config.start_from_scratch,
904            };
905
906            // Rebuild index from scratch with fresh K-means
907            *self = Self::create_with_kmeans(&combined_embeddings, &path_str, &index_config)?;
908
909            // If we've crossed the threshold, clear embeddings.npy
910            if combined_embeddings.len() > config.start_from_scratch
911                && embeddings_npy_exists(index_path)
912            {
913                clear_embeddings_npy(index_path)?;
914            }
915
916            // Return the document IDs assigned to the new embeddings
917            return Ok((start_doc_id..start_doc_id + num_new_docs as i64).collect());
918        }
919
920        // Load buffer
921        let buffer = load_buffer(index_path)?;
922        let buffer_len = buffer.len();
923        let total_new = embeddings.len() + buffer_len;
924
925        // Track the starting document ID for the new embeddings
926        let start_doc_id: i64;
927
928        // Load codec for update operations
929        let mut codec = ResidualCodec::load_from_dir(index_path)?;
930
931        // Check buffer threshold
932        if total_new >= config.buffer_size {
933            // Centroid expansion path (matches fast-plaid update.py:376-422)
934
935            // 1. Get number of buffered docs that were previously indexed
936            let num_buffered = load_buffer_info(index_path)?;
937
938            // 2. Delete buffered docs from index (they were indexed without centroid expansion)
939            if num_buffered > 0 && self.metadata.num_documents >= num_buffered {
940                let start_del_idx = self.metadata.num_documents - num_buffered;
941                let docs_to_delete: Vec<i64> = (start_del_idx..self.metadata.num_documents)
942                    .map(|i| i as i64)
943                    .collect();
944                crate::delete::delete_from_index_keep_buffer(&docs_to_delete, &path_str)?;
945                // Reload metadata after delete
946                let metadata_path = index_path.join("metadata.json");
947                self.metadata = serde_json::from_reader(std::io::BufReader::new(
948                    std::fs::File::open(&metadata_path)?,
949                ))?;
950            }
951
952            // New embeddings start after buffer is re-indexed
953            start_doc_id = (self.metadata.num_documents + buffer_len) as i64;
954
955            // 3. Combine buffer + new embeddings
956            let combined: Vec<Array2<f32>> = buffer
957                .into_iter()
958                .chain(embeddings.iter().cloned())
959                .collect();
960
961            // 4. Expand centroids with outliers from combined embeddings
962            if let Ok(cluster_threshold) = load_cluster_threshold(index_path) {
963                let new_centroids =
964                    update_centroids(index_path, &combined, cluster_threshold, config)?;
965                if new_centroids > 0 {
966                    // Reload codec with new centroids
967                    codec = ResidualCodec::load_from_dir(index_path)?;
968                }
969            }
970
971            // 5. Clear buffer
972            clear_buffer(index_path)?;
973
974            // 6. Update index with ALL combined embeddings (buffer + new)
975            update_index(&combined, &path_str, &codec, Some(config.batch_size), true)?;
976        } else {
977            // Small update: add to buffer and index without centroid expansion
978            // New documents start at current num_documents
979            start_doc_id = self.metadata.num_documents as i64;
980
981            // Accumulate buffer: combine existing buffer with new embeddings
982            let combined_buffer: Vec<Array2<f32>> = buffer
983                .into_iter()
984                .chain(embeddings.iter().cloned())
985                .collect();
986            save_buffer(index_path, &combined_buffer)?;
987
988            // Update index without threshold update
989            update_index(
990                embeddings,
991                &path_str,
992                &codec,
993                Some(config.batch_size),
994                false,
995            )?;
996        }
997
998        // Reload self as mmap
999        *self = Self::load(&path_str)?;
1000
1001        // Return the document IDs assigned to the new embeddings
1002        Ok((start_doc_id..start_doc_id + num_new_docs as i64).collect())
1003    }
1004
1005    /// Update the index with new documents and optional metadata.
1006    ///
1007    /// # Arguments
1008    ///
1009    /// * `embeddings` - New document embeddings to add
1010    /// * `config` - Update configuration
1011    /// * `metadata` - Optional metadata for new documents
1012    ///
1013    /// # Returns
1014    ///
1015    /// Vector of document IDs assigned to the new embeddings
1016    pub fn update_with_metadata(
1017        &mut self,
1018        embeddings: &[Array2<f32>],
1019        config: &crate::update::UpdateConfig,
1020        metadata: Option<&[serde_json::Value]>,
1021    ) -> Result<Vec<i64>> {
1022        // Validate metadata length if provided
1023        if let Some(meta) = metadata {
1024            if meta.len() != embeddings.len() {
1025                return Err(Error::Config(format!(
1026                    "Metadata length ({}) must match embeddings length ({})",
1027                    meta.len(),
1028                    embeddings.len()
1029                )));
1030            }
1031        }
1032
1033        // Perform the update and get document IDs
1034        let doc_ids = self.update(embeddings, config)?;
1035
1036        // Add metadata if provided, using the assigned document IDs
1037        if let Some(meta) = metadata {
1038            crate::filtering::update(&self.path, meta, &doc_ids)?;
1039        }
1040
1041        Ok(doc_ids)
1042    }
1043
1044    /// Update an existing index or create a new one if it doesn't exist.
1045    ///
1046    /// # Arguments
1047    ///
1048    /// * `embeddings` - Document embeddings to add
1049    /// * `index_path` - Directory for the index
1050    /// * `index_config` - Configuration for index creation
1051    /// * `update_config` - Configuration for updates
1052    ///
1053    /// # Returns
1054    ///
1055    /// A tuple of (MmapIndex, `Vec<i64>`) containing the index and document IDs
1056    pub fn update_or_create(
1057        embeddings: &[Array2<f32>],
1058        index_path: &str,
1059        index_config: &IndexConfig,
1060        update_config: &crate::update::UpdateConfig,
1061    ) -> Result<(Self, Vec<i64>)> {
1062        let index_dir = std::path::Path::new(index_path);
1063        let metadata_path = index_dir.join("metadata.json");
1064
1065        if metadata_path.exists() {
1066            // Index exists, load and update
1067            let mut index = Self::load(index_path)?;
1068            let doc_ids = index.update(embeddings, update_config)?;
1069            Ok((index, doc_ids))
1070        } else {
1071            // Index doesn't exist, create new
1072            let num_docs = embeddings.len();
1073            let index = Self::create_with_kmeans(embeddings, index_path, index_config)?;
1074            let doc_ids: Vec<i64> = (0..num_docs as i64).collect();
1075            Ok((index, doc_ids))
1076        }
1077    }
1078
1079    /// Delete documents from the index.
1080    ///
1081    /// # Arguments
1082    ///
1083    /// * `doc_ids` - Slice of document IDs to delete (0-indexed)
1084    ///
1085    /// # Returns
1086    ///
1087    /// The number of documents actually deleted
1088    pub fn delete(&mut self, doc_ids: &[i64]) -> Result<usize> {
1089        self.delete_with_options(doc_ids, true)
1090    }
1091
1092    /// Delete documents from the index with control over metadata deletion.
1093    ///
1094    /// # Arguments
1095    ///
1096    /// * `doc_ids` - Slice of document IDs to delete
1097    /// * `delete_metadata` - If true, also delete from metadata.db if it exists
1098    ///
1099    /// # Returns
1100    ///
1101    /// The number of documents actually deleted
1102    pub fn delete_with_options(&mut self, doc_ids: &[i64], delete_metadata: bool) -> Result<usize> {
1103        let path = self.path.clone();
1104
1105        // Perform the deletion using standalone function
1106        let deleted = crate::delete::delete_from_index(doc_ids, &path)?;
1107
1108        // Also delete from metadata.db if requested
1109        if delete_metadata && deleted > 0 {
1110            let index_path = std::path::Path::new(&path);
1111            let db_path = index_path.join("metadata.db");
1112            if db_path.exists() {
1113                crate::filtering::delete(&path, doc_ids)?;
1114            }
1115        }
1116
1117        // Reload self as mmap
1118        *self = Self::load(&path)?;
1119
1120        Ok(deleted)
1121    }
1122}
1123
1124#[cfg(test)]
1125mod tests {
1126    use super::*;
1127
1128    #[test]
1129    fn test_index_config_default() {
1130        let config = IndexConfig::default();
1131        assert_eq!(config.nbits, 4);
1132        assert_eq!(config.batch_size, 50_000);
1133        assert_eq!(config.seed, Some(42));
1134    }
1135
1136    #[test]
1137    fn test_update_or_create_new_index() {
1138        use ndarray::Array2;
1139        use tempfile::tempdir;
1140
1141        let temp_dir = tempdir().unwrap();
1142        let index_path = temp_dir.path().to_str().unwrap();
1143
1144        // Create test embeddings (5 documents)
1145        let mut embeddings: Vec<Array2<f32>> = Vec::new();
1146        for i in 0..5 {
1147            let mut doc = Array2::<f32>::zeros((5, 32));
1148            for j in 0..5 {
1149                for k in 0..32 {
1150                    doc[[j, k]] = (i as f32 * 0.1) + (j as f32 * 0.01) + (k as f32 * 0.001);
1151                }
1152            }
1153            // Normalize rows
1154            for mut row in doc.rows_mut() {
1155                let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
1156                if norm > 0.0 {
1157                    row.iter_mut().for_each(|x| *x /= norm);
1158                }
1159            }
1160            embeddings.push(doc);
1161        }
1162
1163        let index_config = IndexConfig {
1164            nbits: 2,
1165            batch_size: 50,
1166            seed: Some(42),
1167            kmeans_niters: 2,
1168            ..Default::default()
1169        };
1170        let update_config = crate::update::UpdateConfig::default();
1171
1172        // Index doesn't exist - should create new
1173        let (index, doc_ids) =
1174            MmapIndex::update_or_create(&embeddings, index_path, &index_config, &update_config)
1175                .expect("Failed to create index");
1176
1177        assert_eq!(index.metadata.num_documents, 5);
1178        assert_eq!(doc_ids, vec![0, 1, 2, 3, 4]);
1179
1180        // Verify index was created
1181        assert!(temp_dir.path().join("metadata.json").exists());
1182        assert!(temp_dir.path().join("centroids.npy").exists());
1183    }
1184
1185    #[test]
1186    fn test_update_or_create_existing_index() {
1187        use ndarray::Array2;
1188        use tempfile::tempdir;
1189
1190        let temp_dir = tempdir().unwrap();
1191        let index_path = temp_dir.path().to_str().unwrap();
1192
1193        // Helper to create embeddings
1194        let create_embeddings = |count: usize, offset: usize| -> Vec<Array2<f32>> {
1195            let mut embeddings = Vec::new();
1196            for i in 0..count {
1197                let mut doc = Array2::<f32>::zeros((5, 32));
1198                for j in 0..5 {
1199                    for k in 0..32 {
1200                        doc[[j, k]] =
1201                            ((i + offset) as f32 * 0.1) + (j as f32 * 0.01) + (k as f32 * 0.001);
1202                    }
1203                }
1204                for mut row in doc.rows_mut() {
1205                    let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
1206                    if norm > 0.0 {
1207                        row.iter_mut().for_each(|x| *x /= norm);
1208                    }
1209                }
1210                embeddings.push(doc);
1211            }
1212            embeddings
1213        };
1214
1215        let index_config = IndexConfig {
1216            nbits: 2,
1217            batch_size: 50,
1218            seed: Some(42),
1219            kmeans_niters: 2,
1220            ..Default::default()
1221        };
1222        let update_config = crate::update::UpdateConfig::default();
1223
1224        // First call - creates index with 5 documents
1225        let embeddings1 = create_embeddings(5, 0);
1226        let (index1, doc_ids1) =
1227            MmapIndex::update_or_create(&embeddings1, index_path, &index_config, &update_config)
1228                .expect("Failed to create index");
1229        assert_eq!(index1.metadata.num_documents, 5);
1230        assert_eq!(doc_ids1, vec![0, 1, 2, 3, 4]);
1231
1232        // Second call - updates existing index with 3 more documents
1233        let embeddings2 = create_embeddings(3, 5);
1234        let (index2, doc_ids2) =
1235            MmapIndex::update_or_create(&embeddings2, index_path, &index_config, &update_config)
1236                .expect("Failed to update index");
1237        assert_eq!(index2.metadata.num_documents, 8);
1238        assert_eq!(doc_ids2, vec![5, 6, 7]);
1239    }
1240}