1use std::collections::BTreeMap;
4use std::fs::{self, File};
5use std::io::{BufReader, BufWriter, Write};
6use std::path::Path;
7
8use ndarray::{s, Array1, Array2, Axis};
9use serde::{Deserialize, Serialize};
10
11use crate::codec::ResidualCodec;
12use crate::error::{Error, Result};
13#[cfg(feature = "npy")]
14use crate::kmeans::{compute_kmeans, ComputeKmeansConfig};
15use crate::strided_tensor::{IvfStridedTensor, StridedTensor};
16use crate::utils::{quantile, quantiles};
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct IndexConfig {
21 pub nbits: usize,
23 pub batch_size: usize,
25 pub seed: Option<u64>,
27 #[serde(default = "default_kmeans_niters")]
29 pub kmeans_niters: usize,
30 #[serde(default = "default_max_points_per_centroid")]
32 pub max_points_per_centroid: usize,
33 #[serde(default)]
36 pub n_samples_kmeans: Option<usize>,
37 #[serde(default = "default_start_from_scratch")]
41 pub start_from_scratch: usize,
42}
43
44fn default_start_from_scratch() -> usize {
45 999
46}
47
48fn default_kmeans_niters() -> usize {
49 4
50}
51
52fn default_max_points_per_centroid() -> usize {
53 256
54}
55
56impl Default for IndexConfig {
57 fn default() -> Self {
58 Self {
59 nbits: 4,
60 batch_size: 50_000,
61 seed: Some(42),
62 kmeans_niters: 4,
63 max_points_per_centroid: 256,
64 n_samples_kmeans: None,
65 start_from_scratch: 999,
66 }
67 }
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct Metadata {
73 pub num_chunks: usize,
75 pub nbits: usize,
77 pub num_partitions: usize,
79 pub num_embeddings: usize,
81 pub avg_doclen: f64,
83 pub num_documents: usize,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct ChunkMetadata {
90 pub num_documents: usize,
91 pub num_embeddings: usize,
92 #[serde(default)]
93 pub embedding_offset: usize,
94}
95
96pub struct Index {
98 pub path: String,
100 pub metadata: Metadata,
102 pub codec: ResidualCodec,
104 pub ivf: Array1<i64>,
106 pub ivf_lengths: Array1<i32>,
108 pub ivf_offsets: Array1<i64>,
110 pub doc_codes: Vec<Array1<usize>>,
112 pub doc_lengths: Array1<i64>,
114 pub doc_residuals: Vec<Array2<u8>>,
116}
117
118impl Index {
119 pub fn create(
132 embeddings: &[Array2<f32>],
133 centroids: Array2<f32>,
134 index_path: &str,
135 config: &IndexConfig,
136 ) -> Result<Self> {
137 let index_dir = Path::new(index_path);
138 fs::create_dir_all(index_dir)?;
139
140 let num_documents = embeddings.len();
141 let embedding_dim = centroids.ncols();
142 let num_centroids = centroids.nrows();
143
144 if num_documents == 0 {
145 return Err(Error::IndexCreation("No documents provided".into()));
146 }
147
148 let total_embeddings: usize = embeddings.iter().map(|e| e.nrows()).sum();
150 let avg_doclen = total_embeddings as f64 / num_documents as f64;
151
152 let sample_count = ((16.0 * (120.0 * num_documents as f64).sqrt()) as usize)
154 .min(num_documents)
155 .max(1);
156
157 let mut rng = if let Some(seed) = config.seed {
158 use rand::SeedableRng;
159 rand_chacha::ChaCha8Rng::seed_from_u64(seed)
160 } else {
161 use rand::SeedableRng;
162 rand_chacha::ChaCha8Rng::from_entropy()
163 };
164
165 use rand::seq::SliceRandom;
166 let mut indices: Vec<usize> = (0..num_documents).collect();
167 indices.shuffle(&mut rng);
168 let sample_indices: Vec<usize> = indices.into_iter().take(sample_count).collect();
169
170 let heldout_size = (0.05 * total_embeddings as f64).min(50000.0) as usize;
172 let mut heldout_embeddings: Vec<f32> = Vec::with_capacity(heldout_size * embedding_dim);
173 let mut collected = 0;
174
175 for &idx in sample_indices.iter().rev() {
176 if collected >= heldout_size {
177 break;
178 }
179 let emb = &embeddings[idx];
180 let take = (heldout_size - collected).min(emb.nrows());
181 for row in emb.axis_iter(Axis(0)).take(take) {
182 heldout_embeddings.extend(row.iter());
183 }
184 collected += take;
185 }
186
187 let heldout = Array2::from_shape_vec((collected, embedding_dim), heldout_embeddings)
188 .map_err(|e| Error::IndexCreation(format!("Failed to create heldout array: {}", e)))?;
189
190 let avg_residual = Array1::zeros(embedding_dim);
192 let initial_codec =
193 ResidualCodec::new(config.nbits, centroids.clone(), avg_residual, None, None)?;
194
195 let heldout_codes = initial_codec.compress_into_codes(&heldout);
197
198 let mut residuals = heldout.clone();
200 for i in 0..heldout.nrows() {
201 let centroid = initial_codec.centroids.row(heldout_codes[i]);
202 for j in 0..embedding_dim {
203 residuals[[i, j]] -= centroid[j];
204 }
205 }
206
207 let distances: Array1<f32> = residuals
209 .axis_iter(Axis(0))
210 .map(|row| row.dot(&row).sqrt())
211 .collect();
212 #[allow(unused_variables)]
213 let cluster_threshold = quantile(&distances, 0.75);
214
215 let avg_res_per_dim: Array1<f32> = residuals
217 .axis_iter(Axis(1))
218 .map(|col| col.iter().map(|x| x.abs()).sum::<f32>() / col.len() as f32)
219 .collect();
220
221 let n_options = 1 << config.nbits;
223 let quantile_values: Vec<f64> = (1..n_options)
224 .map(|i| i as f64 / n_options as f64)
225 .collect();
226 let weight_quantile_values: Vec<f64> = (0..n_options)
227 .map(|i| (i as f64 + 0.5) / n_options as f64)
228 .collect();
229
230 let flat_residuals: Array1<f32> = residuals.iter().copied().collect();
232 let bucket_cutoffs = Array1::from_vec(quantiles(&flat_residuals, &quantile_values));
233 let bucket_weights = Array1::from_vec(quantiles(&flat_residuals, &weight_quantile_values));
234
235 let codec = ResidualCodec::new(
236 config.nbits,
237 centroids.clone(),
238 avg_res_per_dim.clone(),
239 Some(bucket_cutoffs.clone()),
240 Some(bucket_weights.clone()),
241 )?;
242
243 #[cfg(feature = "npy")]
245 {
246 use ndarray_npy::WriteNpyExt;
247
248 let centroids_path = index_dir.join("centroids.npy");
249 codec
250 .centroids_view()
251 .to_owned()
252 .write_npy(File::create(¢roids_path)?)?;
253
254 let cutoffs_path = index_dir.join("bucket_cutoffs.npy");
255 bucket_cutoffs.write_npy(File::create(&cutoffs_path)?)?;
256
257 let weights_path = index_dir.join("bucket_weights.npy");
258 bucket_weights.write_npy(File::create(&weights_path)?)?;
259
260 let avg_res_path = index_dir.join("avg_residual.npy");
261 avg_res_per_dim.write_npy(File::create(&avg_res_path)?)?;
262
263 let threshold_path = index_dir.join("cluster_threshold.npy");
264 Array1::from_vec(vec![cluster_threshold]).write_npy(File::create(&threshold_path)?)?;
265 }
266
267 let n_chunks = (num_documents as f64 / config.batch_size as f64).ceil() as usize;
269
270 let plan_path = index_dir.join("plan.json");
272 let plan = serde_json::json!({
273 "nbits": config.nbits,
274 "num_chunks": n_chunks,
275 });
276 let mut plan_file = File::create(&plan_path)?;
277 writeln!(plan_file, "{}", serde_json::to_string_pretty(&plan)?)?;
278
279 let mut all_codes: Vec<usize> = Vec::with_capacity(total_embeddings);
280 let mut doc_codes: Vec<Array1<usize>> = Vec::with_capacity(num_documents);
281 let mut doc_residuals: Vec<Array2<u8>> = Vec::with_capacity(num_documents);
282 let mut doc_lengths: Vec<i64> = Vec::with_capacity(num_documents);
283
284 let progress = indicatif::ProgressBar::new(n_chunks as u64);
285 progress.set_message("Creating index...");
286
287 for chunk_idx in 0..n_chunks {
288 let start = chunk_idx * config.batch_size;
289 let end = (start + config.batch_size).min(num_documents);
290 let chunk_docs = &embeddings[start..end];
291
292 let chunk_doclens: Vec<i64> = chunk_docs.iter().map(|d| d.nrows() as i64).collect();
294 let total_tokens: usize = chunk_doclens.iter().sum::<i64>() as usize;
295
296 let mut batch_embeddings = Array2::<f32>::zeros((total_tokens, embedding_dim));
298 let mut offset = 0;
299 for doc in chunk_docs {
300 let n = doc.nrows();
301 batch_embeddings
302 .slice_mut(s![offset..offset + n, ..])
303 .assign(doc);
304 offset += n;
305 }
306
307 let batch_codes = codec.compress_into_codes(&batch_embeddings);
309
310 let mut batch_residuals = batch_embeddings;
312 {
313 use rayon::prelude::*;
314 let centroids = &codec.centroids;
315 batch_residuals
316 .axis_iter_mut(Axis(0))
317 .into_par_iter()
318 .zip(batch_codes.as_slice().unwrap().par_iter())
319 .for_each(|(mut row, &code)| {
320 let centroid = centroids.row(code);
321 row.iter_mut()
322 .zip(centroid.iter())
323 .for_each(|(r, c)| *r -= c);
324 });
325 }
326
327 let batch_packed = codec.quantize_residuals(&batch_residuals)?;
329
330 let mut code_offset = 0;
332 for &len in &chunk_doclens {
333 let len_usize = len as usize;
334 doc_lengths.push(len);
335
336 let codes: Array1<usize> = batch_codes
337 .slice(s![code_offset..code_offset + len_usize])
338 .to_owned();
339 all_codes.extend(codes.iter().copied());
340 doc_codes.push(codes);
341
342 let packed = batch_packed
343 .slice(s![code_offset..code_offset + len_usize, ..])
344 .to_owned();
345 doc_residuals.push(packed);
346
347 code_offset += len_usize;
348 }
349
350 let chunk_codes_list: Vec<usize> = batch_codes.iter().copied().collect();
351
352 let chunk_meta = ChunkMetadata {
354 num_documents: end - start,
355 num_embeddings: chunk_codes_list.len(),
356 embedding_offset: 0, };
358
359 let chunk_meta_path = index_dir.join(format!("{}.metadata.json", chunk_idx));
360 serde_json::to_writer_pretty(
361 BufWriter::new(File::create(&chunk_meta_path)?),
362 &chunk_meta,
363 )?;
364
365 let doclens_path = index_dir.join(format!("doclens.{}.json", chunk_idx));
367 serde_json::to_writer(BufWriter::new(File::create(&doclens_path)?), &chunk_doclens)?;
368
369 #[cfg(feature = "npy")]
370 {
371 use ndarray_npy::WriteNpyExt;
372
373 let chunk_codes_arr: Array1<i64> = batch_codes.iter().map(|&x| x as i64).collect();
375 let codes_path = index_dir.join(format!("{}.codes.npy", chunk_idx));
376 chunk_codes_arr.write_npy(File::create(&codes_path)?)?;
377
378 let residuals_path = index_dir.join(format!("{}.residuals.npy", chunk_idx));
380 batch_packed.write_npy(File::create(&residuals_path)?)?;
381 }
382
383 progress.inc(1);
384 }
385 progress.finish();
386
387 let mut current_offset = 0usize;
389 for chunk_idx in 0..n_chunks {
390 let chunk_meta_path = index_dir.join(format!("{}.metadata.json", chunk_idx));
391 let mut meta: serde_json::Value =
392 serde_json::from_reader(BufReader::new(File::open(&chunk_meta_path)?))?;
393
394 if let Some(obj) = meta.as_object_mut() {
395 obj.insert("embedding_offset".to_string(), current_offset.into());
396 let num_emb = obj["num_embeddings"].as_u64().unwrap_or(0) as usize;
397 current_offset += num_emb;
398 }
399
400 serde_json::to_writer_pretty(BufWriter::new(File::create(&chunk_meta_path)?), &meta)?;
401 }
402
403 let mut code_to_docs: BTreeMap<usize, Vec<i64>> = BTreeMap::new();
405 let mut emb_idx = 0;
406
407 for (doc_id, &len) in doc_lengths.iter().enumerate() {
408 for _ in 0..len {
409 let code = all_codes[emb_idx];
410 code_to_docs.entry(code).or_default().push(doc_id as i64);
411 emb_idx += 1;
412 }
413 }
414
415 let mut ivf_data: Vec<i64> = Vec::new();
417 let mut ivf_lengths: Vec<i32> = vec![0; num_centroids];
418
419 for (centroid_id, ivf_len) in ivf_lengths.iter_mut().enumerate() {
420 if let Some(docs) = code_to_docs.get(¢roid_id) {
421 let mut unique_docs: Vec<i64> = docs.clone();
422 unique_docs.sort_unstable();
423 unique_docs.dedup();
424 *ivf_len = unique_docs.len() as i32;
425 ivf_data.extend(unique_docs);
426 }
427 }
428
429 let ivf = Array1::from_vec(ivf_data);
430 let ivf_lengths = Array1::from_vec(ivf_lengths);
431
432 let mut ivf_offsets = Array1::<i64>::zeros(num_centroids + 1);
434 for i in 0..num_centroids {
435 ivf_offsets[i + 1] = ivf_offsets[i] + ivf_lengths[i] as i64;
436 }
437
438 #[cfg(feature = "npy")]
439 {
440 use ndarray_npy::WriteNpyExt;
441
442 let ivf_path = index_dir.join("ivf.npy");
443 ivf.write_npy(File::create(&ivf_path)?)?;
444
445 let ivf_lengths_path = index_dir.join("ivf_lengths.npy");
446 ivf_lengths.write_npy(File::create(&ivf_lengths_path)?)?;
447 }
448
449 let metadata = Metadata {
451 num_chunks: n_chunks,
452 nbits: config.nbits,
453 num_partitions: num_centroids,
454 num_embeddings: total_embeddings,
455 avg_doclen,
456 num_documents,
457 };
458
459 let metadata_path = index_dir.join("metadata.json");
460 serde_json::to_writer_pretty(BufWriter::new(File::create(&metadata_path)?), &metadata)?;
461
462 let doc_lengths_arr = Array1::from_vec(doc_lengths);
463
464 Ok(Self {
465 path: index_path.to_string(),
466 metadata,
467 codec,
468 ivf,
469 ivf_lengths,
470 ivf_offsets,
471 doc_codes,
472 doc_lengths: doc_lengths_arr,
473 doc_residuals,
474 })
475 }
476
477 #[cfg(feature = "npy")]
494 pub fn create_with_kmeans(
495 embeddings: &[Array2<f32>],
496 index_path: &str,
497 config: &IndexConfig,
498 ) -> Result<Self> {
499 if embeddings.is_empty() {
500 return Err(Error::IndexCreation("No documents provided".into()));
501 }
502
503 let kmeans_config = ComputeKmeansConfig {
505 kmeans_niters: config.kmeans_niters,
506 max_points_per_centroid: config.max_points_per_centroid,
507 seed: config.seed.unwrap_or(42),
508 n_samples_kmeans: config.n_samples_kmeans,
509 num_partitions: None, };
511
512 let centroids = compute_kmeans(embeddings, &kmeans_config)?;
514
515 let index = Self::create(embeddings, centroids, index_path, config)?;
517
518 if embeddings.len() <= config.start_from_scratch {
521 let index_dir = std::path::Path::new(index_path);
522 crate::update::save_embeddings_npy(index_dir, embeddings)?;
523 }
524
525 Ok(index)
526 }
527
528 #[cfg(feature = "npy")]
530 pub fn load(index_path: &str) -> Result<Self> {
531 use ndarray_npy::ReadNpyExt;
532
533 let index_dir = Path::new(index_path);
534
535 let metadata_path = index_dir.join("metadata.json");
537 let metadata: Metadata = serde_json::from_reader(BufReader::new(
538 File::open(&metadata_path)
539 .map_err(|e| Error::IndexLoad(format!("Failed to open metadata: {}", e)))?,
540 ))?;
541
542 let codec = ResidualCodec::load_from_dir(index_dir)?;
544
545 let ivf_path = index_dir.join("ivf.npy");
547 let ivf: Array1<i64> = Array1::read_npy(
548 File::open(&ivf_path)
549 .map_err(|e| Error::IndexLoad(format!("Failed to open ivf.npy: {}", e)))?,
550 )
551 .map_err(|e| Error::IndexLoad(format!("Failed to read ivf.npy: {}", e)))?;
552
553 let ivf_lengths_path = index_dir.join("ivf_lengths.npy");
554 let ivf_lengths: Array1<i32> = Array1::read_npy(
555 File::open(&ivf_lengths_path)
556 .map_err(|e| Error::IndexLoad(format!("Failed to open ivf_lengths.npy: {}", e)))?,
557 )
558 .map_err(|e| Error::IndexLoad(format!("Failed to read ivf_lengths.npy: {}", e)))?;
559
560 let num_centroids = ivf_lengths.len();
562 let mut ivf_offsets = Array1::<i64>::zeros(num_centroids + 1);
563 for i in 0..num_centroids {
564 ivf_offsets[i + 1] = ivf_offsets[i] + ivf_lengths[i] as i64;
565 }
566
567 let mut doc_codes: Vec<Array1<usize>> = Vec::with_capacity(metadata.num_documents);
569 let mut doc_residuals: Vec<Array2<u8>> = Vec::with_capacity(metadata.num_documents);
570 let mut doc_lengths: Vec<i64> = Vec::with_capacity(metadata.num_documents);
571
572 for chunk_idx in 0..metadata.num_chunks {
573 let doclens_path = index_dir.join(format!("doclens.{}.json", chunk_idx));
575 let chunk_doclens: Vec<i64> =
576 serde_json::from_reader(BufReader::new(File::open(&doclens_path)?))?;
577
578 let codes_path = index_dir.join(format!("{}.codes.npy", chunk_idx));
580 let chunk_codes: Array1<i64> = Array1::read_npy(File::open(&codes_path)?)?;
581
582 let residuals_path = index_dir.join(format!("{}.residuals.npy", chunk_idx));
584 let chunk_residuals: Array2<u8> = Array2::read_npy(File::open(&residuals_path)?)?;
585
586 let mut code_offset = 0;
588 for &len in &chunk_doclens {
589 let len_usize = len as usize;
590 doc_lengths.push(len);
591
592 let codes: Array1<usize> = chunk_codes
593 .slice(s![code_offset..code_offset + len_usize])
594 .iter()
595 .map(|&x| x as usize)
596 .collect();
597 doc_codes.push(codes);
598
599 let res = chunk_residuals
600 .slice(s![code_offset..code_offset + len_usize, ..])
601 .to_owned();
602 doc_residuals.push(res);
603
604 code_offset += len_usize;
605 }
606 }
607
608 Ok(Self {
609 path: index_path.to_string(),
610 metadata,
611 codec,
612 ivf,
613 ivf_lengths,
614 ivf_offsets,
615 doc_codes,
616 doc_lengths: Array1::from_vec(doc_lengths),
617 doc_residuals,
618 })
619 }
620
621 pub fn get_candidates(&self, centroid_indices: &[usize]) -> Vec<i64> {
623 let mut candidates: Vec<i64> = Vec::new();
624
625 for &idx in centroid_indices {
626 if idx < self.ivf_lengths.len() {
627 let start = self.ivf_offsets[idx] as usize;
628 let len = self.ivf_lengths[idx] as usize;
629 candidates.extend(self.ivf.slice(s![start..start + len]).iter());
630 }
631 }
632
633 candidates.sort_unstable();
634 candidates.dedup();
635 candidates
636 }
637
638 pub fn get_document_embeddings(&self, doc_id: usize) -> Result<Array2<f32>> {
640 if doc_id >= self.doc_codes.len() {
641 return Err(Error::Search(format!("Invalid document ID: {}", doc_id)));
642 }
643
644 let codes = &self.doc_codes[doc_id];
645 let residuals = &self.doc_residuals[doc_id];
646
647 self.codec.decompress(residuals, &codes.view())
648 }
649
650 pub fn reconstruct(&self, doc_ids: &[i64]) -> Result<Vec<Array2<f32>>> {
673 use rayon::prelude::*;
675
676 let num_documents = self.doc_codes.len();
677
678 for &doc_id in doc_ids {
680 if doc_id < 0 || doc_id as usize >= num_documents {
681 return Err(Error::Search(format!(
682 "Invalid document ID: {} (index has {} documents)",
683 doc_id, num_documents
684 )));
685 }
686 }
687
688 doc_ids
690 .par_iter()
691 .map(|&doc_id| self.get_document_embeddings(doc_id as usize))
692 .collect()
693 }
694
695 pub fn reconstruct_single(&self, doc_id: i64) -> Result<Array2<f32>> {
699 self.get_document_embeddings(doc_id as usize)
700 }
701
702 #[cfg(feature = "npy")]
730 pub fn update(
731 &mut self,
732 embeddings: &[Array2<f32>],
733 config: &crate::update::UpdateConfig,
734 ) -> Result<Vec<i64>> {
735 use crate::update::{
736 clear_buffer, clear_embeddings_npy, embeddings_npy_exists, load_buffer,
737 load_buffer_info, load_cluster_threshold, load_embeddings_npy, save_buffer,
738 update_centroids, update_index,
739 };
740
741 let path_str = self.path.clone();
743 let index_path = std::path::Path::new(&path_str);
744 let num_new_docs = embeddings.len();
745
746 if self.metadata.num_documents <= config.start_from_scratch {
750 let existing_embeddings = load_embeddings_npy(index_path)?;
752 let start_doc_id = existing_embeddings.len() as i64;
754
755 let combined_embeddings: Vec<Array2<f32>> = existing_embeddings
757 .into_iter()
758 .chain(embeddings.iter().cloned())
759 .collect();
760
761 let index_config = crate::index::IndexConfig {
763 nbits: self.metadata.nbits,
764 batch_size: config.batch_size,
765 seed: Some(config.seed),
766 kmeans_niters: config.kmeans_niters,
767 max_points_per_centroid: config.max_points_per_centroid,
768 n_samples_kmeans: config.n_samples_kmeans,
769 start_from_scratch: config.start_from_scratch,
770 };
771
772 *self = Index::create_with_kmeans(&combined_embeddings, &path_str, &index_config)?;
775
776 if combined_embeddings.len() > config.start_from_scratch
779 && embeddings_npy_exists(index_path)
780 {
781 clear_embeddings_npy(index_path)?;
782 }
783
784 return Ok((start_doc_id..start_doc_id + num_new_docs as i64).collect());
786 }
787
788 let buffer = load_buffer(index_path)?;
790 let buffer_len = buffer.len();
791 let total_new = embeddings.len() + buffer_len;
792
793 let start_doc_id: i64;
795
796 if total_new >= config.buffer_size {
798 let num_buffered = load_buffer_info(index_path)?;
802
803 if num_buffered > 0 && self.metadata.num_documents >= num_buffered {
806 let start_del_idx = self.metadata.num_documents - num_buffered;
807 let docs_to_delete: Vec<i64> = (start_del_idx..self.metadata.num_documents)
808 .map(|i| i as i64)
809 .collect();
810 crate::delete::delete_from_index(&docs_to_delete, &path_str)?;
811 *self = Index::load(&path_str)?;
813 }
814
815 start_doc_id = (self.metadata.num_documents + buffer_len) as i64;
817
818 let combined: Vec<Array2<f32>> = buffer
820 .into_iter()
821 .chain(embeddings.iter().cloned())
822 .collect();
823
824 if let Ok(cluster_threshold) = load_cluster_threshold(index_path) {
826 let new_centroids =
827 update_centroids(index_path, &combined, cluster_threshold, config)?;
828 if new_centroids > 0 {
829 self.codec = ResidualCodec::load_from_dir(index_path)?;
831 }
832 }
833
834 clear_buffer(index_path)?;
836
837 update_index(
839 &combined,
840 &path_str,
841 &self.codec,
842 Some(config.batch_size),
843 true,
844 )?;
845 } else {
846 start_doc_id = self.metadata.num_documents as i64;
849
850 save_buffer(index_path, embeddings)?;
851
852 update_index(
854 embeddings,
855 &path_str,
856 &self.codec,
857 Some(config.batch_size),
858 false,
859 )?;
860 }
861
862 *self = Index::load(&path_str)?;
864
865 Ok((start_doc_id..start_doc_id + num_new_docs as i64).collect())
867 }
868
869 #[cfg(all(feature = "npy", feature = "filtering"))]
885 pub fn update_with_metadata(
886 &mut self,
887 embeddings: &[Array2<f32>],
888 config: &crate::update::UpdateConfig,
889 metadata: Option<&[serde_json::Value]>,
890 ) -> Result<Vec<i64>> {
891 if let Some(meta) = metadata {
893 if meta.len() != embeddings.len() {
894 return Err(Error::Config(format!(
895 "Metadata length ({}) must match embeddings length ({})",
896 meta.len(),
897 embeddings.len()
898 )));
899 }
900 }
901
902 let doc_ids = self.update(embeddings, config)?;
904
905 if let Some(meta) = metadata {
907 crate::filtering::update(&self.path, meta, &doc_ids)?;
908 }
909
910 Ok(doc_ids)
911 }
912
913 #[cfg(feature = "npy")]
950 pub fn update_or_create(
951 embeddings: &[Array2<f32>],
952 index_path: &str,
953 index_config: &IndexConfig,
954 update_config: &crate::update::UpdateConfig,
955 ) -> Result<(Self, Vec<i64>)> {
956 let index_dir = std::path::Path::new(index_path);
957 let metadata_path = index_dir.join("metadata.json");
958
959 if metadata_path.exists() {
960 let mut index = Self::load(index_path)?;
962 let doc_ids = index.update(embeddings, update_config)?;
963 Ok((index, doc_ids))
964 } else {
965 let num_docs = embeddings.len();
967 let index = Self::create_with_kmeans(embeddings, index_path, index_config)?;
968 let doc_ids: Vec<i64> = (0..num_docs as i64).collect();
969 Ok((index, doc_ids))
970 }
971 }
972
973 #[cfg(feature = "npy")]
978 pub fn update_simple(
979 &mut self,
980 embeddings: &[Array2<f32>],
981 batch_size: Option<usize>,
982 ) -> Result<()> {
983 crate::update::update_index(embeddings, &self.path, &self.codec, batch_size, true)?;
984
985 *self = Index::load(&self.path)?;
987 Ok(())
988 }
989
990 #[cfg(feature = "npy")]
1013 pub fn delete(&mut self, doc_ids: &[i64]) -> Result<usize> {
1014 self.delete_with_options(doc_ids, true)
1015 }
1016
1017 #[cfg(feature = "npy")]
1031 pub fn delete_with_options(
1032 &mut self,
1033 doc_ids: &[i64],
1034 #[allow(unused_variables)] delete_metadata: bool,
1035 ) -> Result<usize> {
1036 let deleted = crate::delete::delete_from_index(doc_ids, &self.path)?;
1037
1038 #[cfg(feature = "filtering")]
1040 if delete_metadata && crate::filtering::exists(&self.path) {
1041 crate::filtering::delete(&self.path, doc_ids)?;
1042 }
1043
1044 *self = Index::load(&self.path)?;
1046 Ok(deleted)
1047 }
1048}
1049
1050pub struct LoadedIndex {
1056 pub metadata: Metadata,
1058 pub codec: ResidualCodec,
1060 pub ivf: IvfStridedTensor,
1062 pub doc_codes: StridedTensor<usize>,
1064 pub doc_residuals: StridedTensor<u8>,
1066 pub nbits: usize,
1068}
1069
1070impl LoadedIndex {
1071 pub fn from_index(index: Index) -> Self {
1075 let embedding_dim = index.codec.embedding_dim();
1076 let packed_dim = embedding_dim * index.metadata.nbits / 8;
1077 let num_documents = index.doc_codes.len();
1078
1079 let total_codes: usize = index.doc_lengths.iter().sum::<i64>() as usize;
1081 let mut codes_data = Array2::<usize>::zeros((total_codes, 1));
1082 let mut offset = 0;
1083
1084 for codes in &index.doc_codes {
1085 for (j, &code) in codes.iter().enumerate() {
1086 codes_data[[offset + j, 0]] = code;
1087 }
1088 offset += codes.len();
1089 }
1090
1091 let doc_codes = StridedTensor::new(codes_data, index.doc_lengths.clone());
1092
1093 let mut residuals_data = Array2::<u8>::zeros((total_codes, packed_dim));
1095 offset = 0;
1096
1097 for residuals in &index.doc_residuals {
1098 residuals_data
1099 .slice_mut(s![offset..offset + residuals.nrows(), ..])
1100 .assign(residuals);
1101 offset += residuals.nrows();
1102 }
1103
1104 let doc_residuals = StridedTensor::new(residuals_data, index.doc_lengths.clone());
1105
1106 let ivf = IvfStridedTensor::new(index.ivf, index.ivf_lengths);
1108
1109 Self {
1110 metadata: index.metadata,
1111 codec: index.codec,
1112 ivf,
1113 doc_codes,
1114 doc_residuals,
1115 nbits: num_documents, }
1117 }
1118
1119 #[cfg(feature = "npy")]
1121 pub fn load(index_path: &str) -> Result<Self> {
1122 let index = Index::load(index_path)?;
1123 let nbits = index.metadata.nbits;
1124 let mut loaded = Self::from_index(index);
1125 loaded.nbits = nbits;
1126 Ok(loaded)
1127 }
1128
1129 pub fn get_candidates(&self, centroid_indices: &[usize]) -> Vec<i64> {
1131 self.ivf.lookup(centroid_indices)
1132 }
1133
1134 pub fn lookup_documents(&self, doc_ids: &[usize]) -> (Array1<usize>, Array2<u8>, Array1<i64>) {
1138 let (codes, lengths) = self.doc_codes.lookup_codes(doc_ids);
1139 let (residuals, _) = self.doc_residuals.lookup_2d(doc_ids);
1140 (codes, residuals, lengths)
1141 }
1142
1143 pub fn decompress_documents(&self, doc_ids: &[usize]) -> Result<(Array2<f32>, Array1<i64>)> {
1147 let (codes, residuals, lengths) = self.lookup_documents(doc_ids);
1148
1149 let embeddings = self.codec.decompress(&residuals, &codes.view())?;
1151
1152 Ok((embeddings, lengths))
1153 }
1154
1155 pub fn num_documents(&self) -> usize {
1157 self.doc_codes.len()
1158 }
1159
1160 pub fn embedding_dim(&self) -> usize {
1162 self.codec.embedding_dim()
1163 }
1164
1165 pub fn reconstruct(&self, doc_ids: &[i64]) -> Result<Vec<Array2<f32>>> {
1191 crate::embeddings::reconstruct_embeddings(self, doc_ids)
1192 }
1193
1194 pub fn reconstruct_single(&self, doc_id: i64) -> Result<Array2<f32>> {
1206 crate::embeddings::reconstruct_single(self, doc_id)
1207 }
1208}
1209
1210#[cfg(feature = "npy")]
1235pub struct MmapIndex {
1236 pub path: String,
1238 pub metadata: Metadata,
1240 pub codec: ResidualCodec,
1242 pub ivf: Array1<i64>,
1244 pub ivf_lengths: Array1<i32>,
1246 pub ivf_offsets: Array1<i64>,
1248 pub doc_lengths: Array1<i64>,
1250 pub doc_offsets: Array1<usize>,
1252 pub mmap_codes: crate::mmap::MmapNpyArray1I64,
1254 pub mmap_residuals: crate::mmap::MmapNpyArray2U8,
1256}
1257
1258#[cfg(feature = "npy")]
1259impl MmapIndex {
1260 pub fn load(index_path: &str) -> Result<Self> {
1265 use ndarray_npy::ReadNpyExt;
1266
1267 let index_dir = Path::new(index_path);
1268
1269 let metadata_path = index_dir.join("metadata.json");
1271 let metadata: Metadata = serde_json::from_reader(BufReader::new(
1272 File::open(&metadata_path)
1273 .map_err(|e| Error::IndexLoad(format!("Failed to open metadata: {}", e)))?,
1274 ))?;
1275
1276 let codec = ResidualCodec::load_mmap_from_dir(index_dir)?;
1279
1280 let ivf_path = index_dir.join("ivf.npy");
1282 let ivf: Array1<i64> = Array1::read_npy(
1283 File::open(&ivf_path)
1284 .map_err(|e| Error::IndexLoad(format!("Failed to open ivf.npy: {}", e)))?,
1285 )
1286 .map_err(|e| Error::IndexLoad(format!("Failed to read ivf.npy: {}", e)))?;
1287
1288 let ivf_lengths_path = index_dir.join("ivf_lengths.npy");
1289 let ivf_lengths: Array1<i32> = Array1::read_npy(
1290 File::open(&ivf_lengths_path)
1291 .map_err(|e| Error::IndexLoad(format!("Failed to open ivf_lengths.npy: {}", e)))?,
1292 )
1293 .map_err(|e| Error::IndexLoad(format!("Failed to read ivf_lengths.npy: {}", e)))?;
1294
1295 let num_centroids = ivf_lengths.len();
1297 let mut ivf_offsets = Array1::<i64>::zeros(num_centroids + 1);
1298 for i in 0..num_centroids {
1299 ivf_offsets[i + 1] = ivf_offsets[i] + ivf_lengths[i] as i64;
1300 }
1301
1302 let mut doc_lengths_vec: Vec<i64> = Vec::with_capacity(metadata.num_documents);
1304 for chunk_idx in 0..metadata.num_chunks {
1305 let doclens_path = index_dir.join(format!("doclens.{}.json", chunk_idx));
1306 let chunk_doclens: Vec<i64> =
1307 serde_json::from_reader(BufReader::new(File::open(&doclens_path)?))?;
1308 doc_lengths_vec.extend(chunk_doclens);
1309 }
1310 let doc_lengths = Array1::from_vec(doc_lengths_vec);
1311
1312 let mut doc_offsets = Array1::<usize>::zeros(doc_lengths.len() + 1);
1314 for i in 0..doc_lengths.len() {
1315 doc_offsets[i + 1] = doc_offsets[i] + doc_lengths[i] as usize;
1316 }
1317
1318 let max_len = doc_lengths.iter().cloned().max().unwrap_or(0) as usize;
1320 let last_len = *doc_lengths.last().unwrap_or(&0) as usize;
1321 let padding_needed = max_len.saturating_sub(last_len);
1322
1323 let merged_codes_path =
1325 crate::mmap::merge_codes_chunks(index_dir, metadata.num_chunks, padding_needed)?;
1326 let merged_residuals_path =
1327 crate::mmap::merge_residuals_chunks(index_dir, metadata.num_chunks, padding_needed)?;
1328
1329 let mmap_codes = crate::mmap::MmapNpyArray1I64::from_npy_file(&merged_codes_path)?;
1331 let mmap_residuals = crate::mmap::MmapNpyArray2U8::from_npy_file(&merged_residuals_path)?;
1332
1333 Ok(Self {
1334 path: index_path.to_string(),
1335 metadata,
1336 codec,
1337 ivf,
1338 ivf_lengths,
1339 ivf_offsets,
1340 doc_lengths,
1341 doc_offsets,
1342 mmap_codes,
1343 mmap_residuals,
1344 })
1345 }
1346
1347 pub fn get_candidates(&self, centroid_indices: &[usize]) -> Vec<i64> {
1349 let mut candidates: Vec<i64> = Vec::new();
1350
1351 for &idx in centroid_indices {
1352 if idx < self.ivf_lengths.len() {
1353 let start = self.ivf_offsets[idx] as usize;
1354 let len = self.ivf_lengths[idx] as usize;
1355 candidates.extend(self.ivf.slice(s![start..start + len]).iter());
1356 }
1357 }
1358
1359 candidates.sort_unstable();
1360 candidates.dedup();
1361 candidates
1362 }
1363
1364 pub fn get_document_embeddings(&self, doc_id: usize) -> Result<Array2<f32>> {
1366 if doc_id >= self.doc_lengths.len() {
1367 return Err(Error::Search(format!("Invalid document ID: {}", doc_id)));
1368 }
1369
1370 let start = self.doc_offsets[doc_id];
1371 let end = self.doc_offsets[doc_id + 1];
1372
1373 let codes_slice = self.mmap_codes.slice(start, end);
1375 let residuals_view = self.mmap_residuals.slice_rows(start, end);
1376
1377 let codes: Array1<usize> = Array1::from_iter(codes_slice.iter().map(|&c| c as usize));
1379
1380 let residuals = residuals_view.to_owned();
1382
1383 self.codec.decompress(&residuals, &codes.view())
1385 }
1386
1387 pub fn get_document_codes(&self, doc_ids: &[usize]) -> Vec<Vec<i64>> {
1389 doc_ids
1390 .iter()
1391 .map(|&doc_id| {
1392 if doc_id >= self.doc_lengths.len() {
1393 return vec![];
1394 }
1395 let start = self.doc_offsets[doc_id];
1396 let end = self.doc_offsets[doc_id + 1];
1397 self.mmap_codes.slice(start, end).to_vec()
1398 })
1399 .collect()
1400 }
1401
1402 pub fn decompress_documents(&self, doc_ids: &[usize]) -> Result<(Array2<f32>, Vec<usize>)> {
1404 let mut total_tokens = 0usize;
1406 let mut lengths = Vec::with_capacity(doc_ids.len());
1407 for &doc_id in doc_ids {
1408 if doc_id >= self.doc_lengths.len() {
1409 lengths.push(0);
1410 } else {
1411 let len = self.doc_offsets[doc_id + 1] - self.doc_offsets[doc_id];
1412 lengths.push(len);
1413 total_tokens += len;
1414 }
1415 }
1416
1417 if total_tokens == 0 {
1418 return Ok((Array2::zeros((0, self.codec.embedding_dim())), lengths));
1419 }
1420
1421 let packed_dim = self.mmap_residuals.ncols();
1423 let mut all_codes = Vec::with_capacity(total_tokens);
1424 let mut all_residuals = Array2::<u8>::zeros((total_tokens, packed_dim));
1425 let mut offset = 0;
1426
1427 for &doc_id in doc_ids {
1428 if doc_id >= self.doc_lengths.len() {
1429 continue;
1430 }
1431 let start = self.doc_offsets[doc_id];
1432 let end = self.doc_offsets[doc_id + 1];
1433 let len = end - start;
1434
1435 let codes_slice = self.mmap_codes.slice(start, end);
1437 all_codes.extend(codes_slice.iter().map(|&c| c as usize));
1438
1439 let residuals_view = self.mmap_residuals.slice_rows(start, end);
1441 all_residuals
1442 .slice_mut(s![offset..offset + len, ..])
1443 .assign(&residuals_view);
1444 offset += len;
1445 }
1446
1447 let codes_arr = Array1::from_vec(all_codes);
1448 let embeddings = self.codec.decompress(&all_residuals, &codes_arr.view())?;
1449
1450 Ok((embeddings, lengths))
1451 }
1452
1453 pub fn search(
1465 &self,
1466 query: &Array2<f32>,
1467 params: &crate::search::SearchParameters,
1468 subset: Option<&[i64]>,
1469 ) -> Result<crate::search::SearchResult> {
1470 crate::search::search_one_mmap(self, query, params, subset)
1471 }
1472
1473 pub fn search_batch(
1486 &self,
1487 queries: &[Array2<f32>],
1488 params: &crate::search::SearchParameters,
1489 parallel: bool,
1490 subset: Option<&[i64]>,
1491 ) -> Result<Vec<crate::search::SearchResult>> {
1492 crate::search::search_many_mmap(self, queries, params, parallel, subset)
1493 }
1494
1495 pub fn num_documents(&self) -> usize {
1497 self.doc_lengths.len()
1498 }
1499
1500 pub fn embedding_dim(&self) -> usize {
1502 self.codec.embedding_dim()
1503 }
1504
1505 pub fn reconstruct(&self, doc_ids: &[i64]) -> Result<Vec<Array2<f32>>> {
1531 crate::embeddings::reconstruct_embeddings_mmap(self, doc_ids)
1532 }
1533
1534 pub fn reconstruct_single(&self, doc_id: i64) -> Result<Array2<f32>> {
1546 crate::embeddings::reconstruct_single_mmap(self, doc_id)
1547 }
1548}
1549
1550#[cfg(test)]
1551mod tests {
1552 use super::*;
1553
1554 #[test]
1555 fn test_index_config_default() {
1556 let config = IndexConfig::default();
1557 assert_eq!(config.nbits, 4);
1558 assert_eq!(config.batch_size, 50_000);
1559 assert_eq!(config.seed, Some(42));
1560 }
1561
1562 #[test]
1563 #[cfg(feature = "npy")]
1564 fn test_update_or_create_new_index() {
1565 use ndarray::Array2;
1566 use tempfile::tempdir;
1567
1568 let temp_dir = tempdir().unwrap();
1569 let index_path = temp_dir.path().to_str().unwrap();
1570
1571 let mut embeddings: Vec<Array2<f32>> = Vec::new();
1573 for i in 0..5 {
1574 let mut doc = Array2::<f32>::zeros((5, 32));
1575 for j in 0..5 {
1576 for k in 0..32 {
1577 doc[[j, k]] = (i as f32 * 0.1) + (j as f32 * 0.01) + (k as f32 * 0.001);
1578 }
1579 }
1580 for mut row in doc.rows_mut() {
1582 let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
1583 if norm > 0.0 {
1584 row.iter_mut().for_each(|x| *x /= norm);
1585 }
1586 }
1587 embeddings.push(doc);
1588 }
1589
1590 let index_config = IndexConfig {
1591 nbits: 2,
1592 batch_size: 50,
1593 seed: Some(42),
1594 kmeans_niters: 2,
1595 ..Default::default()
1596 };
1597 let update_config = crate::update::UpdateConfig::default();
1598
1599 let (index, doc_ids) =
1601 Index::update_or_create(&embeddings, index_path, &index_config, &update_config)
1602 .expect("Failed to create index");
1603
1604 assert_eq!(index.metadata.num_documents, 5);
1605 assert_eq!(doc_ids, vec![0, 1, 2, 3, 4]);
1606
1607 assert!(temp_dir.path().join("metadata.json").exists());
1609 assert!(temp_dir.path().join("centroids.npy").exists());
1610 }
1611
1612 #[test]
1613 #[cfg(feature = "npy")]
1614 fn test_update_or_create_existing_index() {
1615 use ndarray::Array2;
1616 use tempfile::tempdir;
1617
1618 let temp_dir = tempdir().unwrap();
1619 let index_path = temp_dir.path().to_str().unwrap();
1620
1621 let create_embeddings = |count: usize, offset: usize| -> Vec<Array2<f32>> {
1623 let mut embeddings = Vec::new();
1624 for i in 0..count {
1625 let mut doc = Array2::<f32>::zeros((5, 32));
1626 for j in 0..5 {
1627 for k in 0..32 {
1628 doc[[j, k]] =
1629 ((i + offset) as f32 * 0.1) + (j as f32 * 0.01) + (k as f32 * 0.001);
1630 }
1631 }
1632 for mut row in doc.rows_mut() {
1633 let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
1634 if norm > 0.0 {
1635 row.iter_mut().for_each(|x| *x /= norm);
1636 }
1637 }
1638 embeddings.push(doc);
1639 }
1640 embeddings
1641 };
1642
1643 let index_config = IndexConfig {
1644 nbits: 2,
1645 batch_size: 50,
1646 seed: Some(42),
1647 kmeans_niters: 2,
1648 ..Default::default()
1649 };
1650 let update_config = crate::update::UpdateConfig::default();
1651
1652 let embeddings1 = create_embeddings(5, 0);
1654 let (index1, doc_ids1) =
1655 Index::update_or_create(&embeddings1, index_path, &index_config, &update_config)
1656 .expect("Failed to create index");
1657 assert_eq!(index1.metadata.num_documents, 5);
1658 assert_eq!(doc_ids1, vec![0, 1, 2, 3, 4]);
1659
1660 let embeddings2 = create_embeddings(3, 5);
1662 let (index2, doc_ids2) =
1663 Index::update_or_create(&embeddings2, index_path, &index_config, &update_config)
1664 .expect("Failed to update index");
1665 assert_eq!(index2.metadata.num_documents, 8);
1666 assert_eq!(doc_ids2, vec![5, 6, 7]);
1667 }
1668}