1#[cfg(target_arch = "aarch64")]
17use std::arch::aarch64::*;
18
19#[cfg(target_arch = "aarch64")]
20use super::simd_config;
21
22#[derive(Debug, Clone, Copy)]
29pub struct Int4Params {
30 pub scale: f32,
32 pub max_abs: f32,
34}
35
36impl Int4Params {
37 pub fn from_vector(vector: &[f32]) -> Self {
39 let mut max_abs: f32 = 0.0;
40 for &v in vector {
41 if v.is_finite() {
42 max_abs = max_abs.max(v.abs());
43 }
44 }
45
46 let scale = if max_abs > 1e-10 {
48 15.0 / (2.0 * max_abs)
49 } else {
50 1.0
51 };
52
53 Self { scale, max_abs }
54 }
55}
56
57#[derive(Debug, Clone)]
61pub struct Int4Vector {
62 pub data: Vec<u8>,
64 pub dims: usize,
66 pub params: Int4Params,
68 pub norm: f32,
70}
71
72impl Int4Vector {
73 pub fn from_f32(vector: &[f32]) -> Self {
79 let params = Int4Params::from_vector(vector);
80 let dims = vector.len();
81
82 let mut norm_sq = 0.0f32;
84 for &v in vector {
85 if v.is_finite() {
86 norm_sq += v * v;
87 }
88 }
89 let norm = norm_sq.sqrt();
90
91 let packed_len = dims.div_ceil(2);
93 let mut data = vec![0u8; packed_len];
94
95 for (i, &elem) in vector[..dims].iter().enumerate() {
96 let v = if elem.is_finite() { elem } else { 0.0 };
97 let q = ((v + params.max_abs) * params.scale)
99 .round()
100 .clamp(0.0, 15.0) as u8;
101
102 let byte_idx = i / 2;
103 if i % 2 == 0 {
104 data[byte_idx] |= q << 4;
106 } else {
107 data[byte_idx] |= q;
109 }
110 }
111
112 Self {
113 data,
114 dims,
115 params,
116 norm,
117 }
118 }
119
120 pub fn to_f32(&self) -> Vec<f32> {
135 let scale = if self.params.scale.is_finite() && self.params.scale != 0.0 {
136 self.params.scale
137 } else {
138 1.0
139 };
140
141 let mut result = Vec::with_capacity(self.dims);
142 for i in 0..self.dims {
143 let byte_idx = i / 2;
144 let q = if i % 2 == 0 {
145 (self.data[byte_idx] >> 4) & 0x0F
146 } else {
147 self.data[byte_idx] & 0x0F
148 };
149 result.push(q as f32 / scale - self.params.max_abs);
150 }
151 result
152 }
153
154 #[inline]
158 pub fn dot_product(&self, other: &Int4Vector) -> f32 {
159 dot_product_int4(self, other)
160 }
161
162 #[inline]
164 pub fn cosine_similarity(&self, other: &Int4Vector) -> f32 {
165 let denom = self.norm * other.norm;
166 if denom == 0.0 || !denom.is_finite() {
167 return 0.0;
168 }
169 self.dot_product(other) / denom
170 }
171
172 #[inline]
174 pub fn cosine_distance(&self, other: &Int4Vector) -> f32 {
175 1.0 - self.cosine_similarity(other)
176 }
177}
178
179#[inline]
186pub fn dot_product_int4(a: &Int4Vector, b: &Int4Vector) -> f32 {
187 if a.dims != b.dims {
188 return 0.0;
189 }
190
191 let scale_a = a.params.scale;
192 let scale_b = b.params.scale;
193 if scale_a == 0.0 || scale_b == 0.0 || !scale_a.is_finite() || !scale_b.is_finite() {
194 return 0.0;
195 }
196
197 let packed_len = a.dims.div_ceil(2);
198 if a.data.len() < packed_len || b.data.len() < packed_len {
199 return 0.0;
200 }
201
202 #[cfg(target_arch = "aarch64")]
203 {
204 let config = simd_config();
205 if config.neon_enabled {
206 let (raw_dot, sum_a, sum_b) =
210 unsafe { dot_product_int4_neon_unrolled(&a.data, &b.data, a.dims) };
211 return finish_int4_dot(raw_dot, sum_a, sum_b, a, b);
212 }
213 }
214
215 let a_deq = a.to_f32();
216 let b_deq = b.to_f32();
217 a_deq.iter().zip(b_deq.iter()).map(|(&x, &y)| x * y).sum()
218}
219
220#[cfg(target_arch = "aarch64")]
221#[inline]
222fn finish_int4_dot(raw_dot: i32, sum_a: i32, sum_b: i32, a: &Int4Vector, b: &Int4Vector) -> f32 {
223 let raw_dot = raw_dot as f32;
224 let sum_a = sum_a as f32;
225 let sum_b = sum_b as f32;
226 let scale_a = a.params.scale;
227 let scale_b = b.params.scale;
228
229 raw_dot / (scale_a * scale_b)
230 - (b.params.max_abs * sum_a / scale_a)
231 - (a.params.max_abs * sum_b / scale_b)
232 + (a.dims as f32 * a.params.max_abs * b.params.max_abs)
233}
234
235#[cfg(target_arch = "aarch64")]
236#[target_feature(enable = "neon")]
237#[inline]
238unsafe fn dot_product_int4_neon_unrolled(a: &[u8], b: &[u8], dims: usize) -> (i32, i32, i32) {
239 debug_assert!(a.len() >= dims.div_ceil(2));
240 debug_assert!(b.len() >= dims.div_ceil(2));
241
242 const BLOCK_BYTES: usize = 16;
243 const UNROLL: usize = 4;
244 const CHUNK_BYTES: usize = BLOCK_BYTES * UNROLL;
245
246 let full_bytes = dims / 2;
250 let chunks = full_bytes / CHUNK_BYTES;
251
252 let mut raw0 = vdupq_n_u32(0);
253 let mut raw1 = vdupq_n_u32(0);
254 let mut raw2 = vdupq_n_u32(0);
255 let mut raw3 = vdupq_n_u32(0);
256 let mut sum_a = vdupq_n_u32(0);
257 let mut sum_b = vdupq_n_u32(0);
258 let mask = vdupq_n_u8(0x0f);
259
260 macro_rules! accumulate_block {
261 ($base:expr, $raw:ident) => {{
262 let a_bytes = vld1q_u8(a.as_ptr().add($base));
263 let b_bytes = vld1q_u8(b.as_ptr().add($base));
264
265 let a_hi = vshrq_n_u8::<4>(a_bytes);
266 let b_hi = vshrq_n_u8::<4>(b_bytes);
267 let a_lo = vandq_u8(a_bytes, mask);
268 let b_lo = vandq_u8(b_bytes, mask);
269
270 $raw = vpadalq_u16($raw, vmull_u8(vget_low_u8(a_hi), vget_low_u8(b_hi)));
271 $raw = vpadalq_u16($raw, vmull_u8(vget_high_u8(a_hi), vget_high_u8(b_hi)));
272 $raw = vpadalq_u16($raw, vmull_u8(vget_low_u8(a_lo), vget_low_u8(b_lo)));
273 $raw = vpadalq_u16($raw, vmull_u8(vget_high_u8(a_lo), vget_high_u8(b_lo)));
274
275 sum_a = vpadalq_u16(sum_a, vpaddlq_u8(a_hi));
276 sum_a = vpadalq_u16(sum_a, vpaddlq_u8(a_lo));
277 sum_b = vpadalq_u16(sum_b, vpaddlq_u8(b_hi));
278 sum_b = vpadalq_u16(sum_b, vpaddlq_u8(b_lo));
279 }};
280 }
281
282 for i in 0..chunks {
283 let base = i * CHUNK_BYTES;
284 accumulate_block!(base, raw0);
285 accumulate_block!(base + BLOCK_BYTES, raw1);
286 accumulate_block!(base + BLOCK_BYTES * 2, raw2);
287 accumulate_block!(base + BLOCK_BYTES * 3, raw3);
288 }
289
290 let raw_vec = vaddq_u32(vaddq_u32(raw0, raw1), vaddq_u32(raw2, raw3));
291 let mut raw_total = (vgetq_lane_u32::<0>(raw_vec)
292 + vgetq_lane_u32::<1>(raw_vec)
293 + vgetq_lane_u32::<2>(raw_vec)
294 + vgetq_lane_u32::<3>(raw_vec)) as i32;
295 let mut sum_a_total = (vgetq_lane_u32::<0>(sum_a)
296 + vgetq_lane_u32::<1>(sum_a)
297 + vgetq_lane_u32::<2>(sum_a)
298 + vgetq_lane_u32::<3>(sum_a)) as i32;
299 let mut sum_b_total = (vgetq_lane_u32::<0>(sum_b)
300 + vgetq_lane_u32::<1>(sum_b)
301 + vgetq_lane_u32::<2>(sum_b)
302 + vgetq_lane_u32::<3>(sum_b)) as i32;
303
304 let remainder_start = chunks * CHUNK_BYTES;
305 for byte_idx in remainder_start..full_bytes {
306 let av = *a.get_unchecked(byte_idx);
307 let bv = *b.get_unchecked(byte_idx);
308 let ah = ((av >> 4) & 0x0f) as i32;
309 let al = (av & 0x0f) as i32;
310 let bh = ((bv >> 4) & 0x0f) as i32;
311 let bl = (bv & 0x0f) as i32;
312
313 raw_total += ah * bh + al * bl;
314 sum_a_total += ah + al;
315 sum_b_total += bh + bl;
316 }
317
318 if dims % 2 == 1 {
319 let av = *a.get_unchecked(full_bytes);
320 let bv = *b.get_unchecked(full_bytes);
321 let ah = ((av >> 4) & 0x0f) as i32;
322 let bh = ((bv >> 4) & 0x0f) as i32;
323
324 raw_total += ah * bh;
325 sum_a_total += ah;
326 sum_b_total += bh;
327 }
328
329 (raw_total, sum_a_total, sum_b_total)
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335
336 fn generate_vector(dim: usize, seed: u64) -> Vec<f32> {
337 let mut state = seed ^ ((dim as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15));
338 (0..dim)
339 .map(|i| {
340 state = state
341 .wrapping_mul(6364136223846793005)
342 .wrapping_add(1442695040888963407)
343 .wrapping_add(i as u64);
344 let unit = ((state >> 32) as u32) as f32 / u32::MAX as f32;
345 unit * 2.0 - 1.0
346 })
347 .collect()
348 }
349
350 #[test]
351 fn test_int4_roundtrip_accuracy() {
352 let original = generate_vector(384, 42);
353 let quantized = Int4Vector::from_f32(&original);
354 let dequantized = quantized.to_f32();
355
356 assert_eq!(dequantized.len(), original.len());
357
358 let max_abs = original
361 .iter()
362 .filter(|v| v.is_finite())
363 .map(|v| v.abs())
364 .fold(0.0f32, f32::max);
365 let expected_max_error = 2.0 * max_abs / 15.0;
366
367 for (i, (orig, deq)) in original.iter().zip(dequantized.iter()).enumerate() {
368 let error = (orig - deq).abs();
369 assert!(
370 error <= expected_max_error + 1e-5,
371 "INT4 roundtrip error too large at index {i}: orig={orig}, deq={deq}, error={error}, max_allowed={expected_max_error}"
372 );
373 }
374 }
375
376 #[test]
377 fn test_int4_packing_correctness() {
378 let v = vec![0.5, -0.5, 0.0, 1.0]; let q = Int4Vector::from_f32(&v);
381 assert_eq!(q.data.len(), 2);
382 assert_eq!(q.dims, 4);
383
384 let deq = q.to_f32();
386 assert_eq!(deq.len(), 4);
387 assert!((deq[0] - 0.5).abs() < 0.15, "deq[0]={}", deq[0]);
389 assert!((deq[1] - (-0.5)).abs() < 0.15, "deq[1]={}", deq[1]);
390 }
391
392 #[test]
393 fn test_int4_odd_dimensions() {
394 let v = generate_vector(383, 77);
396 let q = Int4Vector::from_f32(&v);
397 assert_eq!(q.data.len(), 192); assert_eq!(q.dims, 383);
399
400 let deq = q.to_f32();
401 assert_eq!(deq.len(), 383);
402 }
403
404 #[test]
405 fn test_int4_zero_vector() {
406 let v = vec![0.0; 384];
407 let q = Int4Vector::from_f32(&v);
408 let deq = q.to_f32();
409 for &val in &deq {
410 assert!(
411 val.abs() < 1e-5,
412 "Zero vector should dequantize to near-zero"
413 );
414 }
415 }
416
417 #[test]
418 fn test_int4_dot_product_vs_f32() {
419 let a = generate_vector(384, 101);
423 let b: Vec<f32> = a
424 .iter()
425 .enumerate()
426 .map(|(i, &x)| x + 0.2 * (i as f32 * 0.3).sin())
427 .collect();
428
429 let f32_dot: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
431
432 let qa = Int4Vector::from_f32(&a);
433 let qb = Int4Vector::from_f32(&b);
434 let int4_dot = qa.dot_product(&qb);
435
436 let rel_error = (f32_dot - int4_dot).abs() / f32_dot.abs().max(1.0);
439 assert!(
440 rel_error < 0.15,
441 "INT4 dot product relative error too large: f32={f32_dot}, int4={int4_dot}, rel_error={rel_error}"
442 );
443 }
444
445 #[cfg(target_arch = "aarch64")]
446 #[test]
447 fn test_int4_neon_matches_dequantized_scalar() {
448 for dim in [1, 2, 31, 64, 127, 384, 768] {
449 let a = generate_vector(dim, 501);
450 let b = generate_vector(dim, 777);
451 let qa = Int4Vector::from_f32(&a);
452 let qb = Int4Vector::from_f32(&b);
453
454 let a_deq = qa.to_f32();
455 let b_deq = qb.to_f32();
456 let expected: f32 = a_deq.iter().zip(b_deq.iter()).map(|(&x, &y)| x * y).sum();
457 let got = qa.dot_product(&qb);
458
459 assert!(
460 (expected - got).abs() < 1e-4,
461 "INT4 NEON mismatch for dim={dim}: expected={expected}, got={got}"
462 );
463 }
464 }
465
466 #[test]
467 fn test_int4_cosine_similarity() {
468 let a = generate_vector(384, 301);
469 let b = generate_vector(384, 302);
470
471 let qa = Int4Vector::from_f32(&a);
472 let qb = Int4Vector::from_f32(&b);
473 let int4_cos = qa.cosine_similarity(&qb);
474
475 let dot: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
477 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
478 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
479 let f32_cos = dot / (norm_a * norm_b);
480
481 assert!(
482 (f32_cos - int4_cos).abs() < 0.1,
483 "INT4 cosine too far from f32: f32={f32_cos}, int4={int4_cos}"
484 );
485 }
486
487 #[test]
488 fn test_int4_memory_savings() {
489 let v = generate_vector(384, 999);
490 let q = Int4Vector::from_f32(&v);
491
492 assert_eq!(q.data.len(), 192);
495 assert_eq!(v.len() * 4, 1536);
496 }
497
498 #[test]
499 fn test_int4_nan_inf_handling() {
500 let v = vec![
501 1.0,
502 f32::NAN,
503 f32::INFINITY,
504 f32::NEG_INFINITY,
505 -1.0,
506 0.5,
507 0.0,
508 -0.3,
509 ];
510 let q = Int4Vector::from_f32(&v);
511 let deq = q.to_f32();
512 assert_eq!(deq.len(), 8);
513 for &val in &deq {
517 assert!(val.is_finite(), "Dequantized value should be finite");
518 }
519 }
520}