1use std::collections::BTreeMap;
4use std::fs::{self, File};
5use std::io::{BufReader, BufWriter, Write};
6use std::path::Path;
7
8use ndarray::{s, Array1, Array2, Axis};
9use serde::{Deserialize, Serialize};
10
11use crate::codec::ResidualCodec;
12use crate::error::{Error, Result};
13use crate::kmeans::{compute_kmeans, ComputeKmeansConfig};
14use crate::utils::{quantile, quantiles};
15
16fn compress_and_residuals_cpu(
18 embeddings: &Array2<f32>,
19 codec: &ResidualCodec,
20) -> (Array1<usize>, Array2<f32>) {
21 use rayon::prelude::*;
22
23 let codes = codec.compress_into_codes_cpu(embeddings);
25 let mut residuals = embeddings.clone();
26
27 let centroids = &codec.centroids;
28 residuals
29 .axis_iter_mut(Axis(0))
30 .into_par_iter()
31 .zip(codes.as_slice().unwrap().par_iter())
32 .for_each(|(mut row, &code)| {
33 let centroid = centroids.row(code);
34 row.iter_mut()
35 .zip(centroid.iter())
36 .for_each(|(r, c)| *r -= c);
37 });
38
39 (codes, residuals)
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct IndexConfig {
45 pub nbits: usize,
47 pub batch_size: usize,
49 pub seed: Option<u64>,
51 #[serde(default = "default_kmeans_niters")]
53 pub kmeans_niters: usize,
54 #[serde(default = "default_max_points_per_centroid")]
56 pub max_points_per_centroid: usize,
57 #[serde(default)]
60 pub n_samples_kmeans: Option<usize>,
61 #[serde(default = "default_start_from_scratch")]
65 pub start_from_scratch: usize,
66 #[serde(default)]
69 pub force_cpu: bool,
70}
71
72fn default_start_from_scratch() -> usize {
73 999
74}
75
76fn default_kmeans_niters() -> usize {
77 4
78}
79
80fn default_max_points_per_centroid() -> usize {
81 256
82}
83
84impl Default for IndexConfig {
85 fn default() -> Self {
86 Self {
87 nbits: 4,
88 batch_size: 50_000,
89 seed: Some(42),
90 kmeans_niters: 4,
91 max_points_per_centroid: 256,
92 n_samples_kmeans: None,
93 start_from_scratch: 999,
94 force_cpu: false,
95 }
96 }
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct Metadata {
102 pub num_chunks: usize,
104 pub nbits: usize,
106 pub num_partitions: usize,
108 pub num_embeddings: usize,
110 pub avg_doclen: f64,
112 #[serde(default)]
114 pub num_documents: usize,
115 #[serde(default)]
117 pub embedding_dim: usize,
118 #[serde(default)]
121 pub next_plaid_compatible: bool,
122}
123
124impl Metadata {
125 pub fn load_from_path(index_path: &Path) -> Result<Self> {
127 let metadata_path = index_path.join("metadata.json");
128 let mut metadata: Metadata = serde_json::from_reader(BufReader::new(
129 File::open(&metadata_path)
130 .map_err(|e| Error::IndexLoad(format!("Failed to open metadata: {}", e)))?,
131 ))?;
132
133 if metadata.num_documents == 0 {
135 let mut total_docs = 0usize;
136 for chunk_idx in 0..metadata.num_chunks {
137 let doclens_path = index_path.join(format!("doclens.{}.json", chunk_idx));
138 if let Ok(file) = File::open(&doclens_path) {
139 if let Ok(chunk_doclens) =
140 serde_json::from_reader::<_, Vec<i64>>(BufReader::new(file))
141 {
142 total_docs += chunk_doclens.len();
143 }
144 }
145 }
146 metadata.num_documents = total_docs;
147 }
148
149 Ok(metadata)
150 }
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct ChunkMetadata {
156 pub num_documents: usize,
157 pub num_embeddings: usize,
158 #[serde(default)]
159 pub embedding_offset: usize,
160}
161
162pub fn create_index_files(
184 embeddings: &[Array2<f32>],
185 centroids: Array2<f32>,
186 index_path: &str,
187 config: &IndexConfig,
188) -> Result<Metadata> {
189 let index_dir = Path::new(index_path);
190 fs::create_dir_all(index_dir)?;
191
192 let num_documents = embeddings.len();
193 let embedding_dim = centroids.ncols();
194 let num_centroids = centroids.nrows();
195
196 if num_documents == 0 {
197 return Err(Error::IndexCreation("No documents provided".into()));
198 }
199
200 let total_embeddings: usize = embeddings.iter().map(|e| e.nrows()).sum();
202 let avg_doclen = total_embeddings as f64 / num_documents as f64;
203
204 let sample_count = ((16.0 * (120.0 * num_documents as f64).sqrt()) as usize)
206 .min(num_documents)
207 .max(1);
208
209 let mut rng = if let Some(seed) = config.seed {
210 use rand::SeedableRng;
211 rand_chacha::ChaCha8Rng::seed_from_u64(seed)
212 } else {
213 use rand::SeedableRng;
214 rand_chacha::ChaCha8Rng::from_entropy()
215 };
216
217 use rand::seq::SliceRandom;
218 let mut indices: Vec<usize> = (0..num_documents).collect();
219 indices.shuffle(&mut rng);
220 let sample_indices: Vec<usize> = indices.into_iter().take(sample_count).collect();
221
222 let heldout_size = (0.05 * total_embeddings as f64).min(50000.0) as usize;
224 let mut heldout_embeddings: Vec<f32> = Vec::with_capacity(heldout_size * embedding_dim);
225 let mut collected = 0;
226
227 for &idx in sample_indices.iter().rev() {
228 if collected >= heldout_size {
229 break;
230 }
231 let emb = &embeddings[idx];
232 let take = (heldout_size - collected).min(emb.nrows());
233 for row in emb.axis_iter(Axis(0)).take(take) {
234 heldout_embeddings.extend(row.iter());
235 }
236 collected += take;
237 }
238
239 let heldout = Array2::from_shape_vec((collected, embedding_dim), heldout_embeddings)
240 .map_err(|e| Error::IndexCreation(format!("Failed to create heldout array: {}", e)))?;
241
242 let avg_residual = Array1::zeros(embedding_dim);
244 let initial_codec =
245 ResidualCodec::new(config.nbits, centroids.clone(), avg_residual, None, None)?;
246
247 let heldout_codes = if config.force_cpu {
250 initial_codec.compress_into_codes_cpu(&heldout)
251 } else {
252 initial_codec.compress_into_codes(&heldout)
253 };
254
255 let mut residuals = heldout.clone();
257 for i in 0..heldout.nrows() {
258 let centroid = initial_codec.centroids.row(heldout_codes[i]);
259 for j in 0..embedding_dim {
260 residuals[[i, j]] -= centroid[j];
261 }
262 }
263
264 let distances: Array1<f32> = residuals
266 .axis_iter(Axis(0))
267 .map(|row| row.dot(&row).sqrt())
268 .collect();
269 #[allow(unused_variables)]
270 let cluster_threshold = quantile(&distances, 0.75);
271
272 let avg_res_per_dim: Array1<f32> = residuals
274 .axis_iter(Axis(1))
275 .map(|col| col.iter().map(|x| x.abs()).sum::<f32>() / col.len() as f32)
276 .collect();
277
278 let n_options = 1 << config.nbits;
280 let quantile_values: Vec<f64> = (1..n_options)
281 .map(|i| i as f64 / n_options as f64)
282 .collect();
283 let weight_quantile_values: Vec<f64> = (0..n_options)
284 .map(|i| (i as f64 + 0.5) / n_options as f64)
285 .collect();
286
287 let flat_residuals: Array1<f32> = residuals.iter().copied().collect();
289 let bucket_cutoffs = Array1::from_vec(quantiles(&flat_residuals, &quantile_values));
290 let bucket_weights = Array1::from_vec(quantiles(&flat_residuals, &weight_quantile_values));
291
292 let codec = ResidualCodec::new(
293 config.nbits,
294 centroids.clone(),
295 avg_res_per_dim.clone(),
296 Some(bucket_cutoffs.clone()),
297 Some(bucket_weights.clone()),
298 )?;
299
300 use ndarray_npy::WriteNpyExt;
302
303 let centroids_path = index_dir.join("centroids.npy");
304 codec
305 .centroids_view()
306 .to_owned()
307 .write_npy(File::create(¢roids_path)?)?;
308
309 let cutoffs_path = index_dir.join("bucket_cutoffs.npy");
310 bucket_cutoffs.write_npy(File::create(&cutoffs_path)?)?;
311
312 let weights_path = index_dir.join("bucket_weights.npy");
313 bucket_weights.write_npy(File::create(&weights_path)?)?;
314
315 let avg_res_path = index_dir.join("avg_residual.npy");
316 avg_res_per_dim.write_npy(File::create(&avg_res_path)?)?;
317
318 let threshold_path = index_dir.join("cluster_threshold.npy");
319 Array1::from_vec(vec![cluster_threshold]).write_npy(File::create(&threshold_path)?)?;
320
321 let n_chunks = (num_documents as f64 / config.batch_size as f64).ceil() as usize;
323
324 let plan_path = index_dir.join("plan.json");
326 let plan = serde_json::json!({
327 "nbits": config.nbits,
328 "num_chunks": n_chunks,
329 });
330 let mut plan_file = File::create(&plan_path)?;
331 writeln!(plan_file, "{}", serde_json::to_string_pretty(&plan)?)?;
332
333 let mut all_codes: Vec<usize> = Vec::with_capacity(total_embeddings);
334 let mut doc_lengths: Vec<i64> = Vec::with_capacity(num_documents);
335
336 for chunk_idx in 0..n_chunks {
337 let start = chunk_idx * config.batch_size;
338 let end = (start + config.batch_size).min(num_documents);
339 let chunk_docs = &embeddings[start..end];
340
341 let chunk_doclens: Vec<i64> = chunk_docs.iter().map(|d| d.nrows() as i64).collect();
343 let total_tokens: usize = chunk_doclens.iter().sum::<i64>() as usize;
344
345 let mut batch_embeddings = Array2::<f32>::zeros((total_tokens, embedding_dim));
347 let mut offset = 0;
348 for doc in chunk_docs {
349 let n = doc.nrows();
350 batch_embeddings
351 .slice_mut(s![offset..offset + n, ..])
352 .assign(doc);
353 offset += n;
354 }
355
356 let (batch_codes, batch_residuals) = {
359 #[cfg(feature = "cuda")]
360 {
361 if !config.force_cpu {
362 if let Some(ctx) = crate::cuda::get_global_context() {
363 match crate::cuda::compress_and_residuals_cuda_batched(
364 ctx,
365 &batch_embeddings.view(),
366 &codec.centroids_view(),
367 None,
368 ) {
369 Ok(result) => result,
370 Err(e) => {
371 eprintln!(
372 "[next-plaid] CUDA compress_and_residuals failed: {}, falling back to CPU",
373 e
374 );
375 compress_and_residuals_cpu(&batch_embeddings, &codec)
376 }
377 }
378 } else {
379 compress_and_residuals_cpu(&batch_embeddings, &codec)
380 }
381 } else {
382 compress_and_residuals_cpu(&batch_embeddings, &codec)
383 }
384 }
385 #[cfg(not(feature = "cuda"))]
386 {
387 compress_and_residuals_cpu(&batch_embeddings, &codec)
388 }
389 };
390
391 let batch_packed = codec.quantize_residuals(&batch_residuals)?;
393
394 for &len in &chunk_doclens {
396 doc_lengths.push(len);
397 }
398 all_codes.extend(batch_codes.iter().copied());
399
400 let chunk_meta = ChunkMetadata {
402 num_documents: end - start,
403 num_embeddings: batch_codes.len(),
404 embedding_offset: 0, };
406
407 let chunk_meta_path = index_dir.join(format!("{}.metadata.json", chunk_idx));
408 serde_json::to_writer_pretty(BufWriter::new(File::create(&chunk_meta_path)?), &chunk_meta)?;
409
410 let doclens_path = index_dir.join(format!("doclens.{}.json", chunk_idx));
412 serde_json::to_writer(BufWriter::new(File::create(&doclens_path)?), &chunk_doclens)?;
413
414 let chunk_codes_arr: Array1<i64> = batch_codes.iter().map(|&x| x as i64).collect();
416 let codes_path = index_dir.join(format!("{}.codes.npy", chunk_idx));
417 chunk_codes_arr.write_npy(File::create(&codes_path)?)?;
418
419 let residuals_path = index_dir.join(format!("{}.residuals.npy", chunk_idx));
421 batch_packed.write_npy(File::create(&residuals_path)?)?;
422 }
423
424 let mut current_offset = 0usize;
426 for chunk_idx in 0..n_chunks {
427 let chunk_meta_path = index_dir.join(format!("{}.metadata.json", chunk_idx));
428 let mut meta: serde_json::Value =
429 serde_json::from_reader(BufReader::new(File::open(&chunk_meta_path)?))?;
430
431 if let Some(obj) = meta.as_object_mut() {
432 obj.insert("embedding_offset".to_string(), current_offset.into());
433 let num_emb = obj["num_embeddings"].as_u64().unwrap_or(0) as usize;
434 current_offset += num_emb;
435 }
436
437 serde_json::to_writer_pretty(BufWriter::new(File::create(&chunk_meta_path)?), &meta)?;
438 }
439
440 let mut code_to_docs: BTreeMap<usize, Vec<i64>> = BTreeMap::new();
442 let mut emb_idx = 0;
443
444 for (doc_id, &len) in doc_lengths.iter().enumerate() {
445 for _ in 0..len {
446 let code = all_codes[emb_idx];
447 code_to_docs.entry(code).or_default().push(doc_id as i64);
448 emb_idx += 1;
449 }
450 }
451
452 let mut ivf_data: Vec<i64> = Vec::new();
454 let mut ivf_lengths: Vec<i32> = vec![0; num_centroids];
455
456 for (centroid_id, ivf_len) in ivf_lengths.iter_mut().enumerate() {
457 if let Some(docs) = code_to_docs.get(¢roid_id) {
458 let mut unique_docs: Vec<i64> = docs.clone();
459 unique_docs.sort_unstable();
460 unique_docs.dedup();
461 *ivf_len = unique_docs.len() as i32;
462 ivf_data.extend(unique_docs);
463 }
464 }
465
466 let ivf = Array1::from_vec(ivf_data);
467 let ivf_lengths = Array1::from_vec(ivf_lengths);
468
469 let ivf_path = index_dir.join("ivf.npy");
470 ivf.write_npy(File::create(&ivf_path)?)?;
471
472 let ivf_lengths_path = index_dir.join("ivf_lengths.npy");
473 ivf_lengths.write_npy(File::create(&ivf_lengths_path)?)?;
474
475 let metadata = Metadata {
477 num_chunks: n_chunks,
478 nbits: config.nbits,
479 num_partitions: num_centroids,
480 num_embeddings: total_embeddings,
481 avg_doclen,
482 num_documents,
483 embedding_dim,
484 next_plaid_compatible: true, };
486
487 let metadata_path = index_dir.join("metadata.json");
488 serde_json::to_writer_pretty(BufWriter::new(File::create(&metadata_path)?), &metadata)?;
489
490 Ok(metadata)
491}
492
493pub fn create_index_with_kmeans_files(
508 embeddings: &[Array2<f32>],
509 index_path: &str,
510 config: &IndexConfig,
511) -> Result<Metadata> {
512 if embeddings.is_empty() {
513 return Err(Error::IndexCreation("No documents provided".into()));
514 }
515
516 #[cfg(feature = "cuda")]
519 if !config.force_cpu {
520 let _ = crate::cuda::get_global_context();
521 }
522
523 let kmeans_config = ComputeKmeansConfig {
525 kmeans_niters: config.kmeans_niters,
526 max_points_per_centroid: config.max_points_per_centroid,
527 seed: config.seed.unwrap_or(42),
528 n_samples_kmeans: config.n_samples_kmeans,
529 num_partitions: None, force_cpu: config.force_cpu,
531 };
532
533 let centroids = compute_kmeans(embeddings, &kmeans_config)?;
535
536 let metadata = create_index_files(embeddings, centroids, index_path, config)?;
538
539 if embeddings.len() <= config.start_from_scratch {
541 let index_dir = std::path::Path::new(index_path);
542 crate::update::save_embeddings_npy(index_dir, embeddings)?;
543 }
544
545 Ok(metadata)
546}
547pub struct MmapIndex {
571 pub path: String,
573 pub metadata: Metadata,
575 pub codec: ResidualCodec,
577 pub ivf: Array1<i64>,
579 pub ivf_lengths: Array1<i32>,
581 pub ivf_offsets: Array1<i64>,
583 pub doc_lengths: Array1<i64>,
585 pub doc_offsets: Array1<usize>,
587 pub mmap_codes: crate::mmap::MmapNpyArray1I64,
589 pub mmap_residuals: crate::mmap::MmapNpyArray2U8,
591}
592
593impl MmapIndex {
594 pub fn load(index_path: &str) -> Result<Self> {
602 use ndarray_npy::ReadNpyExt;
603
604 let index_dir = Path::new(index_path);
605
606 let mut metadata = Metadata::load_from_path(index_dir)?;
608
609 if !metadata.next_plaid_compatible {
611 eprintln!("Checking index format compatibility...");
612 let converted = crate::mmap::convert_fastplaid_to_nextplaid(index_dir)?;
613 if converted {
614 eprintln!("Index converted to next-plaid compatible format.");
615 let merged_codes = index_dir.join("merged_codes.npy");
617 let merged_residuals = index_dir.join("merged_residuals.npy");
618 let codes_manifest = index_dir.join("merged_codes.manifest.json");
619 let residuals_manifest = index_dir.join("merged_residuals.manifest.json");
620 for path in [
621 &merged_codes,
622 &merged_residuals,
623 &codes_manifest,
624 &residuals_manifest,
625 ] {
626 if path.exists() {
627 let _ = fs::remove_file(path);
628 }
629 }
630 }
631
632 metadata.next_plaid_compatible = true;
634 let metadata_path = index_dir.join("metadata.json");
635 let file = File::create(&metadata_path)
636 .map_err(|e| Error::IndexLoad(format!("Failed to update metadata: {}", e)))?;
637 serde_json::to_writer_pretty(BufWriter::new(file), &metadata)?;
638 eprintln!("Metadata updated with next_plaid_compatible: true");
639 }
640
641 let codec = ResidualCodec::load_mmap_from_dir(index_dir)?;
644
645 let ivf_path = index_dir.join("ivf.npy");
647 let ivf: Array1<i64> = Array1::read_npy(
648 File::open(&ivf_path)
649 .map_err(|e| Error::IndexLoad(format!("Failed to open ivf.npy: {}", e)))?,
650 )
651 .map_err(|e| Error::IndexLoad(format!("Failed to read ivf.npy: {}", e)))?;
652
653 let ivf_lengths_path = index_dir.join("ivf_lengths.npy");
654 let ivf_lengths: Array1<i32> = Array1::read_npy(
655 File::open(&ivf_lengths_path)
656 .map_err(|e| Error::IndexLoad(format!("Failed to open ivf_lengths.npy: {}", e)))?,
657 )
658 .map_err(|e| Error::IndexLoad(format!("Failed to read ivf_lengths.npy: {}", e)))?;
659
660 let num_centroids = ivf_lengths.len();
662 let mut ivf_offsets = Array1::<i64>::zeros(num_centroids + 1);
663 for i in 0..num_centroids {
664 ivf_offsets[i + 1] = ivf_offsets[i] + ivf_lengths[i] as i64;
665 }
666
667 let mut doc_lengths_vec: Vec<i64> = Vec::with_capacity(metadata.num_documents);
669 for chunk_idx in 0..metadata.num_chunks {
670 let doclens_path = index_dir.join(format!("doclens.{}.json", chunk_idx));
671 let chunk_doclens: Vec<i64> =
672 serde_json::from_reader(BufReader::new(File::open(&doclens_path)?))?;
673 doc_lengths_vec.extend(chunk_doclens);
674 }
675 let doc_lengths = Array1::from_vec(doc_lengths_vec);
676
677 let mut doc_offsets = Array1::<usize>::zeros(doc_lengths.len() + 1);
679 for i in 0..doc_lengths.len() {
680 doc_offsets[i + 1] = doc_offsets[i] + doc_lengths[i] as usize;
681 }
682
683 let max_len = doc_lengths.iter().cloned().max().unwrap_or(0) as usize;
685 let last_len = *doc_lengths.last().unwrap_or(&0) as usize;
686 let padding_needed = max_len.saturating_sub(last_len);
687
688 let merged_codes_path =
690 crate::mmap::merge_codes_chunks(index_dir, metadata.num_chunks, padding_needed)?;
691 let merged_residuals_path =
692 crate::mmap::merge_residuals_chunks(index_dir, metadata.num_chunks, padding_needed)?;
693
694 let mmap_codes = crate::mmap::MmapNpyArray1I64::from_npy_file(&merged_codes_path)?;
696 let mmap_residuals = crate::mmap::MmapNpyArray2U8::from_npy_file(&merged_residuals_path)?;
697
698 Ok(Self {
699 path: index_path.to_string(),
700 metadata,
701 codec,
702 ivf,
703 ivf_lengths,
704 ivf_offsets,
705 doc_lengths,
706 doc_offsets,
707 mmap_codes,
708 mmap_residuals,
709 })
710 }
711
712 pub fn get_candidates(&self, centroid_indices: &[usize]) -> Vec<i64> {
714 let mut candidates: Vec<i64> = Vec::new();
715
716 for &idx in centroid_indices {
717 if idx < self.ivf_lengths.len() {
718 let start = self.ivf_offsets[idx] as usize;
719 let len = self.ivf_lengths[idx] as usize;
720 candidates.extend(self.ivf.slice(s![start..start + len]).iter());
721 }
722 }
723
724 candidates.sort_unstable();
725 candidates.dedup();
726 candidates
727 }
728
729 pub fn get_document_embeddings(&self, doc_id: usize) -> Result<Array2<f32>> {
731 if doc_id >= self.doc_lengths.len() {
732 return Err(Error::Search(format!("Invalid document ID: {}", doc_id)));
733 }
734
735 let start = self.doc_offsets[doc_id];
736 let end = self.doc_offsets[doc_id + 1];
737
738 let codes_slice = self.mmap_codes.slice(start, end);
740 let residuals_view = self.mmap_residuals.slice_rows(start, end);
741
742 let codes: Array1<usize> = Array1::from_iter(codes_slice.iter().map(|&c| c as usize));
744
745 let residuals = residuals_view.to_owned();
747
748 self.codec.decompress(&residuals, &codes.view())
750 }
751
752 pub fn get_document_codes(&self, doc_ids: &[usize]) -> Vec<Vec<i64>> {
754 doc_ids
755 .iter()
756 .map(|&doc_id| {
757 if doc_id >= self.doc_lengths.len() {
758 return vec![];
759 }
760 let start = self.doc_offsets[doc_id];
761 let end = self.doc_offsets[doc_id + 1];
762 self.mmap_codes.slice(start, end).to_vec()
763 })
764 .collect()
765 }
766
767 pub fn decompress_documents(&self, doc_ids: &[usize]) -> Result<(Array2<f32>, Vec<usize>)> {
769 let mut total_tokens = 0usize;
771 let mut lengths = Vec::with_capacity(doc_ids.len());
772 for &doc_id in doc_ids {
773 if doc_id >= self.doc_lengths.len() {
774 lengths.push(0);
775 } else {
776 let len = self.doc_offsets[doc_id + 1] - self.doc_offsets[doc_id];
777 lengths.push(len);
778 total_tokens += len;
779 }
780 }
781
782 if total_tokens == 0 {
783 return Ok((Array2::zeros((0, self.codec.embedding_dim())), lengths));
784 }
785
786 let packed_dim = self.mmap_residuals.ncols();
788 let mut all_codes = Vec::with_capacity(total_tokens);
789 let mut all_residuals = Array2::<u8>::zeros((total_tokens, packed_dim));
790 let mut offset = 0;
791
792 for &doc_id in doc_ids {
793 if doc_id >= self.doc_lengths.len() {
794 continue;
795 }
796 let start = self.doc_offsets[doc_id];
797 let end = self.doc_offsets[doc_id + 1];
798 let len = end - start;
799
800 let codes_slice = self.mmap_codes.slice(start, end);
802 all_codes.extend(codes_slice.iter().map(|&c| c as usize));
803
804 let residuals_view = self.mmap_residuals.slice_rows(start, end);
806 all_residuals
807 .slice_mut(s![offset..offset + len, ..])
808 .assign(&residuals_view);
809 offset += len;
810 }
811
812 let codes_arr = Array1::from_vec(all_codes);
813 let embeddings = self.codec.decompress(&all_residuals, &codes_arr.view())?;
814
815 Ok((embeddings, lengths))
816 }
817
818 pub fn search(
830 &self,
831 query: &Array2<f32>,
832 params: &crate::search::SearchParameters,
833 subset: Option<&[i64]>,
834 ) -> Result<crate::search::SearchResult> {
835 crate::search::search_one_mmap(self, query, params, subset)
836 }
837
838 pub fn search_batch(
851 &self,
852 queries: &[Array2<f32>],
853 params: &crate::search::SearchParameters,
854 parallel: bool,
855 subset: Option<&[i64]>,
856 ) -> Result<Vec<crate::search::SearchResult>> {
857 crate::search::search_many_mmap(self, queries, params, parallel, subset)
858 }
859
860 pub fn num_documents(&self) -> usize {
862 self.doc_lengths.len()
863 }
864
865 pub fn num_embeddings(&self) -> usize {
867 self.metadata.num_embeddings
868 }
869
870 pub fn num_partitions(&self) -> usize {
872 self.metadata.num_partitions
873 }
874
875 pub fn avg_doclen(&self) -> f64 {
877 self.metadata.avg_doclen
878 }
879
880 pub fn embedding_dim(&self) -> usize {
882 self.codec.embedding_dim()
883 }
884
885 pub fn reconstruct(&self, doc_ids: &[i64]) -> Result<Vec<Array2<f32>>> {
911 crate::embeddings::reconstruct_embeddings(self, doc_ids)
912 }
913
914 pub fn reconstruct_single(&self, doc_id: i64) -> Result<Array2<f32>> {
926 crate::embeddings::reconstruct_single(self, doc_id)
927 }
928
929 pub fn create_with_kmeans(
949 embeddings: &[Array2<f32>],
950 index_path: &str,
951 config: &IndexConfig,
952 ) -> Result<Self> {
953 create_index_with_kmeans_files(embeddings, index_path, config)?;
955
956 Self::load(index_path)
958 }
959
960 pub fn update(
988 &mut self,
989 embeddings: &[Array2<f32>],
990 config: &crate::update::UpdateConfig,
991 ) -> Result<Vec<i64>> {
992 use crate::codec::ResidualCodec;
993 use crate::update::{
994 clear_buffer, clear_embeddings_npy, embeddings_npy_exists, load_buffer,
995 load_buffer_info, load_cluster_threshold, load_embeddings_npy, save_buffer,
996 update_centroids, update_index,
997 };
998
999 let path_str = self.path.clone();
1000 let index_path = std::path::Path::new(&path_str);
1001 let num_new_docs = embeddings.len();
1002
1003 if self.metadata.num_documents <= config.start_from_scratch {
1007 let existing_embeddings = load_embeddings_npy(index_path)?;
1009
1010 if existing_embeddings.len() == self.metadata.num_documents {
1015 let start_doc_id = existing_embeddings.len() as i64;
1017
1018 let combined_embeddings: Vec<Array2<f32>> = existing_embeddings
1020 .into_iter()
1021 .chain(embeddings.iter().cloned())
1022 .collect();
1023
1024 let index_config = IndexConfig {
1026 nbits: self.metadata.nbits,
1027 batch_size: config.batch_size,
1028 seed: Some(config.seed),
1029 kmeans_niters: config.kmeans_niters,
1030 max_points_per_centroid: config.max_points_per_centroid,
1031 n_samples_kmeans: config.n_samples_kmeans,
1032 start_from_scratch: config.start_from_scratch,
1033 force_cpu: config.force_cpu,
1034 };
1035
1036 *self = Self::create_with_kmeans(&combined_embeddings, &path_str, &index_config)?;
1038
1039 if combined_embeddings.len() > config.start_from_scratch
1041 && embeddings_npy_exists(index_path)
1042 {
1043 clear_embeddings_npy(index_path)?;
1044 }
1045
1046 return Ok((start_doc_id..start_doc_id + num_new_docs as i64).collect());
1048 }
1049 }
1051
1052 let buffer = load_buffer(index_path)?;
1054 let buffer_len = buffer.len();
1055 let total_new = embeddings.len() + buffer_len;
1056
1057 let start_doc_id: i64;
1059
1060 let mut codec = ResidualCodec::load_from_dir(index_path)?;
1062
1063 if total_new >= config.buffer_size {
1065 let num_buffered = load_buffer_info(index_path)?;
1069
1070 if num_buffered > 0 && self.metadata.num_documents >= num_buffered {
1072 let start_del_idx = self.metadata.num_documents - num_buffered;
1073 let docs_to_delete: Vec<i64> = (start_del_idx..self.metadata.num_documents)
1074 .map(|i| i as i64)
1075 .collect();
1076 crate::delete::delete_from_index_keep_buffer(&docs_to_delete, &path_str)?;
1077 self.metadata = Metadata::load_from_path(index_path)?;
1079 }
1080
1081 start_doc_id = (self.metadata.num_documents + buffer_len) as i64;
1083
1084 let combined: Vec<Array2<f32>> = buffer
1086 .into_iter()
1087 .chain(embeddings.iter().cloned())
1088 .collect();
1089
1090 if let Ok(cluster_threshold) = load_cluster_threshold(index_path) {
1092 let new_centroids =
1093 update_centroids(index_path, &combined, cluster_threshold, config)?;
1094 if new_centroids > 0 {
1095 codec = ResidualCodec::load_from_dir(index_path)?;
1097 }
1098 }
1099
1100 clear_buffer(index_path)?;
1102
1103 update_index(
1105 &combined,
1106 &path_str,
1107 &codec,
1108 Some(config.batch_size),
1109 true,
1110 config.force_cpu,
1111 )?;
1112 } else {
1113 start_doc_id = self.metadata.num_documents as i64;
1116
1117 let combined_buffer: Vec<Array2<f32>> = buffer
1119 .into_iter()
1120 .chain(embeddings.iter().cloned())
1121 .collect();
1122 save_buffer(index_path, &combined_buffer)?;
1123
1124 update_index(
1126 embeddings,
1127 &path_str,
1128 &codec,
1129 Some(config.batch_size),
1130 false,
1131 config.force_cpu,
1132 )?;
1133 }
1134
1135 *self = Self::load(&path_str)?;
1137
1138 Ok((start_doc_id..start_doc_id + num_new_docs as i64).collect())
1140 }
1141
1142 pub fn update_with_metadata(
1154 &mut self,
1155 embeddings: &[Array2<f32>],
1156 config: &crate::update::UpdateConfig,
1157 metadata: Option<&[serde_json::Value]>,
1158 ) -> Result<Vec<i64>> {
1159 if let Some(meta) = metadata {
1161 if meta.len() != embeddings.len() {
1162 return Err(Error::Config(format!(
1163 "Metadata length ({}) must match embeddings length ({})",
1164 meta.len(),
1165 embeddings.len()
1166 )));
1167 }
1168 }
1169
1170 let doc_ids = self.update(embeddings, config)?;
1172
1173 if let Some(meta) = metadata {
1175 crate::filtering::update(&self.path, meta, &doc_ids)?;
1176 }
1177
1178 Ok(doc_ids)
1179 }
1180
1181 pub fn update_or_create(
1194 embeddings: &[Array2<f32>],
1195 index_path: &str,
1196 index_config: &IndexConfig,
1197 update_config: &crate::update::UpdateConfig,
1198 ) -> Result<(Self, Vec<i64>)> {
1199 let index_dir = std::path::Path::new(index_path);
1200 let metadata_path = index_dir.join("metadata.json");
1201
1202 if metadata_path.exists() {
1203 let mut index = Self::load(index_path)?;
1205 let doc_ids = index.update(embeddings, update_config)?;
1206 Ok((index, doc_ids))
1207 } else {
1208 let num_docs = embeddings.len();
1210 let index = Self::create_with_kmeans(embeddings, index_path, index_config)?;
1211 let doc_ids: Vec<i64> = (0..num_docs as i64).collect();
1212 Ok((index, doc_ids))
1213 }
1214 }
1215
1216 pub fn reload(&mut self) -> Result<()> {
1221 *self = Self::load(&self.path)?;
1222 Ok(())
1223 }
1224
1225 pub fn delete(&mut self, doc_ids: &[i64]) -> Result<usize> {
1238 self.delete_with_options(doc_ids, true)
1239 }
1240
1241 pub fn delete_with_options(&mut self, doc_ids: &[i64], delete_metadata: bool) -> Result<usize> {
1255 let path = self.path.clone();
1256
1257 let deleted = crate::delete::delete_from_index(doc_ids, &path)?;
1259
1260 if delete_metadata && deleted > 0 {
1262 let index_path = std::path::Path::new(&path);
1263 let db_path = index_path.join("metadata.db");
1264 if db_path.exists() {
1265 crate::filtering::delete(&path, doc_ids)?;
1266 }
1267 }
1268
1269 Ok(deleted)
1270 }
1271}
1272
1273#[cfg(test)]
1274mod tests {
1275 use super::*;
1276
1277 #[test]
1278 fn test_index_config_default() {
1279 let config = IndexConfig::default();
1280 assert_eq!(config.nbits, 4);
1281 assert_eq!(config.batch_size, 50_000);
1282 assert_eq!(config.seed, Some(42));
1283 }
1284
1285 #[test]
1286 fn test_update_or_create_new_index() {
1287 use ndarray::Array2;
1288 use tempfile::tempdir;
1289
1290 let temp_dir = tempdir().unwrap();
1291 let index_path = temp_dir.path().to_str().unwrap();
1292
1293 let mut embeddings: Vec<Array2<f32>> = Vec::new();
1295 for i in 0..5 {
1296 let mut doc = Array2::<f32>::zeros((5, 32));
1297 for j in 0..5 {
1298 for k in 0..32 {
1299 doc[[j, k]] = (i as f32 * 0.1) + (j as f32 * 0.01) + (k as f32 * 0.001);
1300 }
1301 }
1302 for mut row in doc.rows_mut() {
1304 let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
1305 if norm > 0.0 {
1306 row.iter_mut().for_each(|x| *x /= norm);
1307 }
1308 }
1309 embeddings.push(doc);
1310 }
1311
1312 let index_config = IndexConfig {
1313 nbits: 2,
1314 batch_size: 50,
1315 seed: Some(42),
1316 kmeans_niters: 2,
1317 ..Default::default()
1318 };
1319 let update_config = crate::update::UpdateConfig::default();
1320
1321 let (index, doc_ids) =
1323 MmapIndex::update_or_create(&embeddings, index_path, &index_config, &update_config)
1324 .expect("Failed to create index");
1325
1326 assert_eq!(index.metadata.num_documents, 5);
1327 assert_eq!(doc_ids, vec![0, 1, 2, 3, 4]);
1328
1329 assert!(temp_dir.path().join("metadata.json").exists());
1331 assert!(temp_dir.path().join("centroids.npy").exists());
1332 }
1333
1334 #[test]
1335 fn test_update_or_create_existing_index() {
1336 use ndarray::Array2;
1337 use tempfile::tempdir;
1338
1339 let temp_dir = tempdir().unwrap();
1340 let index_path = temp_dir.path().to_str().unwrap();
1341
1342 let create_embeddings = |count: usize, offset: usize| -> Vec<Array2<f32>> {
1344 let mut embeddings = Vec::new();
1345 for i in 0..count {
1346 let mut doc = Array2::<f32>::zeros((5, 32));
1347 for j in 0..5 {
1348 for k in 0..32 {
1349 doc[[j, k]] =
1350 ((i + offset) as f32 * 0.1) + (j as f32 * 0.01) + (k as f32 * 0.001);
1351 }
1352 }
1353 for mut row in doc.rows_mut() {
1354 let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
1355 if norm > 0.0 {
1356 row.iter_mut().for_each(|x| *x /= norm);
1357 }
1358 }
1359 embeddings.push(doc);
1360 }
1361 embeddings
1362 };
1363
1364 let index_config = IndexConfig {
1365 nbits: 2,
1366 batch_size: 50,
1367 seed: Some(42),
1368 kmeans_niters: 2,
1369 ..Default::default()
1370 };
1371 let update_config = crate::update::UpdateConfig::default();
1372
1373 let embeddings1 = create_embeddings(5, 0);
1375 let (index1, doc_ids1) =
1376 MmapIndex::update_or_create(&embeddings1, index_path, &index_config, &update_config)
1377 .expect("Failed to create index");
1378 assert_eq!(index1.metadata.num_documents, 5);
1379 assert_eq!(doc_ids1, vec![0, 1, 2, 3, 4]);
1380
1381 let embeddings2 = create_embeddings(3, 5);
1383 let (index2, doc_ids2) =
1384 MmapIndex::update_or_create(&embeddings2, index_path, &index_config, &update_config)
1385 .expect("Failed to update index");
1386 assert_eq!(index2.metadata.num_documents, 8);
1387 assert_eq!(doc_ids2, vec![5, 6, 7]);
1388 }
1389}