1use std::collections::BTreeMap;
4use std::fs::{self, File};
5use std::io::{BufReader, BufWriter, Write};
6use std::path::Path;
7
8use ndarray::{s, Array1, Array2, Axis};
9use serde::{Deserialize, Serialize};
10
11use crate::codec::ResidualCodec;
12use crate::error::{Error, Result};
13use crate::kmeans::{compute_kmeans, ComputeKmeansConfig};
14use crate::utils::{quantile, quantiles};
15
16fn compress_and_residuals_cpu(
18 embeddings: &Array2<f32>,
19 codec: &ResidualCodec,
20) -> (Array1<usize>, Array2<f32>) {
21 use rayon::prelude::*;
22
23 let codes = codec.compress_into_codes_cpu(embeddings);
25 let mut residuals = embeddings.clone();
26
27 let centroids = &codec.centroids;
28 residuals
29 .axis_iter_mut(Axis(0))
30 .into_par_iter()
31 .zip(codes.as_slice().unwrap().par_iter())
32 .for_each(|(mut row, &code)| {
33 let centroid = centroids.row(code);
34 row.iter_mut()
35 .zip(centroid.iter())
36 .for_each(|(r, c)| *r -= c);
37 });
38
39 (codes, residuals)
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct IndexConfig {
45 pub nbits: usize,
47 pub batch_size: usize,
49 pub seed: Option<u64>,
51 #[serde(default = "default_kmeans_niters")]
53 pub kmeans_niters: usize,
54 #[serde(default = "default_max_points_per_centroid")]
56 pub max_points_per_centroid: usize,
57 #[serde(default)]
60 pub n_samples_kmeans: Option<usize>,
61 #[serde(default = "default_start_from_scratch")]
65 pub start_from_scratch: usize,
66 #[serde(default)]
69 pub force_cpu: bool,
70 #[serde(default)]
73 pub fts_tokenizer: crate::text_search::FtsTokenizer,
74}
75
76fn default_start_from_scratch() -> usize {
77 999
78}
79
80fn default_kmeans_niters() -> usize {
81 4
82}
83
84fn default_max_points_per_centroid() -> usize {
85 256
86}
87
88impl Default for IndexConfig {
89 fn default() -> Self {
90 Self {
91 nbits: 4,
92 batch_size: 50_000,
93 seed: Some(42),
94 kmeans_niters: 4,
95 max_points_per_centroid: 256,
96 n_samples_kmeans: None,
97 start_from_scratch: 999,
98 force_cpu: false,
99 fts_tokenizer: crate::text_search::FtsTokenizer::default(),
100 }
101 }
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct Metadata {
107 pub num_chunks: usize,
109 pub nbits: usize,
111 pub num_partitions: usize,
113 pub num_embeddings: usize,
115 pub avg_doclen: f64,
117 #[serde(default)]
119 pub num_documents: usize,
120 #[serde(default)]
122 pub embedding_dim: usize,
123 #[serde(default)]
126 pub next_plaid_compatible: bool,
127}
128
129impl Metadata {
130 pub fn load_from_path(index_path: &Path) -> Result<Self> {
132 let metadata_path = index_path.join("metadata.json");
133 let mut metadata: Metadata = serde_json::from_reader(BufReader::new(
134 File::open(&metadata_path)
135 .map_err(|e| Error::IndexLoad(format!("Failed to open metadata: {}", e)))?,
136 ))?;
137
138 if metadata.num_documents == 0 {
140 let mut total_docs = 0usize;
141 for chunk_idx in 0..metadata.num_chunks {
142 let doclens_path = index_path.join(format!("doclens.{}.json", chunk_idx));
143 if let Ok(file) = File::open(&doclens_path) {
144 if let Ok(chunk_doclens) =
145 serde_json::from_reader::<_, Vec<i64>>(BufReader::new(file))
146 {
147 total_docs += chunk_doclens.len();
148 }
149 }
150 }
151 metadata.num_documents = total_docs;
152 }
153
154 Ok(metadata)
155 }
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct ChunkMetadata {
161 pub num_documents: usize,
162 pub num_embeddings: usize,
163 #[serde(default)]
164 pub embedding_offset: usize,
165}
166
167#[derive(Debug, Clone)]
168pub struct EncodedIndexChunk {
169 pub codes: Array1<i64>,
170 pub residuals: Array2<u8>,
171 pub doclens: Vec<i64>,
172}
173
174pub struct PreparedCodecArtifacts {
175 pub codec: ResidualCodec,
176 pub cluster_threshold: f32,
177 pub bucket_cutoffs: Array1<f32>,
178 pub bucket_weights: Array1<f32>,
179 pub avg_res_per_dim: Array1<f32>,
180}
181
182pub fn prepare_codec_artifacts(
183 embeddings: &[Array2<f32>],
184 centroids: Array2<f32>,
185 config: &IndexConfig,
186) -> Result<PreparedCodecArtifacts> {
187 let embedding_dim = centroids.ncols();
188 let total_embeddings: usize = embeddings.iter().map(|e| e.nrows()).sum();
189 let num_documents = embeddings.len();
190
191 if num_documents == 0 {
192 return Err(Error::IndexCreation("No documents provided".into()));
193 }
194
195 let sample_count = ((16.0 * (120.0 * num_documents as f64).sqrt()) as usize)
196 .min(num_documents)
197 .max(1);
198
199 let mut rng = if let Some(seed) = config.seed {
200 use rand::SeedableRng;
201 rand_chacha::ChaCha8Rng::seed_from_u64(seed)
202 } else {
203 use rand::SeedableRng;
204 rand_chacha::ChaCha8Rng::from_entropy()
205 };
206
207 use rand::seq::SliceRandom;
208 let mut indices: Vec<usize> = (0..num_documents).collect();
209 indices.shuffle(&mut rng);
210 let sample_indices: Vec<usize> = indices.into_iter().take(sample_count).collect();
211
212 let heldout_size = (0.05 * total_embeddings as f64).min(50000.0) as usize;
213 let mut heldout_embeddings: Vec<f32> = Vec::with_capacity(heldout_size * embedding_dim);
214 let mut collected = 0;
215
216 for &idx in sample_indices.iter().rev() {
217 if collected >= heldout_size {
218 break;
219 }
220 let emb = &embeddings[idx];
221 let take = (heldout_size - collected).min(emb.nrows());
222 for row in emb.axis_iter(Axis(0)).take(take) {
223 heldout_embeddings.extend(row.iter());
224 }
225 collected += take;
226 }
227
228 let heldout = Array2::from_shape_vec((collected, embedding_dim), heldout_embeddings)
229 .map_err(|e| Error::IndexCreation(format!("Failed to create heldout array: {}", e)))?;
230
231 let avg_residual = Array1::zeros(embedding_dim);
232 let initial_codec =
233 ResidualCodec::new(config.nbits, centroids.clone(), avg_residual, None, None)?;
234
235 let heldout_codes = if config.force_cpu {
236 initial_codec.compress_into_codes_cpu(&heldout)
237 } else {
238 initial_codec.compress_into_codes(&heldout)
239 };
240
241 let mut residuals = heldout.clone();
242 for i in 0..heldout.nrows() {
243 let centroid = initial_codec.centroids.row(heldout_codes[i]);
244 for j in 0..embedding_dim {
245 residuals[[i, j]] -= centroid[j];
246 }
247 }
248
249 let distances: Array1<f32> = residuals
250 .axis_iter(Axis(0))
251 .map(|row| row.dot(&row).sqrt())
252 .collect();
253 let cluster_threshold = quantile(&distances, 0.75);
254
255 let avg_res_per_dim: Array1<f32> = residuals
256 .axis_iter(Axis(1))
257 .map(|col| col.iter().map(|x| x.abs()).sum::<f32>() / col.len() as f32)
258 .collect();
259
260 let n_options = 1 << config.nbits;
261 let quantile_values: Vec<f64> = (1..n_options)
262 .map(|i| i as f64 / n_options as f64)
263 .collect();
264 let weight_quantile_values: Vec<f64> = (0..n_options)
265 .map(|i| (i as f64 + 0.5) / n_options as f64)
266 .collect();
267
268 let flat_residuals: Array1<f32> = residuals.iter().copied().collect();
269 let bucket_cutoffs = Array1::from_vec(quantiles(&flat_residuals, &quantile_values));
270 let bucket_weights = Array1::from_vec(quantiles(&flat_residuals, &weight_quantile_values));
271
272 let codec = ResidualCodec::new(
273 config.nbits,
274 centroids,
275 avg_res_per_dim.clone(),
276 Some(bucket_cutoffs.clone()),
277 Some(bucket_weights.clone()),
278 )?;
279
280 Ok(PreparedCodecArtifacts {
281 codec,
282 cluster_threshold,
283 bucket_cutoffs,
284 bucket_weights,
285 avg_res_per_dim,
286 })
287}
288
289pub fn encode_index_chunk(
290 embeddings: &[Array2<f32>],
291 codec: &ResidualCodec,
292 force_cpu: bool,
293) -> Result<EncodedIndexChunk> {
294 let embedding_dim = codec.embedding_dim();
295 let packed_dim = embedding_dim * codec.nbits / 8;
296 let doclens: Vec<i64> = embeddings.iter().map(|d| d.nrows() as i64).collect();
297 let total_tokens: usize = doclens.iter().sum::<i64>() as usize;
298
299 #[cfg(not(feature = "cuda"))]
300 let _ = force_cpu;
301
302 let mut batch_embeddings = Array2::<f32>::zeros((total_tokens, embedding_dim));
303 let mut offset = 0;
304 for doc in embeddings {
305 let n = doc.nrows();
306 batch_embeddings
307 .slice_mut(s![offset..offset + n, ..])
308 .assign(doc);
309 offset += n;
310 }
311
312 let (batch_codes, batch_residuals) = {
313 #[cfg(feature = "cuda")]
314 {
315 let force_gpu = crate::is_force_gpu();
316 if !force_cpu {
317 if let Some(ctx) = crate::cuda::get_global_context() {
318 match crate::cuda::compress_and_residuals_cuda_batched(
319 &ctx,
320 &batch_embeddings.view(),
321 &codec.centroids_view(),
322 None,
323 ) {
324 Ok(result) => result,
325 Err(e) => {
326 if force_gpu {
327 panic!(
328 "FORCE_GPU is set but CUDA compress_and_residuals failed: {}",
329 e
330 );
331 }
332 println!(
333 "[next-plaid] CUDA compress_and_residuals failed: {}, falling back to CPU",
334 e
335 );
336 compress_and_residuals_cpu(&batch_embeddings, codec)
337 }
338 }
339 } else if force_gpu {
340 panic!("FORCE_GPU is set but CUDA context is unavailable");
341 } else {
342 compress_and_residuals_cpu(&batch_embeddings, codec)
343 }
344 } else {
345 compress_and_residuals_cpu(&batch_embeddings, codec)
346 }
347 }
348 #[cfg(not(feature = "cuda"))]
349 {
350 compress_and_residuals_cpu(&batch_embeddings, codec)
351 }
352 };
353
354 let batch_packed = codec.quantize_residuals(&batch_residuals)?;
355 let (raw_residuals, residuals_offset) = batch_packed.into_raw_vec_and_offset();
356 if residuals_offset != Some(0) {
357 return Err(Error::Shape(format!(
358 "Unexpected residual packing offset: {:?}",
359 residuals_offset
360 )));
361 }
362 let residuals = Array2::from_shape_vec((batch_codes.len(), packed_dim), raw_residuals)
363 .map_err(|e| Error::Shape(format!("Failed to reshape residuals: {}", e)))?;
364 let codes: Array1<i64> = batch_codes.iter().map(|&x| x as i64).collect();
365
366 Ok(EncodedIndexChunk {
367 codes,
368 residuals,
369 doclens,
370 })
371}
372
373pub fn write_index_from_encoded_chunks(
374 chunks: &[EncodedIndexChunk],
375 codec_artifacts: &PreparedCodecArtifacts,
376 index_path: &str,
377 config: &IndexConfig,
378) -> Result<Metadata> {
379 use ndarray_npy::WriteNpyExt;
380
381 let index_dir = Path::new(index_path);
382 fs::create_dir_all(index_dir)?;
383
384 let embedding_dim = codec_artifacts.codec.embedding_dim();
385 let num_centroids = codec_artifacts.codec.num_centroids();
386 let total_embeddings: usize = chunks.iter().map(|c| c.codes.len()).sum();
387 let num_documents: usize = chunks.iter().map(|c| c.doclens.len()).sum();
388 let avg_doclen = if num_documents > 0 {
389 total_embeddings as f64 / num_documents as f64
390 } else {
391 0.0
392 };
393
394 let centroids_path = index_dir.join("centroids.npy");
395 codec_artifacts
396 .codec
397 .centroids_view()
398 .to_owned()
399 .write_npy(File::create(¢roids_path)?)?;
400 codec_artifacts
401 .bucket_cutoffs
402 .write_npy(File::create(index_dir.join("bucket_cutoffs.npy"))?)?;
403 codec_artifacts
404 .bucket_weights
405 .write_npy(File::create(index_dir.join("bucket_weights.npy"))?)?;
406 codec_artifacts
407 .avg_res_per_dim
408 .write_npy(File::create(index_dir.join("avg_residual.npy"))?)?;
409 Array1::from_vec(vec![codec_artifacts.cluster_threshold])
410 .write_npy(File::create(index_dir.join("cluster_threshold.npy"))?)?;
411
412 let n_chunks = chunks.len();
413 let plan = serde_json::json!({
414 "nbits": config.nbits,
415 "num_chunks": n_chunks,
416 });
417 let mut plan_file = File::create(index_dir.join("plan.json"))?;
418 writeln!(plan_file, "{}", serde_json::to_string_pretty(&plan)?)?;
419
420 let mut all_codes: Vec<usize> = Vec::with_capacity(total_embeddings);
421 let mut doc_lengths: Vec<i64> = Vec::with_capacity(num_documents);
422 let mut current_offset = 0usize;
423
424 for (chunk_idx, chunk) in chunks.iter().enumerate() {
425 let chunk_meta = ChunkMetadata {
426 num_documents: chunk.doclens.len(),
427 num_embeddings: chunk.codes.len(),
428 embedding_offset: current_offset,
429 };
430 current_offset += chunk.codes.len();
431
432 serde_json::to_writer_pretty(
433 BufWriter::new(File::create(
434 index_dir.join(format!("{}.metadata.json", chunk_idx)),
435 )?),
436 &chunk_meta,
437 )?;
438 serde_json::to_writer(
439 BufWriter::new(File::create(
440 index_dir.join(format!("doclens.{}.json", chunk_idx)),
441 )?),
442 &chunk.doclens,
443 )?;
444 chunk.codes.write_npy(File::create(
445 index_dir.join(format!("{}.codes.npy", chunk_idx)),
446 )?)?;
447 chunk.residuals.write_npy(File::create(
448 index_dir.join(format!("{}.residuals.npy", chunk_idx)),
449 )?)?;
450
451 doc_lengths.extend_from_slice(&chunk.doclens);
452 all_codes.extend(chunk.codes.iter().map(|&x| x as usize));
453 }
454
455 let mut code_to_docs: BTreeMap<usize, Vec<i64>> = BTreeMap::new();
456 let mut emb_idx = 0;
457 for (doc_id, &len) in doc_lengths.iter().enumerate() {
458 for _ in 0..len {
459 let code = all_codes[emb_idx];
460 code_to_docs.entry(code).or_default().push(doc_id as i64);
461 emb_idx += 1;
462 }
463 }
464
465 let mut ivf_data: Vec<i64> = Vec::new();
466 let mut ivf_lengths: Vec<i32> = vec![0; num_centroids];
467 for (centroid_id, ivf_len) in ivf_lengths.iter_mut().enumerate() {
468 if let Some(docs) = code_to_docs.get(¢roid_id) {
469 let mut unique_docs = docs.clone();
470 unique_docs.sort_unstable();
471 unique_docs.dedup();
472 *ivf_len = unique_docs.len() as i32;
473 ivf_data.extend(unique_docs);
474 }
475 }
476
477 Array1::from_vec(ivf_data).write_npy(File::create(index_dir.join("ivf.npy"))?)?;
478 Array1::from_vec(ivf_lengths).write_npy(File::create(index_dir.join("ivf_lengths.npy"))?)?;
479
480 let metadata = Metadata {
481 num_chunks: n_chunks,
482 nbits: config.nbits,
483 num_partitions: num_centroids,
484 num_embeddings: total_embeddings,
485 avg_doclen,
486 num_documents,
487 embedding_dim,
488 next_plaid_compatible: true,
489 };
490 serde_json::to_writer_pretty(
491 BufWriter::new(File::create(index_dir.join("metadata.json"))?),
492 &metadata,
493 )?;
494
495 Ok(metadata)
496}
497
498pub fn create_index_files(
520 embeddings: &[Array2<f32>],
521 centroids: Array2<f32>,
522 index_path: &str,
523 config: &IndexConfig,
524) -> Result<Metadata> {
525 let index_dir = Path::new(index_path);
526 fs::create_dir_all(index_dir)?;
527
528 let num_documents = embeddings.len();
529 let embedding_dim = centroids.ncols();
530 let num_centroids = centroids.nrows();
531
532 if num_documents == 0 {
533 return Err(Error::IndexCreation("No documents provided".into()));
534 }
535
536 let total_embeddings: usize = embeddings.iter().map(|e| e.nrows()).sum();
538 let avg_doclen = total_embeddings as f64 / num_documents as f64;
539
540 let sample_count = ((16.0 * (120.0 * num_documents as f64).sqrt()) as usize)
542 .min(num_documents)
543 .max(1);
544
545 let mut rng = if let Some(seed) = config.seed {
546 use rand::SeedableRng;
547 rand_chacha::ChaCha8Rng::seed_from_u64(seed)
548 } else {
549 use rand::SeedableRng;
550 rand_chacha::ChaCha8Rng::from_entropy()
551 };
552
553 use rand::seq::SliceRandom;
554 let mut indices: Vec<usize> = (0..num_documents).collect();
555 indices.shuffle(&mut rng);
556 let sample_indices: Vec<usize> = indices.into_iter().take(sample_count).collect();
557
558 let heldout_size = (0.05 * total_embeddings as f64).min(50000.0) as usize;
560 let mut heldout_embeddings: Vec<f32> = Vec::with_capacity(heldout_size * embedding_dim);
561 let mut collected = 0;
562
563 for &idx in sample_indices.iter().rev() {
564 if collected >= heldout_size {
565 break;
566 }
567 let emb = &embeddings[idx];
568 let take = (heldout_size - collected).min(emb.nrows());
569 for row in emb.axis_iter(Axis(0)).take(take) {
570 heldout_embeddings.extend(row.iter());
571 }
572 collected += take;
573 }
574
575 let heldout = Array2::from_shape_vec((collected, embedding_dim), heldout_embeddings)
576 .map_err(|e| Error::IndexCreation(format!("Failed to create heldout array: {}", e)))?;
577
578 let avg_residual = Array1::zeros(embedding_dim);
580 let initial_codec =
581 ResidualCodec::new(config.nbits, centroids.clone(), avg_residual, None, None)?;
582
583 let heldout_codes = if config.force_cpu {
586 initial_codec.compress_into_codes_cpu(&heldout)
587 } else {
588 initial_codec.compress_into_codes(&heldout)
589 };
590
591 let mut residuals = heldout.clone();
593 for i in 0..heldout.nrows() {
594 let centroid = initial_codec.centroids.row(heldout_codes[i]);
595 for j in 0..embedding_dim {
596 residuals[[i, j]] -= centroid[j];
597 }
598 }
599
600 let distances: Array1<f32> = residuals
602 .axis_iter(Axis(0))
603 .map(|row| row.dot(&row).sqrt())
604 .collect();
605 #[allow(unused_variables)]
606 let cluster_threshold = quantile(&distances, 0.75);
607
608 let avg_res_per_dim: Array1<f32> = residuals
610 .axis_iter(Axis(1))
611 .map(|col| col.iter().map(|x| x.abs()).sum::<f32>() / col.len() as f32)
612 .collect();
613
614 let n_options = 1 << config.nbits;
616 let quantile_values: Vec<f64> = (1..n_options)
617 .map(|i| i as f64 / n_options as f64)
618 .collect();
619 let weight_quantile_values: Vec<f64> = (0..n_options)
620 .map(|i| (i as f64 + 0.5) / n_options as f64)
621 .collect();
622
623 let flat_residuals: Array1<f32> = residuals.iter().copied().collect();
625 let bucket_cutoffs = Array1::from_vec(quantiles(&flat_residuals, &quantile_values));
626 let bucket_weights = Array1::from_vec(quantiles(&flat_residuals, &weight_quantile_values));
627
628 let codec = ResidualCodec::new(
629 config.nbits,
630 centroids.clone(),
631 avg_res_per_dim.clone(),
632 Some(bucket_cutoffs.clone()),
633 Some(bucket_weights.clone()),
634 )?;
635
636 use ndarray_npy::WriteNpyExt;
638
639 let centroids_path = index_dir.join("centroids.npy");
640 codec
641 .centroids_view()
642 .to_owned()
643 .write_npy(File::create(¢roids_path)?)?;
644
645 let cutoffs_path = index_dir.join("bucket_cutoffs.npy");
646 bucket_cutoffs.write_npy(File::create(&cutoffs_path)?)?;
647
648 let weights_path = index_dir.join("bucket_weights.npy");
649 bucket_weights.write_npy(File::create(&weights_path)?)?;
650
651 let avg_res_path = index_dir.join("avg_residual.npy");
652 avg_res_per_dim.write_npy(File::create(&avg_res_path)?)?;
653
654 let threshold_path = index_dir.join("cluster_threshold.npy");
655 Array1::from_vec(vec![cluster_threshold]).write_npy(File::create(&threshold_path)?)?;
656
657 let n_chunks = (num_documents as f64 / config.batch_size as f64).ceil() as usize;
659
660 let plan_path = index_dir.join("plan.json");
662 let plan = serde_json::json!({
663 "nbits": config.nbits,
664 "num_chunks": n_chunks,
665 });
666 let mut plan_file = File::create(&plan_path)?;
667 writeln!(plan_file, "{}", serde_json::to_string_pretty(&plan)?)?;
668
669 let mut all_codes: Vec<usize> = Vec::with_capacity(total_embeddings);
670 let mut doc_lengths: Vec<i64> = Vec::with_capacity(num_documents);
671
672 for chunk_idx in 0..n_chunks {
673 let start = chunk_idx * config.batch_size;
674 let end = (start + config.batch_size).min(num_documents);
675 let chunk_docs = &embeddings[start..end];
676
677 let chunk_doclens: Vec<i64> = chunk_docs.iter().map(|d| d.nrows() as i64).collect();
679 let total_tokens: usize = chunk_doclens.iter().sum::<i64>() as usize;
680
681 let mut batch_embeddings = Array2::<f32>::zeros((total_tokens, embedding_dim));
683 let mut offset = 0;
684 for doc in chunk_docs {
685 let n = doc.nrows();
686 batch_embeddings
687 .slice_mut(s![offset..offset + n, ..])
688 .assign(doc);
689 offset += n;
690 }
691
692 let (batch_codes, batch_residuals) = {
695 #[cfg(feature = "cuda")]
696 {
697 let force_gpu = crate::is_force_gpu();
698 if !config.force_cpu {
699 if let Some(ctx) = crate::cuda::get_global_context() {
700 match crate::cuda::compress_and_residuals_cuda_batched(
701 &ctx,
702 &batch_embeddings.view(),
703 &codec.centroids_view(),
704 None,
705 ) {
706 Ok(result) => result,
707 Err(e) => {
708 if force_gpu {
709 panic!("FORCE_GPU is set but CUDA compress_and_residuals failed: {}", e);
710 }
711 eprintln!(
712 "[next-plaid] CUDA compress_and_residuals failed: {}, falling back to CPU",
713 e
714 );
715 compress_and_residuals_cpu(&batch_embeddings, &codec)
716 }
717 }
718 } else if force_gpu {
719 panic!("FORCE_GPU is set but CUDA context is unavailable");
720 } else {
721 compress_and_residuals_cpu(&batch_embeddings, &codec)
722 }
723 } else {
724 compress_and_residuals_cpu(&batch_embeddings, &codec)
725 }
726 }
727 #[cfg(not(feature = "cuda"))]
728 {
729 compress_and_residuals_cpu(&batch_embeddings, &codec)
730 }
731 };
732
733 let batch_packed = codec.quantize_residuals(&batch_residuals)?;
735
736 for &len in &chunk_doclens {
738 doc_lengths.push(len);
739 }
740 all_codes.extend(batch_codes.iter().copied());
741
742 let chunk_meta = ChunkMetadata {
744 num_documents: end - start,
745 num_embeddings: batch_codes.len(),
746 embedding_offset: 0, };
748
749 let chunk_meta_path = index_dir.join(format!("{}.metadata.json", chunk_idx));
750 serde_json::to_writer_pretty(BufWriter::new(File::create(&chunk_meta_path)?), &chunk_meta)?;
751
752 let doclens_path = index_dir.join(format!("doclens.{}.json", chunk_idx));
754 serde_json::to_writer(BufWriter::new(File::create(&doclens_path)?), &chunk_doclens)?;
755
756 let chunk_codes_arr: Array1<i64> = batch_codes.iter().map(|&x| x as i64).collect();
758 let codes_path = index_dir.join(format!("{}.codes.npy", chunk_idx));
759 chunk_codes_arr.write_npy(File::create(&codes_path)?)?;
760
761 let residuals_path = index_dir.join(format!("{}.residuals.npy", chunk_idx));
763 batch_packed.write_npy(File::create(&residuals_path)?)?;
764 }
765
766 let mut current_offset = 0usize;
768 for chunk_idx in 0..n_chunks {
769 let chunk_meta_path = index_dir.join(format!("{}.metadata.json", chunk_idx));
770 let mut meta: serde_json::Value =
771 serde_json::from_reader(BufReader::new(File::open(&chunk_meta_path)?))?;
772
773 if let Some(obj) = meta.as_object_mut() {
774 obj.insert("embedding_offset".to_string(), current_offset.into());
775 let num_emb = obj["num_embeddings"].as_u64().unwrap_or(0) as usize;
776 current_offset += num_emb;
777 }
778
779 serde_json::to_writer_pretty(BufWriter::new(File::create(&chunk_meta_path)?), &meta)?;
780 }
781
782 let mut code_to_docs: BTreeMap<usize, Vec<i64>> = BTreeMap::new();
784 let mut emb_idx = 0;
785
786 for (doc_id, &len) in doc_lengths.iter().enumerate() {
787 for _ in 0..len {
788 let code = all_codes[emb_idx];
789 code_to_docs.entry(code).or_default().push(doc_id as i64);
790 emb_idx += 1;
791 }
792 }
793
794 let mut ivf_data: Vec<i64> = Vec::new();
796 let mut ivf_lengths: Vec<i32> = vec![0; num_centroids];
797
798 for (centroid_id, ivf_len) in ivf_lengths.iter_mut().enumerate() {
799 if let Some(docs) = code_to_docs.get(¢roid_id) {
800 let mut unique_docs: Vec<i64> = docs.clone();
801 unique_docs.sort_unstable();
802 unique_docs.dedup();
803 *ivf_len = unique_docs.len() as i32;
804 ivf_data.extend(unique_docs);
805 }
806 }
807
808 let ivf = Array1::from_vec(ivf_data);
809 let ivf_lengths = Array1::from_vec(ivf_lengths);
810
811 let ivf_path = index_dir.join("ivf.npy");
812 ivf.write_npy(File::create(&ivf_path)?)?;
813
814 let ivf_lengths_path = index_dir.join("ivf_lengths.npy");
815 ivf_lengths.write_npy(File::create(&ivf_lengths_path)?)?;
816
817 let metadata = Metadata {
819 num_chunks: n_chunks,
820 nbits: config.nbits,
821 num_partitions: num_centroids,
822 num_embeddings: total_embeddings,
823 avg_doclen,
824 num_documents,
825 embedding_dim,
826 next_plaid_compatible: true, };
828
829 let metadata_path = index_dir.join("metadata.json");
830 serde_json::to_writer_pretty(BufWriter::new(File::create(&metadata_path)?), &metadata)?;
831
832 Ok(metadata)
833}
834
835pub fn create_index_with_kmeans_files(
850 embeddings: &[Array2<f32>],
851 index_path: &str,
852 config: &IndexConfig,
853) -> Result<Metadata> {
854 if embeddings.is_empty() {
855 return Err(Error::IndexCreation("No documents provided".into()));
856 }
857
858 #[cfg(feature = "cuda")]
861 if !config.force_cpu {
862 if crate::is_force_gpu() {
863 crate::cuda::get_global_context()
864 .expect("FORCE_GPU is set but CUDA context failed to initialize");
865 } else {
866 let _ = crate::cuda::get_global_context();
867 }
868 }
869
870 let kmeans_config = ComputeKmeansConfig {
872 kmeans_niters: config.kmeans_niters,
873 max_points_per_centroid: config.max_points_per_centroid,
874 seed: config.seed.unwrap_or(42),
875 n_samples_kmeans: config.n_samples_kmeans,
876 num_partitions: None, force_cpu: config.force_cpu,
878 };
879
880 let centroids = compute_kmeans(embeddings, &kmeans_config)?;
882
883 let metadata = create_index_files(embeddings, centroids, index_path, config)?;
885
886 if embeddings.len() <= config.start_from_scratch {
888 let index_dir = std::path::Path::new(index_path);
889 crate::update::save_embeddings_npy(index_dir, embeddings)?;
890 }
891
892 Ok(metadata)
893}
894pub struct MmapIndex {
918 pub path: String,
920 pub metadata: Metadata,
922 pub codec: ResidualCodec,
924 pub ivf: Array1<i64>,
926 pub ivf_lengths: Array1<i32>,
928 pub ivf_offsets: Array1<i64>,
930 pub doc_lengths: Array1<i64>,
932 pub doc_offsets: Array1<usize>,
934 pub mmap_codes: crate::mmap::MmapNpyArray1I64,
936 pub mmap_residuals: crate::mmap::MmapNpyArray2U8,
938}
939
940impl MmapIndex {
941 pub fn load(index_path: &str) -> Result<Self> {
949 use ndarray_npy::ReadNpyExt;
950
951 let index_dir = Path::new(index_path);
952
953 let mut metadata = Metadata::load_from_path(index_dir)?;
955
956 if !metadata.next_plaid_compatible {
958 eprintln!("Checking index format compatibility...");
959 let converted = crate::mmap::convert_fastplaid_to_nextplaid(index_dir)?;
960 if converted {
961 eprintln!("Index converted to next-plaid compatible format.");
962 let merged_codes = index_dir.join("merged_codes.npy");
964 let merged_residuals = index_dir.join("merged_residuals.npy");
965 let codes_manifest = index_dir.join("merged_codes.manifest.json");
966 let residuals_manifest = index_dir.join("merged_residuals.manifest.json");
967 for path in [
968 &merged_codes,
969 &merged_residuals,
970 &codes_manifest,
971 &residuals_manifest,
972 ] {
973 if path.exists() {
974 let _ = fs::remove_file(path);
975 }
976 }
977 }
978
979 metadata.next_plaid_compatible = true;
981 let metadata_path = index_dir.join("metadata.json");
982 let file = File::create(&metadata_path)
983 .map_err(|e| Error::IndexLoad(format!("Failed to update metadata: {}", e)))?;
984 serde_json::to_writer_pretty(BufWriter::new(file), &metadata)?;
985 eprintln!("Metadata updated with next_plaid_compatible: true");
986 }
987
988 let codec = ResidualCodec::load_mmap_from_dir(index_dir)?;
991
992 let ivf_path = index_dir.join("ivf.npy");
994 let ivf: Array1<i64> = Array1::read_npy(
995 File::open(&ivf_path)
996 .map_err(|e| Error::IndexLoad(format!("Failed to open ivf.npy: {}", e)))?,
997 )
998 .map_err(|e| Error::IndexLoad(format!("Failed to read ivf.npy: {}", e)))?;
999
1000 let ivf_lengths_path = index_dir.join("ivf_lengths.npy");
1001 let ivf_lengths: Array1<i32> = Array1::read_npy(
1002 File::open(&ivf_lengths_path)
1003 .map_err(|e| Error::IndexLoad(format!("Failed to open ivf_lengths.npy: {}", e)))?,
1004 )
1005 .map_err(|e| Error::IndexLoad(format!("Failed to read ivf_lengths.npy: {}", e)))?;
1006
1007 let num_centroids = ivf_lengths.len();
1009 let mut ivf_offsets = Array1::<i64>::zeros(num_centroids + 1);
1010 for i in 0..num_centroids {
1011 ivf_offsets[i + 1] = ivf_offsets[i] + ivf_lengths[i] as i64;
1012 }
1013
1014 let mut doc_lengths_vec: Vec<i64> = Vec::with_capacity(metadata.num_documents);
1016 for chunk_idx in 0..metadata.num_chunks {
1017 let doclens_path = index_dir.join(format!("doclens.{}.json", chunk_idx));
1018 let chunk_doclens: Vec<i64> =
1019 serde_json::from_reader(BufReader::new(File::open(&doclens_path)?))?;
1020 doc_lengths_vec.extend(chunk_doclens);
1021 }
1022 let doc_lengths = Array1::from_vec(doc_lengths_vec);
1023
1024 let mut doc_offsets = Array1::<usize>::zeros(doc_lengths.len() + 1);
1026 for i in 0..doc_lengths.len() {
1027 doc_offsets[i + 1] = doc_offsets[i] + doc_lengths[i] as usize;
1028 }
1029
1030 let max_len = doc_lengths.iter().cloned().max().unwrap_or(0) as usize;
1032 let last_len = *doc_lengths.last().unwrap_or(&0) as usize;
1033 let padding_needed = max_len.saturating_sub(last_len);
1034
1035 let merged_codes_path =
1036 crate::mmap::merge_codes_chunks(index_dir, metadata.num_chunks, padding_needed)?;
1037 let merged_residuals_path =
1038 crate::mmap::merge_residuals_chunks(index_dir, metadata.num_chunks, padding_needed)?;
1039
1040 let (mmap_codes, mmap_residuals) = (
1041 crate::mmap::MmapNpyArray1I64::from_npy_file(&merged_codes_path)?,
1042 crate::mmap::MmapNpyArray2U8::from_npy_file(&merged_residuals_path)?,
1043 );
1044
1045 Ok(Self {
1046 path: index_path.to_string(),
1047 metadata,
1048 codec,
1049 ivf,
1050 ivf_lengths,
1051 ivf_offsets,
1052 doc_lengths,
1053 doc_offsets,
1054 mmap_codes,
1055 mmap_residuals,
1056 })
1057 }
1058
1059 pub fn get_candidates(&self, centroid_indices: &[usize]) -> Vec<i64> {
1061 let mut candidates: Vec<i64> = Vec::new();
1062
1063 for &idx in centroid_indices {
1064 if idx < self.ivf_lengths.len() {
1065 let start = self.ivf_offsets[idx] as usize;
1066 let len = self.ivf_lengths[idx] as usize;
1067 candidates.extend(self.ivf.slice(s![start..start + len]).iter());
1068 }
1069 }
1070
1071 candidates.sort_unstable();
1072 candidates.dedup();
1073 candidates
1074 }
1075
1076 pub fn get_document_embeddings(&self, doc_id: usize) -> Result<Array2<f32>> {
1078 if doc_id >= self.doc_lengths.len() {
1079 return Err(Error::Search(format!("Invalid document ID: {}", doc_id)));
1080 }
1081
1082 let start = self.doc_offsets[doc_id];
1083 let end = self.doc_offsets[doc_id + 1];
1084
1085 let codes_slice = self.mmap_codes.slice(start, end);
1087 let residuals_view = self.mmap_residuals.slice_rows(start, end);
1088
1089 let codes: Array1<usize> = Array1::from_iter(codes_slice.iter().map(|&c| c as usize));
1091
1092 let residuals = residuals_view.to_owned();
1094
1095 self.codec.decompress(&residuals, &codes.view())
1097 }
1098
1099 pub fn get_document_codes(&self, doc_ids: &[usize]) -> Vec<Vec<i64>> {
1101 doc_ids
1102 .iter()
1103 .map(|&doc_id| {
1104 if doc_id >= self.doc_lengths.len() {
1105 return vec![];
1106 }
1107 let start = self.doc_offsets[doc_id];
1108 let end = self.doc_offsets[doc_id + 1];
1109 self.mmap_codes.slice(start, end).to_vec()
1110 })
1111 .collect()
1112 }
1113
1114 pub fn decompress_documents(&self, doc_ids: &[usize]) -> Result<(Array2<f32>, Vec<usize>)> {
1116 let mut total_tokens = 0usize;
1118 let mut lengths = Vec::with_capacity(doc_ids.len());
1119 for &doc_id in doc_ids {
1120 if doc_id >= self.doc_lengths.len() {
1121 lengths.push(0);
1122 } else {
1123 let len = self.doc_offsets[doc_id + 1] - self.doc_offsets[doc_id];
1124 lengths.push(len);
1125 total_tokens += len;
1126 }
1127 }
1128
1129 if total_tokens == 0 {
1130 return Ok((Array2::zeros((0, self.codec.embedding_dim())), lengths));
1131 }
1132
1133 let packed_dim = self.mmap_residuals.ncols();
1135 let mut all_codes = Vec::with_capacity(total_tokens);
1136 let mut all_residuals = Array2::<u8>::zeros((total_tokens, packed_dim));
1137 let mut offset = 0;
1138
1139 for &doc_id in doc_ids {
1140 if doc_id >= self.doc_lengths.len() {
1141 continue;
1142 }
1143 let start = self.doc_offsets[doc_id];
1144 let end = self.doc_offsets[doc_id + 1];
1145 let len = end - start;
1146
1147 let codes_slice = self.mmap_codes.slice(start, end);
1149 all_codes.extend(codes_slice.iter().map(|&c| c as usize));
1150
1151 let residuals_view = self.mmap_residuals.slice_rows(start, end);
1153 all_residuals
1154 .slice_mut(s![offset..offset + len, ..])
1155 .assign(&residuals_view);
1156 offset += len;
1157 }
1158
1159 let codes_arr = Array1::from_vec(all_codes);
1160 let embeddings = self.codec.decompress(&all_residuals, &codes_arr.view())?;
1161
1162 Ok((embeddings, lengths))
1163 }
1164
1165 pub fn search(
1177 &self,
1178 query: &Array2<f32>,
1179 params: &crate::search::SearchParameters,
1180 subset: Option<&[i64]>,
1181 ) -> Result<crate::search::SearchResult> {
1182 crate::search::search_one_mmap(self, query, params, subset)
1183 }
1184
1185 pub fn search_batch(
1198 &self,
1199 queries: &[Array2<f32>],
1200 params: &crate::search::SearchParameters,
1201 parallel: bool,
1202 subset: Option<&[i64]>,
1203 ) -> Result<Vec<crate::search::SearchResult>> {
1204 crate::search::search_many_mmap(self, queries, params, parallel, subset)
1205 }
1206
1207 pub fn num_documents(&self) -> usize {
1209 self.doc_lengths.len()
1210 }
1211
1212 pub fn num_embeddings(&self) -> usize {
1214 self.metadata.num_embeddings
1215 }
1216
1217 pub fn num_partitions(&self) -> usize {
1219 self.metadata.num_partitions
1220 }
1221
1222 pub fn avg_doclen(&self) -> f64 {
1224 self.metadata.avg_doclen
1225 }
1226
1227 pub fn embedding_dim(&self) -> usize {
1229 self.codec.embedding_dim()
1230 }
1231
1232 fn release_mmaps(&mut self) {
1242 self.mmap_codes = crate::mmap::MmapNpyArray1I64::empty();
1243 self.mmap_residuals = crate::mmap::MmapNpyArray2U8::empty();
1244 self.codec.centroids = crate::codec::CentroidStore::Owned(Array2::zeros((0, 0)));
1245 }
1246
1247 pub fn reconstruct(&self, doc_ids: &[i64]) -> Result<Vec<Array2<f32>>> {
1273 crate::embeddings::reconstruct_embeddings(self, doc_ids)
1274 }
1275
1276 pub fn reconstruct_single(&self, doc_id: i64) -> Result<Array2<f32>> {
1288 crate::embeddings::reconstruct_single(self, doc_id)
1289 }
1290
1291 pub fn create_with_kmeans(
1311 embeddings: &[Array2<f32>],
1312 index_path: &str,
1313 config: &IndexConfig,
1314 ) -> Result<Self> {
1315 create_index_with_kmeans_files(embeddings, index_path, config)?;
1317
1318 Self::load(index_path)
1320 }
1321
1322 pub fn update(
1350 &mut self,
1351 embeddings: &[Array2<f32>],
1352 config: &crate::update::UpdateConfig,
1353 ) -> Result<Vec<i64>> {
1354 use crate::codec::ResidualCodec;
1355 use crate::update::{
1356 clear_buffer, clear_embeddings_npy, embeddings_npy_exists, load_buffer,
1357 load_buffer_info, load_cluster_threshold, load_embeddings_npy, save_buffer,
1358 update_centroids, update_index,
1359 };
1360
1361 let path_str = self.path.clone();
1362 let index_path = std::path::Path::new(&path_str);
1363 let num_new_docs = embeddings.len();
1364
1365 self.release_mmaps();
1370
1371 if self.metadata.num_documents <= config.start_from_scratch {
1375 let existing_embeddings = load_embeddings_npy(index_path)?;
1377
1378 if existing_embeddings.len() == self.metadata.num_documents {
1383 let start_doc_id = existing_embeddings.len() as i64;
1385
1386 let combined_embeddings: Vec<Array2<f32>> = existing_embeddings
1388 .into_iter()
1389 .chain(embeddings.iter().cloned())
1390 .collect();
1391
1392 let index_config = IndexConfig {
1394 nbits: self.metadata.nbits,
1395 batch_size: config.batch_size,
1396 seed: Some(config.seed),
1397 kmeans_niters: config.kmeans_niters,
1398 max_points_per_centroid: config.max_points_per_centroid,
1399 n_samples_kmeans: config.n_samples_kmeans,
1400 start_from_scratch: config.start_from_scratch,
1401 force_cpu: config.force_cpu,
1402 ..Default::default()
1403 };
1404
1405 *self = Self::create_with_kmeans(&combined_embeddings, &path_str, &index_config)?;
1407
1408 if combined_embeddings.len() > config.start_from_scratch
1410 && embeddings_npy_exists(index_path)
1411 {
1412 clear_embeddings_npy(index_path)?;
1413 }
1414
1415 return Ok((start_doc_id..start_doc_id + num_new_docs as i64).collect());
1417 }
1418 }
1420
1421 let buffer = load_buffer(index_path)?;
1423 let buffer_len = buffer.len();
1424 let total_new = embeddings.len() + buffer_len;
1425
1426 let start_doc_id: i64;
1428
1429 let mut codec = ResidualCodec::load_from_dir(index_path)?;
1431
1432 if total_new >= config.buffer_size {
1434 let num_buffered = load_buffer_info(index_path)?;
1438
1439 if num_buffered > 0 && self.metadata.num_documents >= num_buffered {
1441 let start_del_idx = self.metadata.num_documents - num_buffered;
1442 let docs_to_delete: Vec<i64> = (start_del_idx..self.metadata.num_documents)
1443 .map(|i| i as i64)
1444 .collect();
1445 crate::delete::delete_from_index_keep_buffer(&docs_to_delete, &path_str)?;
1446 self.metadata = Metadata::load_from_path(index_path)?;
1448 }
1449
1450 start_doc_id = (self.metadata.num_documents + buffer_len) as i64;
1452
1453 let combined: Vec<Array2<f32>> = buffer
1455 .into_iter()
1456 .chain(embeddings.iter().cloned())
1457 .collect();
1458
1459 if let Ok(cluster_threshold) = load_cluster_threshold(index_path) {
1461 let new_centroids =
1462 update_centroids(index_path, &combined, cluster_threshold, config)?;
1463 if new_centroids > 0 {
1464 codec = ResidualCodec::load_from_dir(index_path)?;
1466 }
1467 }
1468
1469 clear_buffer(index_path)?;
1471
1472 update_index(
1474 &combined,
1475 &path_str,
1476 &codec,
1477 Some(config.batch_size),
1478 true,
1479 config.force_cpu,
1480 )?;
1481 } else {
1482 start_doc_id = self.metadata.num_documents as i64;
1485
1486 let combined_buffer: Vec<Array2<f32>> = buffer
1488 .into_iter()
1489 .chain(embeddings.iter().cloned())
1490 .collect();
1491 save_buffer(index_path, &combined_buffer)?;
1492
1493 update_index(
1495 embeddings,
1496 &path_str,
1497 &codec,
1498 Some(config.batch_size),
1499 false,
1500 config.force_cpu,
1501 )?;
1502 }
1503
1504 *self = Self::load(&path_str)?;
1506
1507 Ok((start_doc_id..start_doc_id + num_new_docs as i64).collect())
1509 }
1510
1511 pub fn update_with_metadata(
1523 &mut self,
1524 embeddings: &[Array2<f32>],
1525 config: &crate::update::UpdateConfig,
1526 metadata: Option<&[serde_json::Value]>,
1527 ) -> Result<Vec<i64>> {
1528 if let Some(meta) = metadata {
1530 if meta.len() != embeddings.len() {
1531 return Err(Error::Config(format!(
1532 "Metadata length ({}) must match embeddings length ({})",
1533 meta.len(),
1534 embeddings.len()
1535 )));
1536 }
1537 }
1538
1539 let doc_ids = self.update(embeddings, config)?;
1541
1542 if let Some(meta) = metadata {
1544 crate::filtering::update(&self.path, meta, &doc_ids)?;
1545 }
1546
1547 Ok(doc_ids)
1548 }
1549
1550 pub fn update_or_create(
1563 embeddings: &[Array2<f32>],
1564 index_path: &str,
1565 index_config: &IndexConfig,
1566 update_config: &crate::update::UpdateConfig,
1567 ) -> Result<(Self, Vec<i64>)> {
1568 let index_dir = std::path::Path::new(index_path);
1569 let metadata_path = index_dir.join("metadata.json");
1570
1571 if metadata_path.exists() {
1572 let mut index = Self::load(index_path)?;
1574 let doc_ids = index.update(embeddings, update_config)?;
1575 Ok((index, doc_ids))
1576 } else {
1577 let num_docs = embeddings.len();
1579 let index = Self::create_with_kmeans(embeddings, index_path, index_config)?;
1580 let doc_ids: Vec<i64> = (0..num_docs as i64).collect();
1581 Ok((index, doc_ids))
1582 }
1583 }
1584
1585 pub fn update_or_create_with_metadata(
1604 embeddings: &[Array2<f32>],
1605 index_path: &str,
1606 index_config: &IndexConfig,
1607 update_config: &crate::update::UpdateConfig,
1608 metadata: Option<&[serde_json::Value]>,
1609 ) -> Result<(Self, Vec<i64>)> {
1610 if let Some(meta) = metadata {
1611 if meta.len() != embeddings.len() {
1612 return Err(Error::Config(format!(
1613 "Metadata length ({}) must match embeddings length ({})",
1614 meta.len(),
1615 embeddings.len()
1616 )));
1617 }
1618 }
1619
1620 let index_dir = std::path::Path::new(index_path);
1621 let metadata_json_path = index_dir.join("metadata.json");
1622
1623 let (index, doc_ids) = if metadata_json_path.exists() {
1624 let mut index = Self::load(index_path)?;
1625 let doc_ids = index.update(embeddings, update_config)?;
1626 (index, doc_ids)
1627 } else {
1628 let num_docs = embeddings.len();
1629 let index = Self::create_with_kmeans(embeddings, index_path, index_config)?;
1630 let doc_ids: Vec<i64> = (0..num_docs as i64).collect();
1631 (index, doc_ids)
1632 };
1633
1634 if let Some(meta) = metadata {
1635 if crate::filtering::exists(index_path) {
1636 crate::filtering::update(index_path, meta, &doc_ids)?;
1637 } else {
1638 crate::filtering::create(index_path, meta, &doc_ids)?;
1639 }
1640 crate::text_search::index(index_path, meta, &doc_ids, &index_config.fts_tokenizer)?;
1642 }
1643
1644 Ok((index, doc_ids))
1645 }
1646
1647 pub fn reload(&mut self) -> Result<()> {
1652 let path = self.path.clone();
1653 self.release_mmaps();
1656 *self = Self::load(&path)?;
1657 Ok(())
1658 }
1659
1660 pub fn delete(&mut self, doc_ids: &[i64]) -> Result<usize> {
1673 self.delete_with_options(doc_ids, true)
1674 }
1675
1676 pub fn delete_with_options(&mut self, doc_ids: &[i64], delete_metadata: bool) -> Result<usize> {
1690 let path = self.path.clone();
1691
1692 self.release_mmaps();
1696
1697 let deleted = crate::delete::delete_from_index(doc_ids, &path)?;
1699
1700 if delete_metadata && deleted > 0 {
1702 let index_path = std::path::Path::new(&path);
1703 let db_path = index_path.join("metadata.db");
1704 if db_path.exists() {
1705 crate::filtering::delete(&path, doc_ids)?;
1706 crate::text_search::rebuild(&path)?;
1708 }
1709 }
1710
1711 Ok(deleted)
1712 }
1713}
1714
1715#[cfg(test)]
1716mod tests {
1717 use super::*;
1718
1719 #[test]
1720 fn test_index_config_default() {
1721 let config = IndexConfig::default();
1722 assert_eq!(config.nbits, 4);
1723 assert_eq!(config.batch_size, 50_000);
1724 assert_eq!(config.seed, Some(42));
1725 }
1726
1727 #[test]
1728 fn test_update_or_create_new_index() {
1729 use ndarray::Array2;
1730 use tempfile::tempdir;
1731
1732 let temp_dir = tempdir().unwrap();
1733 let index_path = temp_dir.path().to_str().unwrap();
1734
1735 let mut embeddings: Vec<Array2<f32>> = Vec::new();
1737 for i in 0..5 {
1738 let mut doc = Array2::<f32>::zeros((5, 32));
1739 for j in 0..5 {
1740 for k in 0..32 {
1741 doc[[j, k]] = (i as f32 * 0.1) + (j as f32 * 0.01) + (k as f32 * 0.001);
1742 }
1743 }
1744 for mut row in doc.rows_mut() {
1746 let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
1747 if norm > 0.0 {
1748 row.iter_mut().for_each(|x| *x /= norm);
1749 }
1750 }
1751 embeddings.push(doc);
1752 }
1753
1754 let index_config = IndexConfig {
1755 nbits: 2,
1756 batch_size: 50,
1757 seed: Some(42),
1758 kmeans_niters: 2,
1759 ..Default::default()
1760 };
1761 let update_config = crate::update::UpdateConfig::default();
1762
1763 let (index, doc_ids) =
1765 MmapIndex::update_or_create(&embeddings, index_path, &index_config, &update_config)
1766 .expect("Failed to create index");
1767
1768 assert_eq!(index.metadata.num_documents, 5);
1769 assert_eq!(doc_ids, vec![0, 1, 2, 3, 4]);
1770
1771 assert!(temp_dir.path().join("metadata.json").exists());
1773 assert!(temp_dir.path().join("centroids.npy").exists());
1774 }
1775
1776 #[test]
1777 fn test_update_or_create_existing_index() {
1778 use ndarray::Array2;
1779 use tempfile::tempdir;
1780
1781 let temp_dir = tempdir().unwrap();
1782 let index_path = temp_dir.path().to_str().unwrap();
1783
1784 let create_embeddings = |count: usize, offset: usize| -> Vec<Array2<f32>> {
1786 let mut embeddings = Vec::new();
1787 for i in 0..count {
1788 let mut doc = Array2::<f32>::zeros((5, 32));
1789 for j in 0..5 {
1790 for k in 0..32 {
1791 doc[[j, k]] =
1792 ((i + offset) as f32 * 0.1) + (j as f32 * 0.01) + (k as f32 * 0.001);
1793 }
1794 }
1795 for mut row in doc.rows_mut() {
1796 let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
1797 if norm > 0.0 {
1798 row.iter_mut().for_each(|x| *x /= norm);
1799 }
1800 }
1801 embeddings.push(doc);
1802 }
1803 embeddings
1804 };
1805
1806 let index_config = IndexConfig {
1807 nbits: 2,
1808 batch_size: 50,
1809 seed: Some(42),
1810 kmeans_niters: 2,
1811 ..Default::default()
1812 };
1813 let update_config = crate::update::UpdateConfig::default();
1814
1815 let embeddings1 = create_embeddings(5, 0);
1817 let (index1, doc_ids1) =
1818 MmapIndex::update_or_create(&embeddings1, index_path, &index_config, &update_config)
1819 .expect("Failed to create index");
1820 assert_eq!(index1.metadata.num_documents, 5);
1821 assert_eq!(doc_ids1, vec![0, 1, 2, 3, 4]);
1822
1823 drop(index1);
1826
1827 let embeddings2 = create_embeddings(3, 5);
1829 let (index2, doc_ids2) =
1830 MmapIndex::update_or_create(&embeddings2, index_path, &index_config, &update_config)
1831 .expect("Failed to update index");
1832 assert_eq!(index2.metadata.num_documents, 8);
1833 assert_eq!(doc_ids2, vec![5, 6, 7]);
1834 }
1835}