lattice-embed 0.2.2

SIMD-accelerated vector operations and embedding generation
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
//! Quantization tier management and unified distance dispatch.
//!
//! Provides a `QuantizationTier` enum for selecting precision levels and
//! a `QuantizedData` enum for storing vectors at any tier with unified
//! distance computation.
//!
//! ## Tier hierarchy
//!
//! | Tier   | Precision | Bytes/dim | Compression | Use case                |
//! |--------|-----------|-----------|-------------|-------------------------|
//! | Full   | f32       | 4.0       | 1x          | Hot data, exact search   |
//! | Int8   | 8-bit     | 1.0       | 4x          | Warm data, HNSW search   |
//! | Int4   | 4-bit     | 0.5       | 8x          | Cool data, pre-filtering |
//! | Binary | 1-bit     | 0.125     | 32x         | Cold data, coarse filter |

use super::binary::BinaryVector;
use super::int4::Int4Vector;
use super::quantized::{QuantizedVector, cosine_similarity_i8_trusted, dot_product_i8_trusted};
use super::{cosine_similarity, dot_product};
use crate::error::{EmbedError, Result};

/// Caller assertion that a vector is L2-unit-normalized (norm ≈ 1).
///
/// When both query and stored vectors carry `UnitNorm`, cosine similarity equals
/// the dot product — the norm division can be skipped entirely.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NormalizationHint {
    /// No guarantee — full cosine (with norm division) is required.
    Unknown,
    /// Caller asserts this vector is L2-unit-normalized (norm ≈ 1 within 1e-4).
    Unit,
}

/// **Unstable**: tier design is under active iteration; tier boundaries may change.
///
/// Quantization precision tier, ordered from highest to lowest fidelity.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum QuantizationTier {
    /// Full f32 precision (4 bytes/dim, 1x baseline).
    Full,
    /// INT8 symmetric quantization (1 byte/dim, 4x compression).
    Int8,
    /// INT4 packed nibble quantization (0.5 bytes/dim, 8x compression).
    Int4,
    /// Binary sign-bit quantization (0.125 bytes/dim, 32x compression).
    Binary,
}

impl QuantizationTier {
    /// **Unstable**: bytes-per-dimension constant; may change with new tiers.
    pub fn bytes_per_dim(&self) -> f32 {
        match self {
            Self::Full => 4.0,
            Self::Int8 => 1.0,
            Self::Int4 => 0.5,
            Self::Binary => 0.125,
        }
    }

    /// **Unstable**: compression ratio; derived from `bytes_per_dim`, may be removed.
    pub fn compression_ratio(&self) -> f32 {
        4.0 / self.bytes_per_dim()
    }

    /// **Unstable**: storage byte computation; may change with new tiers.
    pub fn storage_bytes(&self, dims: usize) -> usize {
        match self {
            Self::Full => dims * 4,
            Self::Int8 => dims,
            Self::Int4 => dims.div_ceil(2),
            Self::Binary => dims.div_ceil(8),
        }
    }

    /// **Unstable**: age-based tier heuristic; boundaries (HOUR/DAY/WEEK) may be tuned.
    ///
    /// - Hot (accessed in last hour): Full
    /// - Warm (accessed in last day): Int8
    /// - Cool (accessed in last week): Int4
    /// - Cold (accessed in last month+): Binary
    pub fn from_age_seconds(age_secs: u64) -> Self {
        const HOUR: u64 = 3600;
        const DAY: u64 = 86400;
        const WEEK: u64 = 604800;

        if age_secs < HOUR {
            Self::Full
        } else if age_secs < DAY {
            Self::Int8
        } else if age_secs < WEEK {
            Self::Int4
        } else {
            Self::Binary
        }
    }
}

/// **Unstable**: unified quantized data container; variants may change with tier redesign.
///
/// Wraps the tier-specific vector types into a single enum for
/// uniform storage and distance dispatch.
#[derive(Debug, Clone)]
pub enum QuantizedData {
    /// Full-precision f32 vector.
    Full(Vec<f32>),
    /// INT8 quantized vector.
    Int8(QuantizedVector),
    /// INT4 packed quantized vector.
    Int4(Int4Vector),
    /// Binary sign-bit vector.
    Binary(BinaryVector),
}

impl QuantizedData {
    /// **Unstable**: returns `QuantizationTier` which is itself Unstable.
    pub fn tier(&self) -> QuantizationTier {
        match self {
            Self::Full(_) => QuantizationTier::Full,
            Self::Int8(_) => QuantizationTier::Int8,
            Self::Int4(_) => QuantizationTier::Int4,
            Self::Binary(_) => QuantizationTier::Binary,
        }
    }

    /// **Unstable**: dimension accessor; may be removed if `QuantizedData` gains a dims field.
    pub fn dims(&self) -> usize {
        match self {
            Self::Full(v) => v.len(),
            Self::Int8(q) => q.data.len(),
            Self::Int4(q) => q.dims,
            Self::Binary(q) => q.dims,
        }
    }

    /// **Unstable**: storage byte count; may change with tier redesign.
    pub fn storage_bytes(&self) -> usize {
        match self {
            Self::Full(v) => v.len() * 4,
            Self::Int8(q) => q.data.len(),
            Self::Int4(q) => q.data.len(),
            Self::Binary(q) => q.data.len(),
        }
    }

    /// **Unstable**: quantization factory; tier dispatch logic may change.
    pub fn from_f32(vector: &[f32], tier: QuantizationTier) -> Self {
        match tier {
            QuantizationTier::Full => Self::Full(vector.to_vec()),
            QuantizationTier::Int8 => Self::Int8(QuantizedVector::from_f32(vector)),
            QuantizationTier::Int4 => Self::Int4(Int4Vector::from_f32(vector)),
            QuantizationTier::Binary => Self::Binary(BinaryVector::from_f32(vector)),
        }
    }

    /// **Unstable**: dequantization; output precision is tier-dependent.
    pub fn to_f32(&self) -> Vec<f32> {
        match self {
            Self::Full(v) => v.clone(),
            Self::Int8(q) => q.to_f32(),
            Self::Int4(q) => q.to_f32(),
            Self::Binary(q) => q.to_f32(),
        }
    }

    /// **Unstable**: tier promotion; re-quantizes via f32; may be rethought.
    ///
    /// Dequantizes to f32 then re-quantizes at the target tier.
    /// Note: this does NOT recover lost information -- it merely changes
    /// the storage format. INT4 -> INT8 promotion fills in new bits
    /// based on the dequantized approximation.
    pub fn promote(&self, target: QuantizationTier) -> Self {
        let f32_data = self.to_f32();
        Self::from_f32(&f32_data, target)
    }

    /// **Unstable**: tier demotion; delegates to `promote`; may be removed.
    pub fn demote(&self, target: QuantizationTier) -> Self {
        self.promote(target) // Same operation, just going the other direction
    }
}

/// **Unstable**: pre-quantized query for repeated distance computation.
///
/// Quantize a query vector once and reuse it against a homogeneous candidate list,
/// eliminating per-call `from_f32` overhead. The query tier must match the stored data tier.
#[derive(Debug, Clone)]
pub enum PreparedQuery {
    /// Full f32 query.
    Full(Vec<f32>),
    /// INT8 quantized query.
    Int8(QuantizedVector),
    /// INT4 packed quantized query.
    Int4(Int4Vector),
    /// Binary sign-bit query.
    Binary(BinaryVector),
}

impl PreparedQuery {
    /// Quantize a query at the given tier for repeated distance calls.
    #[inline]
    pub fn from_f32(query_f32: &[f32], tier: QuantizationTier) -> Self {
        match tier {
            QuantizationTier::Full => Self::Full(query_f32.to_vec()),
            QuantizationTier::Int8 => Self::Int8(QuantizedVector::from_f32(query_f32)),
            QuantizationTier::Int4 => Self::Int4(Int4Vector::from_f32(query_f32)),
            QuantizationTier::Binary => Self::Binary(BinaryVector::from_f32(query_f32)),
        }
    }

    /// Returns the quantization tier of this prepared query.
    #[inline]
    pub fn tier(&self) -> QuantizationTier {
        match self {
            Self::Full(_) => QuantizationTier::Full,
            Self::Int8(_) => QuantizationTier::Int8,
            Self::Int4(_) => QuantizationTier::Int4,
            Self::Binary(_) => QuantizationTier::Binary,
        }
    }

    /// Returns the number of dimensions.
    #[inline]
    pub fn dims(&self) -> usize {
        match self {
            Self::Full(v) => v.len(),
            Self::Int8(q) => q.data.len(),
            Self::Int4(q) => q.dims,
            Self::Binary(q) => q.dims,
        }
    }
}

/// Prepare a query vector for repeated distance computation against a homogeneous tier.
#[inline]
pub fn prepare_query(query_f32: &[f32], tier: QuantizationTier) -> PreparedQuery {
    PreparedQuery::from_f32(query_f32, tier)
}

/// A prepared query annotated with a normalization hint for fast-path dispatch.
///
/// When `norm == NormalizationHint::Unit` and the stored vector is also unit-normalized,
/// `approximate_cosine_distance_prepared_with_meta` skips the norm division and uses
/// `1.0 - dot_product(q, s)` directly — recovering ~26% at 384d on the Full tier.
#[derive(Debug, Clone)]
pub struct PreparedQueryWithMeta {
    /// The quantized query (owns the data).
    pub query: PreparedQuery,
    /// Caller assertion about the query vector's normalization state.
    pub norm: NormalizationHint,
}

impl PreparedQueryWithMeta {
    /// Create a prepared query from an f32 vector, asserting its normalization state.
    #[inline]
    pub fn from_f32(query_f32: &[f32], tier: QuantizationTier, norm: NormalizationHint) -> Self {
        Self {
            query: PreparedQuery::from_f32(query_f32, tier),
            norm,
        }
    }

    /// Returns the quantization tier.
    #[inline]
    pub fn tier(&self) -> QuantizationTier {
        self.query.tier()
    }

    /// Returns the number of dimensions.
    #[inline]
    pub fn dims(&self) -> usize {
        self.query.dims()
    }
}

/// Returns `true` when the squared norm of `v` is within 1e-4 of 1.0.
#[inline]
pub fn is_unit_norm(v: &[f32]) -> bool {
    let sq: f32 = v.iter().map(|x| x * x).sum();
    (sq - 1.0).abs() < 1e-4
}

/// Prepare a query annotated with the given normalization hint.
#[inline]
pub fn prepare_query_with_norm(
    query_f32: &[f32],
    tier: QuantizationTier,
    norm: NormalizationHint,
) -> PreparedQueryWithMeta {
    PreparedQueryWithMeta::from_f32(query_f32, tier, norm)
}

/// **Unstable**: prepared cosine distance; query tier must match stored data tier.
///
/// Returns a value in [0, 2] where 0 = identical, 2 = opposite.
///
/// # Panics
///
/// Panics if the query tier does not match the stored-data tier.  Use
/// [`try_approximate_cosine_distance_prepared`] when tiers are not statically
/// guaranteed to match.
#[inline]
pub fn approximate_cosine_distance_prepared(query: &PreparedQuery, stored: &QuantizedData) -> f32 {
    match (query, stored) {
        (PreparedQuery::Full(q), QuantizedData::Full(s)) => 1.0 - cosine_similarity(q, s),
        (PreparedQuery::Int8(q), QuantizedData::Int8(s)) => {
            1.0 - cosine_similarity_i8_trusted(s, q)
        }
        (PreparedQuery::Int4(q), QuantizedData::Int4(s)) => s.cosine_distance(q),
        (PreparedQuery::Binary(q), QuantizedData::Binary(s)) => s.cosine_distance_approx(q),
        _ => panic!("PreparedQuery tier must match QuantizedData tier"),
    }
}

/// Non-panicking variant of [`approximate_cosine_distance_prepared`].
///
/// Returns `Err(EmbedError::Internal(...))` when the query tier does not match the
/// stored-data tier instead of panicking.  Prefer this in contexts where the tiers
/// may not be statically guaranteed to match.
#[inline]
pub fn try_approximate_cosine_distance_prepared(
    query: &PreparedQuery,
    stored: &QuantizedData,
) -> Result<f32> {
    match (query, stored) {
        (PreparedQuery::Full(q), QuantizedData::Full(s)) => Ok(1.0 - cosine_similarity(q, s)),
        (PreparedQuery::Int8(q), QuantizedData::Int8(s)) => {
            Ok(1.0 - cosine_similarity_i8_trusted(s, q))
        }
        (PreparedQuery::Int4(q), QuantizedData::Int4(s)) => Ok(s.cosine_distance(q)),
        (PreparedQuery::Binary(q), QuantizedData::Binary(s)) => Ok(s.cosine_distance_approx(q)),
        _ => Err(EmbedError::Internal(
            "PreparedQuery tier must match QuantizedData tier for cosine distance".into(),
        )),
    }
}

/// Non-panicking variant of [`approximate_dot_product_prepared`].
///
/// Returns `Err(EmbedError::Internal(...))` for binary inputs or a tier mismatch
/// instead of panicking.
#[inline]
pub fn try_approximate_dot_product_prepared(
    query: &PreparedQuery,
    stored: &QuantizedData,
) -> Result<f32> {
    match (query, stored) {
        (PreparedQuery::Full(q), QuantizedData::Full(s)) => Ok(dot_product(q, s)),
        (PreparedQuery::Int8(q), QuantizedData::Int8(s)) => Ok(dot_product_i8_trusted(q, s)),
        (PreparedQuery::Int4(q), QuantizedData::Int4(s)) => Ok(s.dot_product(q)),
        (PreparedQuery::Binary(_), QuantizedData::Binary(_)) => Err(EmbedError::Internal(
            "Binary has no prepared dot product; use try_approximate_cosine_distance_prepared"
                .into(),
        )),
        _ => Err(EmbedError::Internal(
            "PreparedQuery tier must match QuantizedData tier for dot product".into(),
        )),
    }
}

/// Cosine distance with unit-norm fast path.
///
/// When `meta.norm == NormalizationHint::Unit` and `stored` is a `QuantizedData::Full`
/// vector whose squared norm is ≈ 1, skips the norm division and returns
/// `1.0 - clamp(dot(q, s), -1, 1)`. For all other tier/norm combinations, falls
/// back to `approximate_cosine_distance_prepared`.
///
/// The stored vector's unit claim is verified lazily via [`is_unit_norm`]; callers
/// that batch many lookups against a fixed stored set should pre-check once and
/// use the cheaper overload directly.
#[inline]
pub fn approximate_cosine_distance_prepared_with_meta(
    meta: &PreparedQueryWithMeta,
    stored: &QuantizedData,
    stored_norm: NormalizationHint,
) -> f32 {
    if meta.norm == NormalizationHint::Unit && stored_norm == NormalizationHint::Unit {
        if let (PreparedQuery::Full(q), QuantizedData::Full(s)) = (&meta.query, stored) {
            let dot = dot_product(q, s);
            return 1.0 - dot.clamp(-1.0, 1.0);
        }
    }
    approximate_cosine_distance_prepared(&meta.query, stored)
}

/// **Unstable**: prepared dot product dispatch; query tier must match stored data tier.
///
/// # Panics
///
/// Panics if the query tier does not match the stored-data tier, or if called
/// with `Binary` data (binary has no meaningful dot product; use cosine distance
/// instead).  Use [`try_approximate_dot_product_prepared`] for a non-panicking
/// version.
#[inline]
pub fn approximate_dot_product_prepared(query: &PreparedQuery, stored: &QuantizedData) -> f32 {
    match (query, stored) {
        (PreparedQuery::Full(q), QuantizedData::Full(s)) => dot_product(q, s),
        (PreparedQuery::Int8(q), QuantizedData::Int8(s)) => dot_product_i8_trusted(q, s),
        (PreparedQuery::Int4(q), QuantizedData::Int4(s)) => s.dot_product(q),
        (PreparedQuery::Binary(_), QuantizedData::Binary(_)) => {
            panic!("Binary has no prepared dot product; use approximate_cosine_distance_prepared")
        }
        _ => panic!("PreparedQuery tier must match QuantizedData tier"),
    }
}

/// Compute cosine distances from one prepared query to a slice of stored vectors.
#[inline]
pub fn batch_approximate_cosine_distance_prepared(
    query: &PreparedQuery,
    stored: &[QuantizedData],
) -> Vec<f32> {
    stored
        .iter()
        .map(|item| approximate_cosine_distance_prepared(query, item))
        .collect()
}

/// Like [`batch_approximate_cosine_distance_prepared`] but writes into a caller-supplied buffer.
///
/// Clears and reuses the buffer to avoid allocations across repeated searches.
#[inline]
pub fn batch_approximate_cosine_distance_prepared_into(
    query: &PreparedQuery,
    stored: &[QuantizedData],
    out: &mut Vec<f32>,
) {
    out.clear();
    out.reserve(stored.len());
    out.extend(
        stored
            .iter()
            .map(|item| approximate_cosine_distance_prepared(query, item)),
    );
}

/// **Unstable**: tiered distance dispatch; tier mix and formula may change.
///
/// This is the primary distance function for HNSW search with tiered storage.
/// The query is always f32; the stored data may be at any tier.
///
/// Returns a value in [0, 2] where 0 = identical, 2 = opposite.
pub fn approximate_cosine_distance(query_f32: &[f32], stored: &QuantizedData) -> f32 {
    match stored {
        QuantizedData::Full(v) => {
            // Exact cosine distance
            1.0 - cosine_similarity(query_f32, v)
        }
        QuantizedData::Int8(q) => {
            // Quantize query to INT8, compute via INT8 path
            let query_q = QuantizedVector::from_f32(query_f32);
            1.0 - q.cosine_similarity(&query_q)
        }
        QuantizedData::Int4(q) => {
            // Quantize query to INT4, compute via INT4 path
            let query_q = Int4Vector::from_f32(query_f32);
            q.cosine_distance(&query_q)
        }
        QuantizedData::Binary(q) => {
            // Quantize query to binary, compute Hamming-based approx
            let query_q = BinaryVector::from_f32(query_f32);
            q.cosine_distance_approx(&query_q)
        }
    }
}

/// **Unstable**: tiered dot product dispatch; tier mix and formula may change.
pub fn approximate_dot_product(query_f32: &[f32], stored: &QuantizedData) -> f32 {
    match stored {
        QuantizedData::Full(v) => dot_product(query_f32, v),
        QuantizedData::Int8(q) => {
            let query_q = QuantizedVector::from_f32(query_f32);
            q.dot_product(&query_q)
        }
        QuantizedData::Int4(q) => {
            let query_q = Int4Vector::from_f32(query_f32);
            q.dot_product(&query_q)
        }
        QuantizedData::Binary(_q) => {
            // Binary doesn't have a meaningful dot product; fall back to dequantize
            let stored_f32 = _q.to_f32();
            dot_product(query_f32, &stored_f32)
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn generate_vector(dim: usize, seed: u64) -> Vec<f32> {
        let mut state = seed ^ ((dim as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15));
        (0..dim)
            .map(|i| {
                state = state
                    .wrapping_mul(6364136223846793005)
                    .wrapping_add(1442695040888963407)
                    .wrapping_add(i as u64);
                let unit = ((state >> 32) as u32) as f32 / u32::MAX as f32;
                unit * 2.0 - 1.0
            })
            .collect()
    }

    #[test]
    fn test_tier_bytes_per_dim() {
        assert_eq!(QuantizationTier::Full.bytes_per_dim(), 4.0);
        assert_eq!(QuantizationTier::Int8.bytes_per_dim(), 1.0);
        assert_eq!(QuantizationTier::Int4.bytes_per_dim(), 0.5);
        assert_eq!(QuantizationTier::Binary.bytes_per_dim(), 0.125);
    }

    #[test]
    fn test_tier_compression_ratios() {
        assert_eq!(QuantizationTier::Full.compression_ratio(), 1.0);
        assert_eq!(QuantizationTier::Int8.compression_ratio(), 4.0);
        assert_eq!(QuantizationTier::Int4.compression_ratio(), 8.0);
        assert_eq!(QuantizationTier::Binary.compression_ratio(), 32.0);
    }

    #[test]
    fn test_tier_storage_bytes() {
        assert_eq!(QuantizationTier::Full.storage_bytes(384), 1536);
        assert_eq!(QuantizationTier::Int8.storage_bytes(384), 384);
        assert_eq!(QuantizationTier::Int4.storage_bytes(384), 192);
        assert_eq!(QuantizationTier::Binary.storage_bytes(384), 48);
    }

    #[test]
    fn test_tier_from_age() {
        assert_eq!(
            QuantizationTier::from_age_seconds(0),
            QuantizationTier::Full
        );
        assert_eq!(
            QuantizationTier::from_age_seconds(1800),
            QuantizationTier::Full
        ); // 30 min
        assert_eq!(
            QuantizationTier::from_age_seconds(7200),
            QuantizationTier::Int8
        ); // 2 hours
        assert_eq!(
            QuantizationTier::from_age_seconds(172800),
            QuantizationTier::Int4
        ); // 2 days
        assert_eq!(
            QuantizationTier::from_age_seconds(1_000_000),
            QuantizationTier::Binary
        ); // ~11 days
    }

    #[test]
    fn test_quantized_data_from_f32_all_tiers() {
        let v = generate_vector(384, 42);

        for tier in [
            QuantizationTier::Full,
            QuantizationTier::Int8,
            QuantizationTier::Int4,
            QuantizationTier::Binary,
        ] {
            let data = QuantizedData::from_f32(&v, tier);
            assert_eq!(data.tier(), tier, "tier mismatch for {tier:?}");
            assert_eq!(data.dims(), 384, "dims mismatch for {tier:?}");

            // Verify storage bytes match expected
            let expected_bytes = tier.storage_bytes(384);
            assert_eq!(
                data.storage_bytes(),
                expected_bytes,
                "storage bytes mismatch for {tier:?}"
            );
        }
    }

    #[test]
    fn test_approximate_cosine_distance_ordering() {
        // Vectors a and b should be "closer" than a and c.
        let a = generate_vector(384, 1);
        // b = a + small noise
        let b: Vec<f32> = a
            .iter()
            .enumerate()
            .map(|(i, &x)| x + 0.05 * (i as f32 * 0.3).sin())
            .collect();
        // c = random, uncorrelated
        let c = generate_vector(384, 999);

        for tier in [
            QuantizationTier::Full,
            QuantizationTier::Int8,
            QuantizationTier::Int4,
            QuantizationTier::Binary,
        ] {
            let stored_b = QuantizedData::from_f32(&b, tier);
            let stored_c = QuantizedData::from_f32(&c, tier);

            let dist_ab = approximate_cosine_distance(&a, &stored_b);
            let dist_ac = approximate_cosine_distance(&a, &stored_c);

            // a should be closer to b than to c at all tiers
            assert!(
                dist_ab < dist_ac,
                "{tier:?}: dist(a,b)={dist_ab} should be < dist(a,c)={dist_ac}"
            );
        }
    }

    #[test]
    fn test_promote_demote_roundtrip() {
        let v = generate_vector(384, 42);
        let binary = QuantizedData::from_f32(&v, QuantizationTier::Binary);

        // Promote Binary -> Int4 -> Int8 -> Full
        let int4 = binary.promote(QuantizationTier::Int4);
        assert_eq!(int4.tier(), QuantizationTier::Int4);

        let int8 = int4.promote(QuantizationTier::Int8);
        assert_eq!(int8.tier(), QuantizationTier::Int8);

        let full = int8.promote(QuantizationTier::Full);
        assert_eq!(full.tier(), QuantizationTier::Full);
        assert_eq!(full.dims(), 384);
    }

    #[test]
    fn test_quantized_data_to_f32_roundtrip() {
        let v = generate_vector(384, 55);

        // Full tier should be lossless
        let full_data = QuantizedData::from_f32(&v, QuantizationTier::Full);
        let full_rt = full_data.to_f32();
        for (a, b) in v.iter().zip(full_rt.iter()) {
            assert!((a - b).abs() < 1e-10, "Full tier should be lossless");
        }
    }
}