Skip to main content

lattice_embed/simd/
tier.rs

1//! Quantization tier management and unified distance dispatch.
2//!
3//! Provides a `QuantizationTier` enum for selecting precision levels and
4//! a `QuantizedData` enum for storing vectors at any tier with unified
5//! distance computation.
6//!
7//! ## Tier hierarchy
8//!
9//! | Tier   | Precision | Bytes/dim | Compression | Use case                |
10//! |--------|-----------|-----------|-------------|-------------------------|
11//! | Full   | f32       | 4.0       | 1x          | Hot data, exact search   |
12//! | Int8   | 8-bit     | 1.0       | 4x          | Warm data, HNSW search   |
13//! | Int4   | 4-bit     | 0.5       | 8x          | Cool data, pre-filtering |
14//! | Binary | 1-bit     | 0.125     | 32x         | Cold data, coarse filter |
15
16use super::binary::BinaryVector;
17use super::int4::Int4Vector;
18use super::quantized::{QuantizedVector, cosine_similarity_i8_trusted, dot_product_i8_trusted};
19use super::{cosine_similarity, dot_product};
20use crate::error::{EmbedError, Result};
21
22/// Caller assertion that a vector is L2-unit-normalized (norm ≈ 1).
23///
24/// When both query and stored vectors carry `UnitNorm`, cosine similarity equals
25/// the dot product — the norm division can be skipped entirely.
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum NormalizationHint {
28    /// No guarantee — full cosine (with norm division) is required.
29    Unknown,
30    /// Caller asserts this vector is L2-unit-normalized (norm ≈ 1 within 1e-4).
31    Unit,
32}
33
34/// **Unstable**: tier design is under active iteration; tier boundaries may change.
35///
36/// Quantization precision tier, ordered from highest to lowest fidelity.
37#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
38pub enum QuantizationTier {
39    /// Full f32 precision (4 bytes/dim, 1x baseline).
40    Full,
41    /// INT8 symmetric quantization (1 byte/dim, 4x compression).
42    Int8,
43    /// INT4 packed nibble quantization (0.5 bytes/dim, 8x compression).
44    Int4,
45    /// Binary sign-bit quantization (0.125 bytes/dim, 32x compression).
46    Binary,
47}
48
49impl QuantizationTier {
50    /// **Unstable**: bytes-per-dimension constant; may change with new tiers.
51    pub fn bytes_per_dim(&self) -> f32 {
52        match self {
53            Self::Full => 4.0,
54            Self::Int8 => 1.0,
55            Self::Int4 => 0.5,
56            Self::Binary => 0.125,
57        }
58    }
59
60    /// **Unstable**: compression ratio; derived from `bytes_per_dim`, may be removed.
61    pub fn compression_ratio(&self) -> f32 {
62        4.0 / self.bytes_per_dim()
63    }
64
65    /// **Unstable**: storage byte computation; may change with new tiers.
66    pub fn storage_bytes(&self, dims: usize) -> usize {
67        match self {
68            Self::Full => dims * 4,
69            Self::Int8 => dims,
70            Self::Int4 => dims.div_ceil(2),
71            Self::Binary => dims.div_ceil(8),
72        }
73    }
74
75    /// **Unstable**: age-based tier heuristic; boundaries (HOUR/DAY/WEEK) may be tuned.
76    ///
77    /// - Hot (accessed in last hour): Full
78    /// - Warm (accessed in last day): Int8
79    /// - Cool (accessed in last week): Int4
80    /// - Cold (accessed in last month+): Binary
81    pub fn from_age_seconds(age_secs: u64) -> Self {
82        const HOUR: u64 = 3600;
83        const DAY: u64 = 86400;
84        const WEEK: u64 = 604800;
85
86        if age_secs < HOUR {
87            Self::Full
88        } else if age_secs < DAY {
89            Self::Int8
90        } else if age_secs < WEEK {
91            Self::Int4
92        } else {
93            Self::Binary
94        }
95    }
96}
97
98/// **Unstable**: unified quantized data container; variants may change with tier redesign.
99///
100/// Wraps the tier-specific vector types into a single enum for
101/// uniform storage and distance dispatch.
102#[derive(Debug, Clone)]
103pub enum QuantizedData {
104    /// Full-precision f32 vector.
105    Full(Vec<f32>),
106    /// INT8 quantized vector.
107    Int8(QuantizedVector),
108    /// INT4 packed quantized vector.
109    Int4(Int4Vector),
110    /// Binary sign-bit vector.
111    Binary(BinaryVector),
112}
113
114impl QuantizedData {
115    /// **Unstable**: returns `QuantizationTier` which is itself Unstable.
116    pub fn tier(&self) -> QuantizationTier {
117        match self {
118            Self::Full(_) => QuantizationTier::Full,
119            Self::Int8(_) => QuantizationTier::Int8,
120            Self::Int4(_) => QuantizationTier::Int4,
121            Self::Binary(_) => QuantizationTier::Binary,
122        }
123    }
124
125    /// **Unstable**: dimension accessor; may be removed if `QuantizedData` gains a dims field.
126    pub fn dims(&self) -> usize {
127        match self {
128            Self::Full(v) => v.len(),
129            Self::Int8(q) => q.len(),
130            Self::Int4(q) => q.dims,
131            Self::Binary(q) => q.dims,
132        }
133    }
134
135    /// **Unstable**: storage byte count; may change with tier redesign.
136    pub fn storage_bytes(&self) -> usize {
137        match self {
138            Self::Full(v) => v.len() * 4,
139            Self::Int8(q) => q.len(),
140            Self::Int4(q) => q.data.len(),
141            Self::Binary(q) => q.data.len(),
142        }
143    }
144
145    /// **Unstable**: quantization factory; tier dispatch logic may change.
146    pub fn from_f32(vector: &[f32], tier: QuantizationTier) -> Self {
147        match tier {
148            QuantizationTier::Full => Self::Full(vector.to_vec()),
149            QuantizationTier::Int8 => Self::Int8(QuantizedVector::from_f32(vector)),
150            QuantizationTier::Int4 => Self::Int4(Int4Vector::from_f32(vector)),
151            QuantizationTier::Binary => Self::Binary(BinaryVector::from_f32(vector)),
152        }
153    }
154
155    /// **Unstable**: dequantization; output precision is tier-dependent.
156    pub fn to_f32(&self) -> Vec<f32> {
157        match self {
158            Self::Full(v) => v.clone(),
159            Self::Int8(q) => q.to_f32(),
160            Self::Int4(q) => q.to_f32(),
161            Self::Binary(q) => q.to_f32(),
162        }
163    }
164
165    /// **Unstable**: tier promotion; re-quantizes via f32; may be rethought.
166    ///
167    /// Dequantizes to f32 then re-quantizes at the target tier.
168    /// Note: this does NOT recover lost information -- it merely changes
169    /// the storage format. INT4 -> INT8 promotion fills in new bits
170    /// based on the dequantized approximation.
171    pub fn promote(&self, target: QuantizationTier) -> Self {
172        let f32_data = self.to_f32();
173        Self::from_f32(&f32_data, target)
174    }
175
176    /// **Unstable**: tier demotion; delegates to `promote`; may be removed.
177    pub fn demote(&self, target: QuantizationTier) -> Self {
178        self.promote(target) // Same operation, just going the other direction
179    }
180}
181
182/// **Unstable**: pre-quantized query for repeated distance computation.
183///
184/// Quantize a query vector once and reuse it against a homogeneous candidate list,
185/// eliminating per-call `from_f32` overhead. The query tier must match the stored data tier.
186#[derive(Debug, Clone)]
187pub enum PreparedQuery {
188    /// Full f32 query.
189    Full(Vec<f32>),
190    /// INT8 quantized query.
191    Int8(QuantizedVector),
192    /// INT4 packed quantized query.
193    Int4(Int4Vector),
194    /// Binary sign-bit query.
195    Binary(BinaryVector),
196}
197
198impl PreparedQuery {
199    /// Quantize a query at the given tier for repeated distance calls.
200    #[inline]
201    pub fn from_f32(query_f32: &[f32], tier: QuantizationTier) -> Self {
202        match tier {
203            QuantizationTier::Full => Self::Full(query_f32.to_vec()),
204            QuantizationTier::Int8 => Self::Int8(QuantizedVector::from_f32(query_f32)),
205            QuantizationTier::Int4 => Self::Int4(Int4Vector::from_f32(query_f32)),
206            QuantizationTier::Binary => Self::Binary(BinaryVector::from_f32(query_f32)),
207        }
208    }
209
210    /// Returns the quantization tier of this prepared query.
211    #[inline]
212    pub fn tier(&self) -> QuantizationTier {
213        match self {
214            Self::Full(_) => QuantizationTier::Full,
215            Self::Int8(_) => QuantizationTier::Int8,
216            Self::Int4(_) => QuantizationTier::Int4,
217            Self::Binary(_) => QuantizationTier::Binary,
218        }
219    }
220
221    /// Returns the number of dimensions.
222    #[inline]
223    pub fn dims(&self) -> usize {
224        match self {
225            Self::Full(v) => v.len(),
226            Self::Int8(q) => q.len(),
227            Self::Int4(q) => q.dims,
228            Self::Binary(q) => q.dims,
229        }
230    }
231}
232
233/// Prepare a query vector for repeated distance computation against a homogeneous tier.
234#[inline]
235pub fn prepare_query(query_f32: &[f32], tier: QuantizationTier) -> PreparedQuery {
236    PreparedQuery::from_f32(query_f32, tier)
237}
238
239/// A prepared query annotated with a normalization hint for fast-path dispatch.
240///
241/// When `norm == NormalizationHint::Unit` and the stored vector is also unit-normalized,
242/// `approximate_cosine_distance_prepared_with_meta` skips the norm division and uses
243/// `1.0 - dot_product(q, s)` directly — recovering ~26% at 384d on the Full tier.
244#[derive(Debug, Clone)]
245pub struct PreparedQueryWithMeta {
246    /// The quantized query (owns the data).
247    pub query: PreparedQuery,
248    /// Caller assertion about the query vector's normalization state.
249    pub norm: NormalizationHint,
250}
251
252impl PreparedQueryWithMeta {
253    /// Create a prepared query from an f32 vector, asserting its normalization state.
254    #[inline]
255    pub fn from_f32(query_f32: &[f32], tier: QuantizationTier, norm: NormalizationHint) -> Self {
256        Self {
257            query: PreparedQuery::from_f32(query_f32, tier),
258            norm,
259        }
260    }
261
262    /// Returns the quantization tier.
263    #[inline]
264    pub fn tier(&self) -> QuantizationTier {
265        self.query.tier()
266    }
267
268    /// Returns the number of dimensions.
269    #[inline]
270    pub fn dims(&self) -> usize {
271        self.query.dims()
272    }
273}
274
275/// Returns `true` when the squared norm of `v` is within 1e-4 of 1.0.
276#[inline]
277pub fn is_unit_norm(v: &[f32]) -> bool {
278    let sq: f32 = v.iter().map(|x| x * x).sum();
279    (sq - 1.0).abs() < 1e-4
280}
281
282/// Prepare a query annotated with the given normalization hint.
283#[inline]
284pub fn prepare_query_with_norm(
285    query_f32: &[f32],
286    tier: QuantizationTier,
287    norm: NormalizationHint,
288) -> PreparedQueryWithMeta {
289    PreparedQueryWithMeta::from_f32(query_f32, tier, norm)
290}
291
292/// **Unstable**: prepared cosine distance; query tier must match stored data tier.
293///
294/// Returns a value in [0, 2] where 0 = identical, 2 = opposite.
295///
296/// # Errors
297///
298/// Returns [`EmbedError::TierMismatch`] if the query tier does not match the
299/// stored-data tier.
300#[inline]
301pub fn approximate_cosine_distance_prepared(
302    query: &PreparedQuery,
303    stored: &QuantizedData,
304) -> Result<f32> {
305    match (query, stored) {
306        (PreparedQuery::Full(q), QuantizedData::Full(s)) => Ok(1.0 - cosine_similarity(q, s)),
307        (PreparedQuery::Int8(q), QuantizedData::Int8(s)) => {
308            Ok(1.0 - cosine_similarity_i8_trusted(s, q))
309        }
310        (PreparedQuery::Int4(q), QuantizedData::Int4(s)) => Ok(s.cosine_distance(q)),
311        (PreparedQuery::Binary(q), QuantizedData::Binary(s)) => Ok(s.cosine_distance_approx(q)),
312        _ => Err(EmbedError::TierMismatch {
313            op: "approximate_cosine_distance_prepared",
314            expected: stored.tier(),
315            actual: query.tier(),
316        }),
317    }
318}
319
320/// Alias for [`approximate_cosine_distance_prepared`], retained for API compatibility.
321///
322/// As of the tier-mismatch hardening fix (issue #210), the non-`try_` variant no
323/// longer panics and returns the same `Result` as this function. Prefer the
324/// non-`try_` name in new code.
325#[inline]
326pub fn try_approximate_cosine_distance_prepared(
327    query: &PreparedQuery,
328    stored: &QuantizedData,
329) -> Result<f32> {
330    approximate_cosine_distance_prepared(query, stored)
331}
332
333/// Alias for [`approximate_dot_product_prepared`], retained for API compatibility.
334///
335/// As of the tier-mismatch hardening fix (issue #210), the non-`try_` variant no
336/// longer panics and returns the same `Result` as this function. Prefer the
337/// non-`try_` name in new code.
338#[inline]
339pub fn try_approximate_dot_product_prepared(
340    query: &PreparedQuery,
341    stored: &QuantizedData,
342) -> Result<f32> {
343    approximate_dot_product_prepared(query, stored)
344}
345
346/// Cosine distance with unit-norm fast path.
347///
348/// When `meta.norm == NormalizationHint::Unit` and `stored` is a `QuantizedData::Full`
349/// vector whose squared norm is ≈ 1, skips the norm division and returns
350/// `1.0 - clamp(dot(q, s), -1, 1)`. For all other tier/norm combinations, falls
351/// back to `approximate_cosine_distance_prepared`.
352///
353/// The stored vector's unit claim is verified lazily via [`is_unit_norm`]; callers
354/// that batch many lookups against a fixed stored set should pre-check once and
355/// use the cheaper overload directly.
356///
357/// # Errors
358///
359/// Returns [`EmbedError::TierMismatch`] (propagated from
360/// [`approximate_cosine_distance_prepared`]) if the query tier does not match the
361/// stored-data tier. The unit-norm `Full` fast path returns directly and never
362/// reaches the delegate.
363#[inline]
364pub fn approximate_cosine_distance_prepared_with_meta(
365    meta: &PreparedQueryWithMeta,
366    stored: &QuantizedData,
367    stored_norm: NormalizationHint,
368) -> Result<f32> {
369    if meta.norm == NormalizationHint::Unit
370        && stored_norm == NormalizationHint::Unit
371        && let (PreparedQuery::Full(q), QuantizedData::Full(s)) = (&meta.query, stored)
372    {
373        let dot = dot_product(q, s);
374        return Ok(1.0 - dot.clamp(-1.0, 1.0));
375    }
376    approximate_cosine_distance_prepared(&meta.query, stored)
377}
378
379/// **Unstable**: prepared dot product dispatch; query tier must match stored data tier.
380///
381/// # Errors
382///
383/// Returns [`EmbedError::TierMismatch`] if the query tier does not match the
384/// stored-data tier, or [`EmbedError::Internal`] if called with `Binary` data
385/// (binary has no meaningful dot product; use cosine distance instead).
386#[inline]
387pub fn approximate_dot_product_prepared(
388    query: &PreparedQuery,
389    stored: &QuantizedData,
390) -> Result<f32> {
391    match (query, stored) {
392        (PreparedQuery::Full(q), QuantizedData::Full(s)) => Ok(dot_product(q, s)),
393        (PreparedQuery::Int8(q), QuantizedData::Int8(s)) => Ok(dot_product_i8_trusted(q, s)),
394        (PreparedQuery::Int4(q), QuantizedData::Int4(s)) => Ok(s.dot_product(q)),
395        (PreparedQuery::Binary(_), QuantizedData::Binary(_)) => Err(EmbedError::Internal(
396            "Binary has no prepared dot product; use approximate_cosine_distance_prepared".into(),
397        )),
398        _ => Err(EmbedError::TierMismatch {
399            op: "approximate_dot_product_prepared",
400            expected: stored.tier(),
401            actual: query.tier(),
402        }),
403    }
404}
405
406/// Compute cosine distances from one prepared query to a slice of stored vectors.
407///
408/// # Errors
409///
410/// Returns [`EmbedError::TierMismatch`] (propagated from
411/// [`approximate_cosine_distance_prepared`]) if the query tier does not match any
412/// stored vector's tier.
413#[inline]
414pub fn batch_approximate_cosine_distance_prepared(
415    query: &PreparedQuery,
416    stored: &[QuantizedData],
417) -> Result<Vec<f32>> {
418    stored
419        .iter()
420        .map(|item| approximate_cosine_distance_prepared(query, item))
421        .collect()
422}
423
424/// Like [`batch_approximate_cosine_distance_prepared`] but writes into a caller-supplied buffer.
425///
426/// Clears and reuses the buffer to avoid allocations across repeated searches. On error the
427/// buffer is left cleared (no partial results are written).
428///
429/// # Errors
430///
431/// Returns [`EmbedError::TierMismatch`] (propagated from
432/// [`approximate_cosine_distance_prepared`]) if the query tier does not match any
433/// stored vector's tier.
434#[inline]
435pub fn batch_approximate_cosine_distance_prepared_into(
436    query: &PreparedQuery,
437    stored: &[QuantizedData],
438    out: &mut Vec<f32>,
439) -> Result<()> {
440    out.clear();
441    out.reserve(stored.len());
442    for item in stored {
443        match approximate_cosine_distance_prepared(query, item) {
444            Ok(distance) => out.push(distance),
445            Err(e) => {
446                out.clear();
447                return Err(e);
448            }
449        }
450    }
451    Ok(())
452}
453
454/// Compute cosine distances from a prepared INT8 query to a slice of INT8 candidates.
455///
456/// The query is quantized once outside this function; no per-iteration `from_f32` is called.
457///
458/// # Errors
459///
460/// Returns [`EmbedError::TierMismatch`] if `query` is not an `Int8` `PreparedQuery`.
461#[inline]
462pub fn approximate_int8_batch_prepared(
463    query: &PreparedQuery,
464    candidates: &[QuantizedVector],
465) -> Result<Vec<f32>> {
466    let PreparedQuery::Int8(q) = query else {
467        return Err(EmbedError::TierMismatch {
468            op: "approximate_int8_batch_prepared",
469            expected: QuantizationTier::Int8,
470            actual: query.tier(),
471        });
472    };
473    Ok(candidates
474        .iter()
475        .map(|candidate| 1.0 - cosine_similarity_i8_trusted(candidate, q))
476        .collect())
477}
478
479/// Like [`approximate_int8_batch_prepared`] but writes into a caller-supplied buffer.
480///
481/// On error the buffer is left cleared (no partial results are written).
482///
483/// # Errors
484///
485/// Returns [`EmbedError::TierMismatch`] if `query` is not an `Int8` `PreparedQuery`.
486#[inline]
487pub fn approximate_int8_batch_prepared_into(
488    query: &PreparedQuery,
489    candidates: &[QuantizedVector],
490    out: &mut Vec<f32>,
491) -> Result<()> {
492    out.clear();
493    let PreparedQuery::Int8(q) = query else {
494        return Err(EmbedError::TierMismatch {
495            op: "approximate_int8_batch_prepared_into",
496            expected: QuantizationTier::Int8,
497            actual: query.tier(),
498        });
499    };
500    out.reserve(candidates.len());
501    out.extend(
502        candidates
503            .iter()
504            .map(|candidate| 1.0 - cosine_similarity_i8_trusted(candidate, q)),
505    );
506    Ok(())
507}
508
509/// Compute cosine distances from a prepared INT4 query to a slice of INT4 candidates.
510///
511/// The query is quantized once outside this function; no per-iteration `from_f32` is called.
512///
513/// # Errors
514///
515/// Returns [`EmbedError::TierMismatch`] if `query` is not an `Int4` `PreparedQuery`.
516#[inline]
517pub fn approximate_int4_batch_prepared(
518    query: &PreparedQuery,
519    candidates: &[Int4Vector],
520) -> Result<Vec<f32>> {
521    let PreparedQuery::Int4(q) = query else {
522        return Err(EmbedError::TierMismatch {
523            op: "approximate_int4_batch_prepared",
524            expected: QuantizationTier::Int4,
525            actual: query.tier(),
526        });
527    };
528    Ok(candidates
529        .iter()
530        .map(|candidate| candidate.cosine_distance(q))
531        .collect())
532}
533
534/// Like [`approximate_int4_batch_prepared`] but writes into a caller-supplied buffer.
535///
536/// On error the buffer is left cleared (no partial results are written).
537///
538/// # Errors
539///
540/// Returns [`EmbedError::TierMismatch`] if `query` is not an `Int4` `PreparedQuery`.
541#[inline]
542pub fn approximate_int4_batch_prepared_into(
543    query: &PreparedQuery,
544    candidates: &[Int4Vector],
545    out: &mut Vec<f32>,
546) -> Result<()> {
547    out.clear();
548    let PreparedQuery::Int4(q) = query else {
549        return Err(EmbedError::TierMismatch {
550            op: "approximate_int4_batch_prepared_into",
551            expected: QuantizationTier::Int4,
552            actual: query.tier(),
553        });
554    };
555    out.reserve(candidates.len());
556    out.extend(
557        candidates
558            .iter()
559            .map(|candidate| candidate.cosine_distance(q)),
560    );
561    Ok(())
562}
563
564/// **Unstable**: tiered distance dispatch; tier mix and formula may change.
565///
566/// This is the primary distance function for HNSW search with tiered storage.
567/// The query is always f32; the stored data may be at any tier.
568///
569/// Returns a value in [0, 2] where 0 = identical, 2 = opposite.
570///
571/// # Precondition
572///
573/// `query_f32.len()` must equal the stored vector's dimensionality. Violating
574/// this is a caller bug; correct HNSW usage never triggers it.
575pub fn approximate_cosine_distance(query_f32: &[f32], stored: &QuantizedData) -> f32 {
576    debug_assert_eq!(
577        query_f32.len(),
578        stored.dims(),
579        "approximate_cosine_distance: query length {} != stored dims {}",
580        query_f32.len(),
581        stored.dims(),
582    );
583    match stored {
584        QuantizedData::Full(v) => {
585            // Exact cosine distance
586            1.0 - cosine_similarity(query_f32, v)
587        }
588        QuantizedData::Int8(q) => {
589            // Quantize query to INT8, compute via INT8 path
590            let query_q = QuantizedVector::from_f32(query_f32);
591            1.0 - q.cosine_similarity(&query_q)
592        }
593        QuantizedData::Int4(q) => {
594            // Quantize query to INT4, compute via INT4 path
595            let query_q = Int4Vector::from_f32(query_f32);
596            q.cosine_distance(&query_q)
597        }
598        QuantizedData::Binary(q) => {
599            // Quantize query to binary, compute Hamming-based approx
600            let query_q = BinaryVector::from_f32(query_f32);
601            q.cosine_distance_approx(&query_q)
602        }
603    }
604}
605
606/// **Unstable**: tiered dot product dispatch; tier mix and formula may change.
607pub fn approximate_dot_product(query_f32: &[f32], stored: &QuantizedData) -> f32 {
608    match stored {
609        QuantizedData::Full(v) => dot_product(query_f32, v),
610        QuantizedData::Int8(q) => {
611            let query_q = QuantizedVector::from_f32(query_f32);
612            q.dot_product(&query_q)
613        }
614        QuantizedData::Int4(q) => {
615            let query_q = Int4Vector::from_f32(query_f32);
616            q.dot_product(&query_q)
617        }
618        QuantizedData::Binary(_q) => {
619            // Binary doesn't have a meaningful dot product; fall back to dequantize
620            let stored_f32 = _q.to_f32();
621            dot_product(query_f32, &stored_f32)
622        }
623    }
624}
625
626#[cfg(test)]
627mod tests {
628    use super::*;
629
630    fn generate_vector(dim: usize, seed: u64) -> Vec<f32> {
631        let mut state = seed ^ ((dim as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15));
632        (0..dim)
633            .map(|i| {
634                state = state
635                    .wrapping_mul(6364136223846793005)
636                    .wrapping_add(1442695040888963407)
637                    .wrapping_add(i as u64);
638                let unit = ((state >> 32) as u32) as f32 / u32::MAX as f32;
639                unit * 2.0 - 1.0
640            })
641            .collect()
642    }
643
644    #[test]
645    fn test_tier_bytes_per_dim() {
646        assert_eq!(QuantizationTier::Full.bytes_per_dim(), 4.0);
647        assert_eq!(QuantizationTier::Int8.bytes_per_dim(), 1.0);
648        assert_eq!(QuantizationTier::Int4.bytes_per_dim(), 0.5);
649        assert_eq!(QuantizationTier::Binary.bytes_per_dim(), 0.125);
650    }
651
652    #[test]
653    fn test_tier_compression_ratios() {
654        assert_eq!(QuantizationTier::Full.compression_ratio(), 1.0);
655        assert_eq!(QuantizationTier::Int8.compression_ratio(), 4.0);
656        assert_eq!(QuantizationTier::Int4.compression_ratio(), 8.0);
657        assert_eq!(QuantizationTier::Binary.compression_ratio(), 32.0);
658    }
659
660    #[test]
661    fn test_tier_storage_bytes() {
662        assert_eq!(QuantizationTier::Full.storage_bytes(384), 1536);
663        assert_eq!(QuantizationTier::Int8.storage_bytes(384), 384);
664        assert_eq!(QuantizationTier::Int4.storage_bytes(384), 192);
665        assert_eq!(QuantizationTier::Binary.storage_bytes(384), 48);
666    }
667
668    #[test]
669    fn test_tier_from_age() {
670        assert_eq!(
671            QuantizationTier::from_age_seconds(0),
672            QuantizationTier::Full
673        );
674        assert_eq!(
675            QuantizationTier::from_age_seconds(1800),
676            QuantizationTier::Full
677        ); // 30 min
678        assert_eq!(
679            QuantizationTier::from_age_seconds(7200),
680            QuantizationTier::Int8
681        ); // 2 hours
682        assert_eq!(
683            QuantizationTier::from_age_seconds(172800),
684            QuantizationTier::Int4
685        ); // 2 days
686        assert_eq!(
687            QuantizationTier::from_age_seconds(1_000_000),
688            QuantizationTier::Binary
689        ); // ~11 days
690    }
691
692    #[test]
693    fn test_quantized_data_from_f32_all_tiers() {
694        let v = generate_vector(384, 42);
695
696        for tier in [
697            QuantizationTier::Full,
698            QuantizationTier::Int8,
699            QuantizationTier::Int4,
700            QuantizationTier::Binary,
701        ] {
702            let data = QuantizedData::from_f32(&v, tier);
703            assert_eq!(data.tier(), tier, "tier mismatch for {tier:?}");
704            assert_eq!(data.dims(), 384, "dims mismatch for {tier:?}");
705
706            // Verify storage bytes match expected
707            let expected_bytes = tier.storage_bytes(384);
708            assert_eq!(
709                data.storage_bytes(),
710                expected_bytes,
711                "storage bytes mismatch for {tier:?}"
712            );
713        }
714    }
715
716    #[test]
717    fn test_approximate_cosine_distance_ordering() {
718        // Vectors a and b should be "closer" than a and c.
719        let a = generate_vector(384, 1);
720        // b = a + small noise
721        let b: Vec<f32> = a
722            .iter()
723            .enumerate()
724            .map(|(i, &x)| x + 0.05 * (i as f32 * 0.3).sin())
725            .collect();
726        // c = random, uncorrelated
727        let c = generate_vector(384, 999);
728
729        for tier in [
730            QuantizationTier::Full,
731            QuantizationTier::Int8,
732            QuantizationTier::Int4,
733            QuantizationTier::Binary,
734        ] {
735            let stored_b = QuantizedData::from_f32(&b, tier);
736            let stored_c = QuantizedData::from_f32(&c, tier);
737
738            let dist_ab = approximate_cosine_distance(&a, &stored_b);
739            let dist_ac = approximate_cosine_distance(&a, &stored_c);
740
741            // a should be closer to b than to c at all tiers
742            assert!(
743                dist_ab < dist_ac,
744                "{tier:?}: dist(a,b)={dist_ab} should be < dist(a,c)={dist_ac}"
745            );
746        }
747    }
748
749    #[test]
750    fn test_promote_demote_roundtrip() {
751        let v = generate_vector(384, 42);
752        let binary = QuantizedData::from_f32(&v, QuantizationTier::Binary);
753
754        // Promote Binary -> Int4 -> Int8 -> Full
755        let int4 = binary.promote(QuantizationTier::Int4);
756        assert_eq!(int4.tier(), QuantizationTier::Int4);
757
758        let int8 = int4.promote(QuantizationTier::Int8);
759        assert_eq!(int8.tier(), QuantizationTier::Int8);
760
761        let full = int8.promote(QuantizationTier::Full);
762        assert_eq!(full.tier(), QuantizationTier::Full);
763        assert_eq!(full.dims(), 384);
764    }
765
766    #[test]
767    fn test_int8_batch_prepared_matches_per_item_prepared() {
768        let query = generate_vector(384, 42);
769        let prepared = PreparedQuery::from_f32(&query, QuantizationTier::Int8);
770        let candidates: Vec<QuantizedVector> = (0..32)
771            .map(|i| QuantizedVector::from_f32(&generate_vector(384, i + 1)))
772            .collect();
773        let wrapped: Vec<QuantizedData> = candidates
774            .iter()
775            .cloned()
776            .map(QuantizedData::Int8)
777            .collect();
778
779        let got = approximate_int8_batch_prepared(&prepared, &candidates).unwrap();
780        for (i, item) in wrapped.iter().enumerate() {
781            let expected = approximate_cosine_distance_prepared(&prepared, item).unwrap();
782            assert!(
783                (got[i] - expected).abs() < 1e-6,
784                "int8 batch prepared mismatch at candidate {i}: got={}, expected={}",
785                got[i],
786                expected
787            );
788        }
789    }
790
791    #[test]
792    fn test_int4_batch_prepared_matches_per_item_prepared() {
793        let query = generate_vector(384, 42);
794        let prepared = PreparedQuery::from_f32(&query, QuantizationTier::Int4);
795        let candidates: Vec<Int4Vector> = (0..32)
796            .map(|i| Int4Vector::from_f32(&generate_vector(384, i + 1)))
797            .collect();
798        let wrapped: Vec<QuantizedData> = candidates
799            .iter()
800            .cloned()
801            .map(QuantizedData::Int4)
802            .collect();
803
804        let got = approximate_int4_batch_prepared(&prepared, &candidates).unwrap();
805        for (i, item) in wrapped.iter().enumerate() {
806            let expected = approximate_cosine_distance_prepared(&prepared, item).unwrap();
807            assert!(
808                (got[i] - expected).abs() < 1e-5,
809                "int4 batch prepared mismatch at candidate {i}: got={}, expected={}",
810                got[i],
811                expected
812            );
813        }
814    }
815
816    #[test]
817    fn test_int4_batch_prepared_api_dispatch_parity() {
818        // Verify that approximate_int4_batch_prepared produces the same cosine distance
819        // as approximate_cosine_distance_prepared for each candidate. On aarch64 both
820        // sides dispatch to NEON; on other targets both use the packed scalar fallback.
821        // For direct scalar-vs-NEON integer parity, see int4::tests::test_packed_scalar_matches_neon_exact.
822        for dim in [1usize, 3, 31, 127, 383, 384] {
823            let query = generate_vector(dim, 700 + dim as u64);
824            let candidate = generate_vector(dim, 800 + dim as u64);
825            let prepared = PreparedQuery::from_f32(&query, QuantizationTier::Int4);
826            let q_cand = Int4Vector::from_f32(&candidate);
827            let wrapped = QuantizedData::Int4(q_cand.clone());
828
829            let batch_result = approximate_int4_batch_prepared(&prepared, &[q_cand]).unwrap();
830            let per_item_result =
831                approximate_cosine_distance_prepared(&prepared, &wrapped).unwrap();
832
833            assert!(
834                (batch_result[0] - per_item_result).abs() < 1e-5,
835                "int4 batch prepared dispatch mismatch at dim={dim}: batch={}, per_item={}",
836                batch_result[0],
837                per_item_result
838            );
839        }
840    }
841
842    #[test]
843    fn test_quantized_data_to_f32_roundtrip() {
844        let v = generate_vector(384, 55);
845
846        // Full tier should be lossless
847        let full_data = QuantizedData::from_f32(&v, QuantizationTier::Full);
848        let full_rt = full_data.to_f32();
849        for (a, b) in v.iter().zip(full_rt.iter()) {
850            assert!((a - b).abs() < 1e-10, "Full tier should be lossless");
851        }
852    }
853
854    // ------------------------------------------------------------------
855    // Regression tests for issue #210: tier-mismatch in prepared SIMD
856    // dispatch must return a typed error, not panic.
857    // ------------------------------------------------------------------
858
859    #[test]
860    fn test_cosine_distance_prepared_tier_mismatch_returns_typed_error() {
861        let v = generate_vector(64, 1);
862        let query = PreparedQuery::from_f32(&v, QuantizationTier::Int8);
863        let stored = QuantizedData::from_f32(&v, QuantizationTier::Int4);
864
865        let err = approximate_cosine_distance_prepared(&query, &stored).unwrap_err();
866        match err {
867            EmbedError::TierMismatch {
868                op,
869                expected,
870                actual,
871            } => {
872                assert_eq!(op, "approximate_cosine_distance_prepared");
873                assert_eq!(expected, QuantizationTier::Int4);
874                assert_eq!(actual, QuantizationTier::Int8);
875            }
876            other => panic!("expected TierMismatch, got {other:?}"),
877        }
878
879        // try_ alias must agree.
880        assert!(try_approximate_cosine_distance_prepared(&query, &stored).is_err());
881    }
882
883    #[test]
884    fn test_dot_product_prepared_tier_mismatch_returns_typed_error() {
885        let v = generate_vector(64, 2);
886        let query = PreparedQuery::from_f32(&v, QuantizationTier::Full);
887        let stored = QuantizedData::from_f32(&v, QuantizationTier::Int8);
888
889        let err = approximate_dot_product_prepared(&query, &stored).unwrap_err();
890        assert!(
891            matches!(
892                err,
893                EmbedError::TierMismatch {
894                    op: "approximate_dot_product_prepared",
895                    ..
896                }
897            ),
898            "unexpected error variant: {err:?}"
899        );
900
901        assert!(try_approximate_dot_product_prepared(&query, &stored).is_err());
902    }
903
904    #[test]
905    fn test_dot_product_prepared_binary_returns_typed_error_not_panic() {
906        let v = generate_vector(64, 3);
907        let query = PreparedQuery::from_f32(&v, QuantizationTier::Binary);
908        let stored = QuantizedData::from_f32(&v, QuantizationTier::Binary);
909
910        let err = approximate_dot_product_prepared(&query, &stored).unwrap_err();
911        assert!(
912            matches!(err, EmbedError::Internal(_)),
913            "unexpected error variant: {err:?}"
914        );
915    }
916
917    #[test]
918    fn test_cosine_distance_prepared_with_meta_tier_mismatch_returns_typed_error() {
919        let v = generate_vector(64, 4);
920        let meta =
921            PreparedQueryWithMeta::from_f32(&v, QuantizationTier::Full, NormalizationHint::Unknown);
922        let stored = QuantizedData::from_f32(&v, QuantizationTier::Int8);
923
924        let err = approximate_cosine_distance_prepared_with_meta(
925            &meta,
926            &stored,
927            NormalizationHint::Unknown,
928        )
929        .unwrap_err();
930        assert!(matches!(err, EmbedError::TierMismatch { .. }));
931    }
932
933    #[test]
934    fn test_batch_cosine_distance_prepared_tier_mismatch_returns_typed_error() {
935        let v = generate_vector(64, 5);
936        let query = PreparedQuery::from_f32(&v, QuantizationTier::Int8);
937        let stored = vec![
938            QuantizedData::from_f32(&v, QuantizationTier::Int8),
939            QuantizedData::from_f32(&v, QuantizationTier::Int4), // mismatched
940        ];
941
942        let err = batch_approximate_cosine_distance_prepared(&query, &stored).unwrap_err();
943        assert!(matches!(err, EmbedError::TierMismatch { .. }));
944
945        let mut out = vec![9.0, 9.0, 9.0]; // pre-populated, must be cleared even on error
946        let err =
947            batch_approximate_cosine_distance_prepared_into(&query, &stored, &mut out).unwrap_err();
948        assert!(matches!(err, EmbedError::TierMismatch { .. }));
949        assert!(
950            out.is_empty(),
951            "buffer must be cleared, not left with stale data"
952        );
953    }
954
955    #[test]
956    fn test_int8_batch_prepared_wrong_tier_returns_typed_error() {
957        let v = generate_vector(64, 6);
958        let query = PreparedQuery::from_f32(&v, QuantizationTier::Int4); // not Int8
959        let candidates = vec![QuantizedVector::from_f32(&v)];
960
961        let err = approximate_int8_batch_prepared(&query, &candidates).unwrap_err();
962        match err {
963            EmbedError::TierMismatch {
964                op,
965                expected,
966                actual,
967            } => {
968                assert_eq!(op, "approximate_int8_batch_prepared");
969                assert_eq!(expected, QuantizationTier::Int8);
970                assert_eq!(actual, QuantizationTier::Int4);
971            }
972            other => panic!("expected TierMismatch, got {other:?}"),
973        }
974
975        let mut out = vec![9.0];
976        let err = approximate_int8_batch_prepared_into(&query, &candidates, &mut out).unwrap_err();
977        assert!(matches!(err, EmbedError::TierMismatch { .. }));
978        assert!(
979            out.is_empty(),
980            "buffer must be cleared, not left with stale data"
981        );
982    }
983
984    #[test]
985    fn test_int4_batch_prepared_wrong_tier_returns_typed_error() {
986        let v = generate_vector(64, 7);
987        let query = PreparedQuery::from_f32(&v, QuantizationTier::Int8); // not Int4
988        let candidates = vec![Int4Vector::from_f32(&v)];
989
990        let err = approximate_int4_batch_prepared(&query, &candidates).unwrap_err();
991        match err {
992            EmbedError::TierMismatch {
993                op,
994                expected,
995                actual,
996            } => {
997                assert_eq!(op, "approximate_int4_batch_prepared");
998                assert_eq!(expected, QuantizationTier::Int4);
999                assert_eq!(actual, QuantizationTier::Int8);
1000            }
1001            other => panic!("expected TierMismatch, got {other:?}"),
1002        }
1003
1004        let mut out = vec![9.0];
1005        let err = approximate_int4_batch_prepared_into(&query, &candidates, &mut out).unwrap_err();
1006        assert!(matches!(err, EmbedError::TierMismatch { .. }));
1007        assert!(
1008            out.is_empty(),
1009            "buffer must be cleared, not left with stale data"
1010        );
1011    }
1012}