1use crate::multivector::{
12 codec::ResidualCodec,
13 search::{CandidateScorer, CentroidSelector, ScoreMerger},
14 types::{MultiVectorEmbedding, WarpIndexConfig, WarpSearchConfig},
15};
16use crate::{Chunk, ChunkId, Result};
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct WarpIndex {
45 config: WarpIndexConfig,
47 codec: Option<ResidualCodec>,
49 sizes: Vec<usize>,
51 offsets: Vec<usize>,
53 chunk_ids: Vec<ChunkId>,
55 token_indices: Vec<u16>,
57 residuals: Vec<u8>,
59 #[serde(skip)]
61 chunks: HashMap<ChunkId, Chunk>,
62 #[serde(skip)]
64 pending: Vec<(ChunkId, MultiVectorEmbedding)>,
65 is_built: bool,
67}
68
69impl WarpIndex {
70 #[must_use]
72 pub fn new(config: WarpIndexConfig) -> Self {
73 Self {
74 config,
75 codec: None,
76 sizes: Vec::new(),
77 offsets: Vec::new(),
78 chunk_ids: Vec::new(),
79 token_indices: Vec::new(),
80 residuals: Vec::new(),
81 chunks: HashMap::new(),
82 pending: Vec::new(),
83 is_built: false,
84 }
85 }
86
87 #[must_use]
89 pub fn config(&self) -> &WarpIndexConfig {
90 &self.config
91 }
92
93 #[must_use]
95 pub fn codec(&self) -> Option<&ResidualCodec> {
96 self.codec.as_ref()
97 }
98
99 #[must_use]
101 pub fn is_trained(&self) -> bool {
102 self.codec.is_some()
103 }
104
105 #[must_use]
107 pub fn is_built(&self) -> bool {
108 self.is_built
109 }
110
111 #[must_use]
113 pub fn num_chunks(&self) -> usize {
114 self.chunks.len()
115 }
116
117 #[must_use]
119 pub fn num_tokens(&self) -> usize {
120 self.chunk_ids.len()
121 }
122
123 #[must_use]
125 pub fn is_empty(&self) -> bool {
126 self.chunks.is_empty()
127 }
128
129 #[must_use]
131 pub fn get_chunk(&self, id: &ChunkId) -> Option<&Chunk> {
132 self.chunks.get(id)
133 }
134
135 #[must_use]
137 pub fn memory_usage(&self) -> usize {
138 let codec_size = self
139 .codec
140 .as_ref()
141 .map(|c| {
142 c.centroids().len() * 4 + c.dim() * ((1 << c.nbits()) - 1) * 4 + c.dim() * (1 << c.nbits()) * 4 })
146 .unwrap_or(0);
147
148 let index_size = self.chunk_ids.len() * size_of::<ChunkId>()
149 + self.token_indices.len() * size_of::<u16>()
150 + self.residuals.len()
151 + self.sizes.len() * size_of::<usize>()
152 + self.offsets.len() * size_of::<usize>();
153
154 codec_size + index_size
155 }
156
157 pub fn train(&mut self, samples: &[MultiVectorEmbedding]) -> Result<()> {
169 let total_tokens: usize = samples.iter().map(|s| s.num_tokens()).sum();
171 let min_samples = self.config.effective_min_training_samples();
172
173 if total_tokens < min_samples {
174 return Err(crate::Error::InvalidInput(format!(
175 "Insufficient training tokens: {total_tokens} < {min_samples} required"
176 )));
177 }
178
179 let mut all_embeddings = Vec::with_capacity(total_tokens * self.config.token_dim);
181 for sample in samples {
182 all_embeddings.extend_from_slice(sample.as_slice());
183 }
184
185 let codec = ResidualCodec::train(
187 &all_embeddings,
188 self.config.token_dim,
189 self.config.num_centroids,
190 self.config.nbits,
191 self.config.kmeans_iterations,
192 )?;
193
194 self.codec = Some(codec);
195 Ok(())
196 }
197
198 pub fn insert(&mut self, chunk: Chunk, embedding: MultiVectorEmbedding) -> Result<()> {
208 if self.codec.is_none() {
209 return Err(crate::Error::InvalidInput(
210 "Codec not trained - call train() first".to_string(),
211 ));
212 }
213
214 if self.is_built {
215 return Err(crate::Error::InvalidInput(
216 "Index already built - cannot insert".to_string(),
217 ));
218 }
219
220 contract_pre_embedding_lookup!(embedding.as_slice());
222
223 let chunk_id = chunk.id;
224 self.chunks.insert(chunk_id, chunk);
225 self.pending.push((chunk_id, embedding));
226
227 Ok(())
228 }
229
230 pub fn build(&mut self) -> Result<()> {
239 let codec = self.codec.as_ref().ok_or_else(|| {
240 crate::Error::InvalidInput("Codec not trained - call train() first".to_string())
241 })?;
242
243 let mut centroid_assignments: Vec<Vec<(ChunkId, u16, Vec<u8>)>> =
245 vec![Vec::new(); self.config.num_centroids];
246
247 for (chunk_id, embedding) in &self.pending {
248 for (token_idx, token) in embedding.tokens().enumerate() {
249 let (centroid_id, residual) = codec.compress(token);
250 centroid_assignments[centroid_id].push((*chunk_id, token_idx as u16, residual));
251 }
252 }
253
254 let bytes_per_residual = self.config.packed_residual_size();
256
257 self.sizes = centroid_assignments.iter().map(|v| v.len()).collect();
258 self.offsets = self
259 .sizes
260 .iter()
261 .scan(0, |acc, &size| {
262 let offset = *acc;
263 *acc += size;
264 Some(offset)
265 })
266 .collect();
267
268 let total_tokens: usize = self.sizes.iter().sum();
269 self.chunk_ids = Vec::with_capacity(total_tokens);
270 self.token_indices = Vec::with_capacity(total_tokens);
271 self.residuals = Vec::with_capacity(total_tokens * bytes_per_residual);
272
273 for assignments in centroid_assignments {
274 for (chunk_id, token_idx, residual) in assignments {
275 self.chunk_ids.push(chunk_id);
276 self.token_indices.push(token_idx);
277 self.residuals.extend(residual);
278 }
279 }
280
281 self.pending.clear();
282 self.is_built = true;
283
284 Ok(())
285 }
286
287 pub fn clear_index(&mut self) {
292 self.sizes.clear();
293 self.offsets.clear();
294 self.chunk_ids.clear();
295 self.token_indices.clear();
296 self.residuals.clear();
297 self.is_built = false;
298 }
299
300 pub fn search(
315 &self,
316 query: &MultiVectorEmbedding,
317 search_config: &WarpSearchConfig,
318 ) -> Result<Vec<(ChunkId, f32)>> {
319 let codec = self
320 .codec
321 .as_ref()
322 .ok_or_else(|| crate::Error::InvalidInput("Codec not trained".to_string()))?;
323
324 if !self.is_built {
325 return Err(crate::Error::InvalidInput(
326 "Index not built - call build() first".to_string(),
327 ));
328 }
329
330 let selected_centroids = CentroidSelector::select(
332 query,
333 codec.centroids(),
334 self.config.token_dim,
335 search_config,
336 );
337
338 let mut total_centroids = 0;
340 let max_tokens = search_config.t_prime.unwrap_or(usize::MAX);
341 let bounded_centroids: Vec<Vec<(usize, f32)>> = selected_centroids
342 .into_iter()
343 .take(max_tokens)
344 .map(|centroids| {
345 let take =
346 (search_config.bound.saturating_sub(total_centroids)).min(centroids.len());
347 total_centroids += take;
348 centroids.into_iter().take(take).collect()
349 })
350 .collect();
351
352 let bytes_per_residual = self.config.packed_residual_size();
354
355 let token_scores: Vec<Vec<(ChunkId, u16, f32)>> = bounded_centroids
356 .into_iter()
357 .enumerate()
358 .map(|(query_token_idx, centroids)| {
359 let query_token = query.token(query_token_idx);
360
361 centroids
362 .into_iter()
363 .flat_map(|(centroid_id, centroid_score)| {
364 CandidateScorer::score(
365 query_token,
366 centroid_id,
367 centroid_score,
368 codec,
369 &self.sizes,
370 &self.offsets,
371 &self.chunk_ids,
372 &self.token_indices,
373 &self.residuals,
374 bytes_per_residual,
375 )
376 })
377 .collect()
378 })
379 .collect();
380
381 Ok(ScoreMerger::merge(token_scores, search_config.k))
383 }
384
385 #[must_use]
387 pub fn centroid_size(&self, centroid_id: usize) -> usize {
388 self.sizes.get(centroid_id).copied().unwrap_or(0)
389 }
390
391 #[must_use]
393 pub fn centroid_offset(&self, centroid_id: usize) -> usize {
394 self.offsets.get(centroid_id).copied().unwrap_or(0)
395 }
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401 use crate::DocumentId;
402
403 fn create_test_chunk(content: &str) -> Chunk {
404 Chunk::new(DocumentId::new(), content.to_string(), 0, content.len())
405 }
406
407 fn generate_embedding(num_tokens: usize, dim: usize, seed: u64) -> MultiVectorEmbedding {
408 let mut embeddings = Vec::with_capacity(num_tokens * dim);
409 let mut rng = seed;
410
411 for _ in 0..(num_tokens * dim) {
412 rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
413 let val = ((rng >> 33) as f32 / u32::MAX as f32) * 2.0 - 1.0;
414 embeddings.push(val);
415 }
416
417 MultiVectorEmbedding::new(embeddings, num_tokens, dim)
418 }
419
420 #[test]
423 fn test_index_new() {
424 let config = WarpIndexConfig::new(2, 16, 32);
425 let index = WarpIndex::new(config);
426
427 assert!(!index.is_trained());
428 assert!(!index.is_built());
429 assert!(index.is_empty());
430 }
431
432 #[test]
433 fn test_index_config() {
434 let config = WarpIndexConfig::new(4, 32, 64);
435 let index = WarpIndex::new(config);
436
437 assert_eq!(index.config().nbits, 4);
438 assert_eq!(index.config().num_centroids, 32);
439 assert_eq!(index.config().token_dim, 64);
440 }
441
442 #[test]
445 fn test_index_train() {
446 let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
447 let mut index = WarpIndex::new(config);
448
449 let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
451
452 index.train(&samples).unwrap();
453
454 assert!(index.is_trained());
455 assert!(index.codec().is_some());
456 }
457
458 #[test]
459 fn test_index_train_insufficient_samples() {
460 let config = WarpIndexConfig::new(2, 100, 16); let mut index = WarpIndex::new(config);
462
463 let samples: Vec<_> = (0..5).map(|i| generate_embedding(10, 16, i)).collect();
464
465 let result = index.train(&samples);
466 assert!(result.is_err());
467 }
468
469 #[test]
472 fn test_index_insert() {
473 let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
474 let mut index = WarpIndex::new(config);
475
476 let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
478 index.train(&samples).unwrap();
479
480 let chunk = create_test_chunk("test content");
482 let embedding = generate_embedding(5, 16, 999);
483 index.insert(chunk, embedding).unwrap();
484
485 assert_eq!(index.num_chunks(), 1);
486 }
487
488 #[test]
489 fn test_index_insert_without_training() {
490 let config = WarpIndexConfig::new(2, 8, 16);
491 let mut index = WarpIndex::new(config);
492
493 let chunk = create_test_chunk("test");
494 let embedding = generate_embedding(5, 16, 0);
495
496 let result = index.insert(chunk, embedding);
497 assert!(result.is_err());
498 }
499
500 #[test]
503 fn test_index_build() {
504 let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
505 let mut index = WarpIndex::new(config);
506
507 let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
509 index.train(&samples).unwrap();
510
511 for i in 0..10 {
513 let chunk = create_test_chunk(&format!("document {}", i));
514 let embedding = generate_embedding(5, 16, 1000 + i);
515 index.insert(chunk, embedding).unwrap();
516 }
517
518 index.build().unwrap();
520
521 assert!(index.is_built());
522 assert_eq!(index.num_chunks(), 10);
523 assert_eq!(index.num_tokens(), 50); }
525
526 #[test]
527 fn test_index_cannot_insert_after_build() {
528 let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
529 let mut index = WarpIndex::new(config);
530
531 let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
532 index.train(&samples).unwrap();
533
534 let chunk = create_test_chunk("test");
535 let embedding = generate_embedding(5, 16, 0);
536 index.insert(chunk, embedding).unwrap();
537
538 index.build().unwrap();
539
540 let chunk2 = create_test_chunk("test2");
542 let embedding2 = generate_embedding(5, 16, 1);
543 let result = index.insert(chunk2, embedding2);
544
545 assert!(result.is_err());
546 }
547
548 #[test]
551 fn test_index_search() {
552 let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
553 let mut index = WarpIndex::new(config);
554
555 let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
557 index.train(&samples).unwrap();
558
559 for i in 0..20 {
561 let chunk = create_test_chunk(&format!("document {}", i));
562 let embedding = generate_embedding(5, 16, 1000 + i);
563 index.insert(chunk, embedding).unwrap();
564 }
565
566 index.build().unwrap();
568
569 let query = generate_embedding(3, 16, 9999);
571 let search_config = WarpSearchConfig::with_k(5);
572 let results = index.search(&query, &search_config).unwrap();
573
574 assert!(results.len() <= 5);
575 assert!(!results.is_empty());
576
577 for i in 1..results.len() {
579 assert!(results[i - 1].1 >= results[i].1);
580 }
581 }
582
583 #[test]
584 fn test_index_search_without_build() {
585 let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
586 let mut index = WarpIndex::new(config);
587
588 let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
589 index.train(&samples).unwrap();
590
591 let query = generate_embedding(3, 16, 0);
592 let search_config = WarpSearchConfig::with_k(5);
593 let result = index.search(&query, &search_config);
594
595 assert!(result.is_err());
596 }
597
598 #[test]
601 fn test_index_memory_usage() {
602 let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
603 let mut index = WarpIndex::new(config);
604
605 let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
606 index.train(&samples).unwrap();
607
608 for i in 0..10 {
609 let chunk = create_test_chunk(&format!("doc {}", i));
610 let embedding = generate_embedding(5, 16, 1000 + i);
611 index.insert(chunk, embedding).unwrap();
612 }
613
614 index.build().unwrap();
615
616 let memory = index.memory_usage();
617 assert!(memory > 0);
618 }
619
620 #[test]
621 fn test_index_centroid_stats() {
622 let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
623 let mut index = WarpIndex::new(config);
624
625 let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
626 index.train(&samples).unwrap();
627
628 for i in 0..10 {
629 let chunk = create_test_chunk(&format!("doc {}", i));
630 let embedding = generate_embedding(5, 16, 1000 + i);
631 index.insert(chunk, embedding).unwrap();
632 }
633
634 index.build().unwrap();
635
636 let total: usize = (0..8).map(|c| index.centroid_size(c)).sum();
638 assert_eq!(total, index.num_tokens());
639 }
640
641 #[test]
644 fn test_index_clear_and_rebuild() {
645 let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
646 let mut index = WarpIndex::new(config);
647
648 let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
649 index.train(&samples).unwrap();
650
651 let chunk = create_test_chunk("test");
652 let embedding = generate_embedding(5, 16, 0);
653 index.insert(chunk, embedding).unwrap();
654 index.build().unwrap();
655
656 assert!(index.is_built());
657
658 index.clear_index();
659
660 assert!(!index.is_built());
661 assert_eq!(index.num_tokens(), 0);
662 assert_eq!(index.num_chunks(), 1);
664 }
665
666 #[test]
669 fn test_index_get_chunk() {
670 let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
671 let mut index = WarpIndex::new(config);
672
673 let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
674 index.train(&samples).unwrap();
675
676 let chunk = create_test_chunk("test content");
677 let chunk_id = chunk.id;
678 let embedding = generate_embedding(5, 16, 0);
679 index.insert(chunk, embedding).unwrap();
680
681 let retrieved = index.get_chunk(&chunk_id);
682 assert!(retrieved.is_some());
683 assert_eq!(retrieved.unwrap().content, "test content");
684 }
685
686 use proptest::prelude::*;
689
690 proptest! {
691 #[test]
692 fn prop_search_returns_at_most_k(k in 1usize..20) {
693 let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
694 let mut index = WarpIndex::new(config);
695
696 let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
697 index.train(&samples).unwrap();
698
699 for i in 0..30 {
700 let chunk = create_test_chunk(&format!("doc {}", i));
701 let embedding = generate_embedding(5, 16, 1000 + i as u64);
702 index.insert(chunk, embedding).unwrap();
703 }
704
705 index.build().unwrap();
706
707 let query = generate_embedding(3, 16, 9999);
708 let search_config = WarpSearchConfig::with_k(k);
709 let results = index.search(&query, &search_config).unwrap();
710
711 prop_assert!(results.len() <= k);
712 }
713
714 #[test]
715 fn prop_search_results_sorted_descending(seed in 0u64..1000) {
716 let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
717 let mut index = WarpIndex::new(config);
718
719 let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
720 index.train(&samples).unwrap();
721
722 for i in 0..20 {
723 let chunk = create_test_chunk(&format!("doc {}", i));
724 let embedding = generate_embedding(5, 16, seed + i as u64);
725 index.insert(chunk, embedding).unwrap();
726 }
727
728 index.build().unwrap();
729
730 let query = generate_embedding(3, 16, seed + 1000);
731 let search_config = WarpSearchConfig::with_k(10);
732 let results = index.search(&query, &search_config).unwrap();
733
734 for i in 1..results.len() {
735 prop_assert!(results[i - 1].1 >= results[i].1);
736 }
737 }
738 }
739}