#[cfg(target_arch = "x86_64")]
use std::ptr::write_unaligned;
#[cfg(target_arch = "x86_64")]
use crate::superfile::vector::simd_dispatch::{avx2_enabled, avx512_enabled};
const SQ8_CODE_MAX: f32 = 255.0;
const SQ8_ROUND_HALF_BIAS: f32 = 0.5;
const F32X8_LANES: usize = 8;
#[cfg_attr(not(target_arch = "x86_64"), allow(dead_code))]
const AVX512_F32_LANES: usize = 16;
pub(super) struct Sq8EncodeConsts {
pub(super) inv_scale: Vec<f32>,
pub(super) c2: Vec<f32>,
}
impl Sq8EncodeConsts {
pub(super) fn from_scale_offset(scale: &[f32], offset: &[f32]) -> Self {
debug_assert_eq!(scale.len(), offset.len());
let inv_scale: Vec<f32> = scale.iter().map(|s| 1.0 / s).collect();
let c2: Vec<f32> = offset
.iter()
.zip(inv_scale.iter())
.map(|(o, inv)| (-*o).mul_add(*inv, SQ8_ROUND_HALF_BIAS))
.collect();
Self { inv_scale, c2 }
}
}
#[inline]
pub(super) fn update_min_max(row: &[f32], min_slice: &mut [f32], max_slice: &mut [f32]) {
debug_assert_eq!(row.len(), min_slice.len());
debug_assert_eq!(row.len(), max_slice.len());
#[cfg(target_arch = "x86_64")]
{
if avx512_enabled() {
unsafe { update_min_max_avx512(row, min_slice, max_slice) };
return;
}
if avx2_enabled() {
unsafe { update_min_max_avx2(row, min_slice, max_slice) };
return;
}
}
update_min_max_wide(row, min_slice, max_slice);
}
#[inline]
fn update_min_max_wide(row: &[f32], min_slice: &mut [f32], max_slice: &mut [f32]) {
use wide::f32x8;
let dim = row.len();
let full = dim - dim % F32X8_LANES;
let mut i = 0;
while i < full {
let r: [f32; F32X8_LANES] = row[i..i + F32X8_LANES].try_into().expect("len 8");
let mn: [f32; F32X8_LANES] = min_slice[i..i + F32X8_LANES].try_into().expect("len 8");
let mx: [f32; F32X8_LANES] = max_slice[i..i + F32X8_LANES].try_into().expect("len 8");
let r_v = f32x8::from(r);
let new_min = r_v.fast_min(f32x8::from(mn)).to_array();
let new_max = r_v.fast_max(f32x8::from(mx)).to_array();
min_slice[i..i + F32X8_LANES].copy_from_slice(&new_min);
max_slice[i..i + F32X8_LANES].copy_from_slice(&new_max);
i += F32X8_LANES;
}
while i < dim {
let x = row[i];
if x < min_slice[i] {
min_slice[i] = x;
}
if x > max_slice[i] {
max_slice[i] = x;
}
i += 1;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn update_min_max_avx2(row: &[f32], min_slice: &mut [f32], max_slice: &mut [f32]) {
use std::arch::x86_64::*;
let dim = row.len();
let full = dim - dim % F32X8_LANES;
let mut i = 0;
unsafe {
while i < full {
let r = _mm256_loadu_ps(row.as_ptr().add(i));
let mn = _mm256_loadu_ps(min_slice.as_ptr().add(i));
let mx = _mm256_loadu_ps(max_slice.as_ptr().add(i));
let new_mn = _mm256_min_ps(r, mn);
let new_mx = _mm256_max_ps(r, mx);
_mm256_storeu_ps(min_slice.as_mut_ptr().add(i), new_mn);
_mm256_storeu_ps(max_slice.as_mut_ptr().add(i), new_mx);
i += F32X8_LANES;
}
}
while i < dim {
let x = row[i];
if x < min_slice[i] {
min_slice[i] = x;
}
if x > max_slice[i] {
max_slice[i] = x;
}
i += 1;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
unsafe fn update_min_max_avx512(row: &[f32], min_slice: &mut [f32], max_slice: &mut [f32]) {
use std::arch::x86_64::*;
let dim = row.len();
let full16 = dim - dim % AVX512_F32_LANES;
let mut i = 0;
unsafe {
while i < full16 {
let r = _mm512_loadu_ps(row.as_ptr().add(i));
let mn = _mm512_loadu_ps(min_slice.as_ptr().add(i));
let mx = _mm512_loadu_ps(max_slice.as_ptr().add(i));
let new_mn = _mm512_min_ps(r, mn);
let new_mx = _mm512_max_ps(r, mx);
_mm512_storeu_ps(min_slice.as_mut_ptr().add(i), new_mn);
_mm512_storeu_ps(max_slice.as_mut_ptr().add(i), new_mx);
i += AVX512_F32_LANES;
}
if i + F32X8_LANES <= dim {
let r = _mm256_loadu_ps(row.as_ptr().add(i));
let mn = _mm256_loadu_ps(min_slice.as_ptr().add(i));
let mx = _mm256_loadu_ps(max_slice.as_ptr().add(i));
let new_mn = _mm256_min_ps(r, mn);
let new_mx = _mm256_max_ps(r, mx);
_mm256_storeu_ps(min_slice.as_mut_ptr().add(i), new_mn);
_mm256_storeu_ps(max_slice.as_mut_ptr().add(i), new_mx);
i += F32X8_LANES;
}
}
while i < dim {
let x = row[i];
if x < min_slice[i] {
min_slice[i] = x;
}
if x > max_slice[i] {
max_slice[i] = x;
}
i += 1;
}
}
#[inline]
pub(super) fn sq8_encode_row(row: &[f32], inv_scale: &[f32], c2: &[f32], dst: &mut [u8]) {
debug_assert_eq!(row.len(), inv_scale.len());
debug_assert_eq!(row.len(), c2.len());
debug_assert_eq!(row.len(), dst.len());
#[cfg(target_arch = "x86_64")]
{
if avx512_enabled() {
unsafe { sq8_encode_row_avx512(row, inv_scale, c2, dst) };
return;
}
if avx2_enabled() {
unsafe { sq8_encode_row_avx2(row, inv_scale, c2, dst) };
return;
}
}
sq8_encode_row_wide(row, inv_scale, c2, dst);
}
#[allow(clippy::manual_clamp)]
fn sq8_encode_row_wide(row: &[f32], inv_scale: &[f32], c2: &[f32], dst: &mut [u8]) {
use wide::f32x8;
let dim = row.len();
let zero = f32x8::splat(0.0);
let max255 = f32x8::splat(SQ8_CODE_MAX);
let mut i = 0;
while i + F32X8_LANES <= dim {
let r: [f32; F32X8_LANES] = row[i..i + F32X8_LANES].try_into().expect("len 8");
let inv: [f32; F32X8_LANES] = inv_scale[i..i + F32X8_LANES].try_into().expect("len 8");
let c: [f32; F32X8_LANES] = c2[i..i + F32X8_LANES].try_into().expect("len 8");
let q = f32x8::from(r).mul_add(f32x8::from(inv), f32x8::from(c));
let q_clamped = q.fast_max(zero).fast_min(max255).to_array();
for k in 0..F32X8_LANES {
dst[i + k] = q_clamped[k] as u8;
}
i += F32X8_LANES;
}
while i < dim {
let q = row[i].mul_add(inv_scale[i], c2[i]);
let q_clamped = q.max(0.0).min(SQ8_CODE_MAX);
dst[i] = q_clamped as u8;
i += 1;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2,fma")]
#[allow(clippy::manual_clamp)] unsafe fn sq8_encode_row_avx2(row: &[f32], inv_scale: &[f32], c2: &[f32], dst: &mut [u8]) {
use std::arch::x86_64::*;
let dim = row.len();
let zero = _mm256_setzero_ps();
let max255 = _mm256_set1_ps(SQ8_CODE_MAX);
let mut i = 0;
unsafe {
while i + F32X8_LANES <= dim {
let r = _mm256_loadu_ps(row.as_ptr().add(i));
let inv = _mm256_loadu_ps(inv_scale.as_ptr().add(i));
let c = _mm256_loadu_ps(c2.as_ptr().add(i));
let q = _mm256_fmadd_ps(r, inv, c);
let q_clamped = _mm256_max_ps(_mm256_min_ps(q, max255), zero);
let q_i32 = _mm256_cvttps_epi32(q_clamped);
let lo = _mm256_castsi256_si128(q_i32);
let hi = _mm256_extracti128_si256::<1>(q_i32);
let packed_u16 = _mm_packus_epi32(lo, hi); let packed_u8 = _mm_packus_epi16(packed_u16, packed_u16); let dst_ptr = dst.as_mut_ptr().add(i) as *mut i64;
write_unaligned(dst_ptr, _mm_cvtsi128_si64(packed_u8));
i += F32X8_LANES;
}
}
while i < dim {
let q = row[i].mul_add(inv_scale[i], c2[i]);
let q_clamped = q.max(0.0).min(SQ8_CODE_MAX);
dst[i] = q_clamped as u8;
i += 1;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[allow(clippy::manual_clamp)] unsafe fn sq8_encode_row_avx512(row: &[f32], inv_scale: &[f32], c2: &[f32], dst: &mut [u8]) {
use std::arch::x86_64::*;
let dim = row.len();
let zero = _mm512_setzero_ps();
let max255 = _mm512_set1_ps(SQ8_CODE_MAX);
let mut i = 0;
unsafe {
while i + AVX512_F32_LANES <= dim {
let r = _mm512_loadu_ps(row.as_ptr().add(i));
let inv = _mm512_loadu_ps(inv_scale.as_ptr().add(i));
let c = _mm512_loadu_ps(c2.as_ptr().add(i));
let q = _mm512_fmadd_ps(r, inv, c);
let q_clamped = _mm512_max_ps(_mm512_min_ps(q, max255), zero);
let q_i32 = _mm512_cvttps_epi32(q_clamped);
let packed_u8 = _mm512_cvtusepi32_epi8(q_i32); _mm_storeu_si128(dst.as_mut_ptr().add(i) as *mut __m128i, packed_u8);
i += AVX512_F32_LANES;
}
}
#[cfg(target_arch = "x86_64")]
if i + F32X8_LANES <= dim {
unsafe { sq8_encode_row_avx2_unsafe_tail8(row, inv_scale, c2, dst, &mut i) };
}
while i < dim {
let q = row[i].mul_add(inv_scale[i], c2[i]);
let q_clamped = q.max(0.0).min(SQ8_CODE_MAX);
dst[i] = q_clamped as u8;
i += 1;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2,fma")]
unsafe fn sq8_encode_row_avx2_unsafe_tail8(
row: &[f32],
inv_scale: &[f32],
c2: &[f32],
dst: &mut [u8],
i: &mut usize,
) {
use std::arch::x86_64::*;
let zero = _mm256_setzero_ps();
let max255 = _mm256_set1_ps(SQ8_CODE_MAX);
unsafe {
let r = _mm256_loadu_ps(row.as_ptr().add(*i));
let inv = _mm256_loadu_ps(inv_scale.as_ptr().add(*i));
let c = _mm256_loadu_ps(c2.as_ptr().add(*i));
let q = _mm256_fmadd_ps(r, inv, c);
let q_clamped = _mm256_max_ps(_mm256_min_ps(q, max255), zero);
let q_i32 = _mm256_cvttps_epi32(q_clamped);
let lo = _mm256_castsi256_si128(q_i32);
let hi = _mm256_extracti128_si256::<1>(q_i32);
let packed_u16 = _mm_packus_epi32(lo, hi);
let packed_u8 = _mm_packus_epi16(packed_u16, packed_u16);
let dst_ptr = dst.as_mut_ptr().add(*i) as *mut i64;
write_unaligned(dst_ptr, _mm_cvtsi128_si64(packed_u8));
*i += F32X8_LANES;
}
}
#[cfg(test)]
mod tests {
use std::hint::black_box;
use super::*;
fn synth_sq8_inputs(dim: usize, seed: u64) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
let mut state = seed.wrapping_mul(0x9E37_79B9_7F4A_7C15);
let mut next = || {
state = state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
((state >> 33) as u32) as f32 / (u32::MAX as f32)
};
let mut scale = vec![0.0f32; dim];
let mut offset = vec![0.0f32; dim];
let mut row = vec![0.0f32; dim];
for d in 0..dim {
offset[d] = -2.0 + 4.0 * next();
let span = 0.1 + 4.0 * next();
scale[d] = span / 255.0;
let pick = next();
row[d] = if pick < 0.05 {
offset[d] - 0.5
} else if pick > 0.95 {
offset[d] + span + 0.5
} else {
offset[d] + span * next()
};
}
(row, scale, offset)
}
#[allow(clippy::manual_clamp)] fn sq8_encode_row_reference(row: &[f32], inv_scale: &[f32], c2: &[f32], dst: &mut [u8]) {
debug_assert_eq!(row.len(), dst.len());
for d in 0..row.len() {
let q = row[d].mul_add(inv_scale[d], c2[d]);
let q_clamped = q.max(0.0).min(255.0);
dst[d] = q_clamped as u8;
}
}
#[test]
fn sq8_encode_consts_match_algebraic_identity() {
let dim = 384;
let (_row, scale, offset) = synth_sq8_inputs(dim, 0xC0FFEE);
let consts = Sq8EncodeConsts::from_scale_offset(&scale, &offset);
assert_eq!(consts.inv_scale.len(), dim);
assert_eq!(consts.c2.len(), dim);
for d in 0..dim {
let want_inv = 1.0 / scale[d];
let want_c2 = (-offset[d]).mul_add(want_inv, 0.5);
assert_eq!(consts.inv_scale[d], want_inv);
assert_eq!(consts.c2[d], want_c2);
}
}
#[test]
fn update_min_max_simd_paths_match_scalar() {
let dims = [
1, 7, 8, 15, 16, 17, 24, 31, 32, 33, 47, 48, 96, 384, 512, 1023,
];
for &dim in &dims {
let (row, _scale, _offset) = synth_sq8_inputs(dim, dim as u64 * 13);
let mut mn_ref = vec![f32::INFINITY; dim];
let mut mx_ref = vec![f32::NEG_INFINITY; dim];
for d in 0..dim {
if row[d] < mn_ref[d] {
mn_ref[d] = row[d];
}
if row[d] > mx_ref[d] {
mx_ref[d] = row[d];
}
}
for tier in ["wide", "avx2", "avx512"] {
let mut mn = vec![f32::INFINITY; dim];
let mut mx = vec![f32::NEG_INFINITY; dim];
match tier {
"wide" => update_min_max_wide(&row, &mut mn, &mut mx),
#[cfg(target_arch = "x86_64")]
"avx2" if std::is_x86_feature_detected!("avx2") => {
unsafe { update_min_max_avx2(&row, &mut mn, &mut mx) };
}
#[cfg(target_arch = "x86_64")]
"avx512" if std::is_x86_feature_detected!("avx512f") => {
unsafe { update_min_max_avx512(&row, &mut mn, &mut mx) };
}
_ => continue,
};
assert_eq!(mn, mn_ref, "tier {} min mismatch at dim {}", tier, dim);
assert_eq!(mx, mx_ref, "tier {} max mismatch at dim {}", tier, dim);
}
}
}
#[test]
fn sq8_encode_row_simd_paths_match_scalar() {
let dims = [
1, 7, 8, 15, 16, 17, 24, 31, 32, 33, 47, 48, 96, 384, 512, 1023,
];
for &dim in &dims {
let (row, scale, offset) = synth_sq8_inputs(dim, dim as u64 * 17 + 1);
let consts = Sq8EncodeConsts::from_scale_offset(&scale, &offset);
let mut dst_ref = vec![0u8; dim];
sq8_encode_row_reference(&row, &consts.inv_scale, &consts.c2, &mut dst_ref);
let mut dst_wide = vec![0u8; dim];
sq8_encode_row_wide(&row, &consts.inv_scale, &consts.c2, &mut dst_wide);
assert_eq!(dst_wide, dst_ref, "wide path mismatch at dim {}", dim);
#[cfg(target_arch = "x86_64")]
if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") {
let mut dst_avx2 = vec![0u8; dim];
unsafe {
sq8_encode_row_avx2(&row, &consts.inv_scale, &consts.c2, &mut dst_avx2);
}
assert_eq!(dst_avx2, dst_ref, "avx2 path mismatch at dim {}", dim);
}
#[cfg(target_arch = "x86_64")]
if std::is_x86_feature_detected!("avx512f") {
let mut dst_avx512 = vec![0u8; dim];
unsafe {
sq8_encode_row_avx512(&row, &consts.inv_scale, &consts.c2, &mut dst_avx512);
}
assert_eq!(dst_avx512, dst_ref, "avx512 path mismatch at dim {}", dim);
}
}
}
#[test]
#[ignore = "perf microbench, not a correctness gate"]
fn sq8_encode_microbench() {
use std::time::Instant;
let dims: &[usize] = &[128, 384, 768, 1024, 1536];
let iters: usize = 200_000;
println!("\n### Sq8 f32 → u8 encode per-tier ns / row (dim sweep)\n");
println!("| dim | scalar ns | wide ns | avx2 ns | avx512 ns |");
println!("|----:|----------:|--------:|--------:|----------:|");
for &dim in dims {
let (row, scale, offset) = synth_sq8_inputs(dim, dim as u64 * 23 + 5);
let consts = Sq8EncodeConsts::from_scale_offset(&scale, &offset);
let mut dst = vec![0u8; dim];
let t0 = Instant::now();
for _ in 0..iters {
sq8_encode_row_reference(&row, &consts.inv_scale, &consts.c2, &mut dst);
black_box(&dst);
}
let scalar_ns = t0.elapsed().as_nanos() as f64 / iters as f64;
let t0 = Instant::now();
for _ in 0..iters {
sq8_encode_row_wide(&row, &consts.inv_scale, &consts.c2, &mut dst);
black_box(&dst);
}
let wide_ns = t0.elapsed().as_nanos() as f64 / iters as f64;
#[cfg(target_arch = "x86_64")]
let avx2_ns =
if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") {
let t0 = Instant::now();
for _ in 0..iters {
unsafe {
sq8_encode_row_avx2(&row, &consts.inv_scale, &consts.c2, &mut dst);
}
black_box(&dst);
}
Some(t0.elapsed().as_nanos() as f64 / iters as f64)
} else {
None
};
#[cfg(not(target_arch = "x86_64"))]
let avx2_ns: Option<f64> = None;
#[cfg(target_arch = "x86_64")]
let avx512_ns = if std::is_x86_feature_detected!("avx512f") {
let t0 = Instant::now();
for _ in 0..iters {
unsafe {
sq8_encode_row_avx512(&row, &consts.inv_scale, &consts.c2, &mut dst);
}
black_box(&dst);
}
Some(t0.elapsed().as_nanos() as f64 / iters as f64)
} else {
None
};
#[cfg(not(target_arch = "x86_64"))]
let avx512_ns: Option<f64> = None;
let fmt = |x: Option<f64>| match x {
Some(v) => format!("{:>7.1}", v),
None => " n/a".to_string(),
};
println!(
"| {:>3} | {:>9.1} | {:>7.1} | {} | {} |",
dim,
scalar_ns,
wide_ns,
fmt(avx2_ns),
fmt(avx512_ns),
);
}
}
}