Skip to main content

lattice_embed/simd/
int4.rs

1//! INT4 quantization for ultra-compact embedding storage.
2//!
3//! Two 4-bit values packed per byte (8x compression vs f32).
4//! Uses symmetric unsigned quantization: maps [-max_abs, max_abs] to [0, 15].
5//!
6//! ## Packing format
7//!
8//! High nibble = even index, low nibble = odd index.
9//! For D dimensions, storage is `ceil(D / 2)` bytes.
10//!
11//! ## Dot product
12//!
13//! Dot products dequantize before accumulation so the unsigned INT4 offset is
14//! handled identically on every target.
15
16#[cfg(target_arch = "aarch64")]
17use std::arch::aarch64::*;
18
19#[cfg(target_arch = "aarch64")]
20use super::simd_config;
21
22/// **Unstable**: INT4 quantization internals; scale/bias scheme may change.
23///
24/// Quantization parameters for INT4 conversion.
25///
26/// Uses symmetric unsigned quantization: the float range [-max_abs, max_abs]
27/// is mapped to the integer range [0, 15].
28#[derive(Debug, Clone, Copy)]
29pub struct Int4Params {
30    /// **Unstable**: scale factor; formula may change with quantization scheme update.
31    pub scale: f32,
32    /// **Unstable**: maximum absolute value; field may be removed.
33    pub max_abs: f32,
34}
35
36impl Int4Params {
37    /// **Unstable**: quantization parameter computation; may be folded into `Int4Vector::from_f32`.
38    pub fn from_vector(vector: &[f32]) -> Self {
39        let mut max_abs: f32 = 0.0;
40        for &v in vector {
41            if v.is_finite() {
42                max_abs = max_abs.max(v.abs());
43            }
44        }
45
46        // Epsilon guard: avoid division by near-zero
47        let scale = if max_abs > 1e-10 {
48            15.0 / (2.0 * max_abs)
49        } else {
50            1.0
51        };
52
53        Self { scale, max_abs }
54    }
55}
56
57/// **Unstable**: INT4 quantization format is under active design; struct layout may change.
58///
59/// Quantized INT4 vector with packed nibble storage.
60#[derive(Debug, Clone)]
61pub struct Int4Vector {
62    /// **Unstable**: packed nibble data; bit packing scheme may change.
63    pub data: Vec<u8>,
64    /// **Unstable**: number of original dimensions.
65    pub dims: usize,
66    /// **Unstable**: quantization parameters; may be separated from the vector.
67    pub params: Int4Params,
68    /// **Unstable**: L2 norm; may be removed or moved.
69    pub norm: f32,
70}
71
72impl Int4Vector {
73    /// **Unstable**: quantization format; nibble packing may change.
74    ///
75    /// Each pair of consecutive dimensions is packed into one byte:
76    /// - High nibble (bits 7..4) = even-indexed value
77    /// - Low nibble (bits 3..0) = odd-indexed value
78    pub fn from_f32(vector: &[f32]) -> Self {
79        let params = Int4Params::from_vector(vector);
80        let dims = vector.len();
81
82        // Compute L2 norm
83        let mut norm_sq = 0.0f32;
84        for &v in vector {
85            if v.is_finite() {
86                norm_sq += v * v;
87            }
88        }
89        let norm = norm_sq.sqrt();
90
91        // Quantize each value to [0, 15] and pack pairs into bytes
92        let packed_len = dims.div_ceil(2);
93        let mut data = vec![0u8; packed_len];
94
95        for (i, &elem) in vector[..dims].iter().enumerate() {
96            let v = if elem.is_finite() { elem } else { 0.0 };
97            // Map [-max_abs, max_abs] -> [0, 15]
98            let q = ((v + params.max_abs) * params.scale)
99                .round()
100                .clamp(0.0, 15.0) as u8;
101
102            let byte_idx = i / 2;
103            if i % 2 == 0 {
104                // Even index -> high nibble
105                data[byte_idx] |= q << 4;
106            } else {
107                // Odd index -> low nibble
108                data[byte_idx] |= q;
109            }
110        }
111
112        Self {
113            data,
114            dims,
115            params,
116            norm,
117        }
118    }
119
120    /// **Unstable**: dequantization output semantics may change.
121    ///
122    /// Reverses the quantization: `v[i] = q[i] / scale - max_abs`
123    ///
124    /// # Precision
125    ///
126    /// INT4 unsigned symmetric quantization maps `[-max_abs, max_abs]` to `[0, 15]`
127    /// (16 levels), so the quantization step size is `2 * max_abs / 15`. The maximum
128    /// per-element round-trip error is bounded by half a step: `max_abs / 15`.
129    ///
130    /// For a 384-dim unit-norm embedding (`max_abs` ≈ 1.0), expect element-wise
131    /// absolute error ≤ 0.067 and relative dot-product error ≤ 15% (see
132    /// `test_int4_dot_product_vs_f32` and `test_int4_roundtrip_accuracy`).
133    /// Use `Int8` tier when higher fidelity is required.
134    pub fn to_f32(&self) -> Vec<f32> {
135        let scale = if self.params.scale.is_finite() && self.params.scale != 0.0 {
136            self.params.scale
137        } else {
138            1.0
139        };
140
141        let mut result = Vec::with_capacity(self.dims);
142        for i in 0..self.dims {
143            let byte_idx = i / 2;
144            let q = if i % 2 == 0 {
145                (self.data[byte_idx] >> 4) & 0x0F
146            } else {
147                self.data[byte_idx] & 0x0F
148            };
149            result.push(q as f32 / scale - self.params.max_abs);
150        }
151        result
152    }
153
154    /// **Unstable**: INT4 dot product approximation; formula may change.
155    ///
156    /// Returns the dequantized dot product suitable for cosine distance computation.
157    #[inline]
158    pub fn dot_product(&self, other: &Int4Vector) -> f32 {
159        dot_product_int4(self, other)
160    }
161
162    /// **Unstable**: INT4 cosine similarity approximation; delegates to `dot_product`.
163    #[inline]
164    pub fn cosine_similarity(&self, other: &Int4Vector) -> f32 {
165        let denom = self.norm * other.norm;
166        if denom == 0.0 || !denom.is_finite() {
167            return 0.0;
168        }
169        self.dot_product(other) / denom
170    }
171
172    /// **Unstable**: complement of `cosine_similarity`; definition may evolve.
173    #[inline]
174    pub fn cosine_distance(&self, other: &Int4Vector) -> f32 {
175        1.0 - self.cosine_similarity(other)
176    }
177}
178
179/// **Unstable**: SIMD INT4 dot product; NEON/scalar dispatch may change.
180///
181/// Unpacks nibbles, computes dot product of quantized values, then applies
182/// dequantization scaling: `result = (raw_dot / (scale_a * scale_b)) - correction`
183///
184/// The correction accounts for the unsigned offset in the quantization formula.
185#[inline]
186pub fn dot_product_int4(a: &Int4Vector, b: &Int4Vector) -> f32 {
187    if a.dims != b.dims {
188        return 0.0;
189    }
190
191    let scale_a = a.params.scale;
192    let scale_b = b.params.scale;
193    if scale_a == 0.0 || scale_b == 0.0 || !scale_a.is_finite() || !scale_b.is_finite() {
194        return 0.0;
195    }
196
197    let packed_len = a.dims.div_ceil(2);
198    if a.data.len() < packed_len || b.data.len() < packed_len {
199        return 0.0;
200    }
201
202    #[cfg(target_arch = "aarch64")]
203    {
204        let config = simd_config();
205        if config.neon_enabled {
206            // SAFETY: aarch64 NEON is available by config, the packed data length guard
207            // above prevents out-of-bounds loads, and the callee handles odd dimensions
208            // without reading the padding nibble as a real dimension.
209            let (raw_dot, sum_a, sum_b) =
210                unsafe { dot_product_int4_neon_unrolled(&a.data, &b.data, a.dims) };
211            return finish_int4_dot(raw_dot, sum_a, sum_b, a, b);
212        }
213    }
214
215    let a_deq = a.to_f32();
216    let b_deq = b.to_f32();
217    a_deq.iter().zip(b_deq.iter()).map(|(&x, &y)| x * y).sum()
218}
219
220#[cfg(target_arch = "aarch64")]
221#[inline]
222fn finish_int4_dot(raw_dot: i32, sum_a: i32, sum_b: i32, a: &Int4Vector, b: &Int4Vector) -> f32 {
223    let raw_dot = raw_dot as f32;
224    let sum_a = sum_a as f32;
225    let sum_b = sum_b as f32;
226    let scale_a = a.params.scale;
227    let scale_b = b.params.scale;
228
229    raw_dot / (scale_a * scale_b)
230        - (b.params.max_abs * sum_a / scale_a)
231        - (a.params.max_abs * sum_b / scale_b)
232        + (a.dims as f32 * a.params.max_abs * b.params.max_abs)
233}
234
235#[cfg(target_arch = "aarch64")]
236#[target_feature(enable = "neon")]
237#[inline]
238unsafe fn dot_product_int4_neon_unrolled(a: &[u8], b: &[u8], dims: usize) -> (i32, i32, i32) {
239    debug_assert!(a.len() >= dims.div_ceil(2));
240    debug_assert!(b.len() >= dims.div_ceil(2));
241
242    const BLOCK_BYTES: usize = 16;
243    const UNROLL: usize = 4;
244    const CHUNK_BYTES: usize = BLOCK_BYTES * UNROLL;
245
246    // Only bytes containing two valid dimensions are processed in SIMD.
247    // If dims is odd, the final high nibble is handled separately and the low
248    // padding nibble is ignored to preserve current to_f32 semantics.
249    let full_bytes = dims / 2;
250    let chunks = full_bytes / CHUNK_BYTES;
251
252    let mut raw0 = vdupq_n_u32(0);
253    let mut raw1 = vdupq_n_u32(0);
254    let mut raw2 = vdupq_n_u32(0);
255    let mut raw3 = vdupq_n_u32(0);
256    let mut sum_a = vdupq_n_u32(0);
257    let mut sum_b = vdupq_n_u32(0);
258    let mask = vdupq_n_u8(0x0f);
259
260    macro_rules! accumulate_block {
261        ($base:expr, $raw:ident) => {{
262            let a_bytes = vld1q_u8(a.as_ptr().add($base));
263            let b_bytes = vld1q_u8(b.as_ptr().add($base));
264
265            let a_hi = vshrq_n_u8::<4>(a_bytes);
266            let b_hi = vshrq_n_u8::<4>(b_bytes);
267            let a_lo = vandq_u8(a_bytes, mask);
268            let b_lo = vandq_u8(b_bytes, mask);
269
270            $raw = vpadalq_u16($raw, vmull_u8(vget_low_u8(a_hi), vget_low_u8(b_hi)));
271            $raw = vpadalq_u16($raw, vmull_u8(vget_high_u8(a_hi), vget_high_u8(b_hi)));
272            $raw = vpadalq_u16($raw, vmull_u8(vget_low_u8(a_lo), vget_low_u8(b_lo)));
273            $raw = vpadalq_u16($raw, vmull_u8(vget_high_u8(a_lo), vget_high_u8(b_lo)));
274
275            sum_a = vpadalq_u16(sum_a, vpaddlq_u8(a_hi));
276            sum_a = vpadalq_u16(sum_a, vpaddlq_u8(a_lo));
277            sum_b = vpadalq_u16(sum_b, vpaddlq_u8(b_hi));
278            sum_b = vpadalq_u16(sum_b, vpaddlq_u8(b_lo));
279        }};
280    }
281
282    for i in 0..chunks {
283        let base = i * CHUNK_BYTES;
284        accumulate_block!(base, raw0);
285        accumulate_block!(base + BLOCK_BYTES, raw1);
286        accumulate_block!(base + BLOCK_BYTES * 2, raw2);
287        accumulate_block!(base + BLOCK_BYTES * 3, raw3);
288    }
289
290    let raw_vec = vaddq_u32(vaddq_u32(raw0, raw1), vaddq_u32(raw2, raw3));
291    let mut raw_total = (vgetq_lane_u32::<0>(raw_vec)
292        + vgetq_lane_u32::<1>(raw_vec)
293        + vgetq_lane_u32::<2>(raw_vec)
294        + vgetq_lane_u32::<3>(raw_vec)) as i32;
295    let mut sum_a_total = (vgetq_lane_u32::<0>(sum_a)
296        + vgetq_lane_u32::<1>(sum_a)
297        + vgetq_lane_u32::<2>(sum_a)
298        + vgetq_lane_u32::<3>(sum_a)) as i32;
299    let mut sum_b_total = (vgetq_lane_u32::<0>(sum_b)
300        + vgetq_lane_u32::<1>(sum_b)
301        + vgetq_lane_u32::<2>(sum_b)
302        + vgetq_lane_u32::<3>(sum_b)) as i32;
303
304    let remainder_start = chunks * CHUNK_BYTES;
305    for byte_idx in remainder_start..full_bytes {
306        let av = *a.get_unchecked(byte_idx);
307        let bv = *b.get_unchecked(byte_idx);
308        let ah = ((av >> 4) & 0x0f) as i32;
309        let al = (av & 0x0f) as i32;
310        let bh = ((bv >> 4) & 0x0f) as i32;
311        let bl = (bv & 0x0f) as i32;
312
313        raw_total += ah * bh + al * bl;
314        sum_a_total += ah + al;
315        sum_b_total += bh + bl;
316    }
317
318    if dims % 2 == 1 {
319        let av = *a.get_unchecked(full_bytes);
320        let bv = *b.get_unchecked(full_bytes);
321        let ah = ((av >> 4) & 0x0f) as i32;
322        let bh = ((bv >> 4) & 0x0f) as i32;
323
324        raw_total += ah * bh;
325        sum_a_total += ah;
326        sum_b_total += bh;
327    }
328
329    (raw_total, sum_a_total, sum_b_total)
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335
336    fn generate_vector(dim: usize, seed: u64) -> Vec<f32> {
337        let mut state = seed ^ ((dim as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15));
338        (0..dim)
339            .map(|i| {
340                state = state
341                    .wrapping_mul(6364136223846793005)
342                    .wrapping_add(1442695040888963407)
343                    .wrapping_add(i as u64);
344                let unit = ((state >> 32) as u32) as f32 / u32::MAX as f32;
345                unit * 2.0 - 1.0
346            })
347            .collect()
348    }
349
350    #[test]
351    fn test_int4_roundtrip_accuracy() {
352        let original = generate_vector(384, 42);
353        let quantized = Int4Vector::from_f32(&original);
354        let dequantized = quantized.to_f32();
355
356        assert_eq!(dequantized.len(), original.len());
357
358        // INT4 has only 16 levels, so error is larger than INT8.
359        // Max error should be within 1/15 of the range.
360        let max_abs = original
361            .iter()
362            .filter(|v| v.is_finite())
363            .map(|v| v.abs())
364            .fold(0.0f32, f32::max);
365        let expected_max_error = 2.0 * max_abs / 15.0;
366
367        for (i, (orig, deq)) in original.iter().zip(dequantized.iter()).enumerate() {
368            let error = (orig - deq).abs();
369            assert!(
370                error <= expected_max_error + 1e-5,
371                "INT4 roundtrip error too large at index {i}: orig={orig}, deq={deq}, error={error}, max_allowed={expected_max_error}"
372            );
373        }
374    }
375
376    #[test]
377    fn test_int4_packing_correctness() {
378        // Verify nibble packing: even index -> high nibble, odd -> low
379        let v = vec![0.5, -0.5, 0.0, 1.0]; // 4 values -> 2 packed bytes
380        let q = Int4Vector::from_f32(&v);
381        assert_eq!(q.data.len(), 2);
382        assert_eq!(q.dims, 4);
383
384        // Verify roundtrip preserves approximate values
385        let deq = q.to_f32();
386        assert_eq!(deq.len(), 4);
387        // 0.5 should map to roughly the right region
388        assert!((deq[0] - 0.5).abs() < 0.15, "deq[0]={}", deq[0]);
389        assert!((deq[1] - (-0.5)).abs() < 0.15, "deq[1]={}", deq[1]);
390    }
391
392    #[test]
393    fn test_int4_odd_dimensions() {
394        // Odd number of dimensions: last nibble has a padding zero
395        let v = generate_vector(383, 77);
396        let q = Int4Vector::from_f32(&v);
397        assert_eq!(q.data.len(), 192); // ceil(383/2) = 192
398        assert_eq!(q.dims, 383);
399
400        let deq = q.to_f32();
401        assert_eq!(deq.len(), 383);
402    }
403
404    #[test]
405    fn test_int4_zero_vector() {
406        let v = vec![0.0; 384];
407        let q = Int4Vector::from_f32(&v);
408        let deq = q.to_f32();
409        for &val in &deq {
410            assert!(
411                val.abs() < 1e-5,
412                "Zero vector should dequantize to near-zero"
413            );
414        }
415    }
416
417    #[test]
418    fn test_int4_dot_product_vs_f32() {
419        // Use correlated vectors so the true dot product is large relative to noise.
420        // For uncorrelated random vectors, the expected dot product is ~0 while
421        // quantization noise is O(dims * step^2), so relative error is unbounded.
422        let a = generate_vector(384, 101);
423        let b: Vec<f32> = a
424            .iter()
425            .enumerate()
426            .map(|(i, &x)| x + 0.2 * (i as f32 * 0.3).sin())
427            .collect();
428
429        // f32 reference
430        let f32_dot: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
431
432        let qa = Int4Vector::from_f32(&a);
433        let qb = Int4Vector::from_f32(&b);
434        let int4_dot = qa.dot_product(&qb);
435
436        // INT4 has 16 levels; for correlated vectors the relative error should be
437        // within ~15% (quantization step = 2*max_abs/15 per component).
438        let rel_error = (f32_dot - int4_dot).abs() / f32_dot.abs().max(1.0);
439        assert!(
440            rel_error < 0.15,
441            "INT4 dot product relative error too large: f32={f32_dot}, int4={int4_dot}, rel_error={rel_error}"
442        );
443    }
444
445    #[cfg(target_arch = "aarch64")]
446    #[test]
447    fn test_int4_neon_matches_dequantized_scalar() {
448        for dim in [1, 2, 31, 64, 127, 384, 768] {
449            let a = generate_vector(dim, 501);
450            let b = generate_vector(dim, 777);
451            let qa = Int4Vector::from_f32(&a);
452            let qb = Int4Vector::from_f32(&b);
453
454            let a_deq = qa.to_f32();
455            let b_deq = qb.to_f32();
456            let expected: f32 = a_deq.iter().zip(b_deq.iter()).map(|(&x, &y)| x * y).sum();
457            let got = qa.dot_product(&qb);
458
459            assert!(
460                (expected - got).abs() < 1e-4,
461                "INT4 NEON mismatch for dim={dim}: expected={expected}, got={got}"
462            );
463        }
464    }
465
466    #[test]
467    fn test_int4_cosine_similarity() {
468        let a = generate_vector(384, 301);
469        let b = generate_vector(384, 302);
470
471        let qa = Int4Vector::from_f32(&a);
472        let qb = Int4Vector::from_f32(&b);
473        let int4_cos = qa.cosine_similarity(&qb);
474
475        // Compute f32 reference cosine
476        let dot: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
477        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
478        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
479        let f32_cos = dot / (norm_a * norm_b);
480
481        assert!(
482            (f32_cos - int4_cos).abs() < 0.1,
483            "INT4 cosine too far from f32: f32={f32_cos}, int4={int4_cos}"
484        );
485    }
486
487    #[test]
488    fn test_int4_memory_savings() {
489        let v = generate_vector(384, 999);
490        let q = Int4Vector::from_f32(&v);
491
492        // f32: 384 * 4 = 1536 bytes
493        // INT4: ceil(384/2) = 192 bytes = 8x compression
494        assert_eq!(q.data.len(), 192);
495        assert_eq!(v.len() * 4, 1536);
496    }
497
498    #[test]
499    fn test_int4_nan_inf_handling() {
500        let v = vec![
501            1.0,
502            f32::NAN,
503            f32::INFINITY,
504            f32::NEG_INFINITY,
505            -1.0,
506            0.5,
507            0.0,
508            -0.3,
509        ];
510        let q = Int4Vector::from_f32(&v);
511        let deq = q.to_f32();
512        assert_eq!(deq.len(), 8);
513        // NaN and Inf should be treated as 0
514        // The dequantized value for the "0" slot should be near -max_abs + something,
515        // but the key invariant is no panics and finite output.
516        for &val in &deq {
517            assert!(val.is_finite(), "Dequantized value should be finite");
518        }
519    }
520}