Skip to main content

next_plaid/
index.rs

1//! Index creation and management for PLAID
2
3use std::collections::BTreeMap;
4use std::fs::{self, File};
5use std::io::{BufReader, BufWriter, Write};
6use std::path::Path;
7
8use ndarray::{s, Array1, Array2, Axis};
9use serde::{Deserialize, Serialize};
10
11use crate::codec::ResidualCodec;
12use crate::error::{Error, Result};
13use crate::kmeans::{compute_kmeans, ComputeKmeansConfig};
14use crate::utils::{quantile, quantiles};
15
16/// CPU implementation of fused compress_into_codes + residual computation.
17fn compress_and_residuals_cpu(
18    embeddings: &Array2<f32>,
19    codec: &ResidualCodec,
20) -> (Array1<usize>, Array2<f32>) {
21    use rayon::prelude::*;
22
23    // Use CPU-only version to ensure no CUDA is called
24    let codes = codec.compress_into_codes_cpu(embeddings);
25    let mut residuals = embeddings.clone();
26
27    let centroids = &codec.centroids;
28    residuals
29        .axis_iter_mut(Axis(0))
30        .into_par_iter()
31        .zip(codes.as_slice().unwrap().par_iter())
32        .for_each(|(mut row, &code)| {
33            let centroid = centroids.row(code);
34            row.iter_mut()
35                .zip(centroid.iter())
36                .for_each(|(r, c)| *r -= c);
37        });
38
39    (codes, residuals)
40}
41
42/// Configuration for index creation
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct IndexConfig {
45    /// Number of bits for quantization (typically 2 or 4)
46    pub nbits: usize,
47    /// Batch size for processing
48    pub batch_size: usize,
49    /// Random seed for reproducibility
50    pub seed: Option<u64>,
51    /// Number of K-means iterations (default: 4)
52    #[serde(default = "default_kmeans_niters")]
53    pub kmeans_niters: usize,
54    /// Maximum number of points per centroid for K-means (default: 256)
55    #[serde(default = "default_max_points_per_centroid")]
56    pub max_points_per_centroid: usize,
57    /// Number of samples for K-means training.
58    /// If None, uses heuristic: min(1 + 16 * sqrt(120 * num_documents), num_documents)
59    #[serde(default)]
60    pub n_samples_kmeans: Option<usize>,
61    /// Threshold for start-from-scratch mode (default: 999).
62    /// When the number of documents is <= this threshold, raw embeddings are saved
63    /// to embeddings.npy for potential rebuilds during updates.
64    #[serde(default = "default_start_from_scratch")]
65    pub start_from_scratch: usize,
66    /// Force CPU execution for K-means even when CUDA feature is enabled.
67    /// Useful for small batches where GPU initialization overhead exceeds benefits.
68    #[serde(default)]
69    pub force_cpu: bool,
70    /// FTS5 tokenizer for full-text search over metadata.
71    /// Default: `Unicode61` (word-level). Use `Trigram` for code / substring search.
72    #[serde(default)]
73    pub fts_tokenizer: crate::text_search::FtsTokenizer,
74}
75
76fn default_start_from_scratch() -> usize {
77    999
78}
79
80fn default_kmeans_niters() -> usize {
81    4
82}
83
84fn default_max_points_per_centroid() -> usize {
85    256
86}
87
88impl Default for IndexConfig {
89    fn default() -> Self {
90        Self {
91            nbits: 4,
92            batch_size: 50_000,
93            seed: Some(42),
94            kmeans_niters: 4,
95            max_points_per_centroid: 256,
96            n_samples_kmeans: None,
97            start_from_scratch: 999,
98            force_cpu: false,
99            fts_tokenizer: crate::text_search::FtsTokenizer::default(),
100        }
101    }
102}
103
104/// Metadata for the index
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct Metadata {
107    /// Number of chunks in the index
108    pub num_chunks: usize,
109    /// Number of bits for quantization
110    pub nbits: usize,
111    /// Number of partitions (centroids)
112    pub num_partitions: usize,
113    /// Total number of embeddings
114    pub num_embeddings: usize,
115    /// Average document length
116    pub avg_doclen: f64,
117    /// Total number of documents
118    #[serde(default)]
119    pub num_documents: usize,
120    /// Embedding dimension (columns of centroids matrix)
121    #[serde(default)]
122    pub embedding_dim: usize,
123    /// Whether the index has been converted to next-plaid compatible format.
124    /// If false or missing, the index may need fast-plaid to next-plaid conversion.
125    #[serde(default)]
126    pub next_plaid_compatible: bool,
127}
128
129impl Metadata {
130    /// Load metadata from a JSON file, inferring num_documents from doclens if not present.
131    pub fn load_from_path(index_path: &Path) -> Result<Self> {
132        let metadata_path = index_path.join("metadata.json");
133        let mut metadata: Metadata = serde_json::from_reader(BufReader::new(
134            File::open(&metadata_path)
135                .map_err(|e| Error::IndexLoad(format!("Failed to open metadata: {}", e)))?,
136        ))?;
137
138        // If num_documents is 0 (default), infer from doclens files
139        if metadata.num_documents == 0 {
140            let mut total_docs = 0usize;
141            for chunk_idx in 0..metadata.num_chunks {
142                let doclens_path = index_path.join(format!("doclens.{}.json", chunk_idx));
143                if let Ok(file) = File::open(&doclens_path) {
144                    if let Ok(chunk_doclens) =
145                        serde_json::from_reader::<_, Vec<i64>>(BufReader::new(file))
146                    {
147                        total_docs += chunk_doclens.len();
148                    }
149                }
150            }
151            metadata.num_documents = total_docs;
152        }
153
154        Ok(metadata)
155    }
156}
157
158/// Chunk metadata
159#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct ChunkMetadata {
161    pub num_documents: usize,
162    pub num_embeddings: usize,
163    #[serde(default)]
164    pub embedding_offset: usize,
165}
166
167#[derive(Debug, Clone)]
168pub struct EncodedIndexChunk {
169    pub codes: Array1<i64>,
170    pub residuals: Array2<u8>,
171    pub doclens: Vec<i64>,
172}
173
174pub struct PreparedCodecArtifacts {
175    pub codec: ResidualCodec,
176    pub cluster_threshold: f32,
177    pub bucket_cutoffs: Array1<f32>,
178    pub bucket_weights: Array1<f32>,
179    pub avg_res_per_dim: Array1<f32>,
180}
181
182pub fn prepare_codec_artifacts(
183    embeddings: &[Array2<f32>],
184    centroids: Array2<f32>,
185    config: &IndexConfig,
186) -> Result<PreparedCodecArtifacts> {
187    let embedding_dim = centroids.ncols();
188    let total_embeddings: usize = embeddings.iter().map(|e| e.nrows()).sum();
189    let num_documents = embeddings.len();
190
191    if num_documents == 0 {
192        return Err(Error::IndexCreation("No documents provided".into()));
193    }
194
195    let sample_count = ((16.0 * (120.0 * num_documents as f64).sqrt()) as usize)
196        .min(num_documents)
197        .max(1);
198
199    let mut rng = if let Some(seed) = config.seed {
200        use rand::SeedableRng;
201        rand_chacha::ChaCha8Rng::seed_from_u64(seed)
202    } else {
203        use rand::SeedableRng;
204        rand_chacha::ChaCha8Rng::from_entropy()
205    };
206
207    use rand::seq::SliceRandom;
208    let mut indices: Vec<usize> = (0..num_documents).collect();
209    indices.shuffle(&mut rng);
210    let sample_indices: Vec<usize> = indices.into_iter().take(sample_count).collect();
211
212    let heldout_size = (0.05 * total_embeddings as f64).min(50000.0) as usize;
213    let mut heldout_embeddings: Vec<f32> = Vec::with_capacity(heldout_size * embedding_dim);
214    let mut collected = 0;
215
216    for &idx in sample_indices.iter().rev() {
217        if collected >= heldout_size {
218            break;
219        }
220        let emb = &embeddings[idx];
221        let take = (heldout_size - collected).min(emb.nrows());
222        for row in emb.axis_iter(Axis(0)).take(take) {
223            heldout_embeddings.extend(row.iter());
224        }
225        collected += take;
226    }
227
228    let heldout = Array2::from_shape_vec((collected, embedding_dim), heldout_embeddings)
229        .map_err(|e| Error::IndexCreation(format!("Failed to create heldout array: {}", e)))?;
230
231    let avg_residual = Array1::zeros(embedding_dim);
232    let initial_codec =
233        ResidualCodec::new(config.nbits, centroids.clone(), avg_residual, None, None)?;
234
235    let heldout_codes = if config.force_cpu {
236        initial_codec.compress_into_codes_cpu(&heldout)
237    } else {
238        initial_codec.compress_into_codes(&heldout)
239    };
240
241    let mut residuals = heldout.clone();
242    for i in 0..heldout.nrows() {
243        let centroid = initial_codec.centroids.row(heldout_codes[i]);
244        for j in 0..embedding_dim {
245            residuals[[i, j]] -= centroid[j];
246        }
247    }
248
249    let distances: Array1<f32> = residuals
250        .axis_iter(Axis(0))
251        .map(|row| row.dot(&row).sqrt())
252        .collect();
253    let cluster_threshold = quantile(&distances, 0.75);
254
255    let avg_res_per_dim: Array1<f32> = residuals
256        .axis_iter(Axis(1))
257        .map(|col| col.iter().map(|x| x.abs()).sum::<f32>() / col.len() as f32)
258        .collect();
259
260    let n_options = 1 << config.nbits;
261    let quantile_values: Vec<f64> = (1..n_options)
262        .map(|i| i as f64 / n_options as f64)
263        .collect();
264    let weight_quantile_values: Vec<f64> = (0..n_options)
265        .map(|i| (i as f64 + 0.5) / n_options as f64)
266        .collect();
267
268    let flat_residuals: Array1<f32> = residuals.iter().copied().collect();
269    let bucket_cutoffs = Array1::from_vec(quantiles(&flat_residuals, &quantile_values));
270    let bucket_weights = Array1::from_vec(quantiles(&flat_residuals, &weight_quantile_values));
271
272    let codec = ResidualCodec::new(
273        config.nbits,
274        centroids,
275        avg_res_per_dim.clone(),
276        Some(bucket_cutoffs.clone()),
277        Some(bucket_weights.clone()),
278    )?;
279
280    Ok(PreparedCodecArtifacts {
281        codec,
282        cluster_threshold,
283        bucket_cutoffs,
284        bucket_weights,
285        avg_res_per_dim,
286    })
287}
288
289pub fn encode_index_chunk(
290    embeddings: &[Array2<f32>],
291    codec: &ResidualCodec,
292    force_cpu: bool,
293) -> Result<EncodedIndexChunk> {
294    let embedding_dim = codec.embedding_dim();
295    let packed_dim = embedding_dim * codec.nbits / 8;
296    let doclens: Vec<i64> = embeddings.iter().map(|d| d.nrows() as i64).collect();
297    let total_tokens: usize = doclens.iter().sum::<i64>() as usize;
298
299    #[cfg(not(feature = "cuda"))]
300    let _ = force_cpu;
301
302    let mut batch_embeddings = Array2::<f32>::zeros((total_tokens, embedding_dim));
303    let mut offset = 0;
304    for doc in embeddings {
305        let n = doc.nrows();
306        batch_embeddings
307            .slice_mut(s![offset..offset + n, ..])
308            .assign(doc);
309        offset += n;
310    }
311
312    let (batch_codes, batch_residuals) = {
313        #[cfg(feature = "cuda")]
314        {
315            let force_gpu = crate::is_force_gpu();
316            if !force_cpu {
317                if let Some(ctx) = crate::cuda::get_global_context() {
318                    match crate::cuda::compress_and_residuals_cuda_batched(
319                        &ctx,
320                        &batch_embeddings.view(),
321                        &codec.centroids_view(),
322                        None,
323                    ) {
324                        Ok(result) => result,
325                        Err(e) => {
326                            if force_gpu {
327                                panic!(
328                                    "FORCE_GPU is set but CUDA compress_and_residuals failed: {}",
329                                    e
330                                );
331                            }
332                            println!(
333                                "[next-plaid] CUDA compress_and_residuals failed: {}, falling back to CPU",
334                                e
335                            );
336                            compress_and_residuals_cpu(&batch_embeddings, codec)
337                        }
338                    }
339                } else if force_gpu {
340                    panic!("FORCE_GPU is set but CUDA context is unavailable");
341                } else {
342                    compress_and_residuals_cpu(&batch_embeddings, codec)
343                }
344            } else {
345                compress_and_residuals_cpu(&batch_embeddings, codec)
346            }
347        }
348        #[cfg(not(feature = "cuda"))]
349        {
350            compress_and_residuals_cpu(&batch_embeddings, codec)
351        }
352    };
353
354    let batch_packed = codec.quantize_residuals(&batch_residuals)?;
355    let (raw_residuals, residuals_offset) = batch_packed.into_raw_vec_and_offset();
356    if residuals_offset != Some(0) {
357        return Err(Error::Shape(format!(
358            "Unexpected residual packing offset: {:?}",
359            residuals_offset
360        )));
361    }
362    let residuals = Array2::from_shape_vec((batch_codes.len(), packed_dim), raw_residuals)
363        .map_err(|e| Error::Shape(format!("Failed to reshape residuals: {}", e)))?;
364    let codes: Array1<i64> = batch_codes.iter().map(|&x| x as i64).collect();
365
366    Ok(EncodedIndexChunk {
367        codes,
368        residuals,
369        doclens,
370    })
371}
372
373pub fn write_index_from_encoded_chunks(
374    chunks: &[EncodedIndexChunk],
375    codec_artifacts: &PreparedCodecArtifacts,
376    index_path: &str,
377    config: &IndexConfig,
378) -> Result<Metadata> {
379    use ndarray_npy::WriteNpyExt;
380
381    let index_dir = Path::new(index_path);
382    fs::create_dir_all(index_dir)?;
383
384    let embedding_dim = codec_artifacts.codec.embedding_dim();
385    let num_centroids = codec_artifacts.codec.num_centroids();
386    let total_embeddings: usize = chunks.iter().map(|c| c.codes.len()).sum();
387    let num_documents: usize = chunks.iter().map(|c| c.doclens.len()).sum();
388    let avg_doclen = if num_documents > 0 {
389        total_embeddings as f64 / num_documents as f64
390    } else {
391        0.0
392    };
393
394    let centroids_path = index_dir.join("centroids.npy");
395    codec_artifacts
396        .codec
397        .centroids_view()
398        .to_owned()
399        .write_npy(File::create(&centroids_path)?)?;
400    codec_artifacts
401        .bucket_cutoffs
402        .write_npy(File::create(index_dir.join("bucket_cutoffs.npy"))?)?;
403    codec_artifacts
404        .bucket_weights
405        .write_npy(File::create(index_dir.join("bucket_weights.npy"))?)?;
406    codec_artifacts
407        .avg_res_per_dim
408        .write_npy(File::create(index_dir.join("avg_residual.npy"))?)?;
409    Array1::from_vec(vec![codec_artifacts.cluster_threshold])
410        .write_npy(File::create(index_dir.join("cluster_threshold.npy"))?)?;
411
412    let n_chunks = chunks.len();
413    let plan = serde_json::json!({
414        "nbits": config.nbits,
415        "num_chunks": n_chunks,
416    });
417    let mut plan_file = File::create(index_dir.join("plan.json"))?;
418    writeln!(plan_file, "{}", serde_json::to_string_pretty(&plan)?)?;
419
420    let mut all_codes: Vec<usize> = Vec::with_capacity(total_embeddings);
421    let mut doc_lengths: Vec<i64> = Vec::with_capacity(num_documents);
422    let mut current_offset = 0usize;
423
424    for (chunk_idx, chunk) in chunks.iter().enumerate() {
425        let chunk_meta = ChunkMetadata {
426            num_documents: chunk.doclens.len(),
427            num_embeddings: chunk.codes.len(),
428            embedding_offset: current_offset,
429        };
430        current_offset += chunk.codes.len();
431
432        serde_json::to_writer_pretty(
433            BufWriter::new(File::create(
434                index_dir.join(format!("{}.metadata.json", chunk_idx)),
435            )?),
436            &chunk_meta,
437        )?;
438        serde_json::to_writer(
439            BufWriter::new(File::create(
440                index_dir.join(format!("doclens.{}.json", chunk_idx)),
441            )?),
442            &chunk.doclens,
443        )?;
444        chunk.codes.write_npy(File::create(
445            index_dir.join(format!("{}.codes.npy", chunk_idx)),
446        )?)?;
447        chunk.residuals.write_npy(File::create(
448            index_dir.join(format!("{}.residuals.npy", chunk_idx)),
449        )?)?;
450
451        doc_lengths.extend_from_slice(&chunk.doclens);
452        all_codes.extend(chunk.codes.iter().map(|&x| x as usize));
453    }
454
455    let mut code_to_docs: BTreeMap<usize, Vec<i64>> = BTreeMap::new();
456    let mut emb_idx = 0;
457    for (doc_id, &len) in doc_lengths.iter().enumerate() {
458        for _ in 0..len {
459            let code = all_codes[emb_idx];
460            code_to_docs.entry(code).or_default().push(doc_id as i64);
461            emb_idx += 1;
462        }
463    }
464
465    let mut ivf_data: Vec<i64> = Vec::new();
466    let mut ivf_lengths: Vec<i32> = vec![0; num_centroids];
467    for (centroid_id, ivf_len) in ivf_lengths.iter_mut().enumerate() {
468        if let Some(docs) = code_to_docs.get(&centroid_id) {
469            let mut unique_docs = docs.clone();
470            unique_docs.sort_unstable();
471            unique_docs.dedup();
472            *ivf_len = unique_docs.len() as i32;
473            ivf_data.extend(unique_docs);
474        }
475    }
476
477    Array1::from_vec(ivf_data).write_npy(File::create(index_dir.join("ivf.npy"))?)?;
478    Array1::from_vec(ivf_lengths).write_npy(File::create(index_dir.join("ivf_lengths.npy"))?)?;
479
480    let metadata = Metadata {
481        num_chunks: n_chunks,
482        nbits: config.nbits,
483        num_partitions: num_centroids,
484        num_embeddings: total_embeddings,
485        avg_doclen,
486        num_documents,
487        embedding_dim,
488        next_plaid_compatible: true,
489    };
490    serde_json::to_writer_pretty(
491        BufWriter::new(File::create(index_dir.join("metadata.json"))?),
492        &metadata,
493    )?;
494
495    Ok(metadata)
496}
497
498// ============================================================================
499// Standalone Index Creation Functions
500// ============================================================================
501
502/// Create index files on disk from embeddings and centroids.
503///
504/// This is a standalone function that creates all necessary index files
505/// without constructing an in-memory Index object. Both Index and MmapIndex
506/// can use this function to create their files, then load them in their
507/// preferred format.
508///
509/// # Arguments
510///
511/// * `embeddings` - List of document embeddings
512/// * `centroids` - Pre-computed centroids from K-means
513/// * `index_path` - Directory to save the index
514/// * `config` - Index configuration
515///
516/// # Returns
517///
518/// Metadata about the created index
519pub fn create_index_files(
520    embeddings: &[Array2<f32>],
521    centroids: Array2<f32>,
522    index_path: &str,
523    config: &IndexConfig,
524) -> Result<Metadata> {
525    let index_dir = Path::new(index_path);
526    fs::create_dir_all(index_dir)?;
527
528    let num_documents = embeddings.len();
529    let embedding_dim = centroids.ncols();
530    let num_centroids = centroids.nrows();
531
532    if num_documents == 0 {
533        return Err(Error::IndexCreation("No documents provided".into()));
534    }
535
536    // Calculate statistics
537    let total_embeddings: usize = embeddings.iter().map(|e| e.nrows()).sum();
538    let avg_doclen = total_embeddings as f64 / num_documents as f64;
539
540    // Sample documents for codec training
541    let sample_count = ((16.0 * (120.0 * num_documents as f64).sqrt()) as usize)
542        .min(num_documents)
543        .max(1);
544
545    let mut rng = if let Some(seed) = config.seed {
546        use rand::SeedableRng;
547        rand_chacha::ChaCha8Rng::seed_from_u64(seed)
548    } else {
549        use rand::SeedableRng;
550        rand_chacha::ChaCha8Rng::from_entropy()
551    };
552
553    use rand::seq::SliceRandom;
554    let mut indices: Vec<usize> = (0..num_documents).collect();
555    indices.shuffle(&mut rng);
556    let sample_indices: Vec<usize> = indices.into_iter().take(sample_count).collect();
557
558    // Collect sample embeddings for training
559    let heldout_size = (0.05 * total_embeddings as f64).min(50000.0) as usize;
560    let mut heldout_embeddings: Vec<f32> = Vec::with_capacity(heldout_size * embedding_dim);
561    let mut collected = 0;
562
563    for &idx in sample_indices.iter().rev() {
564        if collected >= heldout_size {
565            break;
566        }
567        let emb = &embeddings[idx];
568        let take = (heldout_size - collected).min(emb.nrows());
569        for row in emb.axis_iter(Axis(0)).take(take) {
570            heldout_embeddings.extend(row.iter());
571        }
572        collected += take;
573    }
574
575    let heldout = Array2::from_shape_vec((collected, embedding_dim), heldout_embeddings)
576        .map_err(|e| Error::IndexCreation(format!("Failed to create heldout array: {}", e)))?;
577
578    // Train codec: compute residuals and quantization parameters
579    let avg_residual = Array1::zeros(embedding_dim);
580    let initial_codec =
581        ResidualCodec::new(config.nbits, centroids.clone(), avg_residual, None, None)?;
582
583    // Compute codes for heldout samples
584    // Use CPU-only version when force_cpu is set to avoid CUDA initialization overhead
585    let heldout_codes = if config.force_cpu {
586        initial_codec.compress_into_codes_cpu(&heldout)
587    } else {
588        initial_codec.compress_into_codes(&heldout)
589    };
590
591    // Compute residuals
592    let mut residuals = heldout.clone();
593    for i in 0..heldout.nrows() {
594        let centroid = initial_codec.centroids.row(heldout_codes[i]);
595        for j in 0..embedding_dim {
596            residuals[[i, j]] -= centroid[j];
597        }
598    }
599
600    // Compute cluster threshold from residual distances
601    let distances: Array1<f32> = residuals
602        .axis_iter(Axis(0))
603        .map(|row| row.dot(&row).sqrt())
604        .collect();
605    #[allow(unused_variables)]
606    let cluster_threshold = quantile(&distances, 0.75);
607
608    // Compute average residual per dimension
609    let avg_res_per_dim: Array1<f32> = residuals
610        .axis_iter(Axis(1))
611        .map(|col| col.iter().map(|x| x.abs()).sum::<f32>() / col.len() as f32)
612        .collect();
613
614    // Compute quantization buckets
615    let n_options = 1 << config.nbits;
616    let quantile_values: Vec<f64> = (1..n_options)
617        .map(|i| i as f64 / n_options as f64)
618        .collect();
619    let weight_quantile_values: Vec<f64> = (0..n_options)
620        .map(|i| (i as f64 + 0.5) / n_options as f64)
621        .collect();
622
623    // Flatten residuals for quantile computation
624    let flat_residuals: Array1<f32> = residuals.iter().copied().collect();
625    let bucket_cutoffs = Array1::from_vec(quantiles(&flat_residuals, &quantile_values));
626    let bucket_weights = Array1::from_vec(quantiles(&flat_residuals, &weight_quantile_values));
627
628    let codec = ResidualCodec::new(
629        config.nbits,
630        centroids.clone(),
631        avg_res_per_dim.clone(),
632        Some(bucket_cutoffs.clone()),
633        Some(bucket_weights.clone()),
634    )?;
635
636    // Save codec components
637    use ndarray_npy::WriteNpyExt;
638
639    let centroids_path = index_dir.join("centroids.npy");
640    codec
641        .centroids_view()
642        .to_owned()
643        .write_npy(File::create(&centroids_path)?)?;
644
645    let cutoffs_path = index_dir.join("bucket_cutoffs.npy");
646    bucket_cutoffs.write_npy(File::create(&cutoffs_path)?)?;
647
648    let weights_path = index_dir.join("bucket_weights.npy");
649    bucket_weights.write_npy(File::create(&weights_path)?)?;
650
651    let avg_res_path = index_dir.join("avg_residual.npy");
652    avg_res_per_dim.write_npy(File::create(&avg_res_path)?)?;
653
654    let threshold_path = index_dir.join("cluster_threshold.npy");
655    Array1::from_vec(vec![cluster_threshold]).write_npy(File::create(&threshold_path)?)?;
656
657    // Process documents in chunks
658    let n_chunks = (num_documents as f64 / config.batch_size as f64).ceil() as usize;
659
660    // Save plan
661    let plan_path = index_dir.join("plan.json");
662    let plan = serde_json::json!({
663        "nbits": config.nbits,
664        "num_chunks": n_chunks,
665    });
666    let mut plan_file = File::create(&plan_path)?;
667    writeln!(plan_file, "{}", serde_json::to_string_pretty(&plan)?)?;
668
669    let mut all_codes: Vec<usize> = Vec::with_capacity(total_embeddings);
670    let mut doc_lengths: Vec<i64> = Vec::with_capacity(num_documents);
671
672    for chunk_idx in 0..n_chunks {
673        let start = chunk_idx * config.batch_size;
674        let end = (start + config.batch_size).min(num_documents);
675        let chunk_docs = &embeddings[start..end];
676
677        // Collect document lengths
678        let chunk_doclens: Vec<i64> = chunk_docs.iter().map(|d| d.nrows() as i64).collect();
679        let total_tokens: usize = chunk_doclens.iter().sum::<i64>() as usize;
680
681        // Concatenate all embeddings in the chunk for batch processing
682        let mut batch_embeddings = Array2::<f32>::zeros((total_tokens, embedding_dim));
683        let mut offset = 0;
684        for doc in chunk_docs {
685            let n = doc.nrows();
686            batch_embeddings
687                .slice_mut(s![offset..offset + n, ..])
688                .assign(doc);
689            offset += n;
690        }
691
692        // BATCH: Compress embeddings and compute residuals
693        // Try CUDA fused operation first, fall back to CPU (skip CUDA if force_cpu is set)
694        let (batch_codes, batch_residuals) = {
695            #[cfg(feature = "cuda")]
696            {
697                let force_gpu = crate::is_force_gpu();
698                if !config.force_cpu {
699                    if let Some(ctx) = crate::cuda::get_global_context() {
700                        match crate::cuda::compress_and_residuals_cuda_batched(
701                            &ctx,
702                            &batch_embeddings.view(),
703                            &codec.centroids_view(),
704                            None,
705                        ) {
706                            Ok(result) => result,
707                            Err(e) => {
708                                if force_gpu {
709                                    panic!("FORCE_GPU is set but CUDA compress_and_residuals failed: {}", e);
710                                }
711                                eprintln!(
712                                    "[next-plaid] CUDA compress_and_residuals failed: {}, falling back to CPU",
713                                    e
714                                );
715                                compress_and_residuals_cpu(&batch_embeddings, &codec)
716                            }
717                        }
718                    } else if force_gpu {
719                        panic!("FORCE_GPU is set but CUDA context is unavailable");
720                    } else {
721                        compress_and_residuals_cpu(&batch_embeddings, &codec)
722                    }
723                } else {
724                    compress_and_residuals_cpu(&batch_embeddings, &codec)
725                }
726            }
727            #[cfg(not(feature = "cuda"))]
728            {
729                compress_and_residuals_cpu(&batch_embeddings, &codec)
730            }
731        };
732
733        // BATCH: Quantize all residuals at once
734        let batch_packed = codec.quantize_residuals(&batch_residuals)?;
735
736        // Track codes for IVF building
737        for &len in &chunk_doclens {
738            doc_lengths.push(len);
739        }
740        all_codes.extend(batch_codes.iter().copied());
741
742        // Save chunk metadata
743        let chunk_meta = ChunkMetadata {
744            num_documents: end - start,
745            num_embeddings: batch_codes.len(),
746            embedding_offset: 0, // Will be updated later
747        };
748
749        let chunk_meta_path = index_dir.join(format!("{}.metadata.json", chunk_idx));
750        serde_json::to_writer_pretty(BufWriter::new(File::create(&chunk_meta_path)?), &chunk_meta)?;
751
752        // Save chunk doclens
753        let doclens_path = index_dir.join(format!("doclens.{}.json", chunk_idx));
754        serde_json::to_writer(BufWriter::new(File::create(&doclens_path)?), &chunk_doclens)?;
755
756        // Save chunk codes
757        let chunk_codes_arr: Array1<i64> = batch_codes.iter().map(|&x| x as i64).collect();
758        let codes_path = index_dir.join(format!("{}.codes.npy", chunk_idx));
759        chunk_codes_arr.write_npy(File::create(&codes_path)?)?;
760
761        // Save chunk residuals
762        let residuals_path = index_dir.join(format!("{}.residuals.npy", chunk_idx));
763        batch_packed.write_npy(File::create(&residuals_path)?)?;
764    }
765
766    // Update chunk metadata with global offsets
767    let mut current_offset = 0usize;
768    for chunk_idx in 0..n_chunks {
769        let chunk_meta_path = index_dir.join(format!("{}.metadata.json", chunk_idx));
770        let mut meta: serde_json::Value =
771            serde_json::from_reader(BufReader::new(File::open(&chunk_meta_path)?))?;
772
773        if let Some(obj) = meta.as_object_mut() {
774            obj.insert("embedding_offset".to_string(), current_offset.into());
775            let num_emb = obj["num_embeddings"].as_u64().unwrap_or(0) as usize;
776            current_offset += num_emb;
777        }
778
779        serde_json::to_writer_pretty(BufWriter::new(File::create(&chunk_meta_path)?), &meta)?;
780    }
781
782    // Build IVF (Inverted File)
783    let mut code_to_docs: BTreeMap<usize, Vec<i64>> = BTreeMap::new();
784    let mut emb_idx = 0;
785
786    for (doc_id, &len) in doc_lengths.iter().enumerate() {
787        for _ in 0..len {
788            let code = all_codes[emb_idx];
789            code_to_docs.entry(code).or_default().push(doc_id as i64);
790            emb_idx += 1;
791        }
792    }
793
794    // Deduplicate document IDs per centroid
795    let mut ivf_data: Vec<i64> = Vec::new();
796    let mut ivf_lengths: Vec<i32> = vec![0; num_centroids];
797
798    for (centroid_id, ivf_len) in ivf_lengths.iter_mut().enumerate() {
799        if let Some(docs) = code_to_docs.get(&centroid_id) {
800            let mut unique_docs: Vec<i64> = docs.clone();
801            unique_docs.sort_unstable();
802            unique_docs.dedup();
803            *ivf_len = unique_docs.len() as i32;
804            ivf_data.extend(unique_docs);
805        }
806    }
807
808    let ivf = Array1::from_vec(ivf_data);
809    let ivf_lengths = Array1::from_vec(ivf_lengths);
810
811    let ivf_path = index_dir.join("ivf.npy");
812    ivf.write_npy(File::create(&ivf_path)?)?;
813
814    let ivf_lengths_path = index_dir.join("ivf_lengths.npy");
815    ivf_lengths.write_npy(File::create(&ivf_lengths_path)?)?;
816
817    // Save global metadata
818    let metadata = Metadata {
819        num_chunks: n_chunks,
820        nbits: config.nbits,
821        num_partitions: num_centroids,
822        num_embeddings: total_embeddings,
823        avg_doclen,
824        num_documents,
825        embedding_dim,
826        next_plaid_compatible: true, // Created by next-plaid, always compatible
827    };
828
829    let metadata_path = index_dir.join("metadata.json");
830    serde_json::to_writer_pretty(BufWriter::new(File::create(&metadata_path)?), &metadata)?;
831
832    Ok(metadata)
833}
834
835/// Create index files with automatic K-means centroid computation.
836///
837/// This is a standalone function that runs K-means to compute centroids,
838/// then creates all index files on disk.
839///
840/// # Arguments
841///
842/// * `embeddings` - List of document embeddings
843/// * `index_path` - Directory to save the index
844/// * `config` - Index configuration
845///
846/// # Returns
847///
848/// Metadata about the created index
849pub fn create_index_with_kmeans_files(
850    embeddings: &[Array2<f32>],
851    index_path: &str,
852    config: &IndexConfig,
853) -> Result<Metadata> {
854    if embeddings.is_empty() {
855        return Err(Error::IndexCreation("No documents provided".into()));
856    }
857
858    // Pre-initialize CUDA if available (first init can take 10-20s due to driver initialization)
859    // Skip if force_cpu is set to avoid unnecessary initialization overhead
860    #[cfg(feature = "cuda")]
861    if !config.force_cpu {
862        if crate::is_force_gpu() {
863            crate::cuda::get_global_context()
864                .expect("FORCE_GPU is set but CUDA context failed to initialize");
865        } else {
866            let _ = crate::cuda::get_global_context();
867        }
868    }
869
870    // Build K-means configuration from IndexConfig
871    let kmeans_config = ComputeKmeansConfig {
872        kmeans_niters: config.kmeans_niters,
873        max_points_per_centroid: config.max_points_per_centroid,
874        seed: config.seed.unwrap_or(42),
875        n_samples_kmeans: config.n_samples_kmeans,
876        num_partitions: None, // Let the heuristic decide
877        force_cpu: config.force_cpu,
878    };
879
880    // Compute centroids using fast-plaid's approach
881    let centroids = compute_kmeans(embeddings, &kmeans_config)?;
882
883    // Create the index files
884    let metadata = create_index_files(embeddings, centroids, index_path, config)?;
885
886    // If below start_from_scratch threshold, save raw embeddings for potential rebuilds
887    if embeddings.len() <= config.start_from_scratch {
888        let index_dir = std::path::Path::new(index_path);
889        crate::update::save_embeddings_npy(index_dir, embeddings)?;
890    }
891
892    Ok(metadata)
893}
894// ============================================================================
895// Memory-Mapped Index for Low Memory Usage
896// ============================================================================
897
898/// A memory-mapped PLAID index for multi-vector search.
899///
900/// This struct uses memory-mapped files for the large arrays (codes and residuals)
901/// instead of loading them entirely into RAM. Only small tensors (centroids,
902/// bucket weights, IVF) are loaded into memory.
903///
904/// # Memory Usage
905///
906/// Only small tensors (~50 MB for SciFact 5K docs) are loaded into RAM,
907/// with code and residual data accessed via OS-managed memory mapping.
908///
909/// # Usage
910///
911/// ```ignore
912/// use next_plaid::MmapIndex;
913///
914/// let index = MmapIndex::load("/path/to/index")?;
915/// let results = index.search(&query, &params, None)?;
916/// ```
917pub struct MmapIndex {
918    /// Path to the index directory
919    pub path: String,
920    /// Index metadata
921    pub metadata: Metadata,
922    /// Residual codec for quantization/decompression
923    pub codec: ResidualCodec,
924    /// IVF data (concatenated passage IDs per centroid)
925    pub ivf: Array1<i64>,
926    /// IVF lengths (number of passages per centroid)
927    pub ivf_lengths: Array1<i32>,
928    /// IVF offsets (cumulative offsets into ivf array)
929    pub ivf_offsets: Array1<i64>,
930    /// Document lengths (number of tokens per document)
931    pub doc_lengths: Array1<i64>,
932    /// Cumulative document offsets for indexing into codes/residuals
933    pub doc_offsets: Array1<usize>,
934    /// Memory-mapped codes array (public for search access)
935    pub mmap_codes: crate::mmap::MmapNpyArray1I64,
936    /// Memory-mapped residuals array (public for search access)
937    pub mmap_residuals: crate::mmap::MmapNpyArray2U8,
938}
939
940impl MmapIndex {
941    /// Load a memory-mapped index from disk.
942    ///
943    /// This creates merged files for codes and residuals if they don't exist,
944    /// then memory-maps them for efficient access.
945    ///
946    /// If the index was created by fast-plaid, it will be automatically converted
947    /// to next-plaid compatible format on first load.
948    pub fn load(index_path: &str) -> Result<Self> {
949        use ndarray_npy::ReadNpyExt;
950
951        let index_dir = Path::new(index_path);
952
953        // Load metadata (infers num_documents from doclens if not present)
954        let mut metadata = Metadata::load_from_path(index_dir)?;
955
956        // Check if conversion from fast-plaid format is needed
957        if !metadata.next_plaid_compatible {
958            eprintln!("Checking index format compatibility...");
959            let converted = crate::mmap::convert_fastplaid_to_nextplaid(index_dir)?;
960            if converted {
961                eprintln!("Index converted to next-plaid compatible format.");
962                // Delete any existing merged files since the source files changed
963                let merged_codes = index_dir.join("merged_codes.npy");
964                let merged_residuals = index_dir.join("merged_residuals.npy");
965                let codes_manifest = index_dir.join("merged_codes.manifest.json");
966                let residuals_manifest = index_dir.join("merged_residuals.manifest.json");
967                for path in [
968                    &merged_codes,
969                    &merged_residuals,
970                    &codes_manifest,
971                    &residuals_manifest,
972                ] {
973                    if path.exists() {
974                        let _ = fs::remove_file(path);
975                    }
976                }
977            }
978
979            // Mark as compatible and save metadata
980            metadata.next_plaid_compatible = true;
981            let metadata_path = index_dir.join("metadata.json");
982            let file = File::create(&metadata_path)
983                .map_err(|e| Error::IndexLoad(format!("Failed to update metadata: {}", e)))?;
984            serde_json::to_writer_pretty(BufWriter::new(file), &metadata)?;
985            eprintln!("Metadata updated with next_plaid_compatible: true");
986        }
987
988        // Load codec with memory-mapped centroids for reduced RAM usage.
989        // Other small tensors (bucket weights, etc.) are still loaded into memory.
990        let codec = ResidualCodec::load_mmap_from_dir(index_dir)?;
991
992        // Load IVF (small tensor)
993        let ivf_path = index_dir.join("ivf.npy");
994        let ivf: Array1<i64> = Array1::read_npy(
995            File::open(&ivf_path)
996                .map_err(|e| Error::IndexLoad(format!("Failed to open ivf.npy: {}", e)))?,
997        )
998        .map_err(|e| Error::IndexLoad(format!("Failed to read ivf.npy: {}", e)))?;
999
1000        let ivf_lengths_path = index_dir.join("ivf_lengths.npy");
1001        let ivf_lengths: Array1<i32> = Array1::read_npy(
1002            File::open(&ivf_lengths_path)
1003                .map_err(|e| Error::IndexLoad(format!("Failed to open ivf_lengths.npy: {}", e)))?,
1004        )
1005        .map_err(|e| Error::IndexLoad(format!("Failed to read ivf_lengths.npy: {}", e)))?;
1006
1007        // Compute IVF offsets
1008        let num_centroids = ivf_lengths.len();
1009        let mut ivf_offsets = Array1::<i64>::zeros(num_centroids + 1);
1010        for i in 0..num_centroids {
1011            ivf_offsets[i + 1] = ivf_offsets[i] + ivf_lengths[i] as i64;
1012        }
1013
1014        // Load document lengths from all chunks
1015        let mut doc_lengths_vec: Vec<i64> = Vec::with_capacity(metadata.num_documents);
1016        for chunk_idx in 0..metadata.num_chunks {
1017            let doclens_path = index_dir.join(format!("doclens.{}.json", chunk_idx));
1018            let chunk_doclens: Vec<i64> =
1019                serde_json::from_reader(BufReader::new(File::open(&doclens_path)?))?;
1020            doc_lengths_vec.extend(chunk_doclens);
1021        }
1022        let doc_lengths = Array1::from_vec(doc_lengths_vec);
1023
1024        // Compute document offsets for indexing
1025        let mut doc_offsets = Array1::<usize>::zeros(doc_lengths.len() + 1);
1026        for i in 0..doc_lengths.len() {
1027            doc_offsets[i + 1] = doc_offsets[i] + doc_lengths[i] as usize;
1028        }
1029
1030        // Compute padding needed for StridedTensor compatibility
1031        let max_len = doc_lengths.iter().cloned().max().unwrap_or(0) as usize;
1032        let last_len = *doc_lengths.last().unwrap_or(&0) as usize;
1033        let padding_needed = max_len.saturating_sub(last_len);
1034
1035        let merged_codes_path =
1036            crate::mmap::merge_codes_chunks(index_dir, metadata.num_chunks, padding_needed)?;
1037        let merged_residuals_path =
1038            crate::mmap::merge_residuals_chunks(index_dir, metadata.num_chunks, padding_needed)?;
1039
1040        let (mmap_codes, mmap_residuals) = (
1041            crate::mmap::MmapNpyArray1I64::from_npy_file(&merged_codes_path)?,
1042            crate::mmap::MmapNpyArray2U8::from_npy_file(&merged_residuals_path)?,
1043        );
1044
1045        Ok(Self {
1046            path: index_path.to_string(),
1047            metadata,
1048            codec,
1049            ivf,
1050            ivf_lengths,
1051            ivf_offsets,
1052            doc_lengths,
1053            doc_offsets,
1054            mmap_codes,
1055            mmap_residuals,
1056        })
1057    }
1058
1059    /// Get candidate documents from IVF for given centroid indices.
1060    pub fn get_candidates(&self, centroid_indices: &[usize]) -> Vec<i64> {
1061        let mut candidates: Vec<i64> = Vec::new();
1062
1063        for &idx in centroid_indices {
1064            if idx < self.ivf_lengths.len() {
1065                let start = self.ivf_offsets[idx] as usize;
1066                let len = self.ivf_lengths[idx] as usize;
1067                candidates.extend(self.ivf.slice(s![start..start + len]).iter());
1068            }
1069        }
1070
1071        candidates.sort_unstable();
1072        candidates.dedup();
1073        candidates
1074    }
1075
1076    /// Get document embeddings by decompressing codes and residuals.
1077    pub fn get_document_embeddings(&self, doc_id: usize) -> Result<Array2<f32>> {
1078        if doc_id >= self.doc_lengths.len() {
1079            return Err(Error::Search(format!("Invalid document ID: {}", doc_id)));
1080        }
1081
1082        let start = self.doc_offsets[doc_id];
1083        let end = self.doc_offsets[doc_id + 1];
1084
1085        // Get codes and residuals from mmap
1086        let codes_slice = self.mmap_codes.slice(start, end);
1087        let residuals_view = self.mmap_residuals.slice_rows(start, end);
1088
1089        // Convert codes to Array1<usize>
1090        let codes: Array1<usize> = Array1::from_iter(codes_slice.iter().map(|&c| c as usize));
1091
1092        // Convert residuals to owned Array2
1093        let residuals = residuals_view.to_owned();
1094
1095        // Decompress
1096        self.codec.decompress(&residuals, &codes.view())
1097    }
1098
1099    /// Get codes for a batch of document IDs (for approximate scoring).
1100    pub fn get_document_codes(&self, doc_ids: &[usize]) -> Vec<Vec<i64>> {
1101        doc_ids
1102            .iter()
1103            .map(|&doc_id| {
1104                if doc_id >= self.doc_lengths.len() {
1105                    return vec![];
1106                }
1107                let start = self.doc_offsets[doc_id];
1108                let end = self.doc_offsets[doc_id + 1];
1109                self.mmap_codes.slice(start, end).to_vec()
1110            })
1111            .collect()
1112    }
1113
1114    /// Decompress embeddings for a batch of document IDs.
1115    pub fn decompress_documents(&self, doc_ids: &[usize]) -> Result<(Array2<f32>, Vec<usize>)> {
1116        // Compute total tokens
1117        let mut total_tokens = 0usize;
1118        let mut lengths = Vec::with_capacity(doc_ids.len());
1119        for &doc_id in doc_ids {
1120            if doc_id >= self.doc_lengths.len() {
1121                lengths.push(0);
1122            } else {
1123                let len = self.doc_offsets[doc_id + 1] - self.doc_offsets[doc_id];
1124                lengths.push(len);
1125                total_tokens += len;
1126            }
1127        }
1128
1129        if total_tokens == 0 {
1130            return Ok((Array2::zeros((0, self.codec.embedding_dim())), lengths));
1131        }
1132
1133        // Gather all codes and residuals
1134        let packed_dim = self.mmap_residuals.ncols();
1135        let mut all_codes = Vec::with_capacity(total_tokens);
1136        let mut all_residuals = Array2::<u8>::zeros((total_tokens, packed_dim));
1137        let mut offset = 0;
1138
1139        for &doc_id in doc_ids {
1140            if doc_id >= self.doc_lengths.len() {
1141                continue;
1142            }
1143            let start = self.doc_offsets[doc_id];
1144            let end = self.doc_offsets[doc_id + 1];
1145            let len = end - start;
1146
1147            // Append codes
1148            let codes_slice = self.mmap_codes.slice(start, end);
1149            all_codes.extend(codes_slice.iter().map(|&c| c as usize));
1150
1151            // Copy residuals
1152            let residuals_view = self.mmap_residuals.slice_rows(start, end);
1153            all_residuals
1154                .slice_mut(s![offset..offset + len, ..])
1155                .assign(&residuals_view);
1156            offset += len;
1157        }
1158
1159        let codes_arr = Array1::from_vec(all_codes);
1160        let embeddings = self.codec.decompress(&all_residuals, &codes_arr.view())?;
1161
1162        Ok((embeddings, lengths))
1163    }
1164
1165    /// Search for similar documents.
1166    ///
1167    /// # Arguments
1168    ///
1169    /// * `query` - Query embedding matrix [num_tokens, dim]
1170    /// * `params` - Search parameters
1171    /// * `subset` - Optional subset of document IDs to search within
1172    ///
1173    /// # Returns
1174    ///
1175    /// Search result containing top-k document IDs and scores.
1176    pub fn search(
1177        &self,
1178        query: &Array2<f32>,
1179        params: &crate::search::SearchParameters,
1180        subset: Option<&[i64]>,
1181    ) -> Result<crate::search::SearchResult> {
1182        crate::search::search_one_mmap(self, query, params, subset)
1183    }
1184
1185    /// Search for multiple queries in batch.
1186    ///
1187    /// # Arguments
1188    ///
1189    /// * `queries` - Slice of query embedding matrices
1190    /// * `params` - Search parameters
1191    /// * `parallel` - If true, process queries in parallel using rayon
1192    /// * `subset` - Optional subset of document IDs to search within
1193    ///
1194    /// # Returns
1195    ///
1196    /// Vector of search results, one per query.
1197    pub fn search_batch(
1198        &self,
1199        queries: &[Array2<f32>],
1200        params: &crate::search::SearchParameters,
1201        parallel: bool,
1202        subset: Option<&[i64]>,
1203    ) -> Result<Vec<crate::search::SearchResult>> {
1204        crate::search::search_many_mmap(self, queries, params, parallel, subset)
1205    }
1206
1207    /// Get the number of documents in the index.
1208    pub fn num_documents(&self) -> usize {
1209        self.doc_lengths.len()
1210    }
1211
1212    /// Get the total number of embeddings in the index.
1213    pub fn num_embeddings(&self) -> usize {
1214        self.metadata.num_embeddings
1215    }
1216
1217    /// Get the number of partitions (centroids).
1218    pub fn num_partitions(&self) -> usize {
1219        self.metadata.num_partitions
1220    }
1221
1222    /// Get the average document length.
1223    pub fn avg_doclen(&self) -> f64 {
1224        self.metadata.avg_doclen
1225    }
1226
1227    /// Get the embedding dimension.
1228    pub fn embedding_dim(&self) -> usize {
1229        self.codec.embedding_dim()
1230    }
1231
1232    /// Release all memory-mapped file handles.
1233    ///
1234    /// On Windows, files that are memory-mapped cannot be deleted, renamed, or
1235    /// truncated (OS error 1224 / ERROR_USER_MAPPED_FILE). This method replaces
1236    /// file-backed mmaps with anonymous (non-file) mmaps so that subsequent
1237    /// file operations on the index directory can proceed.
1238    ///
1239    /// After calling this, the index is not usable for search — it must be
1240    /// reloaded via `Self::load()`.
1241    fn release_mmaps(&mut self) {
1242        self.mmap_codes = crate::mmap::MmapNpyArray1I64::empty();
1243        self.mmap_residuals = crate::mmap::MmapNpyArray2U8::empty();
1244        self.codec.centroids = crate::codec::CentroidStore::Owned(Array2::zeros((0, 0)));
1245    }
1246
1247    /// Reconstruct embeddings for specific documents.
1248    ///
1249    /// This method retrieves the compressed codes and residuals for each document
1250    /// from memory-mapped files and decompresses them to recover the original embeddings.
1251    ///
1252    /// # Arguments
1253    ///
1254    /// * `doc_ids` - Slice of document IDs to reconstruct (0-indexed)
1255    ///
1256    /// # Returns
1257    ///
1258    /// A vector of 2D arrays, one per document. Each array has shape `[num_tokens, dim]`.
1259    ///
1260    /// # Example
1261    ///
1262    /// ```ignore
1263    /// use next_plaid::MmapIndex;
1264    ///
1265    /// let index = MmapIndex::load("/path/to/index")?;
1266    /// let embeddings = index.reconstruct(&[0, 1, 2])?;
1267    ///
1268    /// for (i, emb) in embeddings.iter().enumerate() {
1269    ///     println!("Document {}: {} tokens x {} dim", i, emb.nrows(), emb.ncols());
1270    /// }
1271    /// ```
1272    pub fn reconstruct(&self, doc_ids: &[i64]) -> Result<Vec<Array2<f32>>> {
1273        crate::embeddings::reconstruct_embeddings(self, doc_ids)
1274    }
1275
1276    /// Reconstruct a single document's embeddings.
1277    ///
1278    /// Convenience method for reconstructing a single document.
1279    ///
1280    /// # Arguments
1281    ///
1282    /// * `doc_id` - Document ID to reconstruct (0-indexed)
1283    ///
1284    /// # Returns
1285    ///
1286    /// A 2D array with shape `[num_tokens, dim]`.
1287    pub fn reconstruct_single(&self, doc_id: i64) -> Result<Array2<f32>> {
1288        crate::embeddings::reconstruct_single(self, doc_id)
1289    }
1290
1291    /// Create a new index from document embeddings with automatic centroid computation.
1292    ///
1293    /// This method:
1294    /// 1. Computes centroids using K-means
1295    /// 2. Creates index files on disk
1296    /// 3. Loads the index using memory-mapped I/O
1297    ///
1298    /// Note: During creation, data is temporarily held in RAM for processing,
1299    /// then written to disk and loaded as mmap.
1300    ///
1301    /// # Arguments
1302    ///
1303    /// * `embeddings` - List of document embeddings, each of shape `[num_tokens, dim]`
1304    /// * `index_path` - Directory to save the index
1305    /// * `config` - Index configuration
1306    ///
1307    /// # Returns
1308    ///
1309    /// The created MmapIndex
1310    pub fn create_with_kmeans(
1311        embeddings: &[Array2<f32>],
1312        index_path: &str,
1313        config: &IndexConfig,
1314    ) -> Result<Self> {
1315        // Use standalone function to create files
1316        create_index_with_kmeans_files(embeddings, index_path, config)?;
1317
1318        // Load as memory-mapped index
1319        Self::load(index_path)
1320    }
1321
1322    /// Update the index with new documents, matching fast-plaid behavior.
1323    ///
1324    /// This method adds new documents to an existing index with three possible paths:
1325    ///
1326    /// 1. **Start-from-scratch mode** (num_documents <= start_from_scratch):
1327    ///    - Loads existing embeddings from `embeddings.npy` if available
1328    ///    - Combines with new embeddings
1329    ///    - Rebuilds the entire index from scratch with fresh K-means
1330    ///    - Clears `embeddings.npy` if total exceeds threshold
1331    ///
1332    /// 2. **Buffer mode** (total_new < buffer_size):
1333    ///    - Adds new documents to the index without centroid expansion
1334    ///    - Saves embeddings to buffer for later centroid expansion
1335    ///
1336    /// 3. **Centroid expansion mode** (total_new >= buffer_size):
1337    ///    - Deletes previously buffered documents
1338    ///    - Expands centroids with outliers from combined buffer + new embeddings
1339    ///    - Re-indexes all combined embeddings with expanded centroids
1340    ///
1341    /// # Arguments
1342    ///
1343    /// * `embeddings` - New document embeddings to add
1344    /// * `config` - Update configuration
1345    ///
1346    /// # Returns
1347    ///
1348    /// Vector of document IDs assigned to the new embeddings
1349    pub fn update(
1350        &mut self,
1351        embeddings: &[Array2<f32>],
1352        config: &crate::update::UpdateConfig,
1353    ) -> Result<Vec<i64>> {
1354        use crate::codec::ResidualCodec;
1355        use crate::update::{
1356            clear_buffer, clear_embeddings_npy, embeddings_npy_exists, load_buffer,
1357            load_buffer_info, load_cluster_threshold, load_embeddings_npy, save_buffer,
1358            update_centroids, update_index,
1359        };
1360
1361        let path_str = self.path.clone();
1362        let index_path = std::path::Path::new(&path_str);
1363        let num_new_docs = embeddings.len();
1364
1365        // Release mmap handles before any file operations (delete, rename,
1366        // truncate). On Windows, files that are memory-mapped cannot be
1367        // modified, causing OS error 1224 (ERROR_USER_MAPPED_FILE).
1368        // The index will be fully reloaded from disk at the end of this method.
1369        self.release_mmaps();
1370
1371        // ==================================================================
1372        // Start-from-scratch mode (fast-plaid update.py:312-346)
1373        // ==================================================================
1374        if self.metadata.num_documents <= config.start_from_scratch {
1375            // Load existing embeddings if available
1376            let existing_embeddings = load_embeddings_npy(index_path)?;
1377
1378            // Check if embeddings.npy is in sync with the index.
1379            // If not (e.g., after delete when index was above threshold), we can't do
1380            // start-from-scratch mode because we don't have all the old embeddings.
1381            // Fall through to buffer mode instead.
1382            if existing_embeddings.len() == self.metadata.num_documents {
1383                // New documents start after existing documents
1384                let start_doc_id = existing_embeddings.len() as i64;
1385
1386                // Combine existing + new embeddings
1387                let combined_embeddings: Vec<Array2<f32>> = existing_embeddings
1388                    .into_iter()
1389                    .chain(embeddings.iter().cloned())
1390                    .collect();
1391
1392                // Build IndexConfig from UpdateConfig for create_with_kmeans
1393                let index_config = IndexConfig {
1394                    nbits: self.metadata.nbits,
1395                    batch_size: config.batch_size,
1396                    seed: Some(config.seed),
1397                    kmeans_niters: config.kmeans_niters,
1398                    max_points_per_centroid: config.max_points_per_centroid,
1399                    n_samples_kmeans: config.n_samples_kmeans,
1400                    start_from_scratch: config.start_from_scratch,
1401                    force_cpu: config.force_cpu,
1402                    ..Default::default()
1403                };
1404
1405                // Rebuild index from scratch with fresh K-means
1406                *self = Self::create_with_kmeans(&combined_embeddings, &path_str, &index_config)?;
1407
1408                // If we've crossed the threshold, clear embeddings.npy
1409                if combined_embeddings.len() > config.start_from_scratch
1410                    && embeddings_npy_exists(index_path)
1411                {
1412                    clear_embeddings_npy(index_path)?;
1413                }
1414
1415                // Return the document IDs assigned to the new embeddings
1416                return Ok((start_doc_id..start_doc_id + num_new_docs as i64).collect());
1417            }
1418            // else: embeddings.npy is out of sync, fall through to buffer mode
1419        }
1420
1421        // Load buffer
1422        let buffer = load_buffer(index_path)?;
1423        let buffer_len = buffer.len();
1424        let total_new = embeddings.len() + buffer_len;
1425
1426        // Track the starting document ID for the new embeddings
1427        let start_doc_id: i64;
1428
1429        // Load codec for update operations
1430        let mut codec = ResidualCodec::load_from_dir(index_path)?;
1431
1432        // Check buffer threshold
1433        if total_new >= config.buffer_size {
1434            // Centroid expansion path (matches fast-plaid update.py:376-422)
1435
1436            // 1. Get number of buffered docs that were previously indexed
1437            let num_buffered = load_buffer_info(index_path)?;
1438
1439            // 2. Delete buffered docs from index (they were indexed without centroid expansion)
1440            if num_buffered > 0 && self.metadata.num_documents >= num_buffered {
1441                let start_del_idx = self.metadata.num_documents - num_buffered;
1442                let docs_to_delete: Vec<i64> = (start_del_idx..self.metadata.num_documents)
1443                    .map(|i| i as i64)
1444                    .collect();
1445                crate::delete::delete_from_index_keep_buffer(&docs_to_delete, &path_str)?;
1446                // Reload metadata after delete
1447                self.metadata = Metadata::load_from_path(index_path)?;
1448            }
1449
1450            // New embeddings start after buffer is re-indexed
1451            start_doc_id = (self.metadata.num_documents + buffer_len) as i64;
1452
1453            // 3. Combine buffer + new embeddings
1454            let combined: Vec<Array2<f32>> = buffer
1455                .into_iter()
1456                .chain(embeddings.iter().cloned())
1457                .collect();
1458
1459            // 4. Expand centroids with outliers from combined embeddings
1460            if let Ok(cluster_threshold) = load_cluster_threshold(index_path) {
1461                let new_centroids =
1462                    update_centroids(index_path, &combined, cluster_threshold, config)?;
1463                if new_centroids > 0 {
1464                    // Reload codec with new centroids
1465                    codec = ResidualCodec::load_from_dir(index_path)?;
1466                }
1467            }
1468
1469            // 5. Clear buffer
1470            clear_buffer(index_path)?;
1471
1472            // 6. Update index with ALL combined embeddings (buffer + new)
1473            update_index(
1474                &combined,
1475                &path_str,
1476                &codec,
1477                Some(config.batch_size),
1478                true,
1479                config.force_cpu,
1480            )?;
1481        } else {
1482            // Small update: add to buffer and index without centroid expansion
1483            // New documents start at current num_documents
1484            start_doc_id = self.metadata.num_documents as i64;
1485
1486            // Accumulate buffer: combine existing buffer with new embeddings
1487            let combined_buffer: Vec<Array2<f32>> = buffer
1488                .into_iter()
1489                .chain(embeddings.iter().cloned())
1490                .collect();
1491            save_buffer(index_path, &combined_buffer)?;
1492
1493            // Update index without threshold update
1494            update_index(
1495                embeddings,
1496                &path_str,
1497                &codec,
1498                Some(config.batch_size),
1499                false,
1500                config.force_cpu,
1501            )?;
1502        }
1503
1504        // Reload self as mmap
1505        *self = Self::load(&path_str)?;
1506
1507        // Return the document IDs assigned to the new embeddings
1508        Ok((start_doc_id..start_doc_id + num_new_docs as i64).collect())
1509    }
1510
1511    /// Update the index with new documents and optional metadata.
1512    ///
1513    /// # Arguments
1514    ///
1515    /// * `embeddings` - New document embeddings to add
1516    /// * `config` - Update configuration
1517    /// * `metadata` - Optional metadata for new documents
1518    ///
1519    /// # Returns
1520    ///
1521    /// Vector of document IDs assigned to the new embeddings
1522    pub fn update_with_metadata(
1523        &mut self,
1524        embeddings: &[Array2<f32>],
1525        config: &crate::update::UpdateConfig,
1526        metadata: Option<&[serde_json::Value]>,
1527    ) -> Result<Vec<i64>> {
1528        // Validate metadata length if provided
1529        if let Some(meta) = metadata {
1530            if meta.len() != embeddings.len() {
1531                return Err(Error::Config(format!(
1532                    "Metadata length ({}) must match embeddings length ({})",
1533                    meta.len(),
1534                    embeddings.len()
1535                )));
1536            }
1537        }
1538
1539        // Perform the update and get document IDs
1540        let doc_ids = self.update(embeddings, config)?;
1541
1542        // Add metadata if provided, using the assigned document IDs
1543        if let Some(meta) = metadata {
1544            crate::filtering::update(&self.path, meta, &doc_ids)?;
1545        }
1546
1547        Ok(doc_ids)
1548    }
1549
1550    /// Update an existing index or create a new one if it doesn't exist.
1551    ///
1552    /// # Arguments
1553    ///
1554    /// * `embeddings` - Document embeddings to add
1555    /// * `index_path` - Directory for the index
1556    /// * `index_config` - Configuration for index creation
1557    /// * `update_config` - Configuration for updates
1558    ///
1559    /// # Returns
1560    ///
1561    /// A tuple of (MmapIndex, `Vec<i64>`) containing the index and document IDs
1562    pub fn update_or_create(
1563        embeddings: &[Array2<f32>],
1564        index_path: &str,
1565        index_config: &IndexConfig,
1566        update_config: &crate::update::UpdateConfig,
1567    ) -> Result<(Self, Vec<i64>)> {
1568        let index_dir = std::path::Path::new(index_path);
1569        let metadata_path = index_dir.join("metadata.json");
1570
1571        if metadata_path.exists() {
1572            // Index exists, load and update
1573            let mut index = Self::load(index_path)?;
1574            let doc_ids = index.update(embeddings, update_config)?;
1575            Ok((index, doc_ids))
1576        } else {
1577            // Index doesn't exist, create new
1578            let num_docs = embeddings.len();
1579            let index = Self::create_with_kmeans(embeddings, index_path, index_config)?;
1580            let doc_ids: Vec<i64> = (0..num_docs as i64).collect();
1581            Ok((index, doc_ids))
1582        }
1583    }
1584
1585    /// Update an existing index or create a new one, with metadata and automatic
1586    /// FTS5 full-text indexing.
1587    ///
1588    /// This is the primary entry point for streaming document ingestion. On each
1589    /// call, embeddings and their metadata are added to the index. The FTS5
1590    /// full-text search index over metadata is kept in sync automatically.
1591    ///
1592    /// # Arguments
1593    ///
1594    /// * `embeddings` - Document embeddings to add
1595    /// * `index_path` - Directory for the index
1596    /// * `index_config` - Configuration for index creation (used only on first call)
1597    /// * `update_config` - Configuration for updates
1598    /// * `metadata` - Optional metadata for the documents (one JSON object per embedding)
1599    ///
1600    /// # Returns
1601    ///
1602    /// A tuple of (MmapIndex, `Vec<i64>`) containing the index and assigned document IDs
1603    pub fn update_or_create_with_metadata(
1604        embeddings: &[Array2<f32>],
1605        index_path: &str,
1606        index_config: &IndexConfig,
1607        update_config: &crate::update::UpdateConfig,
1608        metadata: Option<&[serde_json::Value]>,
1609    ) -> Result<(Self, Vec<i64>)> {
1610        if let Some(meta) = metadata {
1611            if meta.len() != embeddings.len() {
1612                return Err(Error::Config(format!(
1613                    "Metadata length ({}) must match embeddings length ({})",
1614                    meta.len(),
1615                    embeddings.len()
1616                )));
1617            }
1618        }
1619
1620        let index_dir = std::path::Path::new(index_path);
1621        let metadata_json_path = index_dir.join("metadata.json");
1622
1623        let (index, doc_ids) = if metadata_json_path.exists() {
1624            let mut index = Self::load(index_path)?;
1625            let doc_ids = index.update(embeddings, update_config)?;
1626            (index, doc_ids)
1627        } else {
1628            let num_docs = embeddings.len();
1629            let index = Self::create_with_kmeans(embeddings, index_path, index_config)?;
1630            let doc_ids: Vec<i64> = (0..num_docs as i64).collect();
1631            (index, doc_ids)
1632        };
1633
1634        if let Some(meta) = metadata {
1635            if crate::filtering::exists(index_path) {
1636                crate::filtering::update(index_path, meta, &doc_ids)?;
1637            } else {
1638                crate::filtering::create(index_path, meta, &doc_ids)?;
1639            }
1640            // Index metadata into FTS5 for full-text search
1641            crate::text_search::index(index_path, meta, &doc_ids, &index_config.fts_tokenizer)?;
1642        }
1643
1644        Ok((index, doc_ids))
1645    }
1646
1647    /// Reload the index from disk.
1648    ///
1649    /// This should be called after delete operations to refresh the in-memory
1650    /// representation with the updated on-disk state.
1651    pub fn reload(&mut self) -> Result<()> {
1652        let path = self.path.clone();
1653        // Release mmap handles before reloading so that merge_*_chunks can
1654        // rename/overwrite the merged files on Windows (OS error 1224).
1655        self.release_mmaps();
1656        *self = Self::load(&path)?;
1657        Ok(())
1658    }
1659
1660    /// Delete documents from the index.
1661    ///
1662    /// This performs the deletion on disk but does NOT reload the in-memory index.
1663    /// Call `reload()` after all delete operations are complete to refresh the index.
1664    ///
1665    /// # Arguments
1666    ///
1667    /// * `doc_ids` - Slice of document IDs to delete (0-indexed)
1668    ///
1669    /// # Returns
1670    ///
1671    /// The number of documents actually deleted
1672    pub fn delete(&mut self, doc_ids: &[i64]) -> Result<usize> {
1673        self.delete_with_options(doc_ids, true)
1674    }
1675
1676    /// Delete documents from the index with control over metadata deletion.
1677    ///
1678    /// This performs the deletion on disk but does NOT reload the in-memory index.
1679    /// Call `reload()` after all delete operations are complete to refresh the index.
1680    ///
1681    /// # Arguments
1682    ///
1683    /// * `doc_ids` - Slice of document IDs to delete
1684    /// * `delete_metadata` - If true, also delete from metadata.db if it exists
1685    ///
1686    /// # Returns
1687    ///
1688    /// The number of documents actually deleted
1689    pub fn delete_with_options(&mut self, doc_ids: &[i64], delete_metadata: bool) -> Result<usize> {
1690        let path = self.path.clone();
1691
1692        // Release mmap handles before deletion. delete_from_index calls
1693        // clear_merged_files which removes the memory-mapped merged files.
1694        // On Windows this fails with OS error 1224 if the mmaps are active.
1695        self.release_mmaps();
1696
1697        // Perform the deletion using standalone function
1698        let deleted = crate::delete::delete_from_index(doc_ids, &path)?;
1699
1700        // Also delete from metadata.db if requested
1701        if delete_metadata && deleted > 0 {
1702            let index_path = std::path::Path::new(&path);
1703            let db_path = index_path.join("metadata.db");
1704            if db_path.exists() {
1705                crate::filtering::delete(&path, doc_ids)?;
1706                // Rebuild FTS5 index after metadata re-indexing
1707                crate::text_search::rebuild(&path)?;
1708            }
1709        }
1710
1711        Ok(deleted)
1712    }
1713}
1714
1715#[cfg(test)]
1716mod tests {
1717    use super::*;
1718
1719    #[test]
1720    fn test_index_config_default() {
1721        let config = IndexConfig::default();
1722        assert_eq!(config.nbits, 4);
1723        assert_eq!(config.batch_size, 50_000);
1724        assert_eq!(config.seed, Some(42));
1725    }
1726
1727    #[test]
1728    fn test_update_or_create_new_index() {
1729        use ndarray::Array2;
1730        use tempfile::tempdir;
1731
1732        let temp_dir = tempdir().unwrap();
1733        let index_path = temp_dir.path().to_str().unwrap();
1734
1735        // Create test embeddings (5 documents)
1736        let mut embeddings: Vec<Array2<f32>> = Vec::new();
1737        for i in 0..5 {
1738            let mut doc = Array2::<f32>::zeros((5, 32));
1739            for j in 0..5 {
1740                for k in 0..32 {
1741                    doc[[j, k]] = (i as f32 * 0.1) + (j as f32 * 0.01) + (k as f32 * 0.001);
1742                }
1743            }
1744            // Normalize rows
1745            for mut row in doc.rows_mut() {
1746                let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
1747                if norm > 0.0 {
1748                    row.iter_mut().for_each(|x| *x /= norm);
1749                }
1750            }
1751            embeddings.push(doc);
1752        }
1753
1754        let index_config = IndexConfig {
1755            nbits: 2,
1756            batch_size: 50,
1757            seed: Some(42),
1758            kmeans_niters: 2,
1759            ..Default::default()
1760        };
1761        let update_config = crate::update::UpdateConfig::default();
1762
1763        // Index doesn't exist - should create new
1764        let (index, doc_ids) =
1765            MmapIndex::update_or_create(&embeddings, index_path, &index_config, &update_config)
1766                .expect("Failed to create index");
1767
1768        assert_eq!(index.metadata.num_documents, 5);
1769        assert_eq!(doc_ids, vec![0, 1, 2, 3, 4]);
1770
1771        // Verify index was created
1772        assert!(temp_dir.path().join("metadata.json").exists());
1773        assert!(temp_dir.path().join("centroids.npy").exists());
1774    }
1775
1776    #[test]
1777    fn test_update_or_create_existing_index() {
1778        use ndarray::Array2;
1779        use tempfile::tempdir;
1780
1781        let temp_dir = tempdir().unwrap();
1782        let index_path = temp_dir.path().to_str().unwrap();
1783
1784        // Helper to create embeddings
1785        let create_embeddings = |count: usize, offset: usize| -> Vec<Array2<f32>> {
1786            let mut embeddings = Vec::new();
1787            for i in 0..count {
1788                let mut doc = Array2::<f32>::zeros((5, 32));
1789                for j in 0..5 {
1790                    for k in 0..32 {
1791                        doc[[j, k]] =
1792                            ((i + offset) as f32 * 0.1) + (j as f32 * 0.01) + (k as f32 * 0.001);
1793                    }
1794                }
1795                for mut row in doc.rows_mut() {
1796                    let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
1797                    if norm > 0.0 {
1798                        row.iter_mut().for_each(|x| *x /= norm);
1799                    }
1800                }
1801                embeddings.push(doc);
1802            }
1803            embeddings
1804        };
1805
1806        let index_config = IndexConfig {
1807            nbits: 2,
1808            batch_size: 50,
1809            seed: Some(42),
1810            kmeans_niters: 2,
1811            ..Default::default()
1812        };
1813        let update_config = crate::update::UpdateConfig::default();
1814
1815        // First call - creates index with 5 documents
1816        let embeddings1 = create_embeddings(5, 0);
1817        let (index1, doc_ids1) =
1818            MmapIndex::update_or_create(&embeddings1, index_path, &index_config, &update_config)
1819                .expect("Failed to create index");
1820        assert_eq!(index1.metadata.num_documents, 5);
1821        assert_eq!(doc_ids1, vec![0, 1, 2, 3, 4]);
1822
1823        // Drop previous index to release mmap handles before updating.
1824        // On Windows, files cannot be modified while memory-mapped.
1825        drop(index1);
1826
1827        // Second call - updates existing index with 3 more documents
1828        let embeddings2 = create_embeddings(3, 5);
1829        let (index2, doc_ids2) =
1830            MmapIndex::update_or_create(&embeddings2, index_path, &index_config, &update_config)
1831                .expect("Failed to update index");
1832        assert_eq!(index2.metadata.num_documents, 8);
1833        assert_eq!(doc_ids2, vec![5, 6, 7]);
1834    }
1835}