next_plaid/
index.rs

1//! Index creation and management for PLAID
2
3use std::collections::BTreeMap;
4use std::fs::{self, File};
5use std::io::{BufReader, BufWriter, Write};
6use std::path::Path;
7
8use ndarray::{s, Array1, Array2, Axis};
9use serde::{Deserialize, Serialize};
10
11use crate::codec::ResidualCodec;
12use crate::error::{Error, Result};
13#[cfg(feature = "npy")]
14use crate::kmeans::{compute_kmeans, ComputeKmeansConfig};
15use crate::strided_tensor::{IvfStridedTensor, StridedTensor};
16use crate::utils::{quantile, quantiles};
17
18/// Configuration for index creation
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct IndexConfig {
21    /// Number of bits for quantization (typically 2 or 4)
22    pub nbits: usize,
23    /// Batch size for processing
24    pub batch_size: usize,
25    /// Random seed for reproducibility
26    pub seed: Option<u64>,
27    /// Number of K-means iterations (default: 4)
28    #[serde(default = "default_kmeans_niters")]
29    pub kmeans_niters: usize,
30    /// Maximum number of points per centroid for K-means (default: 256)
31    #[serde(default = "default_max_points_per_centroid")]
32    pub max_points_per_centroid: usize,
33    /// Number of samples for K-means training.
34    /// If None, uses heuristic: min(1 + 16 * sqrt(120 * num_documents), num_documents)
35    #[serde(default)]
36    pub n_samples_kmeans: Option<usize>,
37    /// Threshold for start-from-scratch mode (default: 999).
38    /// When the number of documents is <= this threshold, raw embeddings are saved
39    /// to embeddings.npy for potential rebuilds during updates.
40    #[serde(default = "default_start_from_scratch")]
41    pub start_from_scratch: usize,
42}
43
44fn default_start_from_scratch() -> usize {
45    999
46}
47
48fn default_kmeans_niters() -> usize {
49    4
50}
51
52fn default_max_points_per_centroid() -> usize {
53    256
54}
55
56impl Default for IndexConfig {
57    fn default() -> Self {
58        Self {
59            nbits: 4,
60            batch_size: 50_000,
61            seed: Some(42),
62            kmeans_niters: 4,
63            max_points_per_centroid: 256,
64            n_samples_kmeans: None,
65            start_from_scratch: 999,
66        }
67    }
68}
69
70/// Metadata for the index
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct Metadata {
73    /// Number of chunks in the index
74    pub num_chunks: usize,
75    /// Number of bits for quantization
76    pub nbits: usize,
77    /// Number of partitions (centroids)
78    pub num_partitions: usize,
79    /// Total number of embeddings
80    pub num_embeddings: usize,
81    /// Average document length
82    pub avg_doclen: f64,
83    /// Total number of documents
84    pub num_documents: usize,
85}
86
87/// Chunk metadata
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct ChunkMetadata {
90    pub num_documents: usize,
91    pub num_embeddings: usize,
92    #[serde(default)]
93    pub embedding_offset: usize,
94}
95
96/// A PLAID index for multi-vector search
97pub struct Index {
98    /// Path to the index directory
99    pub path: String,
100    /// Index metadata
101    pub metadata: Metadata,
102    /// Residual codec for quantization/decompression
103    pub codec: ResidualCodec,
104    /// Inverted file: list of document IDs per centroid
105    pub ivf: Array1<i64>,
106    /// Lengths of each inverted list
107    pub ivf_lengths: Array1<i32>,
108    /// Cumulative offsets for IVF lists
109    pub ivf_offsets: Array1<i64>,
110    /// Document codes (centroid assignments) per document
111    pub doc_codes: Vec<Array1<usize>>,
112    /// Document lengths
113    pub doc_lengths: Array1<i64>,
114    /// Packed residuals per document
115    pub doc_residuals: Vec<Array2<u8>>,
116}
117
118impl Index {
119    /// Create a new index from document embeddings.
120    ///
121    /// # Arguments
122    ///
123    /// * `embeddings` - List of document embeddings, each of shape `[num_tokens, dim]`
124    /// * `centroids` - Pre-computed centroids from k-means, shape `[num_centroids, dim]`
125    /// * `index_path` - Directory to save the index
126    /// * `config` - Index configuration
127    ///
128    /// # Returns
129    ///
130    /// The created index
131    pub fn create(
132        embeddings: &[Array2<f32>],
133        centroids: Array2<f32>,
134        index_path: &str,
135        config: &IndexConfig,
136    ) -> Result<Self> {
137        let index_dir = Path::new(index_path);
138        fs::create_dir_all(index_dir)?;
139
140        let num_documents = embeddings.len();
141        let embedding_dim = centroids.ncols();
142        let num_centroids = centroids.nrows();
143
144        if num_documents == 0 {
145            return Err(Error::IndexCreation("No documents provided".into()));
146        }
147
148        // Calculate statistics
149        let total_embeddings: usize = embeddings.iter().map(|e| e.nrows()).sum();
150        let avg_doclen = total_embeddings as f64 / num_documents as f64;
151
152        // Sample documents for codec training
153        let sample_count = ((16.0 * (120.0 * num_documents as f64).sqrt()) as usize)
154            .min(num_documents)
155            .max(1);
156
157        let mut rng = if let Some(seed) = config.seed {
158            use rand::SeedableRng;
159            rand_chacha::ChaCha8Rng::seed_from_u64(seed)
160        } else {
161            use rand::SeedableRng;
162            rand_chacha::ChaCha8Rng::from_entropy()
163        };
164
165        use rand::seq::SliceRandom;
166        let mut indices: Vec<usize> = (0..num_documents).collect();
167        indices.shuffle(&mut rng);
168        let sample_indices: Vec<usize> = indices.into_iter().take(sample_count).collect();
169
170        // Collect sample embeddings for training
171        let heldout_size = (0.05 * total_embeddings as f64).min(50000.0) as usize;
172        let mut heldout_embeddings: Vec<f32> = Vec::with_capacity(heldout_size * embedding_dim);
173        let mut collected = 0;
174
175        for &idx in sample_indices.iter().rev() {
176            if collected >= heldout_size {
177                break;
178            }
179            let emb = &embeddings[idx];
180            let take = (heldout_size - collected).min(emb.nrows());
181            for row in emb.axis_iter(Axis(0)).take(take) {
182                heldout_embeddings.extend(row.iter());
183            }
184            collected += take;
185        }
186
187        let heldout = Array2::from_shape_vec((collected, embedding_dim), heldout_embeddings)
188            .map_err(|e| Error::IndexCreation(format!("Failed to create heldout array: {}", e)))?;
189
190        // Train codec: compute residuals and quantization parameters
191        let avg_residual = Array1::zeros(embedding_dim);
192        let initial_codec =
193            ResidualCodec::new(config.nbits, centroids.clone(), avg_residual, None, None)?;
194
195        // Compute codes for heldout samples
196        let heldout_codes = initial_codec.compress_into_codes(&heldout);
197
198        // Compute residuals
199        let mut residuals = heldout.clone();
200        for i in 0..heldout.nrows() {
201            let centroid = initial_codec.centroids.row(heldout_codes[i]);
202            for j in 0..embedding_dim {
203                residuals[[i, j]] -= centroid[j];
204            }
205        }
206
207        // Compute cluster threshold from residual distances
208        let distances: Array1<f32> = residuals
209            .axis_iter(Axis(0))
210            .map(|row| row.dot(&row).sqrt())
211            .collect();
212        #[allow(unused_variables)]
213        let cluster_threshold = quantile(&distances, 0.75);
214
215        // Compute average residual per dimension
216        let avg_res_per_dim: Array1<f32> = residuals
217            .axis_iter(Axis(1))
218            .map(|col| col.iter().map(|x| x.abs()).sum::<f32>() / col.len() as f32)
219            .collect();
220
221        // Compute quantization buckets
222        let n_options = 1 << config.nbits;
223        let quantile_values: Vec<f64> = (1..n_options)
224            .map(|i| i as f64 / n_options as f64)
225            .collect();
226        let weight_quantile_values: Vec<f64> = (0..n_options)
227            .map(|i| (i as f64 + 0.5) / n_options as f64)
228            .collect();
229
230        // Flatten residuals for quantile computation
231        let flat_residuals: Array1<f32> = residuals.iter().copied().collect();
232        let bucket_cutoffs = Array1::from_vec(quantiles(&flat_residuals, &quantile_values));
233        let bucket_weights = Array1::from_vec(quantiles(&flat_residuals, &weight_quantile_values));
234
235        let codec = ResidualCodec::new(
236            config.nbits,
237            centroids.clone(),
238            avg_res_per_dim.clone(),
239            Some(bucket_cutoffs.clone()),
240            Some(bucket_weights.clone()),
241        )?;
242
243        // Save codec components
244        #[cfg(feature = "npy")]
245        {
246            use ndarray_npy::WriteNpyExt;
247
248            let centroids_path = index_dir.join("centroids.npy");
249            codec
250                .centroids_view()
251                .to_owned()
252                .write_npy(File::create(&centroids_path)?)?;
253
254            let cutoffs_path = index_dir.join("bucket_cutoffs.npy");
255            bucket_cutoffs.write_npy(File::create(&cutoffs_path)?)?;
256
257            let weights_path = index_dir.join("bucket_weights.npy");
258            bucket_weights.write_npy(File::create(&weights_path)?)?;
259
260            let avg_res_path = index_dir.join("avg_residual.npy");
261            avg_res_per_dim.write_npy(File::create(&avg_res_path)?)?;
262
263            let threshold_path = index_dir.join("cluster_threshold.npy");
264            Array1::from_vec(vec![cluster_threshold]).write_npy(File::create(&threshold_path)?)?;
265        }
266
267        // Process documents in chunks
268        let n_chunks = (num_documents as f64 / config.batch_size as f64).ceil() as usize;
269
270        // Save plan
271        let plan_path = index_dir.join("plan.json");
272        let plan = serde_json::json!({
273            "nbits": config.nbits,
274            "num_chunks": n_chunks,
275        });
276        let mut plan_file = File::create(&plan_path)?;
277        writeln!(plan_file, "{}", serde_json::to_string_pretty(&plan)?)?;
278
279        let mut all_codes: Vec<usize> = Vec::with_capacity(total_embeddings);
280        let mut doc_codes: Vec<Array1<usize>> = Vec::with_capacity(num_documents);
281        let mut doc_residuals: Vec<Array2<u8>> = Vec::with_capacity(num_documents);
282        let mut doc_lengths: Vec<i64> = Vec::with_capacity(num_documents);
283
284        let progress = indicatif::ProgressBar::new(n_chunks as u64);
285        progress.set_message("Creating index...");
286
287        for chunk_idx in 0..n_chunks {
288            let start = chunk_idx * config.batch_size;
289            let end = (start + config.batch_size).min(num_documents);
290            let chunk_docs = &embeddings[start..end];
291
292            // Collect document lengths
293            let chunk_doclens: Vec<i64> = chunk_docs.iter().map(|d| d.nrows() as i64).collect();
294            let total_tokens: usize = chunk_doclens.iter().sum::<i64>() as usize;
295
296            // Concatenate all embeddings in the chunk for batch processing
297            let mut batch_embeddings = Array2::<f32>::zeros((total_tokens, embedding_dim));
298            let mut offset = 0;
299            for doc in chunk_docs {
300                let n = doc.nrows();
301                batch_embeddings
302                    .slice_mut(s![offset..offset + n, ..])
303                    .assign(doc);
304                offset += n;
305            }
306
307            // BATCH: Compress all embeddings at once
308            let batch_codes = codec.compress_into_codes(&batch_embeddings);
309
310            // BATCH: Compute residuals using parallel subtraction
311            let mut batch_residuals = batch_embeddings;
312            {
313                use rayon::prelude::*;
314                let centroids = &codec.centroids;
315                batch_residuals
316                    .axis_iter_mut(Axis(0))
317                    .into_par_iter()
318                    .zip(batch_codes.as_slice().unwrap().par_iter())
319                    .for_each(|(mut row, &code)| {
320                        let centroid = centroids.row(code);
321                        row.iter_mut()
322                            .zip(centroid.iter())
323                            .for_each(|(r, c)| *r -= c);
324                    });
325            }
326
327            // BATCH: Quantize all residuals at once
328            let batch_packed = codec.quantize_residuals(&batch_residuals)?;
329
330            // Split results back into per-document arrays
331            let mut code_offset = 0;
332            for &len in &chunk_doclens {
333                let len_usize = len as usize;
334                doc_lengths.push(len);
335
336                let codes: Array1<usize> = batch_codes
337                    .slice(s![code_offset..code_offset + len_usize])
338                    .to_owned();
339                all_codes.extend(codes.iter().copied());
340                doc_codes.push(codes);
341
342                let packed = batch_packed
343                    .slice(s![code_offset..code_offset + len_usize, ..])
344                    .to_owned();
345                doc_residuals.push(packed);
346
347                code_offset += len_usize;
348            }
349
350            let chunk_codes_list: Vec<usize> = batch_codes.iter().copied().collect();
351
352            // Save chunk metadata
353            let chunk_meta = ChunkMetadata {
354                num_documents: end - start,
355                num_embeddings: chunk_codes_list.len(),
356                embedding_offset: 0, // Will be updated later
357            };
358
359            let chunk_meta_path = index_dir.join(format!("{}.metadata.json", chunk_idx));
360            serde_json::to_writer_pretty(
361                BufWriter::new(File::create(&chunk_meta_path)?),
362                &chunk_meta,
363            )?;
364
365            // Save chunk doclens
366            let doclens_path = index_dir.join(format!("doclens.{}.json", chunk_idx));
367            serde_json::to_writer(BufWriter::new(File::create(&doclens_path)?), &chunk_doclens)?;
368
369            #[cfg(feature = "npy")]
370            {
371                use ndarray_npy::WriteNpyExt;
372
373                // Save chunk codes (already in batch form)
374                let chunk_codes_arr: Array1<i64> = batch_codes.iter().map(|&x| x as i64).collect();
375                let codes_path = index_dir.join(format!("{}.codes.npy", chunk_idx));
376                chunk_codes_arr.write_npy(File::create(&codes_path)?)?;
377
378                // Save chunk residuals (already in batch form)
379                let residuals_path = index_dir.join(format!("{}.residuals.npy", chunk_idx));
380                batch_packed.write_npy(File::create(&residuals_path)?)?;
381            }
382
383            progress.inc(1);
384        }
385        progress.finish();
386
387        // Update chunk metadata with global offsets
388        let mut current_offset = 0usize;
389        for chunk_idx in 0..n_chunks {
390            let chunk_meta_path = index_dir.join(format!("{}.metadata.json", chunk_idx));
391            let mut meta: serde_json::Value =
392                serde_json::from_reader(BufReader::new(File::open(&chunk_meta_path)?))?;
393
394            if let Some(obj) = meta.as_object_mut() {
395                obj.insert("embedding_offset".to_string(), current_offset.into());
396                let num_emb = obj["num_embeddings"].as_u64().unwrap_or(0) as usize;
397                current_offset += num_emb;
398            }
399
400            serde_json::to_writer_pretty(BufWriter::new(File::create(&chunk_meta_path)?), &meta)?;
401        }
402
403        // Build IVF (Inverted File)
404        let mut code_to_docs: BTreeMap<usize, Vec<i64>> = BTreeMap::new();
405        let mut emb_idx = 0;
406
407        for (doc_id, &len) in doc_lengths.iter().enumerate() {
408            for _ in 0..len {
409                let code = all_codes[emb_idx];
410                code_to_docs.entry(code).or_default().push(doc_id as i64);
411                emb_idx += 1;
412            }
413        }
414
415        // Deduplicate document IDs per centroid
416        let mut ivf_data: Vec<i64> = Vec::new();
417        let mut ivf_lengths: Vec<i32> = vec![0; num_centroids];
418
419        for (centroid_id, ivf_len) in ivf_lengths.iter_mut().enumerate() {
420            if let Some(docs) = code_to_docs.get(&centroid_id) {
421                let mut unique_docs: Vec<i64> = docs.clone();
422                unique_docs.sort_unstable();
423                unique_docs.dedup();
424                *ivf_len = unique_docs.len() as i32;
425                ivf_data.extend(unique_docs);
426            }
427        }
428
429        let ivf = Array1::from_vec(ivf_data);
430        let ivf_lengths = Array1::from_vec(ivf_lengths);
431
432        // Compute IVF offsets
433        let mut ivf_offsets = Array1::<i64>::zeros(num_centroids + 1);
434        for i in 0..num_centroids {
435            ivf_offsets[i + 1] = ivf_offsets[i] + ivf_lengths[i] as i64;
436        }
437
438        #[cfg(feature = "npy")]
439        {
440            use ndarray_npy::WriteNpyExt;
441
442            let ivf_path = index_dir.join("ivf.npy");
443            ivf.write_npy(File::create(&ivf_path)?)?;
444
445            let ivf_lengths_path = index_dir.join("ivf_lengths.npy");
446            ivf_lengths.write_npy(File::create(&ivf_lengths_path)?)?;
447        }
448
449        // Save global metadata
450        let metadata = Metadata {
451            num_chunks: n_chunks,
452            nbits: config.nbits,
453            num_partitions: num_centroids,
454            num_embeddings: total_embeddings,
455            avg_doclen,
456            num_documents,
457        };
458
459        let metadata_path = index_dir.join("metadata.json");
460        serde_json::to_writer_pretty(BufWriter::new(File::create(&metadata_path)?), &metadata)?;
461
462        let doc_lengths_arr = Array1::from_vec(doc_lengths);
463
464        Ok(Self {
465            path: index_path.to_string(),
466            metadata,
467            codec,
468            ivf,
469            ivf_lengths,
470            ivf_offsets,
471            doc_codes,
472            doc_lengths: doc_lengths_arr,
473            doc_residuals,
474        })
475    }
476
477    /// Create a new index from document embeddings with automatic centroid computation.
478    ///
479    /// This method implements the same logic as fast-plaid's `create()`:
480    /// 1. Computes centroids using K-means with automatic K calculation
481    /// 2. Creates the index using the computed centroids
482    /// 3. If num_documents <= start_from_scratch, saves raw embeddings for potential rebuilds
483    ///
484    /// # Arguments
485    ///
486    /// * `embeddings` - List of document embeddings, each of shape `[num_tokens, dim]`
487    /// * `index_path` - Directory to save the index
488    /// * `config` - Index configuration (includes K-means parameters)
489    ///
490    /// # Returns
491    ///
492    /// The created index
493    #[cfg(feature = "npy")]
494    pub fn create_with_kmeans(
495        embeddings: &[Array2<f32>],
496        index_path: &str,
497        config: &IndexConfig,
498    ) -> Result<Self> {
499        if embeddings.is_empty() {
500            return Err(Error::IndexCreation("No documents provided".into()));
501        }
502
503        // Build K-means configuration from IndexConfig
504        let kmeans_config = ComputeKmeansConfig {
505            kmeans_niters: config.kmeans_niters,
506            max_points_per_centroid: config.max_points_per_centroid,
507            seed: config.seed.unwrap_or(42),
508            n_samples_kmeans: config.n_samples_kmeans,
509            num_partitions: None, // Let the heuristic decide
510        };
511
512        // Compute centroids using fast-plaid's approach
513        let centroids = compute_kmeans(embeddings, &kmeans_config)?;
514
515        // Create the index with the computed centroids
516        let index = Self::create(embeddings, centroids, index_path, config)?;
517
518        // If below start_from_scratch threshold, save raw embeddings for potential rebuilds
519        // This matches fast-plaid's behavior in create() (fast_plaid.py:499-506)
520        if embeddings.len() <= config.start_from_scratch {
521            let index_dir = std::path::Path::new(index_path);
522            crate::update::save_embeddings_npy(index_dir, embeddings)?;
523        }
524
525        Ok(index)
526    }
527
528    /// Load an existing index from disk.
529    #[cfg(feature = "npy")]
530    pub fn load(index_path: &str) -> Result<Self> {
531        use ndarray_npy::ReadNpyExt;
532
533        let index_dir = Path::new(index_path);
534
535        // Load metadata
536        let metadata_path = index_dir.join("metadata.json");
537        let metadata: Metadata = serde_json::from_reader(BufReader::new(
538            File::open(&metadata_path)
539                .map_err(|e| Error::IndexLoad(format!("Failed to open metadata: {}", e)))?,
540        ))?;
541
542        // Load codec (with owned centroids for in-memory Index)
543        let codec = ResidualCodec::load_from_dir(index_dir)?;
544
545        // Load IVF
546        let ivf_path = index_dir.join("ivf.npy");
547        let ivf: Array1<i64> = Array1::read_npy(
548            File::open(&ivf_path)
549                .map_err(|e| Error::IndexLoad(format!("Failed to open ivf.npy: {}", e)))?,
550        )
551        .map_err(|e| Error::IndexLoad(format!("Failed to read ivf.npy: {}", e)))?;
552
553        let ivf_lengths_path = index_dir.join("ivf_lengths.npy");
554        let ivf_lengths: Array1<i32> = Array1::read_npy(
555            File::open(&ivf_lengths_path)
556                .map_err(|e| Error::IndexLoad(format!("Failed to open ivf_lengths.npy: {}", e)))?,
557        )
558        .map_err(|e| Error::IndexLoad(format!("Failed to read ivf_lengths.npy: {}", e)))?;
559
560        // Compute IVF offsets
561        let num_centroids = ivf_lengths.len();
562        let mut ivf_offsets = Array1::<i64>::zeros(num_centroids + 1);
563        for i in 0..num_centroids {
564            ivf_offsets[i + 1] = ivf_offsets[i] + ivf_lengths[i] as i64;
565        }
566
567        // Load document codes and residuals
568        let mut doc_codes: Vec<Array1<usize>> = Vec::with_capacity(metadata.num_documents);
569        let mut doc_residuals: Vec<Array2<u8>> = Vec::with_capacity(metadata.num_documents);
570        let mut doc_lengths: Vec<i64> = Vec::with_capacity(metadata.num_documents);
571
572        for chunk_idx in 0..metadata.num_chunks {
573            // Load doclens
574            let doclens_path = index_dir.join(format!("doclens.{}.json", chunk_idx));
575            let chunk_doclens: Vec<i64> =
576                serde_json::from_reader(BufReader::new(File::open(&doclens_path)?))?;
577
578            // Load codes
579            let codes_path = index_dir.join(format!("{}.codes.npy", chunk_idx));
580            let chunk_codes: Array1<i64> = Array1::read_npy(File::open(&codes_path)?)?;
581
582            // Load residuals
583            let residuals_path = index_dir.join(format!("{}.residuals.npy", chunk_idx));
584            let chunk_residuals: Array2<u8> = Array2::read_npy(File::open(&residuals_path)?)?;
585
586            // Split codes and residuals by document
587            let mut code_offset = 0;
588            for &len in &chunk_doclens {
589                let len_usize = len as usize;
590                doc_lengths.push(len);
591
592                let codes: Array1<usize> = chunk_codes
593                    .slice(s![code_offset..code_offset + len_usize])
594                    .iter()
595                    .map(|&x| x as usize)
596                    .collect();
597                doc_codes.push(codes);
598
599                let res = chunk_residuals
600                    .slice(s![code_offset..code_offset + len_usize, ..])
601                    .to_owned();
602                doc_residuals.push(res);
603
604                code_offset += len_usize;
605            }
606        }
607
608        Ok(Self {
609            path: index_path.to_string(),
610            metadata,
611            codec,
612            ivf,
613            ivf_lengths,
614            ivf_offsets,
615            doc_codes,
616            doc_lengths: Array1::from_vec(doc_lengths),
617            doc_residuals,
618        })
619    }
620
621    /// Get candidate documents from IVF for given centroid indices.
622    pub fn get_candidates(&self, centroid_indices: &[usize]) -> Vec<i64> {
623        let mut candidates: Vec<i64> = Vec::new();
624
625        for &idx in centroid_indices {
626            if idx < self.ivf_lengths.len() {
627                let start = self.ivf_offsets[idx] as usize;
628                let len = self.ivf_lengths[idx] as usize;
629                candidates.extend(self.ivf.slice(s![start..start + len]).iter());
630            }
631        }
632
633        candidates.sort_unstable();
634        candidates.dedup();
635        candidates
636    }
637
638    /// Get document embeddings by decompressing codes and residuals.
639    pub fn get_document_embeddings(&self, doc_id: usize) -> Result<Array2<f32>> {
640        if doc_id >= self.doc_codes.len() {
641            return Err(Error::Search(format!("Invalid document ID: {}", doc_id)));
642        }
643
644        let codes = &self.doc_codes[doc_id];
645        let residuals = &self.doc_residuals[doc_id];
646
647        self.codec.decompress(residuals, &codes.view())
648    }
649
650    /// Reconstruct embeddings for specific documents.
651    ///
652    /// This is a convenience method that converts the Index to a LoadedIndex
653    /// and then reconstructs the embeddings. For repeated reconstruction calls,
654    /// consider using `LoadedIndex::reconstruct()` directly.
655    ///
656    /// # Arguments
657    ///
658    /// * `doc_ids` - Slice of document IDs to reconstruct (0-indexed)
659    ///
660    /// # Returns
661    ///
662    /// A vector of 2D arrays, one per document. Each array has shape `[num_tokens, dim]`.
663    ///
664    /// # Example
665    ///
666    /// ```ignore
667    /// use next_plaid::Index;
668    ///
669    /// let index = Index::load("/path/to/index")?;
670    /// let embeddings = index.reconstruct(&[0, 1, 2])?;
671    /// ```
672    pub fn reconstruct(&self, doc_ids: &[i64]) -> Result<Vec<Array2<f32>>> {
673        // For Index, we can directly decompress per-document
674        use rayon::prelude::*;
675
676        let num_documents = self.doc_codes.len();
677
678        // Validate document IDs
679        for &doc_id in doc_ids {
680            if doc_id < 0 || doc_id as usize >= num_documents {
681                return Err(Error::Search(format!(
682                    "Invalid document ID: {} (index has {} documents)",
683                    doc_id, num_documents
684                )));
685            }
686        }
687
688        // Process documents in parallel
689        doc_ids
690            .par_iter()
691            .map(|&doc_id| self.get_document_embeddings(doc_id as usize))
692            .collect()
693    }
694
695    /// Reconstruct a single document's embeddings.
696    ///
697    /// Convenience method for reconstructing a single document.
698    pub fn reconstruct_single(&self, doc_id: i64) -> Result<Array2<f32>> {
699        self.get_document_embeddings(doc_id as usize)
700    }
701
702    /// Update the index with new documents, matching fast-plaid behavior.
703    ///
704    /// This method adds new documents to an existing index with three possible paths:
705    ///
706    /// 1. **Start-from-scratch mode** (num_documents <= start_from_scratch):
707    ///    - Loads existing embeddings from `embeddings.npy` if available
708    ///    - Combines with new embeddings
709    ///    - Rebuilds the entire index from scratch with fresh K-means
710    ///    - Clears `embeddings.npy` if total exceeds threshold
711    ///
712    /// 2. **Buffer mode** (total_new < buffer_size):
713    ///    - Adds new documents to the index without centroid expansion
714    ///    - Saves embeddings to buffer for later centroid expansion
715    ///
716    /// 3. **Centroid expansion mode** (total_new >= buffer_size):
717    ///    - Deletes previously buffered documents
718    ///    - Expands centroids with outliers from combined buffer + new embeddings
719    ///    - Re-indexes all combined embeddings with expanded centroids
720    ///
721    /// # Arguments
722    ///
723    /// * `embeddings` - New document embeddings to add
724    /// * `config` - Update configuration
725    ///
726    /// # Returns
727    ///
728    /// Vector of document IDs assigned to the new embeddings, with self reloaded to reflect changes.
729    #[cfg(feature = "npy")]
730    pub fn update(
731        &mut self,
732        embeddings: &[Array2<f32>],
733        config: &crate::update::UpdateConfig,
734    ) -> Result<Vec<i64>> {
735        use crate::update::{
736            clear_buffer, clear_embeddings_npy, embeddings_npy_exists, load_buffer,
737            load_buffer_info, load_cluster_threshold, load_embeddings_npy, save_buffer,
738            update_centroids, update_index,
739        };
740
741        // Clone path to avoid borrow issues when reassigning self
742        let path_str = self.path.clone();
743        let index_path = std::path::Path::new(&path_str);
744        let num_new_docs = embeddings.len();
745
746        // ==================================================================
747        // Start-from-scratch mode (fast-plaid update.py:312-346)
748        // ==================================================================
749        if self.metadata.num_documents <= config.start_from_scratch {
750            // Load existing embeddings if available
751            let existing_embeddings = load_embeddings_npy(index_path)?;
752            // New documents start after existing documents
753            let start_doc_id = existing_embeddings.len() as i64;
754
755            // Combine existing + new embeddings
756            let combined_embeddings: Vec<Array2<f32>> = existing_embeddings
757                .into_iter()
758                .chain(embeddings.iter().cloned())
759                .collect();
760
761            // Build IndexConfig from UpdateConfig for create_with_kmeans
762            let index_config = crate::index::IndexConfig {
763                nbits: self.metadata.nbits,
764                batch_size: config.batch_size,
765                seed: Some(config.seed),
766                kmeans_niters: config.kmeans_niters,
767                max_points_per_centroid: config.max_points_per_centroid,
768                n_samples_kmeans: config.n_samples_kmeans,
769                start_from_scratch: config.start_from_scratch,
770            };
771
772            // Rebuild index from scratch with fresh K-means
773            // Note: create_with_kmeans will save embeddings.npy if below threshold
774            *self = Index::create_with_kmeans(&combined_embeddings, &path_str, &index_config)?;
775
776            // If we've crossed the threshold, clear embeddings.npy
777            // (create_with_kmeans won't save it if above threshold)
778            if combined_embeddings.len() > config.start_from_scratch
779                && embeddings_npy_exists(index_path)
780            {
781                clear_embeddings_npy(index_path)?;
782            }
783
784            // Return the document IDs assigned to the new embeddings
785            return Ok((start_doc_id..start_doc_id + num_new_docs as i64).collect());
786        }
787
788        // Load buffer
789        let buffer = load_buffer(index_path)?;
790        let buffer_len = buffer.len();
791        let total_new = embeddings.len() + buffer_len;
792
793        // Track the starting document ID for the new embeddings
794        let start_doc_id: i64;
795
796        // Check buffer threshold
797        if total_new >= config.buffer_size {
798            // Centroid expansion path (matches fast-plaid update.py:376-422)
799
800            // 1. Get number of buffered docs that were previously indexed
801            let num_buffered = load_buffer_info(index_path)?;
802
803            // 2. Delete buffered docs from index (they were indexed without centroid expansion)
804            //    This matches fast-plaid's delete_fn(subset=documents_to_delete)
805            if num_buffered > 0 && self.metadata.num_documents >= num_buffered {
806                let start_del_idx = self.metadata.num_documents - num_buffered;
807                let docs_to_delete: Vec<i64> = (start_del_idx..self.metadata.num_documents)
808                    .map(|i| i as i64)
809                    .collect();
810                crate::delete::delete_from_index(&docs_to_delete, &path_str)?;
811                // Reload after delete to get updated metadata
812                *self = Index::load(&path_str)?;
813            }
814
815            // New embeddings start after buffer is re-indexed
816            start_doc_id = (self.metadata.num_documents + buffer_len) as i64;
817
818            // 3. Combine buffer + new embeddings
819            let combined: Vec<Array2<f32>> = buffer
820                .into_iter()
821                .chain(embeddings.iter().cloned())
822                .collect();
823
824            // 4. Expand centroids with outliers from combined embeddings
825            if let Ok(cluster_threshold) = load_cluster_threshold(index_path) {
826                let new_centroids =
827                    update_centroids(index_path, &combined, cluster_threshold, config)?;
828                if new_centroids > 0 {
829                    // Reload codec with new centroids
830                    self.codec = ResidualCodec::load_from_dir(index_path)?;
831                }
832            }
833
834            // 5. Clear buffer
835            clear_buffer(index_path)?;
836
837            // 6. Update index with ALL combined embeddings (buffer + new)
838            update_index(
839                &combined,
840                &path_str,
841                &self.codec,
842                Some(config.batch_size),
843                true,
844            )?;
845        } else {
846            // Small update: add to buffer and index without centroid expansion
847            // New documents start at current num_documents
848            start_doc_id = self.metadata.num_documents as i64;
849
850            save_buffer(index_path, embeddings)?;
851
852            // Update index without threshold update
853            update_index(
854                embeddings,
855                &path_str,
856                &self.codec,
857                Some(config.batch_size),
858                false,
859            )?;
860        }
861
862        // Reload the index
863        *self = Index::load(&path_str)?;
864
865        // Return the document IDs assigned to the new embeddings
866        Ok((start_doc_id..start_doc_id + num_new_docs as i64).collect())
867    }
868
869    /// Update the index with new documents and optional metadata.
870    ///
871    /// This method adds new documents and their metadata to an existing index.
872    /// Uses the same buffer mechanism as `update()`, but also manages the
873    /// metadata database.
874    ///
875    /// # Arguments
876    ///
877    /// * `embeddings` - New document embeddings to add
878    /// * `config` - Update configuration
879    /// * `metadata` - Optional metadata for new documents (must match embeddings length)
880    ///
881    /// # Returns
882    ///
883    /// Vector of document IDs assigned to the new embeddings.
884    #[cfg(all(feature = "npy", feature = "filtering"))]
885    pub fn update_with_metadata(
886        &mut self,
887        embeddings: &[Array2<f32>],
888        config: &crate::update::UpdateConfig,
889        metadata: Option<&[serde_json::Value]>,
890    ) -> Result<Vec<i64>> {
891        // Validate metadata length if provided
892        if let Some(meta) = metadata {
893            if meta.len() != embeddings.len() {
894                return Err(Error::Config(format!(
895                    "Metadata length ({}) must match embeddings length ({})",
896                    meta.len(),
897                    embeddings.len()
898                )));
899            }
900        }
901
902        // Perform the update and get document IDs
903        let doc_ids = self.update(embeddings, config)?;
904
905        // Add metadata if provided, using the assigned document IDs
906        if let Some(meta) = metadata {
907            crate::filtering::update(&self.path, meta, &doc_ids)?;
908        }
909
910        Ok(doc_ids)
911    }
912
913    /// Update an existing index or create a new one if it doesn't exist.
914    ///
915    /// This is a convenience method that:
916    /// 1. Checks if an index exists at the given path
917    /// 2. If not, creates a new index with `create_with_kmeans`
918    /// 3. If yes, loads the index and calls `update`
919    ///
920    /// # Arguments
921    ///
922    /// * `embeddings` - Document embeddings to add
923    /// * `index_path` - Directory for the index
924    /// * `index_config` - Configuration for index creation (used if creating new)
925    /// * `update_config` - Configuration for updates (used if updating existing)
926    ///
927    /// # Returns
928    ///
929    /// A tuple of (Index, `Vec<i64>`) containing the created/updated Index and
930    /// the document IDs assigned to the embeddings.
931    ///
932    /// # Example
933    ///
934    /// ```ignore
935    /// use next_plaid::{Index, IndexConfig, UpdateConfig};
936    ///
937    /// let embeddings: Vec<Array2<f32>> = load_embeddings();
938    /// let index_config = IndexConfig::default();
939    /// let update_config = UpdateConfig::default();
940    ///
941    /// // Creates index if it doesn't exist, otherwise updates it
942    /// let (index, doc_ids) = Index::update_or_create(
943    ///     &embeddings,
944    ///     "path/to/index",
945    ///     &index_config,
946    ///     &update_config,
947    /// )?;
948    /// ```
949    #[cfg(feature = "npy")]
950    pub fn update_or_create(
951        embeddings: &[Array2<f32>],
952        index_path: &str,
953        index_config: &IndexConfig,
954        update_config: &crate::update::UpdateConfig,
955    ) -> Result<(Self, Vec<i64>)> {
956        let index_dir = std::path::Path::new(index_path);
957        let metadata_path = index_dir.join("metadata.json");
958
959        if metadata_path.exists() {
960            // Index exists, load and update
961            let mut index = Self::load(index_path)?;
962            let doc_ids = index.update(embeddings, update_config)?;
963            Ok((index, doc_ids))
964        } else {
965            // Index doesn't exist, create new - document IDs are 0..num_embeddings
966            let num_docs = embeddings.len();
967            let index = Self::create_with_kmeans(embeddings, index_path, index_config)?;
968            let doc_ids: Vec<i64> = (0..num_docs as i64).collect();
969            Ok((index, doc_ids))
970        }
971    }
972
973    /// Simple update without buffer mechanism.
974    ///
975    /// This directly adds documents to the index without centroid expansion
976    /// or buffer management. Useful for testing or when you want full control.
977    #[cfg(feature = "npy")]
978    pub fn update_simple(
979        &mut self,
980        embeddings: &[Array2<f32>],
981        batch_size: Option<usize>,
982    ) -> Result<()> {
983        crate::update::update_index(embeddings, &self.path, &self.codec, batch_size, true)?;
984
985        // Reload the index
986        *self = Index::load(&self.path)?;
987        Ok(())
988    }
989
990    /// Delete documents from the index.
991    ///
992    /// This removes the specified documents by rewriting the index chunks
993    /// and rebuilding the IVF. The index is reloaded after deletion.
994    ///
995    /// # Arguments
996    ///
997    /// * `doc_ids` - Slice of document IDs to delete (0-indexed)
998    ///
999    /// # Returns
1000    ///
1001    /// The number of documents actually deleted.
1002    ///
1003    /// # Example
1004    ///
1005    /// ```ignore
1006    /// use next_plaid::Index;
1007    ///
1008    /// let mut index = Index::load("/path/to/index")?;
1009    /// let deleted = index.delete(&[2, 5, 7])?;
1010    /// println!("Deleted {} documents", deleted);
1011    /// ```
1012    #[cfg(feature = "npy")]
1013    pub fn delete(&mut self, doc_ids: &[i64]) -> Result<usize> {
1014        self.delete_with_options(doc_ids, true)
1015    }
1016
1017    /// Delete documents from the index with control over metadata deletion.
1018    ///
1019    /// This is useful when you want to delete documents without affecting the
1020    /// metadata database (e.g., during buffer management in updates).
1021    ///
1022    /// # Arguments
1023    ///
1024    /// * `doc_ids` - Slice of document IDs to delete (0-indexed)
1025    /// * `delete_metadata` - If true, also delete from metadata.db if it exists
1026    ///
1027    /// # Returns
1028    ///
1029    /// The number of documents actually deleted.
1030    #[cfg(feature = "npy")]
1031    pub fn delete_with_options(
1032        &mut self,
1033        doc_ids: &[i64],
1034        #[allow(unused_variables)] delete_metadata: bool,
1035    ) -> Result<usize> {
1036        let deleted = crate::delete::delete_from_index(doc_ids, &self.path)?;
1037
1038        // Delete from metadata database if it exists and requested
1039        #[cfg(feature = "filtering")]
1040        if delete_metadata && crate::filtering::exists(&self.path) {
1041            crate::filtering::delete(&self.path, doc_ids)?;
1042        }
1043
1044        // Reload the index
1045        *self = Index::load(&self.path)?;
1046        Ok(deleted)
1047    }
1048}
1049
1050/// A loaded index optimized for search with StridedTensor storage.
1051///
1052/// This struct contains all data required for search operations, stored in
1053/// a format optimized for batch lookups. It uses `StridedTensor` for efficient
1054/// retrieval of variable-length sequences.
1055pub struct LoadedIndex {
1056    /// Index metadata
1057    pub metadata: Metadata,
1058    /// Residual codec for quantization/decompression
1059    pub codec: ResidualCodec,
1060    /// IVF index mapping centroids to document IDs
1061    pub ivf: IvfStridedTensor,
1062    /// Document codes (centroid assignments) stored contiguously
1063    pub doc_codes: StridedTensor<usize>,
1064    /// Packed residuals stored contiguously
1065    pub doc_residuals: StridedTensor<u8>,
1066    /// Number of bits for quantization
1067    pub nbits: usize,
1068}
1069
1070impl LoadedIndex {
1071    /// Create a LoadedIndex from an existing Index.
1072    ///
1073    /// This converts the Index's per-document storage to contiguous StridedTensor storage.
1074    pub fn from_index(index: Index) -> Self {
1075        let embedding_dim = index.codec.embedding_dim();
1076        let packed_dim = embedding_dim * index.metadata.nbits / 8;
1077        let num_documents = index.doc_codes.len();
1078
1079        // Convert doc_codes to contiguous storage
1080        let total_codes: usize = index.doc_lengths.iter().sum::<i64>() as usize;
1081        let mut codes_data = Array2::<usize>::zeros((total_codes, 1));
1082        let mut offset = 0;
1083
1084        for codes in &index.doc_codes {
1085            for (j, &code) in codes.iter().enumerate() {
1086                codes_data[[offset + j, 0]] = code;
1087            }
1088            offset += codes.len();
1089        }
1090
1091        let doc_codes = StridedTensor::new(codes_data, index.doc_lengths.clone());
1092
1093        // Convert doc_residuals to contiguous storage
1094        let mut residuals_data = Array2::<u8>::zeros((total_codes, packed_dim));
1095        offset = 0;
1096
1097        for residuals in &index.doc_residuals {
1098            residuals_data
1099                .slice_mut(s![offset..offset + residuals.nrows(), ..])
1100                .assign(residuals);
1101            offset += residuals.nrows();
1102        }
1103
1104        let doc_residuals = StridedTensor::new(residuals_data, index.doc_lengths.clone());
1105
1106        // Convert IVF to IvfStridedTensor
1107        let ivf = IvfStridedTensor::new(index.ivf, index.ivf_lengths);
1108
1109        Self {
1110            metadata: index.metadata,
1111            codec: index.codec,
1112            ivf,
1113            doc_codes,
1114            doc_residuals,
1115            nbits: num_documents, // Will be set below
1116        }
1117    }
1118
1119    /// Load a LoadedIndex from disk.
1120    #[cfg(feature = "npy")]
1121    pub fn load(index_path: &str) -> Result<Self> {
1122        let index = Index::load(index_path)?;
1123        let nbits = index.metadata.nbits;
1124        let mut loaded = Self::from_index(index);
1125        loaded.nbits = nbits;
1126        Ok(loaded)
1127    }
1128
1129    /// Get candidate documents from IVF for given centroid indices.
1130    pub fn get_candidates(&self, centroid_indices: &[usize]) -> Vec<i64> {
1131        self.ivf.lookup(centroid_indices)
1132    }
1133
1134    /// Lookup codes and residuals for a batch of document IDs.
1135    ///
1136    /// Returns (codes, residuals, lengths) for efficient batch decompression.
1137    pub fn lookup_documents(&self, doc_ids: &[usize]) -> (Array1<usize>, Array2<u8>, Array1<i64>) {
1138        let (codes, lengths) = self.doc_codes.lookup_codes(doc_ids);
1139        let (residuals, _) = self.doc_residuals.lookup_2d(doc_ids);
1140        (codes, residuals, lengths)
1141    }
1142
1143    /// Decompress embeddings for a batch of document IDs.
1144    ///
1145    /// Returns the decompressed embeddings as a contiguous array along with lengths.
1146    pub fn decompress_documents(&self, doc_ids: &[usize]) -> Result<(Array2<f32>, Array1<i64>)> {
1147        let (codes, residuals, lengths) = self.lookup_documents(doc_ids);
1148
1149        // Decompress using the codec
1150        let embeddings = self.codec.decompress(&residuals, &codes.view())?;
1151
1152        Ok((embeddings, lengths))
1153    }
1154
1155    /// Get the number of documents in the index.
1156    pub fn num_documents(&self) -> usize {
1157        self.doc_codes.len()
1158    }
1159
1160    /// Get the embedding dimension.
1161    pub fn embedding_dim(&self) -> usize {
1162        self.codec.embedding_dim()
1163    }
1164
1165    /// Reconstruct embeddings for specific documents.
1166    ///
1167    /// This method retrieves the compressed codes and residuals for each document
1168    /// and decompresses them to recover the original embeddings.
1169    ///
1170    /// # Arguments
1171    ///
1172    /// * `doc_ids` - Slice of document IDs to reconstruct (0-indexed)
1173    ///
1174    /// # Returns
1175    ///
1176    /// A vector of 2D arrays, one per document. Each array has shape `[num_tokens, dim]`.
1177    ///
1178    /// # Example
1179    ///
1180    /// ```ignore
1181    /// use next_plaid::LoadedIndex;
1182    ///
1183    /// let index = LoadedIndex::load("/path/to/index")?;
1184    /// let embeddings = index.reconstruct(&[0, 1, 2])?;
1185    ///
1186    /// for (i, emb) in embeddings.iter().enumerate() {
1187    ///     println!("Document {}: {} tokens x {} dim", i, emb.nrows(), emb.ncols());
1188    /// }
1189    /// ```
1190    pub fn reconstruct(&self, doc_ids: &[i64]) -> Result<Vec<Array2<f32>>> {
1191        crate::embeddings::reconstruct_embeddings(self, doc_ids)
1192    }
1193
1194    /// Reconstruct a single document's embeddings.
1195    ///
1196    /// Convenience method for reconstructing a single document.
1197    ///
1198    /// # Arguments
1199    ///
1200    /// * `doc_id` - Document ID to reconstruct (0-indexed)
1201    ///
1202    /// # Returns
1203    ///
1204    /// A 2D array with shape `[num_tokens, dim]`.
1205    pub fn reconstruct_single(&self, doc_id: i64) -> Result<Array2<f32>> {
1206        crate::embeddings::reconstruct_single(self, doc_id)
1207    }
1208}
1209
1210// ============================================================================
1211// Memory-Mapped Index for Low Memory Usage
1212// ============================================================================
1213
1214/// A memory-mapped index optimized for low memory usage.
1215///
1216/// This struct uses memory-mapped files for the large arrays (codes and residuals)
1217/// instead of loading them entirely into RAM. Only small tensors (centroids,
1218/// bucket weights, IVF) are loaded into memory.
1219///
1220/// # Memory Usage
1221///
1222/// For a typical index:
1223/// - `Index`: All data in RAM (~491 MB for SciFact 5K docs)
1224/// - `MmapIndex`: Only small tensors (~50 MB) + OS-managed mmap
1225///
1226/// # Usage
1227///
1228/// ```ignore
1229/// use next_plaid::MmapIndex;
1230///
1231/// let index = MmapIndex::load("/path/to/index")?;
1232/// let results = index.search(&query, &params, None)?;
1233/// ```
1234#[cfg(feature = "npy")]
1235pub struct MmapIndex {
1236    /// Path to the index directory
1237    pub path: String,
1238    /// Index metadata
1239    pub metadata: Metadata,
1240    /// Residual codec for quantization/decompression
1241    pub codec: ResidualCodec,
1242    /// IVF data (concatenated passage IDs per centroid)
1243    pub ivf: Array1<i64>,
1244    /// IVF lengths (number of passages per centroid)
1245    pub ivf_lengths: Array1<i32>,
1246    /// IVF offsets (cumulative offsets into ivf array)
1247    pub ivf_offsets: Array1<i64>,
1248    /// Document lengths (number of tokens per document)
1249    pub doc_lengths: Array1<i64>,
1250    /// Cumulative document offsets for indexing into codes/residuals
1251    pub doc_offsets: Array1<usize>,
1252    /// Memory-mapped codes array (public for search access)
1253    pub mmap_codes: crate::mmap::MmapNpyArray1I64,
1254    /// Memory-mapped residuals array (public for search access)
1255    pub mmap_residuals: crate::mmap::MmapNpyArray2U8,
1256}
1257
1258#[cfg(feature = "npy")]
1259impl MmapIndex {
1260    /// Load a memory-mapped index from disk.
1261    ///
1262    /// This creates merged files for codes and residuals if they don't exist,
1263    /// then memory-maps them for efficient access.
1264    pub fn load(index_path: &str) -> Result<Self> {
1265        use ndarray_npy::ReadNpyExt;
1266
1267        let index_dir = Path::new(index_path);
1268
1269        // Load metadata
1270        let metadata_path = index_dir.join("metadata.json");
1271        let metadata: Metadata = serde_json::from_reader(BufReader::new(
1272            File::open(&metadata_path)
1273                .map_err(|e| Error::IndexLoad(format!("Failed to open metadata: {}", e)))?,
1274        ))?;
1275
1276        // Load codec with memory-mapped centroids for reduced RAM usage.
1277        // Other small tensors (bucket weights, etc.) are still loaded into memory.
1278        let codec = ResidualCodec::load_mmap_from_dir(index_dir)?;
1279
1280        // Load IVF (small tensor)
1281        let ivf_path = index_dir.join("ivf.npy");
1282        let ivf: Array1<i64> = Array1::read_npy(
1283            File::open(&ivf_path)
1284                .map_err(|e| Error::IndexLoad(format!("Failed to open ivf.npy: {}", e)))?,
1285        )
1286        .map_err(|e| Error::IndexLoad(format!("Failed to read ivf.npy: {}", e)))?;
1287
1288        let ivf_lengths_path = index_dir.join("ivf_lengths.npy");
1289        let ivf_lengths: Array1<i32> = Array1::read_npy(
1290            File::open(&ivf_lengths_path)
1291                .map_err(|e| Error::IndexLoad(format!("Failed to open ivf_lengths.npy: {}", e)))?,
1292        )
1293        .map_err(|e| Error::IndexLoad(format!("Failed to read ivf_lengths.npy: {}", e)))?;
1294
1295        // Compute IVF offsets
1296        let num_centroids = ivf_lengths.len();
1297        let mut ivf_offsets = Array1::<i64>::zeros(num_centroids + 1);
1298        for i in 0..num_centroids {
1299            ivf_offsets[i + 1] = ivf_offsets[i] + ivf_lengths[i] as i64;
1300        }
1301
1302        // Load document lengths from all chunks
1303        let mut doc_lengths_vec: Vec<i64> = Vec::with_capacity(metadata.num_documents);
1304        for chunk_idx in 0..metadata.num_chunks {
1305            let doclens_path = index_dir.join(format!("doclens.{}.json", chunk_idx));
1306            let chunk_doclens: Vec<i64> =
1307                serde_json::from_reader(BufReader::new(File::open(&doclens_path)?))?;
1308            doc_lengths_vec.extend(chunk_doclens);
1309        }
1310        let doc_lengths = Array1::from_vec(doc_lengths_vec);
1311
1312        // Compute document offsets for indexing
1313        let mut doc_offsets = Array1::<usize>::zeros(doc_lengths.len() + 1);
1314        for i in 0..doc_lengths.len() {
1315            doc_offsets[i + 1] = doc_offsets[i] + doc_lengths[i] as usize;
1316        }
1317
1318        // Compute padding needed for StridedTensor compatibility
1319        let max_len = doc_lengths.iter().cloned().max().unwrap_or(0) as usize;
1320        let last_len = *doc_lengths.last().unwrap_or(&0) as usize;
1321        let padding_needed = max_len.saturating_sub(last_len);
1322
1323        // Create merged files if needed
1324        let merged_codes_path =
1325            crate::mmap::merge_codes_chunks(index_dir, metadata.num_chunks, padding_needed)?;
1326        let merged_residuals_path =
1327            crate::mmap::merge_residuals_chunks(index_dir, metadata.num_chunks, padding_needed)?;
1328
1329        // Memory-map the merged files
1330        let mmap_codes = crate::mmap::MmapNpyArray1I64::from_npy_file(&merged_codes_path)?;
1331        let mmap_residuals = crate::mmap::MmapNpyArray2U8::from_npy_file(&merged_residuals_path)?;
1332
1333        Ok(Self {
1334            path: index_path.to_string(),
1335            metadata,
1336            codec,
1337            ivf,
1338            ivf_lengths,
1339            ivf_offsets,
1340            doc_lengths,
1341            doc_offsets,
1342            mmap_codes,
1343            mmap_residuals,
1344        })
1345    }
1346
1347    /// Get candidate documents from IVF for given centroid indices.
1348    pub fn get_candidates(&self, centroid_indices: &[usize]) -> Vec<i64> {
1349        let mut candidates: Vec<i64> = Vec::new();
1350
1351        for &idx in centroid_indices {
1352            if idx < self.ivf_lengths.len() {
1353                let start = self.ivf_offsets[idx] as usize;
1354                let len = self.ivf_lengths[idx] as usize;
1355                candidates.extend(self.ivf.slice(s![start..start + len]).iter());
1356            }
1357        }
1358
1359        candidates.sort_unstable();
1360        candidates.dedup();
1361        candidates
1362    }
1363
1364    /// Get document embeddings by decompressing codes and residuals.
1365    pub fn get_document_embeddings(&self, doc_id: usize) -> Result<Array2<f32>> {
1366        if doc_id >= self.doc_lengths.len() {
1367            return Err(Error::Search(format!("Invalid document ID: {}", doc_id)));
1368        }
1369
1370        let start = self.doc_offsets[doc_id];
1371        let end = self.doc_offsets[doc_id + 1];
1372
1373        // Get codes and residuals from mmap
1374        let codes_slice = self.mmap_codes.slice(start, end);
1375        let residuals_view = self.mmap_residuals.slice_rows(start, end);
1376
1377        // Convert codes to Array1<usize>
1378        let codes: Array1<usize> = Array1::from_iter(codes_slice.iter().map(|&c| c as usize));
1379
1380        // Convert residuals to owned Array2
1381        let residuals = residuals_view.to_owned();
1382
1383        // Decompress
1384        self.codec.decompress(&residuals, &codes.view())
1385    }
1386
1387    /// Get codes for a batch of document IDs (for approximate scoring).
1388    pub fn get_document_codes(&self, doc_ids: &[usize]) -> Vec<Vec<i64>> {
1389        doc_ids
1390            .iter()
1391            .map(|&doc_id| {
1392                if doc_id >= self.doc_lengths.len() {
1393                    return vec![];
1394                }
1395                let start = self.doc_offsets[doc_id];
1396                let end = self.doc_offsets[doc_id + 1];
1397                self.mmap_codes.slice(start, end).to_vec()
1398            })
1399            .collect()
1400    }
1401
1402    /// Decompress embeddings for a batch of document IDs.
1403    pub fn decompress_documents(&self, doc_ids: &[usize]) -> Result<(Array2<f32>, Vec<usize>)> {
1404        // Compute total tokens
1405        let mut total_tokens = 0usize;
1406        let mut lengths = Vec::with_capacity(doc_ids.len());
1407        for &doc_id in doc_ids {
1408            if doc_id >= self.doc_lengths.len() {
1409                lengths.push(0);
1410            } else {
1411                let len = self.doc_offsets[doc_id + 1] - self.doc_offsets[doc_id];
1412                lengths.push(len);
1413                total_tokens += len;
1414            }
1415        }
1416
1417        if total_tokens == 0 {
1418            return Ok((Array2::zeros((0, self.codec.embedding_dim())), lengths));
1419        }
1420
1421        // Gather all codes and residuals
1422        let packed_dim = self.mmap_residuals.ncols();
1423        let mut all_codes = Vec::with_capacity(total_tokens);
1424        let mut all_residuals = Array2::<u8>::zeros((total_tokens, packed_dim));
1425        let mut offset = 0;
1426
1427        for &doc_id in doc_ids {
1428            if doc_id >= self.doc_lengths.len() {
1429                continue;
1430            }
1431            let start = self.doc_offsets[doc_id];
1432            let end = self.doc_offsets[doc_id + 1];
1433            let len = end - start;
1434
1435            // Append codes
1436            let codes_slice = self.mmap_codes.slice(start, end);
1437            all_codes.extend(codes_slice.iter().map(|&c| c as usize));
1438
1439            // Copy residuals
1440            let residuals_view = self.mmap_residuals.slice_rows(start, end);
1441            all_residuals
1442                .slice_mut(s![offset..offset + len, ..])
1443                .assign(&residuals_view);
1444            offset += len;
1445        }
1446
1447        let codes_arr = Array1::from_vec(all_codes);
1448        let embeddings = self.codec.decompress(&all_residuals, &codes_arr.view())?;
1449
1450        Ok((embeddings, lengths))
1451    }
1452
1453    /// Search for similar documents.
1454    ///
1455    /// # Arguments
1456    ///
1457    /// * `query` - Query embedding matrix [num_tokens, dim]
1458    /// * `params` - Search parameters
1459    /// * `subset` - Optional subset of document IDs to search within
1460    ///
1461    /// # Returns
1462    ///
1463    /// Search result containing top-k document IDs and scores.
1464    pub fn search(
1465        &self,
1466        query: &Array2<f32>,
1467        params: &crate::search::SearchParameters,
1468        subset: Option<&[i64]>,
1469    ) -> Result<crate::search::SearchResult> {
1470        crate::search::search_one_mmap(self, query, params, subset)
1471    }
1472
1473    /// Search for multiple queries in batch.
1474    ///
1475    /// # Arguments
1476    ///
1477    /// * `queries` - Slice of query embedding matrices
1478    /// * `params` - Search parameters
1479    /// * `parallel` - If true, process queries in parallel using rayon
1480    /// * `subset` - Optional subset of document IDs to search within
1481    ///
1482    /// # Returns
1483    ///
1484    /// Vector of search results, one per query.
1485    pub fn search_batch(
1486        &self,
1487        queries: &[Array2<f32>],
1488        params: &crate::search::SearchParameters,
1489        parallel: bool,
1490        subset: Option<&[i64]>,
1491    ) -> Result<Vec<crate::search::SearchResult>> {
1492        crate::search::search_many_mmap(self, queries, params, parallel, subset)
1493    }
1494
1495    /// Get the number of documents in the index.
1496    pub fn num_documents(&self) -> usize {
1497        self.doc_lengths.len()
1498    }
1499
1500    /// Get the embedding dimension.
1501    pub fn embedding_dim(&self) -> usize {
1502        self.codec.embedding_dim()
1503    }
1504
1505    /// Reconstruct embeddings for specific documents.
1506    ///
1507    /// This method retrieves the compressed codes and residuals for each document
1508    /// from memory-mapped files and decompresses them to recover the original embeddings.
1509    ///
1510    /// # Arguments
1511    ///
1512    /// * `doc_ids` - Slice of document IDs to reconstruct (0-indexed)
1513    ///
1514    /// # Returns
1515    ///
1516    /// A vector of 2D arrays, one per document. Each array has shape `[num_tokens, dim]`.
1517    ///
1518    /// # Example
1519    ///
1520    /// ```ignore
1521    /// use next_plaid::MmapIndex;
1522    ///
1523    /// let index = MmapIndex::load("/path/to/index")?;
1524    /// let embeddings = index.reconstruct(&[0, 1, 2])?;
1525    ///
1526    /// for (i, emb) in embeddings.iter().enumerate() {
1527    ///     println!("Document {}: {} tokens x {} dim", i, emb.nrows(), emb.ncols());
1528    /// }
1529    /// ```
1530    pub fn reconstruct(&self, doc_ids: &[i64]) -> Result<Vec<Array2<f32>>> {
1531        crate::embeddings::reconstruct_embeddings_mmap(self, doc_ids)
1532    }
1533
1534    /// Reconstruct a single document's embeddings.
1535    ///
1536    /// Convenience method for reconstructing a single document.
1537    ///
1538    /// # Arguments
1539    ///
1540    /// * `doc_id` - Document ID to reconstruct (0-indexed)
1541    ///
1542    /// # Returns
1543    ///
1544    /// A 2D array with shape `[num_tokens, dim]`.
1545    pub fn reconstruct_single(&self, doc_id: i64) -> Result<Array2<f32>> {
1546        crate::embeddings::reconstruct_single_mmap(self, doc_id)
1547    }
1548}
1549
1550#[cfg(test)]
1551mod tests {
1552    use super::*;
1553
1554    #[test]
1555    fn test_index_config_default() {
1556        let config = IndexConfig::default();
1557        assert_eq!(config.nbits, 4);
1558        assert_eq!(config.batch_size, 50_000);
1559        assert_eq!(config.seed, Some(42));
1560    }
1561
1562    #[test]
1563    #[cfg(feature = "npy")]
1564    fn test_update_or_create_new_index() {
1565        use ndarray::Array2;
1566        use tempfile::tempdir;
1567
1568        let temp_dir = tempdir().unwrap();
1569        let index_path = temp_dir.path().to_str().unwrap();
1570
1571        // Create test embeddings (5 documents)
1572        let mut embeddings: Vec<Array2<f32>> = Vec::new();
1573        for i in 0..5 {
1574            let mut doc = Array2::<f32>::zeros((5, 32));
1575            for j in 0..5 {
1576                for k in 0..32 {
1577                    doc[[j, k]] = (i as f32 * 0.1) + (j as f32 * 0.01) + (k as f32 * 0.001);
1578                }
1579            }
1580            // Normalize rows
1581            for mut row in doc.rows_mut() {
1582                let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
1583                if norm > 0.0 {
1584                    row.iter_mut().for_each(|x| *x /= norm);
1585                }
1586            }
1587            embeddings.push(doc);
1588        }
1589
1590        let index_config = IndexConfig {
1591            nbits: 2,
1592            batch_size: 50,
1593            seed: Some(42),
1594            kmeans_niters: 2,
1595            ..Default::default()
1596        };
1597        let update_config = crate::update::UpdateConfig::default();
1598
1599        // Index doesn't exist - should create new
1600        let (index, doc_ids) =
1601            Index::update_or_create(&embeddings, index_path, &index_config, &update_config)
1602                .expect("Failed to create index");
1603
1604        assert_eq!(index.metadata.num_documents, 5);
1605        assert_eq!(doc_ids, vec![0, 1, 2, 3, 4]);
1606
1607        // Verify index was created
1608        assert!(temp_dir.path().join("metadata.json").exists());
1609        assert!(temp_dir.path().join("centroids.npy").exists());
1610    }
1611
1612    #[test]
1613    #[cfg(feature = "npy")]
1614    fn test_update_or_create_existing_index() {
1615        use ndarray::Array2;
1616        use tempfile::tempdir;
1617
1618        let temp_dir = tempdir().unwrap();
1619        let index_path = temp_dir.path().to_str().unwrap();
1620
1621        // Helper to create embeddings
1622        let create_embeddings = |count: usize, offset: usize| -> Vec<Array2<f32>> {
1623            let mut embeddings = Vec::new();
1624            for i in 0..count {
1625                let mut doc = Array2::<f32>::zeros((5, 32));
1626                for j in 0..5 {
1627                    for k in 0..32 {
1628                        doc[[j, k]] =
1629                            ((i + offset) as f32 * 0.1) + (j as f32 * 0.01) + (k as f32 * 0.001);
1630                    }
1631                }
1632                for mut row in doc.rows_mut() {
1633                    let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
1634                    if norm > 0.0 {
1635                        row.iter_mut().for_each(|x| *x /= norm);
1636                    }
1637                }
1638                embeddings.push(doc);
1639            }
1640            embeddings
1641        };
1642
1643        let index_config = IndexConfig {
1644            nbits: 2,
1645            batch_size: 50,
1646            seed: Some(42),
1647            kmeans_niters: 2,
1648            ..Default::default()
1649        };
1650        let update_config = crate::update::UpdateConfig::default();
1651
1652        // First call - creates index with 5 documents
1653        let embeddings1 = create_embeddings(5, 0);
1654        let (index1, doc_ids1) =
1655            Index::update_or_create(&embeddings1, index_path, &index_config, &update_config)
1656                .expect("Failed to create index");
1657        assert_eq!(index1.metadata.num_documents, 5);
1658        assert_eq!(doc_ids1, vec![0, 1, 2, 3, 4]);
1659
1660        // Second call - updates existing index with 3 more documents
1661        let embeddings2 = create_embeddings(3, 5);
1662        let (index2, doc_ids2) =
1663            Index::update_or_create(&embeddings2, index_path, &index_config, &update_config)
1664                .expect("Failed to update index");
1665        assert_eq!(index2.metadata.num_documents, 8);
1666        assert_eq!(doc_ids2, vec![5, 6, 7]);
1667    }
1668}