aprender_rag/multivector/
types.rs1use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct MultiVectorEmbedding {
34 embeddings: Vec<f32>,
36 num_tokens: usize,
38 dim: usize,
40}
41
42impl MultiVectorEmbedding {
43 #[must_use]
49 pub fn new(embeddings: Vec<f32>, num_tokens: usize, dim: usize) -> Self {
50 assert_eq!(
51 embeddings.len(),
52 num_tokens * dim,
53 "Embedding size mismatch: expected {} ({}×{}), got {}",
54 num_tokens * dim,
55 num_tokens,
56 dim,
57 embeddings.len()
58 );
59 contract_pre_embedding_lookup!(embeddings);
61 Self { embeddings, num_tokens, dim }
62 }
63
64 #[must_use]
66 pub fn from_tokens(tokens: &[Vec<f32>]) -> Self {
67 if tokens.is_empty() {
68 return Self { embeddings: Vec::new(), num_tokens: 0, dim: 0 };
69 }
70
71 let dim = tokens[0].len();
72 let num_tokens = tokens.len();
73 let mut embeddings = Vec::with_capacity(num_tokens * dim);
74
75 for token in tokens {
76 assert_eq!(token.len(), dim, "All tokens must have the same dimension");
77 embeddings.extend_from_slice(token);
78 }
79
80 Self { embeddings, num_tokens, dim }
81 }
82
83 #[must_use]
85 pub fn num_tokens(&self) -> usize {
86 self.num_tokens
87 }
88
89 #[must_use]
91 pub fn dim(&self) -> usize {
92 self.dim
93 }
94
95 #[must_use]
101 pub fn token(&self, i: usize) -> &[f32] {
102 assert!(i < self.num_tokens, "Token index out of bounds");
103 let start = i * self.dim;
104 &self.embeddings[start..start + self.dim]
105 }
106
107 pub fn tokens(&self) -> impl Iterator<Item = &[f32]> {
112 if self.dim == 0 {
113 [].chunks_exact(1)
115 } else {
116 self.embeddings.chunks_exact(self.dim)
117 }
118 }
119
120 #[must_use]
122 pub fn as_slice(&self) -> &[f32] {
123 &self.embeddings
124 }
125
126 pub fn as_mut_slice(&mut self) -> &mut [f32] {
128 &mut self.embeddings
129 }
130
131 #[must_use]
133 pub fn size_bytes(&self) -> usize {
134 self.embeddings.len() * size_of::<f32>()
135 }
136
137 #[must_use]
139 pub fn is_empty(&self) -> bool {
140 self.num_tokens == 0
141 }
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct WarpIndexConfig {
159 pub nbits: u8,
164
165 pub num_centroids: usize,
170
171 pub token_dim: usize,
173
174 pub min_training_samples: Option<usize>,
179
180 pub kmeans_iterations: usize,
182}
183
184impl Default for WarpIndexConfig {
185 fn default() -> Self {
186 Self {
187 nbits: 2,
188 num_centroids: 1024,
189 token_dim: 128,
190 min_training_samples: None,
191 kmeans_iterations: 20,
192 }
193 }
194}
195
196impl WarpIndexConfig {
197 #[must_use]
199 pub fn new(nbits: u8, num_centroids: usize, token_dim: usize) -> Self {
200 Self { nbits, num_centroids, token_dim, ..Default::default() }
201 }
202
203 #[must_use]
205 pub fn with_min_training_samples(mut self, samples: usize) -> Self {
206 self.min_training_samples = Some(samples);
207 self
208 }
209
210 #[must_use]
212 pub fn with_kmeans_iterations(mut self, iterations: usize) -> Self {
213 self.kmeans_iterations = iterations;
214 self
215 }
216
217 #[must_use]
219 pub fn effective_min_training_samples(&self) -> usize {
220 self.min_training_samples.unwrap_or(10 * self.num_centroids)
221 }
222
223 #[must_use]
225 pub fn packed_residual_size(&self) -> usize {
226 (self.token_dim * self.nbits as usize + 7) / 8
227 }
228
229 pub fn validate(&self) -> Result<(), &'static str> {
231 if self.nbits != 2 && self.nbits != 4 {
232 return Err("nbits must be 2 or 4");
233 }
234 if self.num_centroids == 0 {
235 return Err("num_centroids must be > 0");
236 }
237 if self.token_dim == 0 {
238 return Err("token_dim must be > 0");
239 }
240 if self.kmeans_iterations == 0 {
241 return Err("kmeans_iterations must be > 0");
242 }
243 Ok(())
244 }
245}
246
247#[derive(Debug, Clone, Serialize, Deserialize)]
252pub struct WarpSearchConfig {
253 pub k: usize,
255
256 pub nprobe: u32,
261
262 pub bound: usize,
266
267 pub t_prime: Option<usize>,
272
273 pub centroid_score_threshold: f32,
278}
279
280impl Default for WarpSearchConfig {
281 fn default() -> Self {
282 Self { k: 10, nprobe: 4, bound: 128, t_prime: None, centroid_score_threshold: 0.4 }
283 }
284}
285
286impl WarpSearchConfig {
287 #[must_use]
289 pub fn with_k(k: usize) -> Self {
290 Self { k, ..Default::default() }
291 }
292
293 #[must_use]
295 pub fn nprobe(mut self, nprobe: u32) -> Self {
296 self.nprobe = nprobe;
297 self
298 }
299
300 #[must_use]
302 pub fn bound(mut self, bound: usize) -> Self {
303 self.bound = bound;
304 self
305 }
306
307 #[must_use]
309 pub fn t_prime(mut self, t_prime: usize) -> Self {
310 self.t_prime = Some(t_prime);
311 self
312 }
313
314 #[must_use]
316 pub fn centroid_score_threshold(mut self, threshold: f32) -> Self {
317 self.centroid_score_threshold = threshold;
318 self
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 #[test]
329 fn test_multivector_new() {
330 let embeddings = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
331 let mv = MultiVectorEmbedding::new(embeddings, 2, 3);
332
333 assert_eq!(mv.num_tokens(), 2);
334 assert_eq!(mv.dim(), 3);
335 assert_eq!(mv.token(0), &[1.0, 2.0, 3.0]);
336 assert_eq!(mv.token(1), &[4.0, 5.0, 6.0]);
337 }
338
339 #[test]
340 #[should_panic(expected = "Embedding size mismatch")]
341 fn test_multivector_size_mismatch() {
342 let embeddings = vec![1.0, 2.0, 3.0];
343 let _ = MultiVectorEmbedding::new(embeddings, 2, 3); }
345
346 #[test]
347 fn test_multivector_from_tokens() {
348 let tokens = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
349 let mv = MultiVectorEmbedding::from_tokens(&tokens);
350
351 assert_eq!(mv.num_tokens(), 3);
352 assert_eq!(mv.dim(), 2);
353 }
354
355 #[test]
356 fn test_multivector_from_tokens_empty() {
357 let tokens: Vec<Vec<f32>> = vec![];
358 let mv = MultiVectorEmbedding::from_tokens(&tokens);
359
360 assert_eq!(mv.num_tokens(), 0);
361 assert!(mv.is_empty());
362 }
363
364 #[test]
367 fn test_multivector_dim_zero_tokens_no_panic() {
368 let mv = MultiVectorEmbedding::from_tokens(&[]);
369 assert_eq!(mv.dim(), 0);
370 assert_eq!(mv.tokens().count(), 0); }
372
373 #[test]
376 fn test_multivector_new_zero_dim_zero_tokens() {
377 let mv = MultiVectorEmbedding::new(vec![], 0, 0);
378 assert_eq!(mv.tokens().count(), 0);
379 assert!(mv.is_empty());
380 }
381
382 #[test]
383 fn test_multivector_tokens_iterator() {
384 let embeddings = vec![1.0, 2.0, 3.0, 4.0];
385 let mv = MultiVectorEmbedding::new(embeddings, 2, 2);
386
387 let tokens: Vec<&[f32]> = mv.tokens().collect();
388 assert_eq!(tokens.len(), 2);
389 assert_eq!(tokens[0], &[1.0, 2.0]);
390 assert_eq!(tokens[1], &[3.0, 4.0]);
391 }
392
393 #[test]
394 fn test_multivector_size_bytes() {
395 let embeddings = vec![0.0; 100];
396 let mv = MultiVectorEmbedding::new(embeddings, 10, 10);
397
398 assert_eq!(mv.size_bytes(), 100 * 4); }
400
401 #[test]
402 fn test_multivector_as_slice() {
403 let embeddings = vec![1.0, 2.0, 3.0];
404 let mv = MultiVectorEmbedding::new(embeddings.clone(), 1, 3);
405
406 assert_eq!(mv.as_slice(), &[1.0, 2.0, 3.0]);
407 }
408
409 #[test]
410 fn test_multivector_serialization() {
411 let mv = MultiVectorEmbedding::new(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
412 let json = serde_json::to_string(&mv).unwrap();
413 let deserialized: MultiVectorEmbedding = serde_json::from_str(&json).unwrap();
414
415 assert_eq!(mv.num_tokens(), deserialized.num_tokens());
416 assert_eq!(mv.dim(), deserialized.dim());
417 assert_eq!(mv.as_slice(), deserialized.as_slice());
418 }
419
420 #[test]
423 fn test_index_config_default() {
424 let config = WarpIndexConfig::default();
425
426 assert_eq!(config.nbits, 2);
427 assert_eq!(config.num_centroids, 1024);
428 assert_eq!(config.token_dim, 128);
429 assert_eq!(config.kmeans_iterations, 20);
430 }
431
432 #[test]
433 fn test_index_config_new() {
434 let config = WarpIndexConfig::new(4, 256, 64);
435
436 assert_eq!(config.nbits, 4);
437 assert_eq!(config.num_centroids, 256);
438 assert_eq!(config.token_dim, 64);
439 }
440
441 #[test]
442 fn test_index_config_builders() {
443 let config = WarpIndexConfig::new(2, 512, 128)
444 .with_min_training_samples(5000)
445 .with_kmeans_iterations(30);
446
447 assert_eq!(config.min_training_samples, Some(5000));
448 assert_eq!(config.kmeans_iterations, 30);
449 }
450
451 #[test]
452 fn test_index_config_effective_min_samples() {
453 let config = WarpIndexConfig::new(2, 100, 128);
454 assert_eq!(config.effective_min_training_samples(), 1000); let config = config.with_min_training_samples(500);
457 assert_eq!(config.effective_min_training_samples(), 500);
458 }
459
460 #[test]
461 fn test_index_config_packed_size() {
462 let config = WarpIndexConfig::new(2, 1024, 128);
464 assert_eq!(config.packed_residual_size(), 32);
465
466 let config = WarpIndexConfig::new(4, 1024, 128);
468 assert_eq!(config.packed_residual_size(), 64);
469 }
470
471 #[test]
472 fn test_index_config_validate() {
473 let config = WarpIndexConfig::default();
474 assert!(config.validate().is_ok());
475
476 let bad_nbits = WarpIndexConfig { nbits: 3, ..Default::default() };
477 assert!(bad_nbits.validate().is_err());
478
479 let bad_centroids = WarpIndexConfig { num_centroids: 0, ..Default::default() };
480 assert!(bad_centroids.validate().is_err());
481 }
482
483 #[test]
484 fn test_index_config_serialization() {
485 let config = WarpIndexConfig::new(4, 512, 64);
486 let json = serde_json::to_string(&config).unwrap();
487 let deserialized: WarpIndexConfig = serde_json::from_str(&json).unwrap();
488
489 assert_eq!(config.nbits, deserialized.nbits);
490 assert_eq!(config.num_centroids, deserialized.num_centroids);
491 assert_eq!(config.token_dim, deserialized.token_dim);
492 }
493
494 #[test]
497 fn test_search_config_default() {
498 let config = WarpSearchConfig::default();
499
500 assert_eq!(config.k, 10);
501 assert_eq!(config.nprobe, 4);
502 assert_eq!(config.bound, 128);
503 assert!(config.t_prime.is_none());
504 assert!((config.centroid_score_threshold - 0.4).abs() < 0.001);
505 }
506
507 #[test]
508 fn test_search_config_with_k() {
509 let config = WarpSearchConfig::with_k(20);
510 assert_eq!(config.k, 20);
511 }
512
513 #[test]
514 fn test_search_config_builders() {
515 let config = WarpSearchConfig::with_k(5)
516 .nprobe(8)
517 .bound(256)
518 .t_prime(10)
519 .centroid_score_threshold(0.5);
520
521 assert_eq!(config.k, 5);
522 assert_eq!(config.nprobe, 8);
523 assert_eq!(config.bound, 256);
524 assert_eq!(config.t_prime, Some(10));
525 assert!((config.centroid_score_threshold - 0.5).abs() < 0.001);
526 }
527
528 #[test]
529 fn test_search_config_serialization() {
530 let config = WarpSearchConfig::with_k(15).nprobe(6);
531 let json = serde_json::to_string(&config).unwrap();
532 let deserialized: WarpSearchConfig = serde_json::from_str(&json).unwrap();
533
534 assert_eq!(config.k, deserialized.k);
535 assert_eq!(config.nprobe, deserialized.nprobe);
536 }
537
538 use proptest::prelude::*;
541
542 proptest! {
543 #[test]
544 fn prop_multivector_tokens_count_matches(
545 num_tokens in 1usize..20,
546 dim in 1usize..64
547 ) {
548 let embeddings = vec![0.0f32; num_tokens * dim];
549 let mv = MultiVectorEmbedding::new(embeddings, num_tokens, dim);
550
551 prop_assert_eq!(mv.num_tokens(), num_tokens);
552 prop_assert_eq!(mv.dim(), dim);
553 prop_assert_eq!(mv.tokens().count(), num_tokens);
554 }
555
556 #[test]
557 fn prop_multivector_token_slices_correct_size(
558 num_tokens in 1usize..10,
559 dim in 1usize..32
560 ) {
561 let embeddings = vec![0.0f32; num_tokens * dim];
562 let mv = MultiVectorEmbedding::new(embeddings, num_tokens, dim);
563
564 for i in 0..num_tokens {
565 prop_assert_eq!(mv.token(i).len(), dim);
566 }
567 }
568
569 #[test]
570 fn prop_index_config_packed_size_formula(
571 nbits in prop::sample::select(vec![2u8, 4]),
572 dim in 1usize..256
573 ) {
574 let config = WarpIndexConfig::new(nbits, 1024, dim);
575 let expected = (dim * nbits as usize + 7) / 8;
576 prop_assert_eq!(config.packed_residual_size(), expected);
577 }
578 }
579}