1use std::collections::BTreeMap;
4use std::fs::{self, File};
5use std::io::{BufReader, BufWriter, Write};
6use std::path::Path;
7
8use ndarray::{s, Array1, Array2, Axis};
9use serde::{Deserialize, Serialize};
10
11use crate::codec::ResidualCodec;
12use crate::error::{Error, Result};
13use crate::kmeans::{compute_kmeans, ComputeKmeansConfig};
14use crate::utils::{quantile, quantiles};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct IndexConfig {
19 pub nbits: usize,
21 pub batch_size: usize,
23 pub seed: Option<u64>,
25 #[serde(default = "default_kmeans_niters")]
27 pub kmeans_niters: usize,
28 #[serde(default = "default_max_points_per_centroid")]
30 pub max_points_per_centroid: usize,
31 #[serde(default)]
34 pub n_samples_kmeans: Option<usize>,
35 #[serde(default = "default_start_from_scratch")]
39 pub start_from_scratch: usize,
40}
41
42fn default_start_from_scratch() -> usize {
43 999
44}
45
46fn default_kmeans_niters() -> usize {
47 4
48}
49
50fn default_max_points_per_centroid() -> usize {
51 256
52}
53
54impl Default for IndexConfig {
55 fn default() -> Self {
56 Self {
57 nbits: 4,
58 batch_size: 50_000,
59 seed: Some(42),
60 kmeans_niters: 4,
61 max_points_per_centroid: 256,
62 n_samples_kmeans: None,
63 start_from_scratch: 999,
64 }
65 }
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct Metadata {
71 pub num_chunks: usize,
73 pub nbits: usize,
75 pub num_partitions: usize,
77 pub num_embeddings: usize,
79 pub avg_doclen: f64,
81 #[serde(default)]
83 pub num_documents: usize,
84 #[serde(default)]
87 pub next_plaid_compatible: bool,
88}
89
90impl Metadata {
91 pub fn load_from_path(index_path: &Path) -> Result<Self> {
93 let metadata_path = index_path.join("metadata.json");
94 let mut metadata: Metadata = serde_json::from_reader(BufReader::new(
95 File::open(&metadata_path)
96 .map_err(|e| Error::IndexLoad(format!("Failed to open metadata: {}", e)))?,
97 ))?;
98
99 if metadata.num_documents == 0 {
101 let mut total_docs = 0usize;
102 for chunk_idx in 0..metadata.num_chunks {
103 let doclens_path = index_path.join(format!("doclens.{}.json", chunk_idx));
104 if let Ok(file) = File::open(&doclens_path) {
105 if let Ok(chunk_doclens) =
106 serde_json::from_reader::<_, Vec<i64>>(BufReader::new(file))
107 {
108 total_docs += chunk_doclens.len();
109 }
110 }
111 }
112 metadata.num_documents = total_docs;
113 }
114
115 Ok(metadata)
116 }
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct ChunkMetadata {
122 pub num_documents: usize,
123 pub num_embeddings: usize,
124 #[serde(default)]
125 pub embedding_offset: usize,
126}
127
128pub fn create_index_files(
150 embeddings: &[Array2<f32>],
151 centroids: Array2<f32>,
152 index_path: &str,
153 config: &IndexConfig,
154) -> Result<Metadata> {
155 let index_dir = Path::new(index_path);
156 fs::create_dir_all(index_dir)?;
157
158 let num_documents = embeddings.len();
159 let embedding_dim = centroids.ncols();
160 let num_centroids = centroids.nrows();
161
162 if num_documents == 0 {
163 return Err(Error::IndexCreation("No documents provided".into()));
164 }
165
166 let total_embeddings: usize = embeddings.iter().map(|e| e.nrows()).sum();
168 let avg_doclen = total_embeddings as f64 / num_documents as f64;
169
170 let sample_count = ((16.0 * (120.0 * num_documents as f64).sqrt()) as usize)
172 .min(num_documents)
173 .max(1);
174
175 let mut rng = if let Some(seed) = config.seed {
176 use rand::SeedableRng;
177 rand_chacha::ChaCha8Rng::seed_from_u64(seed)
178 } else {
179 use rand::SeedableRng;
180 rand_chacha::ChaCha8Rng::from_entropy()
181 };
182
183 use rand::seq::SliceRandom;
184 let mut indices: Vec<usize> = (0..num_documents).collect();
185 indices.shuffle(&mut rng);
186 let sample_indices: Vec<usize> = indices.into_iter().take(sample_count).collect();
187
188 let heldout_size = (0.05 * total_embeddings as f64).min(50000.0) as usize;
190 let mut heldout_embeddings: Vec<f32> = Vec::with_capacity(heldout_size * embedding_dim);
191 let mut collected = 0;
192
193 for &idx in sample_indices.iter().rev() {
194 if collected >= heldout_size {
195 break;
196 }
197 let emb = &embeddings[idx];
198 let take = (heldout_size - collected).min(emb.nrows());
199 for row in emb.axis_iter(Axis(0)).take(take) {
200 heldout_embeddings.extend(row.iter());
201 }
202 collected += take;
203 }
204
205 let heldout = Array2::from_shape_vec((collected, embedding_dim), heldout_embeddings)
206 .map_err(|e| Error::IndexCreation(format!("Failed to create heldout array: {}", e)))?;
207
208 let avg_residual = Array1::zeros(embedding_dim);
210 let initial_codec =
211 ResidualCodec::new(config.nbits, centroids.clone(), avg_residual, None, None)?;
212
213 let heldout_codes = initial_codec.compress_into_codes(&heldout);
215
216 let mut residuals = heldout.clone();
218 for i in 0..heldout.nrows() {
219 let centroid = initial_codec.centroids.row(heldout_codes[i]);
220 for j in 0..embedding_dim {
221 residuals[[i, j]] -= centroid[j];
222 }
223 }
224
225 let distances: Array1<f32> = residuals
227 .axis_iter(Axis(0))
228 .map(|row| row.dot(&row).sqrt())
229 .collect();
230 #[allow(unused_variables)]
231 let cluster_threshold = quantile(&distances, 0.75);
232
233 let avg_res_per_dim: Array1<f32> = residuals
235 .axis_iter(Axis(1))
236 .map(|col| col.iter().map(|x| x.abs()).sum::<f32>() / col.len() as f32)
237 .collect();
238
239 let n_options = 1 << config.nbits;
241 let quantile_values: Vec<f64> = (1..n_options)
242 .map(|i| i as f64 / n_options as f64)
243 .collect();
244 let weight_quantile_values: Vec<f64> = (0..n_options)
245 .map(|i| (i as f64 + 0.5) / n_options as f64)
246 .collect();
247
248 let flat_residuals: Array1<f32> = residuals.iter().copied().collect();
250 let bucket_cutoffs = Array1::from_vec(quantiles(&flat_residuals, &quantile_values));
251 let bucket_weights = Array1::from_vec(quantiles(&flat_residuals, &weight_quantile_values));
252
253 let codec = ResidualCodec::new(
254 config.nbits,
255 centroids.clone(),
256 avg_res_per_dim.clone(),
257 Some(bucket_cutoffs.clone()),
258 Some(bucket_weights.clone()),
259 )?;
260
261 use ndarray_npy::WriteNpyExt;
263
264 let centroids_path = index_dir.join("centroids.npy");
265 codec
266 .centroids_view()
267 .to_owned()
268 .write_npy(File::create(¢roids_path)?)?;
269
270 let cutoffs_path = index_dir.join("bucket_cutoffs.npy");
271 bucket_cutoffs.write_npy(File::create(&cutoffs_path)?)?;
272
273 let weights_path = index_dir.join("bucket_weights.npy");
274 bucket_weights.write_npy(File::create(&weights_path)?)?;
275
276 let avg_res_path = index_dir.join("avg_residual.npy");
277 avg_res_per_dim.write_npy(File::create(&avg_res_path)?)?;
278
279 let threshold_path = index_dir.join("cluster_threshold.npy");
280 Array1::from_vec(vec![cluster_threshold]).write_npy(File::create(&threshold_path)?)?;
281
282 let n_chunks = (num_documents as f64 / config.batch_size as f64).ceil() as usize;
284
285 let plan_path = index_dir.join("plan.json");
287 let plan = serde_json::json!({
288 "nbits": config.nbits,
289 "num_chunks": n_chunks,
290 });
291 let mut plan_file = File::create(&plan_path)?;
292 writeln!(plan_file, "{}", serde_json::to_string_pretty(&plan)?)?;
293
294 let mut all_codes: Vec<usize> = Vec::with_capacity(total_embeddings);
295 let mut doc_lengths: Vec<i64> = Vec::with_capacity(num_documents);
296
297 for chunk_idx in 0..n_chunks {
298 let start = chunk_idx * config.batch_size;
299 let end = (start + config.batch_size).min(num_documents);
300 let chunk_docs = &embeddings[start..end];
301
302 let chunk_doclens: Vec<i64> = chunk_docs.iter().map(|d| d.nrows() as i64).collect();
304 let total_tokens: usize = chunk_doclens.iter().sum::<i64>() as usize;
305
306 let mut batch_embeddings = Array2::<f32>::zeros((total_tokens, embedding_dim));
308 let mut offset = 0;
309 for doc in chunk_docs {
310 let n = doc.nrows();
311 batch_embeddings
312 .slice_mut(s![offset..offset + n, ..])
313 .assign(doc);
314 offset += n;
315 }
316
317 let batch_codes = codec.compress_into_codes(&batch_embeddings);
319
320 let mut batch_residuals = batch_embeddings;
322 {
323 use rayon::prelude::*;
324 let centroids = &codec.centroids;
325 batch_residuals
326 .axis_iter_mut(Axis(0))
327 .into_par_iter()
328 .zip(batch_codes.as_slice().unwrap().par_iter())
329 .for_each(|(mut row, &code)| {
330 let centroid = centroids.row(code);
331 row.iter_mut()
332 .zip(centroid.iter())
333 .for_each(|(r, c)| *r -= c);
334 });
335 }
336
337 let batch_packed = codec.quantize_residuals(&batch_residuals)?;
339
340 for &len in &chunk_doclens {
342 doc_lengths.push(len);
343 }
344 all_codes.extend(batch_codes.iter().copied());
345
346 let chunk_meta = ChunkMetadata {
348 num_documents: end - start,
349 num_embeddings: batch_codes.len(),
350 embedding_offset: 0, };
352
353 let chunk_meta_path = index_dir.join(format!("{}.metadata.json", chunk_idx));
354 serde_json::to_writer_pretty(BufWriter::new(File::create(&chunk_meta_path)?), &chunk_meta)?;
355
356 let doclens_path = index_dir.join(format!("doclens.{}.json", chunk_idx));
358 serde_json::to_writer(BufWriter::new(File::create(&doclens_path)?), &chunk_doclens)?;
359
360 let chunk_codes_arr: Array1<i64> = batch_codes.iter().map(|&x| x as i64).collect();
362 let codes_path = index_dir.join(format!("{}.codes.npy", chunk_idx));
363 chunk_codes_arr.write_npy(File::create(&codes_path)?)?;
364
365 let residuals_path = index_dir.join(format!("{}.residuals.npy", chunk_idx));
367 batch_packed.write_npy(File::create(&residuals_path)?)?;
368 }
369
370 let mut current_offset = 0usize;
372 for chunk_idx in 0..n_chunks {
373 let chunk_meta_path = index_dir.join(format!("{}.metadata.json", chunk_idx));
374 let mut meta: serde_json::Value =
375 serde_json::from_reader(BufReader::new(File::open(&chunk_meta_path)?))?;
376
377 if let Some(obj) = meta.as_object_mut() {
378 obj.insert("embedding_offset".to_string(), current_offset.into());
379 let num_emb = obj["num_embeddings"].as_u64().unwrap_or(0) as usize;
380 current_offset += num_emb;
381 }
382
383 serde_json::to_writer_pretty(BufWriter::new(File::create(&chunk_meta_path)?), &meta)?;
384 }
385
386 let mut code_to_docs: BTreeMap<usize, Vec<i64>> = BTreeMap::new();
388 let mut emb_idx = 0;
389
390 for (doc_id, &len) in doc_lengths.iter().enumerate() {
391 for _ in 0..len {
392 let code = all_codes[emb_idx];
393 code_to_docs.entry(code).or_default().push(doc_id as i64);
394 emb_idx += 1;
395 }
396 }
397
398 let mut ivf_data: Vec<i64> = Vec::new();
400 let mut ivf_lengths: Vec<i32> = vec![0; num_centroids];
401
402 for (centroid_id, ivf_len) in ivf_lengths.iter_mut().enumerate() {
403 if let Some(docs) = code_to_docs.get(¢roid_id) {
404 let mut unique_docs: Vec<i64> = docs.clone();
405 unique_docs.sort_unstable();
406 unique_docs.dedup();
407 *ivf_len = unique_docs.len() as i32;
408 ivf_data.extend(unique_docs);
409 }
410 }
411
412 let ivf = Array1::from_vec(ivf_data);
413 let ivf_lengths = Array1::from_vec(ivf_lengths);
414
415 let ivf_path = index_dir.join("ivf.npy");
416 ivf.write_npy(File::create(&ivf_path)?)?;
417
418 let ivf_lengths_path = index_dir.join("ivf_lengths.npy");
419 ivf_lengths.write_npy(File::create(&ivf_lengths_path)?)?;
420
421 let metadata = Metadata {
423 num_chunks: n_chunks,
424 nbits: config.nbits,
425 num_partitions: num_centroids,
426 num_embeddings: total_embeddings,
427 avg_doclen,
428 num_documents,
429 next_plaid_compatible: true, };
431
432 let metadata_path = index_dir.join("metadata.json");
433 serde_json::to_writer_pretty(BufWriter::new(File::create(&metadata_path)?), &metadata)?;
434
435 Ok(metadata)
436}
437
438pub fn create_index_with_kmeans_files(
453 embeddings: &[Array2<f32>],
454 index_path: &str,
455 config: &IndexConfig,
456) -> Result<Metadata> {
457 if embeddings.is_empty() {
458 return Err(Error::IndexCreation("No documents provided".into()));
459 }
460
461 let kmeans_config = ComputeKmeansConfig {
463 kmeans_niters: config.kmeans_niters,
464 max_points_per_centroid: config.max_points_per_centroid,
465 seed: config.seed.unwrap_or(42),
466 n_samples_kmeans: config.n_samples_kmeans,
467 num_partitions: None, };
469
470 let centroids = compute_kmeans(embeddings, &kmeans_config)?;
472
473 let metadata = create_index_files(embeddings, centroids, index_path, config)?;
475
476 if embeddings.len() <= config.start_from_scratch {
478 let index_dir = std::path::Path::new(index_path);
479 crate::update::save_embeddings_npy(index_dir, embeddings)?;
480 }
481
482 Ok(metadata)
483}
484pub struct MmapIndex {
508 pub path: String,
510 pub metadata: Metadata,
512 pub codec: ResidualCodec,
514 pub ivf: Array1<i64>,
516 pub ivf_lengths: Array1<i32>,
518 pub ivf_offsets: Array1<i64>,
520 pub doc_lengths: Array1<i64>,
522 pub doc_offsets: Array1<usize>,
524 pub mmap_codes: crate::mmap::MmapNpyArray1I64,
526 pub mmap_residuals: crate::mmap::MmapNpyArray2U8,
528}
529
530impl MmapIndex {
531 pub fn load(index_path: &str) -> Result<Self> {
539 use ndarray_npy::ReadNpyExt;
540
541 let index_dir = Path::new(index_path);
542
543 let mut metadata = Metadata::load_from_path(index_dir)?;
545
546 if !metadata.next_plaid_compatible {
548 eprintln!("Checking index format compatibility...");
549 let converted = crate::mmap::convert_fastplaid_to_nextplaid(index_dir)?;
550 if converted {
551 eprintln!("Index converted to next-plaid compatible format.");
552 let merged_codes = index_dir.join("merged_codes.npy");
554 let merged_residuals = index_dir.join("merged_residuals.npy");
555 let codes_manifest = index_dir.join("merged_codes.manifest.json");
556 let residuals_manifest = index_dir.join("merged_residuals.manifest.json");
557 for path in [
558 &merged_codes,
559 &merged_residuals,
560 &codes_manifest,
561 &residuals_manifest,
562 ] {
563 if path.exists() {
564 let _ = fs::remove_file(path);
565 }
566 }
567 }
568
569 metadata.next_plaid_compatible = true;
571 let metadata_path = index_dir.join("metadata.json");
572 let file = File::create(&metadata_path)
573 .map_err(|e| Error::IndexLoad(format!("Failed to update metadata: {}", e)))?;
574 serde_json::to_writer_pretty(BufWriter::new(file), &metadata)?;
575 eprintln!("Metadata updated with next_plaid_compatible: true");
576 }
577
578 let codec = ResidualCodec::load_mmap_from_dir(index_dir)?;
581
582 let ivf_path = index_dir.join("ivf.npy");
584 let ivf: Array1<i64> = Array1::read_npy(
585 File::open(&ivf_path)
586 .map_err(|e| Error::IndexLoad(format!("Failed to open ivf.npy: {}", e)))?,
587 )
588 .map_err(|e| Error::IndexLoad(format!("Failed to read ivf.npy: {}", e)))?;
589
590 let ivf_lengths_path = index_dir.join("ivf_lengths.npy");
591 let ivf_lengths: Array1<i32> = Array1::read_npy(
592 File::open(&ivf_lengths_path)
593 .map_err(|e| Error::IndexLoad(format!("Failed to open ivf_lengths.npy: {}", e)))?,
594 )
595 .map_err(|e| Error::IndexLoad(format!("Failed to read ivf_lengths.npy: {}", e)))?;
596
597 let num_centroids = ivf_lengths.len();
599 let mut ivf_offsets = Array1::<i64>::zeros(num_centroids + 1);
600 for i in 0..num_centroids {
601 ivf_offsets[i + 1] = ivf_offsets[i] + ivf_lengths[i] as i64;
602 }
603
604 let mut doc_lengths_vec: Vec<i64> = Vec::with_capacity(metadata.num_documents);
606 for chunk_idx in 0..metadata.num_chunks {
607 let doclens_path = index_dir.join(format!("doclens.{}.json", chunk_idx));
608 let chunk_doclens: Vec<i64> =
609 serde_json::from_reader(BufReader::new(File::open(&doclens_path)?))?;
610 doc_lengths_vec.extend(chunk_doclens);
611 }
612 let doc_lengths = Array1::from_vec(doc_lengths_vec);
613
614 let mut doc_offsets = Array1::<usize>::zeros(doc_lengths.len() + 1);
616 for i in 0..doc_lengths.len() {
617 doc_offsets[i + 1] = doc_offsets[i] + doc_lengths[i] as usize;
618 }
619
620 let max_len = doc_lengths.iter().cloned().max().unwrap_or(0) as usize;
622 let last_len = *doc_lengths.last().unwrap_or(&0) as usize;
623 let padding_needed = max_len.saturating_sub(last_len);
624
625 let merged_codes_path =
627 crate::mmap::merge_codes_chunks(index_dir, metadata.num_chunks, padding_needed)?;
628 let merged_residuals_path =
629 crate::mmap::merge_residuals_chunks(index_dir, metadata.num_chunks, padding_needed)?;
630
631 let mmap_codes = crate::mmap::MmapNpyArray1I64::from_npy_file(&merged_codes_path)?;
633 let mmap_residuals = crate::mmap::MmapNpyArray2U8::from_npy_file(&merged_residuals_path)?;
634
635 Ok(Self {
636 path: index_path.to_string(),
637 metadata,
638 codec,
639 ivf,
640 ivf_lengths,
641 ivf_offsets,
642 doc_lengths,
643 doc_offsets,
644 mmap_codes,
645 mmap_residuals,
646 })
647 }
648
649 pub fn get_candidates(&self, centroid_indices: &[usize]) -> Vec<i64> {
651 let mut candidates: Vec<i64> = Vec::new();
652
653 for &idx in centroid_indices {
654 if idx < self.ivf_lengths.len() {
655 let start = self.ivf_offsets[idx] as usize;
656 let len = self.ivf_lengths[idx] as usize;
657 candidates.extend(self.ivf.slice(s![start..start + len]).iter());
658 }
659 }
660
661 candidates.sort_unstable();
662 candidates.dedup();
663 candidates
664 }
665
666 pub fn get_document_embeddings(&self, doc_id: usize) -> Result<Array2<f32>> {
668 if doc_id >= self.doc_lengths.len() {
669 return Err(Error::Search(format!("Invalid document ID: {}", doc_id)));
670 }
671
672 let start = self.doc_offsets[doc_id];
673 let end = self.doc_offsets[doc_id + 1];
674
675 let codes_slice = self.mmap_codes.slice(start, end);
677 let residuals_view = self.mmap_residuals.slice_rows(start, end);
678
679 let codes: Array1<usize> = Array1::from_iter(codes_slice.iter().map(|&c| c as usize));
681
682 let residuals = residuals_view.to_owned();
684
685 self.codec.decompress(&residuals, &codes.view())
687 }
688
689 pub fn get_document_codes(&self, doc_ids: &[usize]) -> Vec<Vec<i64>> {
691 doc_ids
692 .iter()
693 .map(|&doc_id| {
694 if doc_id >= self.doc_lengths.len() {
695 return vec![];
696 }
697 let start = self.doc_offsets[doc_id];
698 let end = self.doc_offsets[doc_id + 1];
699 self.mmap_codes.slice(start, end).to_vec()
700 })
701 .collect()
702 }
703
704 pub fn decompress_documents(&self, doc_ids: &[usize]) -> Result<(Array2<f32>, Vec<usize>)> {
706 let mut total_tokens = 0usize;
708 let mut lengths = Vec::with_capacity(doc_ids.len());
709 for &doc_id in doc_ids {
710 if doc_id >= self.doc_lengths.len() {
711 lengths.push(0);
712 } else {
713 let len = self.doc_offsets[doc_id + 1] - self.doc_offsets[doc_id];
714 lengths.push(len);
715 total_tokens += len;
716 }
717 }
718
719 if total_tokens == 0 {
720 return Ok((Array2::zeros((0, self.codec.embedding_dim())), lengths));
721 }
722
723 let packed_dim = self.mmap_residuals.ncols();
725 let mut all_codes = Vec::with_capacity(total_tokens);
726 let mut all_residuals = Array2::<u8>::zeros((total_tokens, packed_dim));
727 let mut offset = 0;
728
729 for &doc_id in doc_ids {
730 if doc_id >= self.doc_lengths.len() {
731 continue;
732 }
733 let start = self.doc_offsets[doc_id];
734 let end = self.doc_offsets[doc_id + 1];
735 let len = end - start;
736
737 let codes_slice = self.mmap_codes.slice(start, end);
739 all_codes.extend(codes_slice.iter().map(|&c| c as usize));
740
741 let residuals_view = self.mmap_residuals.slice_rows(start, end);
743 all_residuals
744 .slice_mut(s![offset..offset + len, ..])
745 .assign(&residuals_view);
746 offset += len;
747 }
748
749 let codes_arr = Array1::from_vec(all_codes);
750 let embeddings = self.codec.decompress(&all_residuals, &codes_arr.view())?;
751
752 Ok((embeddings, lengths))
753 }
754
755 pub fn search(
767 &self,
768 query: &Array2<f32>,
769 params: &crate::search::SearchParameters,
770 subset: Option<&[i64]>,
771 ) -> Result<crate::search::SearchResult> {
772 crate::search::search_one_mmap(self, query, params, subset)
773 }
774
775 pub fn search_batch(
788 &self,
789 queries: &[Array2<f32>],
790 params: &crate::search::SearchParameters,
791 parallel: bool,
792 subset: Option<&[i64]>,
793 ) -> Result<Vec<crate::search::SearchResult>> {
794 crate::search::search_many_mmap(self, queries, params, parallel, subset)
795 }
796
797 pub fn num_documents(&self) -> usize {
799 self.doc_lengths.len()
800 }
801
802 pub fn num_embeddings(&self) -> usize {
804 self.metadata.num_embeddings
805 }
806
807 pub fn num_partitions(&self) -> usize {
809 self.metadata.num_partitions
810 }
811
812 pub fn avg_doclen(&self) -> f64 {
814 self.metadata.avg_doclen
815 }
816
817 pub fn embedding_dim(&self) -> usize {
819 self.codec.embedding_dim()
820 }
821
822 pub fn reconstruct(&self, doc_ids: &[i64]) -> Result<Vec<Array2<f32>>> {
848 crate::embeddings::reconstruct_embeddings(self, doc_ids)
849 }
850
851 pub fn reconstruct_single(&self, doc_id: i64) -> Result<Array2<f32>> {
863 crate::embeddings::reconstruct_single(self, doc_id)
864 }
865
866 pub fn create_with_kmeans(
886 embeddings: &[Array2<f32>],
887 index_path: &str,
888 config: &IndexConfig,
889 ) -> Result<Self> {
890 create_index_with_kmeans_files(embeddings, index_path, config)?;
892
893 Self::load(index_path)
895 }
896
897 pub fn update(
925 &mut self,
926 embeddings: &[Array2<f32>],
927 config: &crate::update::UpdateConfig,
928 ) -> Result<Vec<i64>> {
929 use crate::codec::ResidualCodec;
930 use crate::update::{
931 clear_buffer, clear_embeddings_npy, embeddings_npy_exists, load_buffer,
932 load_buffer_info, load_cluster_threshold, load_embeddings_npy, save_buffer,
933 update_centroids, update_index,
934 };
935
936 let path_str = self.path.clone();
937 let index_path = std::path::Path::new(&path_str);
938 let num_new_docs = embeddings.len();
939
940 if self.metadata.num_documents <= config.start_from_scratch {
944 let existing_embeddings = load_embeddings_npy(index_path)?;
946 let start_doc_id = existing_embeddings.len() as i64;
948
949 let combined_embeddings: Vec<Array2<f32>> = existing_embeddings
951 .into_iter()
952 .chain(embeddings.iter().cloned())
953 .collect();
954
955 let index_config = IndexConfig {
957 nbits: self.metadata.nbits,
958 batch_size: config.batch_size,
959 seed: Some(config.seed),
960 kmeans_niters: config.kmeans_niters,
961 max_points_per_centroid: config.max_points_per_centroid,
962 n_samples_kmeans: config.n_samples_kmeans,
963 start_from_scratch: config.start_from_scratch,
964 };
965
966 *self = Self::create_with_kmeans(&combined_embeddings, &path_str, &index_config)?;
968
969 if combined_embeddings.len() > config.start_from_scratch
971 && embeddings_npy_exists(index_path)
972 {
973 clear_embeddings_npy(index_path)?;
974 }
975
976 return Ok((start_doc_id..start_doc_id + num_new_docs as i64).collect());
978 }
979
980 let buffer = load_buffer(index_path)?;
982 let buffer_len = buffer.len();
983 let total_new = embeddings.len() + buffer_len;
984
985 let start_doc_id: i64;
987
988 let mut codec = ResidualCodec::load_from_dir(index_path)?;
990
991 if total_new >= config.buffer_size {
993 let num_buffered = load_buffer_info(index_path)?;
997
998 if num_buffered > 0 && self.metadata.num_documents >= num_buffered {
1000 let start_del_idx = self.metadata.num_documents - num_buffered;
1001 let docs_to_delete: Vec<i64> = (start_del_idx..self.metadata.num_documents)
1002 .map(|i| i as i64)
1003 .collect();
1004 crate::delete::delete_from_index_keep_buffer(&docs_to_delete, &path_str)?;
1005 self.metadata = Metadata::load_from_path(index_path)?;
1007 }
1008
1009 start_doc_id = (self.metadata.num_documents + buffer_len) as i64;
1011
1012 let combined: Vec<Array2<f32>> = buffer
1014 .into_iter()
1015 .chain(embeddings.iter().cloned())
1016 .collect();
1017
1018 if let Ok(cluster_threshold) = load_cluster_threshold(index_path) {
1020 let new_centroids =
1021 update_centroids(index_path, &combined, cluster_threshold, config)?;
1022 if new_centroids > 0 {
1023 codec = ResidualCodec::load_from_dir(index_path)?;
1025 }
1026 }
1027
1028 clear_buffer(index_path)?;
1030
1031 update_index(&combined, &path_str, &codec, Some(config.batch_size), true)?;
1033 } else {
1034 start_doc_id = self.metadata.num_documents as i64;
1037
1038 let combined_buffer: Vec<Array2<f32>> = buffer
1040 .into_iter()
1041 .chain(embeddings.iter().cloned())
1042 .collect();
1043 save_buffer(index_path, &combined_buffer)?;
1044
1045 update_index(
1047 embeddings,
1048 &path_str,
1049 &codec,
1050 Some(config.batch_size),
1051 false,
1052 )?;
1053 }
1054
1055 *self = Self::load(&path_str)?;
1057
1058 Ok((start_doc_id..start_doc_id + num_new_docs as i64).collect())
1060 }
1061
1062 pub fn update_with_metadata(
1074 &mut self,
1075 embeddings: &[Array2<f32>],
1076 config: &crate::update::UpdateConfig,
1077 metadata: Option<&[serde_json::Value]>,
1078 ) -> Result<Vec<i64>> {
1079 if let Some(meta) = metadata {
1081 if meta.len() != embeddings.len() {
1082 return Err(Error::Config(format!(
1083 "Metadata length ({}) must match embeddings length ({})",
1084 meta.len(),
1085 embeddings.len()
1086 )));
1087 }
1088 }
1089
1090 let doc_ids = self.update(embeddings, config)?;
1092
1093 if let Some(meta) = metadata {
1095 crate::filtering::update(&self.path, meta, &doc_ids)?;
1096 }
1097
1098 Ok(doc_ids)
1099 }
1100
1101 pub fn update_or_create(
1114 embeddings: &[Array2<f32>],
1115 index_path: &str,
1116 index_config: &IndexConfig,
1117 update_config: &crate::update::UpdateConfig,
1118 ) -> Result<(Self, Vec<i64>)> {
1119 let index_dir = std::path::Path::new(index_path);
1120 let metadata_path = index_dir.join("metadata.json");
1121
1122 if metadata_path.exists() {
1123 let mut index = Self::load(index_path)?;
1125 let doc_ids = index.update(embeddings, update_config)?;
1126 Ok((index, doc_ids))
1127 } else {
1128 let num_docs = embeddings.len();
1130 let index = Self::create_with_kmeans(embeddings, index_path, index_config)?;
1131 let doc_ids: Vec<i64> = (0..num_docs as i64).collect();
1132 Ok((index, doc_ids))
1133 }
1134 }
1135
1136 pub fn delete(&mut self, doc_ids: &[i64]) -> Result<usize> {
1146 self.delete_with_options(doc_ids, true)
1147 }
1148
1149 pub fn delete_with_options(&mut self, doc_ids: &[i64], delete_metadata: bool) -> Result<usize> {
1160 let path = self.path.clone();
1161
1162 let deleted = crate::delete::delete_from_index(doc_ids, &path)?;
1164
1165 if delete_metadata && deleted > 0 {
1167 let index_path = std::path::Path::new(&path);
1168 let db_path = index_path.join("metadata.db");
1169 if db_path.exists() {
1170 crate::filtering::delete(&path, doc_ids)?;
1171 }
1172 }
1173
1174 *self = Self::load(&path)?;
1176
1177 Ok(deleted)
1178 }
1179}
1180
1181#[cfg(test)]
1182mod tests {
1183 use super::*;
1184
1185 #[test]
1186 fn test_index_config_default() {
1187 let config = IndexConfig::default();
1188 assert_eq!(config.nbits, 4);
1189 assert_eq!(config.batch_size, 50_000);
1190 assert_eq!(config.seed, Some(42));
1191 }
1192
1193 #[test]
1194 fn test_update_or_create_new_index() {
1195 use ndarray::Array2;
1196 use tempfile::tempdir;
1197
1198 let temp_dir = tempdir().unwrap();
1199 let index_path = temp_dir.path().to_str().unwrap();
1200
1201 let mut embeddings: Vec<Array2<f32>> = Vec::new();
1203 for i in 0..5 {
1204 let mut doc = Array2::<f32>::zeros((5, 32));
1205 for j in 0..5 {
1206 for k in 0..32 {
1207 doc[[j, k]] = (i as f32 * 0.1) + (j as f32 * 0.01) + (k as f32 * 0.001);
1208 }
1209 }
1210 for mut row in doc.rows_mut() {
1212 let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
1213 if norm > 0.0 {
1214 row.iter_mut().for_each(|x| *x /= norm);
1215 }
1216 }
1217 embeddings.push(doc);
1218 }
1219
1220 let index_config = IndexConfig {
1221 nbits: 2,
1222 batch_size: 50,
1223 seed: Some(42),
1224 kmeans_niters: 2,
1225 ..Default::default()
1226 };
1227 let update_config = crate::update::UpdateConfig::default();
1228
1229 let (index, doc_ids) =
1231 MmapIndex::update_or_create(&embeddings, index_path, &index_config, &update_config)
1232 .expect("Failed to create index");
1233
1234 assert_eq!(index.metadata.num_documents, 5);
1235 assert_eq!(doc_ids, vec![0, 1, 2, 3, 4]);
1236
1237 assert!(temp_dir.path().join("metadata.json").exists());
1239 assert!(temp_dir.path().join("centroids.npy").exists());
1240 }
1241
1242 #[test]
1243 fn test_update_or_create_existing_index() {
1244 use ndarray::Array2;
1245 use tempfile::tempdir;
1246
1247 let temp_dir = tempdir().unwrap();
1248 let index_path = temp_dir.path().to_str().unwrap();
1249
1250 let create_embeddings = |count: usize, offset: usize| -> Vec<Array2<f32>> {
1252 let mut embeddings = Vec::new();
1253 for i in 0..count {
1254 let mut doc = Array2::<f32>::zeros((5, 32));
1255 for j in 0..5 {
1256 for k in 0..32 {
1257 doc[[j, k]] =
1258 ((i + offset) as f32 * 0.1) + (j as f32 * 0.01) + (k as f32 * 0.001);
1259 }
1260 }
1261 for mut row in doc.rows_mut() {
1262 let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
1263 if norm > 0.0 {
1264 row.iter_mut().for_each(|x| *x /= norm);
1265 }
1266 }
1267 embeddings.push(doc);
1268 }
1269 embeddings
1270 };
1271
1272 let index_config = IndexConfig {
1273 nbits: 2,
1274 batch_size: 50,
1275 seed: Some(42),
1276 kmeans_niters: 2,
1277 ..Default::default()
1278 };
1279 let update_config = crate::update::UpdateConfig::default();
1280
1281 let embeddings1 = create_embeddings(5, 0);
1283 let (index1, doc_ids1) =
1284 MmapIndex::update_or_create(&embeddings1, index_path, &index_config, &update_config)
1285 .expect("Failed to create index");
1286 assert_eq!(index1.metadata.num_documents, 5);
1287 assert_eq!(doc_ids1, vec![0, 1, 2, 3, 4]);
1288
1289 let embeddings2 = create_embeddings(3, 5);
1291 let (index2, doc_ids2) =
1292 MmapIndex::update_or_create(&embeddings2, index_path, &index_config, &update_config)
1293 .expect("Failed to update index");
1294 assert_eq!(index2.metadata.num_documents, 8);
1295 assert_eq!(doc_ids2, vec![5, 6, 7]);
1296 }
1297}