1use super::binary::BinaryVector;
17use super::int4::Int4Vector;
18use super::quantized::{QuantizedVector, cosine_similarity_i8_trusted, dot_product_i8_trusted};
19use super::{cosine_similarity, dot_product};
20use crate::error::{EmbedError, Result};
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum NormalizationHint {
28 Unknown,
30 Unit,
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
38pub enum QuantizationTier {
39 Full,
41 Int8,
43 Int4,
45 Binary,
47}
48
49impl QuantizationTier {
50 pub fn bytes_per_dim(&self) -> f32 {
52 match self {
53 Self::Full => 4.0,
54 Self::Int8 => 1.0,
55 Self::Int4 => 0.5,
56 Self::Binary => 0.125,
57 }
58 }
59
60 pub fn compression_ratio(&self) -> f32 {
62 4.0 / self.bytes_per_dim()
63 }
64
65 pub fn storage_bytes(&self, dims: usize) -> usize {
67 match self {
68 Self::Full => dims * 4,
69 Self::Int8 => dims,
70 Self::Int4 => dims.div_ceil(2),
71 Self::Binary => dims.div_ceil(8),
72 }
73 }
74
75 pub fn from_age_seconds(age_secs: u64) -> Self {
82 const HOUR: u64 = 3600;
83 const DAY: u64 = 86400;
84 const WEEK: u64 = 604800;
85
86 if age_secs < HOUR {
87 Self::Full
88 } else if age_secs < DAY {
89 Self::Int8
90 } else if age_secs < WEEK {
91 Self::Int4
92 } else {
93 Self::Binary
94 }
95 }
96}
97
98#[derive(Debug, Clone)]
103pub enum QuantizedData {
104 Full(Vec<f32>),
106 Int8(QuantizedVector),
108 Int4(Int4Vector),
110 Binary(BinaryVector),
112}
113
114impl QuantizedData {
115 pub fn tier(&self) -> QuantizationTier {
117 match self {
118 Self::Full(_) => QuantizationTier::Full,
119 Self::Int8(_) => QuantizationTier::Int8,
120 Self::Int4(_) => QuantizationTier::Int4,
121 Self::Binary(_) => QuantizationTier::Binary,
122 }
123 }
124
125 pub fn dims(&self) -> usize {
127 match self {
128 Self::Full(v) => v.len(),
129 Self::Int8(q) => q.data.len(),
130 Self::Int4(q) => q.dims,
131 Self::Binary(q) => q.dims,
132 }
133 }
134
135 pub fn storage_bytes(&self) -> usize {
137 match self {
138 Self::Full(v) => v.len() * 4,
139 Self::Int8(q) => q.data.len(),
140 Self::Int4(q) => q.data.len(),
141 Self::Binary(q) => q.data.len(),
142 }
143 }
144
145 pub fn from_f32(vector: &[f32], tier: QuantizationTier) -> Self {
147 match tier {
148 QuantizationTier::Full => Self::Full(vector.to_vec()),
149 QuantizationTier::Int8 => Self::Int8(QuantizedVector::from_f32(vector)),
150 QuantizationTier::Int4 => Self::Int4(Int4Vector::from_f32(vector)),
151 QuantizationTier::Binary => Self::Binary(BinaryVector::from_f32(vector)),
152 }
153 }
154
155 pub fn to_f32(&self) -> Vec<f32> {
157 match self {
158 Self::Full(v) => v.clone(),
159 Self::Int8(q) => q.to_f32(),
160 Self::Int4(q) => q.to_f32(),
161 Self::Binary(q) => q.to_f32(),
162 }
163 }
164
165 pub fn promote(&self, target: QuantizationTier) -> Self {
172 let f32_data = self.to_f32();
173 Self::from_f32(&f32_data, target)
174 }
175
176 pub fn demote(&self, target: QuantizationTier) -> Self {
178 self.promote(target) }
180}
181
182#[derive(Debug, Clone)]
187pub enum PreparedQuery {
188 Full(Vec<f32>),
190 Int8(QuantizedVector),
192 Int4(Int4Vector),
194 Binary(BinaryVector),
196}
197
198impl PreparedQuery {
199 #[inline]
201 pub fn from_f32(query_f32: &[f32], tier: QuantizationTier) -> Self {
202 match tier {
203 QuantizationTier::Full => Self::Full(query_f32.to_vec()),
204 QuantizationTier::Int8 => Self::Int8(QuantizedVector::from_f32(query_f32)),
205 QuantizationTier::Int4 => Self::Int4(Int4Vector::from_f32(query_f32)),
206 QuantizationTier::Binary => Self::Binary(BinaryVector::from_f32(query_f32)),
207 }
208 }
209
210 #[inline]
212 pub fn tier(&self) -> QuantizationTier {
213 match self {
214 Self::Full(_) => QuantizationTier::Full,
215 Self::Int8(_) => QuantizationTier::Int8,
216 Self::Int4(_) => QuantizationTier::Int4,
217 Self::Binary(_) => QuantizationTier::Binary,
218 }
219 }
220
221 #[inline]
223 pub fn dims(&self) -> usize {
224 match self {
225 Self::Full(v) => v.len(),
226 Self::Int8(q) => q.data.len(),
227 Self::Int4(q) => q.dims,
228 Self::Binary(q) => q.dims,
229 }
230 }
231}
232
233#[inline]
235pub fn prepare_query(query_f32: &[f32], tier: QuantizationTier) -> PreparedQuery {
236 PreparedQuery::from_f32(query_f32, tier)
237}
238
239#[derive(Debug, Clone)]
245pub struct PreparedQueryWithMeta {
246 pub query: PreparedQuery,
248 pub norm: NormalizationHint,
250}
251
252impl PreparedQueryWithMeta {
253 #[inline]
255 pub fn from_f32(query_f32: &[f32], tier: QuantizationTier, norm: NormalizationHint) -> Self {
256 Self {
257 query: PreparedQuery::from_f32(query_f32, tier),
258 norm,
259 }
260 }
261
262 #[inline]
264 pub fn tier(&self) -> QuantizationTier {
265 self.query.tier()
266 }
267
268 #[inline]
270 pub fn dims(&self) -> usize {
271 self.query.dims()
272 }
273}
274
275#[inline]
277pub fn is_unit_norm(v: &[f32]) -> bool {
278 let sq: f32 = v.iter().map(|x| x * x).sum();
279 (sq - 1.0).abs() < 1e-4
280}
281
282#[inline]
284pub fn prepare_query_with_norm(
285 query_f32: &[f32],
286 tier: QuantizationTier,
287 norm: NormalizationHint,
288) -> PreparedQueryWithMeta {
289 PreparedQueryWithMeta::from_f32(query_f32, tier, norm)
290}
291
292#[inline]
302pub fn approximate_cosine_distance_prepared(query: &PreparedQuery, stored: &QuantizedData) -> f32 {
303 match (query, stored) {
304 (PreparedQuery::Full(q), QuantizedData::Full(s)) => 1.0 - cosine_similarity(q, s),
305 (PreparedQuery::Int8(q), QuantizedData::Int8(s)) => {
306 1.0 - cosine_similarity_i8_trusted(s, q)
307 }
308 (PreparedQuery::Int4(q), QuantizedData::Int4(s)) => s.cosine_distance(q),
309 (PreparedQuery::Binary(q), QuantizedData::Binary(s)) => s.cosine_distance_approx(q),
310 _ => panic!("PreparedQuery tier must match QuantizedData tier"),
311 }
312}
313
314#[inline]
320pub fn try_approximate_cosine_distance_prepared(
321 query: &PreparedQuery,
322 stored: &QuantizedData,
323) -> Result<f32> {
324 match (query, stored) {
325 (PreparedQuery::Full(q), QuantizedData::Full(s)) => Ok(1.0 - cosine_similarity(q, s)),
326 (PreparedQuery::Int8(q), QuantizedData::Int8(s)) => {
327 Ok(1.0 - cosine_similarity_i8_trusted(s, q))
328 }
329 (PreparedQuery::Int4(q), QuantizedData::Int4(s)) => Ok(s.cosine_distance(q)),
330 (PreparedQuery::Binary(q), QuantizedData::Binary(s)) => Ok(s.cosine_distance_approx(q)),
331 _ => Err(EmbedError::Internal(
332 "PreparedQuery tier must match QuantizedData tier for cosine distance".into(),
333 )),
334 }
335}
336
337#[inline]
342pub fn try_approximate_dot_product_prepared(
343 query: &PreparedQuery,
344 stored: &QuantizedData,
345) -> Result<f32> {
346 match (query, stored) {
347 (PreparedQuery::Full(q), QuantizedData::Full(s)) => Ok(dot_product(q, s)),
348 (PreparedQuery::Int8(q), QuantizedData::Int8(s)) => Ok(dot_product_i8_trusted(q, s)),
349 (PreparedQuery::Int4(q), QuantizedData::Int4(s)) => Ok(s.dot_product(q)),
350 (PreparedQuery::Binary(_), QuantizedData::Binary(_)) => Err(EmbedError::Internal(
351 "Binary has no prepared dot product; use try_approximate_cosine_distance_prepared"
352 .into(),
353 )),
354 _ => Err(EmbedError::Internal(
355 "PreparedQuery tier must match QuantizedData tier for dot product".into(),
356 )),
357 }
358}
359
360#[inline]
371pub fn approximate_cosine_distance_prepared_with_meta(
372 meta: &PreparedQueryWithMeta,
373 stored: &QuantizedData,
374 stored_norm: NormalizationHint,
375) -> f32 {
376 if meta.norm == NormalizationHint::Unit && stored_norm == NormalizationHint::Unit {
377 if let (PreparedQuery::Full(q), QuantizedData::Full(s)) = (&meta.query, stored) {
378 let dot = dot_product(q, s);
379 return 1.0 - dot.clamp(-1.0, 1.0);
380 }
381 }
382 approximate_cosine_distance_prepared(&meta.query, stored)
383}
384
385#[inline]
394pub fn approximate_dot_product_prepared(query: &PreparedQuery, stored: &QuantizedData) -> f32 {
395 match (query, stored) {
396 (PreparedQuery::Full(q), QuantizedData::Full(s)) => dot_product(q, s),
397 (PreparedQuery::Int8(q), QuantizedData::Int8(s)) => dot_product_i8_trusted(q, s),
398 (PreparedQuery::Int4(q), QuantizedData::Int4(s)) => s.dot_product(q),
399 (PreparedQuery::Binary(_), QuantizedData::Binary(_)) => {
400 panic!("Binary has no prepared dot product; use approximate_cosine_distance_prepared")
401 }
402 _ => panic!("PreparedQuery tier must match QuantizedData tier"),
403 }
404}
405
406#[inline]
408pub fn batch_approximate_cosine_distance_prepared(
409 query: &PreparedQuery,
410 stored: &[QuantizedData],
411) -> Vec<f32> {
412 stored
413 .iter()
414 .map(|item| approximate_cosine_distance_prepared(query, item))
415 .collect()
416}
417
418#[inline]
422pub fn batch_approximate_cosine_distance_prepared_into(
423 query: &PreparedQuery,
424 stored: &[QuantizedData],
425 out: &mut Vec<f32>,
426) {
427 out.clear();
428 out.reserve(stored.len());
429 out.extend(
430 stored
431 .iter()
432 .map(|item| approximate_cosine_distance_prepared(query, item)),
433 );
434}
435
436pub fn approximate_cosine_distance(query_f32: &[f32], stored: &QuantizedData) -> f32 {
443 match stored {
444 QuantizedData::Full(v) => {
445 1.0 - cosine_similarity(query_f32, v)
447 }
448 QuantizedData::Int8(q) => {
449 let query_q = QuantizedVector::from_f32(query_f32);
451 1.0 - q.cosine_similarity(&query_q)
452 }
453 QuantizedData::Int4(q) => {
454 let query_q = Int4Vector::from_f32(query_f32);
456 q.cosine_distance(&query_q)
457 }
458 QuantizedData::Binary(q) => {
459 let query_q = BinaryVector::from_f32(query_f32);
461 q.cosine_distance_approx(&query_q)
462 }
463 }
464}
465
466pub fn approximate_dot_product(query_f32: &[f32], stored: &QuantizedData) -> f32 {
468 match stored {
469 QuantizedData::Full(v) => dot_product(query_f32, v),
470 QuantizedData::Int8(q) => {
471 let query_q = QuantizedVector::from_f32(query_f32);
472 q.dot_product(&query_q)
473 }
474 QuantizedData::Int4(q) => {
475 let query_q = Int4Vector::from_f32(query_f32);
476 q.dot_product(&query_q)
477 }
478 QuantizedData::Binary(_q) => {
479 let stored_f32 = _q.to_f32();
481 dot_product(query_f32, &stored_f32)
482 }
483 }
484}
485
486#[cfg(test)]
487mod tests {
488 use super::*;
489
490 fn generate_vector(dim: usize, seed: u64) -> Vec<f32> {
491 let mut state = seed ^ ((dim as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15));
492 (0..dim)
493 .map(|i| {
494 state = state
495 .wrapping_mul(6364136223846793005)
496 .wrapping_add(1442695040888963407)
497 .wrapping_add(i as u64);
498 let unit = ((state >> 32) as u32) as f32 / u32::MAX as f32;
499 unit * 2.0 - 1.0
500 })
501 .collect()
502 }
503
504 #[test]
505 fn test_tier_bytes_per_dim() {
506 assert_eq!(QuantizationTier::Full.bytes_per_dim(), 4.0);
507 assert_eq!(QuantizationTier::Int8.bytes_per_dim(), 1.0);
508 assert_eq!(QuantizationTier::Int4.bytes_per_dim(), 0.5);
509 assert_eq!(QuantizationTier::Binary.bytes_per_dim(), 0.125);
510 }
511
512 #[test]
513 fn test_tier_compression_ratios() {
514 assert_eq!(QuantizationTier::Full.compression_ratio(), 1.0);
515 assert_eq!(QuantizationTier::Int8.compression_ratio(), 4.0);
516 assert_eq!(QuantizationTier::Int4.compression_ratio(), 8.0);
517 assert_eq!(QuantizationTier::Binary.compression_ratio(), 32.0);
518 }
519
520 #[test]
521 fn test_tier_storage_bytes() {
522 assert_eq!(QuantizationTier::Full.storage_bytes(384), 1536);
523 assert_eq!(QuantizationTier::Int8.storage_bytes(384), 384);
524 assert_eq!(QuantizationTier::Int4.storage_bytes(384), 192);
525 assert_eq!(QuantizationTier::Binary.storage_bytes(384), 48);
526 }
527
528 #[test]
529 fn test_tier_from_age() {
530 assert_eq!(
531 QuantizationTier::from_age_seconds(0),
532 QuantizationTier::Full
533 );
534 assert_eq!(
535 QuantizationTier::from_age_seconds(1800),
536 QuantizationTier::Full
537 ); assert_eq!(
539 QuantizationTier::from_age_seconds(7200),
540 QuantizationTier::Int8
541 ); assert_eq!(
543 QuantizationTier::from_age_seconds(172800),
544 QuantizationTier::Int4
545 ); assert_eq!(
547 QuantizationTier::from_age_seconds(1_000_000),
548 QuantizationTier::Binary
549 ); }
551
552 #[test]
553 fn test_quantized_data_from_f32_all_tiers() {
554 let v = generate_vector(384, 42);
555
556 for tier in [
557 QuantizationTier::Full,
558 QuantizationTier::Int8,
559 QuantizationTier::Int4,
560 QuantizationTier::Binary,
561 ] {
562 let data = QuantizedData::from_f32(&v, tier);
563 assert_eq!(data.tier(), tier, "tier mismatch for {tier:?}");
564 assert_eq!(data.dims(), 384, "dims mismatch for {tier:?}");
565
566 let expected_bytes = tier.storage_bytes(384);
568 assert_eq!(
569 data.storage_bytes(),
570 expected_bytes,
571 "storage bytes mismatch for {tier:?}"
572 );
573 }
574 }
575
576 #[test]
577 fn test_approximate_cosine_distance_ordering() {
578 let a = generate_vector(384, 1);
580 let b: Vec<f32> = a
582 .iter()
583 .enumerate()
584 .map(|(i, &x)| x + 0.05 * (i as f32 * 0.3).sin())
585 .collect();
586 let c = generate_vector(384, 999);
588
589 for tier in [
590 QuantizationTier::Full,
591 QuantizationTier::Int8,
592 QuantizationTier::Int4,
593 QuantizationTier::Binary,
594 ] {
595 let stored_b = QuantizedData::from_f32(&b, tier);
596 let stored_c = QuantizedData::from_f32(&c, tier);
597
598 let dist_ab = approximate_cosine_distance(&a, &stored_b);
599 let dist_ac = approximate_cosine_distance(&a, &stored_c);
600
601 assert!(
603 dist_ab < dist_ac,
604 "{tier:?}: dist(a,b)={dist_ab} should be < dist(a,c)={dist_ac}"
605 );
606 }
607 }
608
609 #[test]
610 fn test_promote_demote_roundtrip() {
611 let v = generate_vector(384, 42);
612 let binary = QuantizedData::from_f32(&v, QuantizationTier::Binary);
613
614 let int4 = binary.promote(QuantizationTier::Int4);
616 assert_eq!(int4.tier(), QuantizationTier::Int4);
617
618 let int8 = int4.promote(QuantizationTier::Int8);
619 assert_eq!(int8.tier(), QuantizationTier::Int8);
620
621 let full = int8.promote(QuantizationTier::Full);
622 assert_eq!(full.tier(), QuantizationTier::Full);
623 assert_eq!(full.dims(), 384);
624 }
625
626 #[test]
627 fn test_quantized_data_to_f32_roundtrip() {
628 let v = generate_vector(384, 55);
629
630 let full_data = QuantizedData::from_f32(&v, QuantizationTier::Full);
632 let full_rt = full_data.to_f32();
633 for (a, b) in v.iter().zip(full_rt.iter()) {
634 assert!((a - b).abs() < 1e-10, "Full tier should be lossless");
635 }
636 }
637}