1use super::binary::BinaryVector;
17use super::int4::Int4Vector;
18use super::quantized::{QuantizedVector, cosine_similarity_i8_trusted, dot_product_i8_trusted};
19use super::{cosine_similarity, dot_product};
20use crate::error::{EmbedError, Result};
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum NormalizationHint {
28 Unknown,
30 Unit,
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
38pub enum QuantizationTier {
39 Full,
41 Int8,
43 Int4,
45 Binary,
47}
48
49impl QuantizationTier {
50 pub fn bytes_per_dim(&self) -> f32 {
52 match self {
53 Self::Full => 4.0,
54 Self::Int8 => 1.0,
55 Self::Int4 => 0.5,
56 Self::Binary => 0.125,
57 }
58 }
59
60 pub fn compression_ratio(&self) -> f32 {
62 4.0 / self.bytes_per_dim()
63 }
64
65 pub fn storage_bytes(&self, dims: usize) -> usize {
67 match self {
68 Self::Full => dims * 4,
69 Self::Int8 => dims,
70 Self::Int4 => dims.div_ceil(2),
71 Self::Binary => dims.div_ceil(8),
72 }
73 }
74
75 pub fn from_age_seconds(age_secs: u64) -> Self {
82 const HOUR: u64 = 3600;
83 const DAY: u64 = 86400;
84 const WEEK: u64 = 604800;
85
86 if age_secs < HOUR {
87 Self::Full
88 } else if age_secs < DAY {
89 Self::Int8
90 } else if age_secs < WEEK {
91 Self::Int4
92 } else {
93 Self::Binary
94 }
95 }
96}
97
98#[derive(Debug, Clone)]
103pub enum QuantizedData {
104 Full(Vec<f32>),
106 Int8(QuantizedVector),
108 Int4(Int4Vector),
110 Binary(BinaryVector),
112}
113
114impl QuantizedData {
115 pub fn tier(&self) -> QuantizationTier {
117 match self {
118 Self::Full(_) => QuantizationTier::Full,
119 Self::Int8(_) => QuantizationTier::Int8,
120 Self::Int4(_) => QuantizationTier::Int4,
121 Self::Binary(_) => QuantizationTier::Binary,
122 }
123 }
124
125 pub fn dims(&self) -> usize {
127 match self {
128 Self::Full(v) => v.len(),
129 Self::Int8(q) => q.len(),
130 Self::Int4(q) => q.dims,
131 Self::Binary(q) => q.dims,
132 }
133 }
134
135 pub fn storage_bytes(&self) -> usize {
137 match self {
138 Self::Full(v) => v.len() * 4,
139 Self::Int8(q) => q.len(),
140 Self::Int4(q) => q.data.len(),
141 Self::Binary(q) => q.data.len(),
142 }
143 }
144
145 pub fn from_f32(vector: &[f32], tier: QuantizationTier) -> Self {
147 match tier {
148 QuantizationTier::Full => Self::Full(vector.to_vec()),
149 QuantizationTier::Int8 => Self::Int8(QuantizedVector::from_f32(vector)),
150 QuantizationTier::Int4 => Self::Int4(Int4Vector::from_f32(vector)),
151 QuantizationTier::Binary => Self::Binary(BinaryVector::from_f32(vector)),
152 }
153 }
154
155 pub fn to_f32(&self) -> Vec<f32> {
157 match self {
158 Self::Full(v) => v.clone(),
159 Self::Int8(q) => q.to_f32(),
160 Self::Int4(q) => q.to_f32(),
161 Self::Binary(q) => q.to_f32(),
162 }
163 }
164
165 pub fn promote(&self, target: QuantizationTier) -> Self {
172 let f32_data = self.to_f32();
173 Self::from_f32(&f32_data, target)
174 }
175
176 pub fn demote(&self, target: QuantizationTier) -> Self {
178 self.promote(target) }
180}
181
182#[derive(Debug, Clone)]
187pub enum PreparedQuery {
188 Full(Vec<f32>),
190 Int8(QuantizedVector),
192 Int4(Int4Vector),
194 Binary(BinaryVector),
196}
197
198impl PreparedQuery {
199 #[inline]
201 pub fn from_f32(query_f32: &[f32], tier: QuantizationTier) -> Self {
202 match tier {
203 QuantizationTier::Full => Self::Full(query_f32.to_vec()),
204 QuantizationTier::Int8 => Self::Int8(QuantizedVector::from_f32(query_f32)),
205 QuantizationTier::Int4 => Self::Int4(Int4Vector::from_f32(query_f32)),
206 QuantizationTier::Binary => Self::Binary(BinaryVector::from_f32(query_f32)),
207 }
208 }
209
210 #[inline]
212 pub fn tier(&self) -> QuantizationTier {
213 match self {
214 Self::Full(_) => QuantizationTier::Full,
215 Self::Int8(_) => QuantizationTier::Int8,
216 Self::Int4(_) => QuantizationTier::Int4,
217 Self::Binary(_) => QuantizationTier::Binary,
218 }
219 }
220
221 #[inline]
223 pub fn dims(&self) -> usize {
224 match self {
225 Self::Full(v) => v.len(),
226 Self::Int8(q) => q.len(),
227 Self::Int4(q) => q.dims,
228 Self::Binary(q) => q.dims,
229 }
230 }
231}
232
233#[inline]
235pub fn prepare_query(query_f32: &[f32], tier: QuantizationTier) -> PreparedQuery {
236 PreparedQuery::from_f32(query_f32, tier)
237}
238
239#[derive(Debug, Clone)]
245pub struct PreparedQueryWithMeta {
246 pub query: PreparedQuery,
248 pub norm: NormalizationHint,
250}
251
252impl PreparedQueryWithMeta {
253 #[inline]
255 pub fn from_f32(query_f32: &[f32], tier: QuantizationTier, norm: NormalizationHint) -> Self {
256 Self {
257 query: PreparedQuery::from_f32(query_f32, tier),
258 norm,
259 }
260 }
261
262 #[inline]
264 pub fn tier(&self) -> QuantizationTier {
265 self.query.tier()
266 }
267
268 #[inline]
270 pub fn dims(&self) -> usize {
271 self.query.dims()
272 }
273}
274
275#[inline]
277pub fn is_unit_norm(v: &[f32]) -> bool {
278 let sq: f32 = v.iter().map(|x| x * x).sum();
279 (sq - 1.0).abs() < 1e-4
280}
281
282#[inline]
284pub fn prepare_query_with_norm(
285 query_f32: &[f32],
286 tier: QuantizationTier,
287 norm: NormalizationHint,
288) -> PreparedQueryWithMeta {
289 PreparedQueryWithMeta::from_f32(query_f32, tier, norm)
290}
291
292#[inline]
301pub fn approximate_cosine_distance_prepared(
302 query: &PreparedQuery,
303 stored: &QuantizedData,
304) -> Result<f32> {
305 match (query, stored) {
306 (PreparedQuery::Full(q), QuantizedData::Full(s)) => Ok(1.0 - cosine_similarity(q, s)),
307 (PreparedQuery::Int8(q), QuantizedData::Int8(s)) => {
308 Ok(1.0 - cosine_similarity_i8_trusted(s, q))
309 }
310 (PreparedQuery::Int4(q), QuantizedData::Int4(s)) => Ok(s.cosine_distance(q)),
311 (PreparedQuery::Binary(q), QuantizedData::Binary(s)) => Ok(s.cosine_distance_approx(q)),
312 _ => Err(EmbedError::TierMismatch {
313 op: "approximate_cosine_distance_prepared",
314 expected: stored.tier(),
315 actual: query.tier(),
316 }),
317 }
318}
319
320#[inline]
326pub fn try_approximate_cosine_distance_prepared(
327 query: &PreparedQuery,
328 stored: &QuantizedData,
329) -> Result<f32> {
330 approximate_cosine_distance_prepared(query, stored)
331}
332
333#[inline]
339pub fn try_approximate_dot_product_prepared(
340 query: &PreparedQuery,
341 stored: &QuantizedData,
342) -> Result<f32> {
343 approximate_dot_product_prepared(query, stored)
344}
345
346#[inline]
364pub fn approximate_cosine_distance_prepared_with_meta(
365 meta: &PreparedQueryWithMeta,
366 stored: &QuantizedData,
367 stored_norm: NormalizationHint,
368) -> Result<f32> {
369 if meta.norm == NormalizationHint::Unit
370 && stored_norm == NormalizationHint::Unit
371 && let (PreparedQuery::Full(q), QuantizedData::Full(s)) = (&meta.query, stored)
372 {
373 let dot = dot_product(q, s);
374 return Ok(1.0 - dot.clamp(-1.0, 1.0));
375 }
376 approximate_cosine_distance_prepared(&meta.query, stored)
377}
378
379#[inline]
387pub fn approximate_dot_product_prepared(
388 query: &PreparedQuery,
389 stored: &QuantizedData,
390) -> Result<f32> {
391 match (query, stored) {
392 (PreparedQuery::Full(q), QuantizedData::Full(s)) => Ok(dot_product(q, s)),
393 (PreparedQuery::Int8(q), QuantizedData::Int8(s)) => Ok(dot_product_i8_trusted(q, s)),
394 (PreparedQuery::Int4(q), QuantizedData::Int4(s)) => Ok(s.dot_product(q)),
395 (PreparedQuery::Binary(_), QuantizedData::Binary(_)) => Err(EmbedError::Internal(
396 "Binary has no prepared dot product; use approximate_cosine_distance_prepared".into(),
397 )),
398 _ => Err(EmbedError::TierMismatch {
399 op: "approximate_dot_product_prepared",
400 expected: stored.tier(),
401 actual: query.tier(),
402 }),
403 }
404}
405
406#[inline]
414pub fn batch_approximate_cosine_distance_prepared(
415 query: &PreparedQuery,
416 stored: &[QuantizedData],
417) -> Result<Vec<f32>> {
418 stored
419 .iter()
420 .map(|item| approximate_cosine_distance_prepared(query, item))
421 .collect()
422}
423
424#[inline]
435pub fn batch_approximate_cosine_distance_prepared_into(
436 query: &PreparedQuery,
437 stored: &[QuantizedData],
438 out: &mut Vec<f32>,
439) -> Result<()> {
440 out.clear();
441 out.reserve(stored.len());
442 for item in stored {
443 match approximate_cosine_distance_prepared(query, item) {
444 Ok(distance) => out.push(distance),
445 Err(e) => {
446 out.clear();
447 return Err(e);
448 }
449 }
450 }
451 Ok(())
452}
453
454#[inline]
462pub fn approximate_int8_batch_prepared(
463 query: &PreparedQuery,
464 candidates: &[QuantizedVector],
465) -> Result<Vec<f32>> {
466 let PreparedQuery::Int8(q) = query else {
467 return Err(EmbedError::TierMismatch {
468 op: "approximate_int8_batch_prepared",
469 expected: QuantizationTier::Int8,
470 actual: query.tier(),
471 });
472 };
473 Ok(candidates
474 .iter()
475 .map(|candidate| 1.0 - cosine_similarity_i8_trusted(candidate, q))
476 .collect())
477}
478
479#[inline]
487pub fn approximate_int8_batch_prepared_into(
488 query: &PreparedQuery,
489 candidates: &[QuantizedVector],
490 out: &mut Vec<f32>,
491) -> Result<()> {
492 out.clear();
493 let PreparedQuery::Int8(q) = query else {
494 return Err(EmbedError::TierMismatch {
495 op: "approximate_int8_batch_prepared_into",
496 expected: QuantizationTier::Int8,
497 actual: query.tier(),
498 });
499 };
500 out.reserve(candidates.len());
501 out.extend(
502 candidates
503 .iter()
504 .map(|candidate| 1.0 - cosine_similarity_i8_trusted(candidate, q)),
505 );
506 Ok(())
507}
508
509#[inline]
517pub fn approximate_int4_batch_prepared(
518 query: &PreparedQuery,
519 candidates: &[Int4Vector],
520) -> Result<Vec<f32>> {
521 let PreparedQuery::Int4(q) = query else {
522 return Err(EmbedError::TierMismatch {
523 op: "approximate_int4_batch_prepared",
524 expected: QuantizationTier::Int4,
525 actual: query.tier(),
526 });
527 };
528 Ok(candidates
529 .iter()
530 .map(|candidate| candidate.cosine_distance(q))
531 .collect())
532}
533
534#[inline]
542pub fn approximate_int4_batch_prepared_into(
543 query: &PreparedQuery,
544 candidates: &[Int4Vector],
545 out: &mut Vec<f32>,
546) -> Result<()> {
547 out.clear();
548 let PreparedQuery::Int4(q) = query else {
549 return Err(EmbedError::TierMismatch {
550 op: "approximate_int4_batch_prepared_into",
551 expected: QuantizationTier::Int4,
552 actual: query.tier(),
553 });
554 };
555 out.reserve(candidates.len());
556 out.extend(
557 candidates
558 .iter()
559 .map(|candidate| candidate.cosine_distance(q)),
560 );
561 Ok(())
562}
563
564pub fn approximate_cosine_distance(query_f32: &[f32], stored: &QuantizedData) -> f32 {
576 debug_assert_eq!(
577 query_f32.len(),
578 stored.dims(),
579 "approximate_cosine_distance: query length {} != stored dims {}",
580 query_f32.len(),
581 stored.dims(),
582 );
583 match stored {
584 QuantizedData::Full(v) => {
585 1.0 - cosine_similarity(query_f32, v)
587 }
588 QuantizedData::Int8(q) => {
589 let query_q = QuantizedVector::from_f32(query_f32);
591 1.0 - q.cosine_similarity(&query_q)
592 }
593 QuantizedData::Int4(q) => {
594 let query_q = Int4Vector::from_f32(query_f32);
596 q.cosine_distance(&query_q)
597 }
598 QuantizedData::Binary(q) => {
599 let query_q = BinaryVector::from_f32(query_f32);
601 q.cosine_distance_approx(&query_q)
602 }
603 }
604}
605
606pub fn approximate_dot_product(query_f32: &[f32], stored: &QuantizedData) -> f32 {
608 match stored {
609 QuantizedData::Full(v) => dot_product(query_f32, v),
610 QuantizedData::Int8(q) => {
611 let query_q = QuantizedVector::from_f32(query_f32);
612 q.dot_product(&query_q)
613 }
614 QuantizedData::Int4(q) => {
615 let query_q = Int4Vector::from_f32(query_f32);
616 q.dot_product(&query_q)
617 }
618 QuantizedData::Binary(_q) => {
619 let stored_f32 = _q.to_f32();
621 dot_product(query_f32, &stored_f32)
622 }
623 }
624}
625
626#[cfg(test)]
627mod tests {
628 use super::*;
629
630 fn generate_vector(dim: usize, seed: u64) -> Vec<f32> {
631 let mut state = seed ^ ((dim as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15));
632 (0..dim)
633 .map(|i| {
634 state = state
635 .wrapping_mul(6364136223846793005)
636 .wrapping_add(1442695040888963407)
637 .wrapping_add(i as u64);
638 let unit = ((state >> 32) as u32) as f32 / u32::MAX as f32;
639 unit * 2.0 - 1.0
640 })
641 .collect()
642 }
643
644 #[test]
645 fn test_tier_bytes_per_dim() {
646 assert_eq!(QuantizationTier::Full.bytes_per_dim(), 4.0);
647 assert_eq!(QuantizationTier::Int8.bytes_per_dim(), 1.0);
648 assert_eq!(QuantizationTier::Int4.bytes_per_dim(), 0.5);
649 assert_eq!(QuantizationTier::Binary.bytes_per_dim(), 0.125);
650 }
651
652 #[test]
653 fn test_tier_compression_ratios() {
654 assert_eq!(QuantizationTier::Full.compression_ratio(), 1.0);
655 assert_eq!(QuantizationTier::Int8.compression_ratio(), 4.0);
656 assert_eq!(QuantizationTier::Int4.compression_ratio(), 8.0);
657 assert_eq!(QuantizationTier::Binary.compression_ratio(), 32.0);
658 }
659
660 #[test]
661 fn test_tier_storage_bytes() {
662 assert_eq!(QuantizationTier::Full.storage_bytes(384), 1536);
663 assert_eq!(QuantizationTier::Int8.storage_bytes(384), 384);
664 assert_eq!(QuantizationTier::Int4.storage_bytes(384), 192);
665 assert_eq!(QuantizationTier::Binary.storage_bytes(384), 48);
666 }
667
668 #[test]
669 fn test_tier_from_age() {
670 assert_eq!(
671 QuantizationTier::from_age_seconds(0),
672 QuantizationTier::Full
673 );
674 assert_eq!(
675 QuantizationTier::from_age_seconds(1800),
676 QuantizationTier::Full
677 ); assert_eq!(
679 QuantizationTier::from_age_seconds(7200),
680 QuantizationTier::Int8
681 ); assert_eq!(
683 QuantizationTier::from_age_seconds(172800),
684 QuantizationTier::Int4
685 ); assert_eq!(
687 QuantizationTier::from_age_seconds(1_000_000),
688 QuantizationTier::Binary
689 ); }
691
692 #[test]
693 fn test_quantized_data_from_f32_all_tiers() {
694 let v = generate_vector(384, 42);
695
696 for tier in [
697 QuantizationTier::Full,
698 QuantizationTier::Int8,
699 QuantizationTier::Int4,
700 QuantizationTier::Binary,
701 ] {
702 let data = QuantizedData::from_f32(&v, tier);
703 assert_eq!(data.tier(), tier, "tier mismatch for {tier:?}");
704 assert_eq!(data.dims(), 384, "dims mismatch for {tier:?}");
705
706 let expected_bytes = tier.storage_bytes(384);
708 assert_eq!(
709 data.storage_bytes(),
710 expected_bytes,
711 "storage bytes mismatch for {tier:?}"
712 );
713 }
714 }
715
716 #[test]
717 fn test_approximate_cosine_distance_ordering() {
718 let a = generate_vector(384, 1);
720 let b: Vec<f32> = a
722 .iter()
723 .enumerate()
724 .map(|(i, &x)| x + 0.05 * (i as f32 * 0.3).sin())
725 .collect();
726 let c = generate_vector(384, 999);
728
729 for tier in [
730 QuantizationTier::Full,
731 QuantizationTier::Int8,
732 QuantizationTier::Int4,
733 QuantizationTier::Binary,
734 ] {
735 let stored_b = QuantizedData::from_f32(&b, tier);
736 let stored_c = QuantizedData::from_f32(&c, tier);
737
738 let dist_ab = approximate_cosine_distance(&a, &stored_b);
739 let dist_ac = approximate_cosine_distance(&a, &stored_c);
740
741 assert!(
743 dist_ab < dist_ac,
744 "{tier:?}: dist(a,b)={dist_ab} should be < dist(a,c)={dist_ac}"
745 );
746 }
747 }
748
749 #[test]
750 fn test_promote_demote_roundtrip() {
751 let v = generate_vector(384, 42);
752 let binary = QuantizedData::from_f32(&v, QuantizationTier::Binary);
753
754 let int4 = binary.promote(QuantizationTier::Int4);
756 assert_eq!(int4.tier(), QuantizationTier::Int4);
757
758 let int8 = int4.promote(QuantizationTier::Int8);
759 assert_eq!(int8.tier(), QuantizationTier::Int8);
760
761 let full = int8.promote(QuantizationTier::Full);
762 assert_eq!(full.tier(), QuantizationTier::Full);
763 assert_eq!(full.dims(), 384);
764 }
765
766 #[test]
767 fn test_int8_batch_prepared_matches_per_item_prepared() {
768 let query = generate_vector(384, 42);
769 let prepared = PreparedQuery::from_f32(&query, QuantizationTier::Int8);
770 let candidates: Vec<QuantizedVector> = (0..32)
771 .map(|i| QuantizedVector::from_f32(&generate_vector(384, i + 1)))
772 .collect();
773 let wrapped: Vec<QuantizedData> = candidates
774 .iter()
775 .cloned()
776 .map(QuantizedData::Int8)
777 .collect();
778
779 let got = approximate_int8_batch_prepared(&prepared, &candidates).unwrap();
780 for (i, item) in wrapped.iter().enumerate() {
781 let expected = approximate_cosine_distance_prepared(&prepared, item).unwrap();
782 assert!(
783 (got[i] - expected).abs() < 1e-6,
784 "int8 batch prepared mismatch at candidate {i}: got={}, expected={}",
785 got[i],
786 expected
787 );
788 }
789 }
790
791 #[test]
792 fn test_int4_batch_prepared_matches_per_item_prepared() {
793 let query = generate_vector(384, 42);
794 let prepared = PreparedQuery::from_f32(&query, QuantizationTier::Int4);
795 let candidates: Vec<Int4Vector> = (0..32)
796 .map(|i| Int4Vector::from_f32(&generate_vector(384, i + 1)))
797 .collect();
798 let wrapped: Vec<QuantizedData> = candidates
799 .iter()
800 .cloned()
801 .map(QuantizedData::Int4)
802 .collect();
803
804 let got = approximate_int4_batch_prepared(&prepared, &candidates).unwrap();
805 for (i, item) in wrapped.iter().enumerate() {
806 let expected = approximate_cosine_distance_prepared(&prepared, item).unwrap();
807 assert!(
808 (got[i] - expected).abs() < 1e-5,
809 "int4 batch prepared mismatch at candidate {i}: got={}, expected={}",
810 got[i],
811 expected
812 );
813 }
814 }
815
816 #[test]
817 fn test_int4_batch_prepared_api_dispatch_parity() {
818 for dim in [1usize, 3, 31, 127, 383, 384] {
823 let query = generate_vector(dim, 700 + dim as u64);
824 let candidate = generate_vector(dim, 800 + dim as u64);
825 let prepared = PreparedQuery::from_f32(&query, QuantizationTier::Int4);
826 let q_cand = Int4Vector::from_f32(&candidate);
827 let wrapped = QuantizedData::Int4(q_cand.clone());
828
829 let batch_result = approximate_int4_batch_prepared(&prepared, &[q_cand]).unwrap();
830 let per_item_result =
831 approximate_cosine_distance_prepared(&prepared, &wrapped).unwrap();
832
833 assert!(
834 (batch_result[0] - per_item_result).abs() < 1e-5,
835 "int4 batch prepared dispatch mismatch at dim={dim}: batch={}, per_item={}",
836 batch_result[0],
837 per_item_result
838 );
839 }
840 }
841
842 #[test]
843 fn test_quantized_data_to_f32_roundtrip() {
844 let v = generate_vector(384, 55);
845
846 let full_data = QuantizedData::from_f32(&v, QuantizationTier::Full);
848 let full_rt = full_data.to_f32();
849 for (a, b) in v.iter().zip(full_rt.iter()) {
850 assert!((a - b).abs() < 1e-10, "Full tier should be lossless");
851 }
852 }
853
854 #[test]
860 fn test_cosine_distance_prepared_tier_mismatch_returns_typed_error() {
861 let v = generate_vector(64, 1);
862 let query = PreparedQuery::from_f32(&v, QuantizationTier::Int8);
863 let stored = QuantizedData::from_f32(&v, QuantizationTier::Int4);
864
865 let err = approximate_cosine_distance_prepared(&query, &stored).unwrap_err();
866 match err {
867 EmbedError::TierMismatch {
868 op,
869 expected,
870 actual,
871 } => {
872 assert_eq!(op, "approximate_cosine_distance_prepared");
873 assert_eq!(expected, QuantizationTier::Int4);
874 assert_eq!(actual, QuantizationTier::Int8);
875 }
876 other => panic!("expected TierMismatch, got {other:?}"),
877 }
878
879 assert!(try_approximate_cosine_distance_prepared(&query, &stored).is_err());
881 }
882
883 #[test]
884 fn test_dot_product_prepared_tier_mismatch_returns_typed_error() {
885 let v = generate_vector(64, 2);
886 let query = PreparedQuery::from_f32(&v, QuantizationTier::Full);
887 let stored = QuantizedData::from_f32(&v, QuantizationTier::Int8);
888
889 let err = approximate_dot_product_prepared(&query, &stored).unwrap_err();
890 assert!(
891 matches!(
892 err,
893 EmbedError::TierMismatch {
894 op: "approximate_dot_product_prepared",
895 ..
896 }
897 ),
898 "unexpected error variant: {err:?}"
899 );
900
901 assert!(try_approximate_dot_product_prepared(&query, &stored).is_err());
902 }
903
904 #[test]
905 fn test_dot_product_prepared_binary_returns_typed_error_not_panic() {
906 let v = generate_vector(64, 3);
907 let query = PreparedQuery::from_f32(&v, QuantizationTier::Binary);
908 let stored = QuantizedData::from_f32(&v, QuantizationTier::Binary);
909
910 let err = approximate_dot_product_prepared(&query, &stored).unwrap_err();
911 assert!(
912 matches!(err, EmbedError::Internal(_)),
913 "unexpected error variant: {err:?}"
914 );
915 }
916
917 #[test]
918 fn test_cosine_distance_prepared_with_meta_tier_mismatch_returns_typed_error() {
919 let v = generate_vector(64, 4);
920 let meta =
921 PreparedQueryWithMeta::from_f32(&v, QuantizationTier::Full, NormalizationHint::Unknown);
922 let stored = QuantizedData::from_f32(&v, QuantizationTier::Int8);
923
924 let err = approximate_cosine_distance_prepared_with_meta(
925 &meta,
926 &stored,
927 NormalizationHint::Unknown,
928 )
929 .unwrap_err();
930 assert!(matches!(err, EmbedError::TierMismatch { .. }));
931 }
932
933 #[test]
934 fn test_batch_cosine_distance_prepared_tier_mismatch_returns_typed_error() {
935 let v = generate_vector(64, 5);
936 let query = PreparedQuery::from_f32(&v, QuantizationTier::Int8);
937 let stored = vec![
938 QuantizedData::from_f32(&v, QuantizationTier::Int8),
939 QuantizedData::from_f32(&v, QuantizationTier::Int4), ];
941
942 let err = batch_approximate_cosine_distance_prepared(&query, &stored).unwrap_err();
943 assert!(matches!(err, EmbedError::TierMismatch { .. }));
944
945 let mut out = vec![9.0, 9.0, 9.0]; let err =
947 batch_approximate_cosine_distance_prepared_into(&query, &stored, &mut out).unwrap_err();
948 assert!(matches!(err, EmbedError::TierMismatch { .. }));
949 assert!(
950 out.is_empty(),
951 "buffer must be cleared, not left with stale data"
952 );
953 }
954
955 #[test]
956 fn test_int8_batch_prepared_wrong_tier_returns_typed_error() {
957 let v = generate_vector(64, 6);
958 let query = PreparedQuery::from_f32(&v, QuantizationTier::Int4); let candidates = vec![QuantizedVector::from_f32(&v)];
960
961 let err = approximate_int8_batch_prepared(&query, &candidates).unwrap_err();
962 match err {
963 EmbedError::TierMismatch {
964 op,
965 expected,
966 actual,
967 } => {
968 assert_eq!(op, "approximate_int8_batch_prepared");
969 assert_eq!(expected, QuantizationTier::Int8);
970 assert_eq!(actual, QuantizationTier::Int4);
971 }
972 other => panic!("expected TierMismatch, got {other:?}"),
973 }
974
975 let mut out = vec![9.0];
976 let err = approximate_int8_batch_prepared_into(&query, &candidates, &mut out).unwrap_err();
977 assert!(matches!(err, EmbedError::TierMismatch { .. }));
978 assert!(
979 out.is_empty(),
980 "buffer must be cleared, not left with stale data"
981 );
982 }
983
984 #[test]
985 fn test_int4_batch_prepared_wrong_tier_returns_typed_error() {
986 let v = generate_vector(64, 7);
987 let query = PreparedQuery::from_f32(&v, QuantizationTier::Int8); let candidates = vec![Int4Vector::from_f32(&v)];
989
990 let err = approximate_int4_batch_prepared(&query, &candidates).unwrap_err();
991 match err {
992 EmbedError::TierMismatch {
993 op,
994 expected,
995 actual,
996 } => {
997 assert_eq!(op, "approximate_int4_batch_prepared");
998 assert_eq!(expected, QuantizationTier::Int4);
999 assert_eq!(actual, QuantizationTier::Int8);
1000 }
1001 other => panic!("expected TierMismatch, got {other:?}"),
1002 }
1003
1004 let mut out = vec![9.0];
1005 let err = approximate_int4_batch_prepared_into(&query, &candidates, &mut out).unwrap_err();
1006 assert!(matches!(err, EmbedError::TierMismatch { .. }));
1007 assert!(
1008 out.is_empty(),
1009 "buffer must be cleared, not left with stale data"
1010 );
1011 }
1012}