#![allow(dead_code)]
#![allow(clippy::wrong_self_convention)]
use archmage::prelude::*;
use magetypes::simd::generic::f32x8 as GenericF32x8;
use magetypes::simd::generic::i32x8 as GenericI32x8;
#[derive(Clone, Copy, Debug, bytemuck::Pod, bytemuck::Zeroable)]
#[repr(C, align(32))]
pub struct Block8x8f {
pub rows: [[f32; 8]; 8],
}
impl Block8x8f {
pub const ZERO: Self = Self {
rows: [[0.0; 8]; 8],
};
#[inline]
pub fn from_array(arr: &[f32; 64]) -> Self {
let mut rows = [[0.0f32; 8]; 8];
for (row_idx, row) in rows.iter_mut().enumerate() {
let start = row_idx * 8;
*row = arr[start..start + 8].try_into().unwrap();
}
Self { rows }
}
#[inline]
pub fn to_array(&self) -> [f32; 64] {
let mut arr = [0.0f32; 64];
for (row_idx, row) in self.rows.iter().enumerate() {
arr[row_idx * 8..row_idx * 8 + 8].copy_from_slice(row);
}
arr
}
#[inline]
pub fn get(&self, row: usize, col: usize) -> f32 {
self.rows[row][col]
}
#[inline]
pub fn set(&mut self, row: usize, col: usize, value: f32) {
self.rows[row][col] = value;
}
#[inline]
pub fn scale(&self, factor: f32) -> Self {
let mut result = Self::ZERO;
for i in 0..8 {
for j in 0..8 {
result.rows[i][j] = self.rows[i][j] * factor;
}
}
result
}
#[inline]
pub fn mul(&self, other: &Self) -> Self {
let mut result = Self::ZERO;
for i in 0..8 {
for j in 0..8 {
result.rows[i][j] = self.rows[i][j] * other.rows[i][j];
}
}
result
}
#[inline]
pub fn add(&self, other: &Self) -> Self {
let mut result = Self::ZERO;
for i in 0..8 {
for j in 0..8 {
result.rows[i][j] = self.rows[i][j] + other.rows[i][j];
}
}
result
}
}
impl Default for Block8x8f {
fn default() -> Self {
Self::ZERO
}
}
#[derive(Clone, Copy, Debug)]
#[repr(C, align(16))]
pub struct Block8x8i16 {
pub rows: [[i16; 8]; 8],
}
impl Block8x8i16 {
pub const ZERO: Self = Self { rows: [[0; 8]; 8] };
#[inline]
pub fn from_array(arr: &[i16; 64]) -> Self {
let mut rows = [[0i16; 8]; 8];
for (row_idx, row) in rows.iter_mut().enumerate() {
let start = row_idx * 8;
*row = arr[start..start + 8].try_into().unwrap();
}
Self { rows }
}
#[inline]
pub fn to_array(&self) -> [i16; 64] {
let mut arr = [0i16; 64];
for (row_idx, row) in self.rows.iter().enumerate() {
arr[row_idx * 8..row_idx * 8 + 8].copy_from_slice(row);
}
arr
}
#[inline]
pub fn get(&self, row: usize, col: usize) -> i16 {
self.rows[row][col]
}
}
impl Default for Block8x8i16 {
fn default() -> Self {
Self::ZERO
}
}
#[derive(Clone, Debug)]
#[repr(C, align(32))]
pub struct QuantTableSimd {
pub mul_rows: [[f32; 8]; 8],
pub values: [u16; 64],
}
#[derive(Clone, Debug)]
#[repr(C, align(32))]
pub struct ZeroBiasSimd {
pub offset_rows: [[f32; 8]; 8],
pub mul_rows: [[f32; 8]; 8],
}
impl ZeroBiasSimd {
pub fn from_params(params: &crate::quant::ZeroBiasParams) -> Self {
let mut offset_rows = [[0.0f32; 8]; 8];
let mut mul_rows = [[0.0f32; 8]; 8];
for row in 0..8 {
let start = row * 8;
offset_rows[row] = params.offset[start..start + 8].try_into().unwrap();
mul_rows[row] = params.mul[start..start + 8].try_into().unwrap();
}
Self {
offset_rows,
mul_rows,
}
}
}
impl QuantTableSimd {
pub fn from_values(values: &[u16; 64]) -> Self {
let mut mul_rows = [[0.0f32; 8]; 8];
for row in 0..8 {
let start = row * 8;
for col in 0..8 {
mul_rows[row][col] = 8.0 / values[start + col] as f32;
}
}
Self {
mul_rows,
values: *values,
}
}
pub fn from_f32_values(values: &[f32; 64]) -> Self {
let mut mul_rows = [[0.0f32; 8]; 8];
let mut u16_values = [0u16; 64];
for row in 0..8 {
let start = row * 8;
for col in 0..8 {
mul_rows[row][col] = 8.0 / values[start + col];
u16_values[start + col] = values[start + col].round() as u16;
}
}
Self {
mul_rows,
values: u16_values,
}
}
#[inline]
pub fn quantize(&self, block: &Block8x8f) -> Block8x8i32 {
let mut result = Block8x8i32::ZERO;
for i in 0..8 {
for j in 0..8 {
result.rows[i][j] = (block.rows[i][j] * self.mul_rows[i][j]).round() as i32;
}
}
result
}
#[inline]
pub fn quantize_with_zero_bias_zigzag(
&self,
block: &Block8x8f,
zero_bias: &ZeroBiasSimd,
aq_strength: f32,
) -> [i16; 64] {
quantize_block_zigzag(&self.mul_rows, block, zero_bias, aq_strength)
}
#[inline]
pub fn quantize_with_zero_bias(
&self,
block: &Block8x8f,
zero_bias: &ZeroBiasSimd,
aq_strength: f32,
) -> [i16; 64] {
quantize_block(&self.mul_rows, block, zero_bias, aq_strength)
}
#[inline]
pub fn quantize_array_with_zero_bias(
&self,
coeffs: &[f32; 64],
zero_bias: &ZeroBiasSimd,
aq_strength: f32,
) -> [i16; 64] {
let mut result = [0i16; 64];
for row in 0..8 {
let k = row * 8;
for col in 0..8 {
let qval = coeffs[k + col] * self.mul_rows[row][col];
let threshold =
zero_bias.offset_rows[row][col] + zero_bias.mul_rows[row][col] * aq_strength;
if qval.abs() >= threshold {
result[k + col] = fast_round_i32(qval) as i16;
}
}
}
result
}
}
#[derive(Clone, Copy, Debug)]
#[repr(C, align(32))]
pub struct Block8x8i32 {
pub rows: [[i32; 8]; 8],
}
impl Block8x8i32 {
pub const ZERO: Self = Self { rows: [[0; 8]; 8] };
#[inline]
pub fn to_i16(&self) -> Block8x8i16 {
let mut result = Block8x8i16::ZERO;
for i in 0..8 {
for j in 0..8 {
result.rows[i][j] = self.rows[i][j].clamp(-32768, 32767) as i16;
}
}
result
}
#[inline]
pub fn to_i16_array(&self) -> [i16; 64] {
self.to_i16().to_array()
}
}
impl Default for Block8x8i32 {
fn default() -> Self {
Self::ZERO
}
}
#[magetypes(v3, neon, wasm128, scalar)]
#[inline(always)]
fn mage_quantize_block_zigzag(
token: Token,
block: &Block8x8f,
mul_rows: &[[f32; 8]; 8],
zero_bias: &ZeroBiasSimd,
aq_strength: f32,
) -> [i16; 64] {
use crate::foundation::consts::JPEG_ZIGZAG_ORDER;
#[allow(non_camel_case_types)]
type f32x8 = GenericF32x8<Token>;
#[allow(non_camel_case_types)]
type i32x8 = GenericI32x8<Token>;
let aq_m = f32x8::splat(token, aq_strength);
let zero_i32 = i32x8::zero(token);
let mut result = [0i16; 64];
for row in 0..8 {
let block_m = f32x8::from_array(token, block.rows[row]);
let mul_m = f32x8::from_array(token, mul_rows[row]);
let offset_m = f32x8::from_array(token, zero_bias.offset_rows[row]);
let bias_mul_m = f32x8::from_array(token, zero_bias.mul_rows[row]);
let qval = block_m * mul_m;
let threshold = bias_mul_m.mul_add(aq_m, offset_m);
let abs_qval = qval.abs();
let mask = abs_qval.simd_ge(threshold);
let rounded = qval.to_i32_round();
let mask_i32 = mask.bitcast_to_i32();
let blended = i32x8::blend(mask_i32, rounded, zero_i32);
let arr = blended.to_array();
let k = row * 8;
result[JPEG_ZIGZAG_ORDER[k] as usize] = arr[0] as i16;
result[JPEG_ZIGZAG_ORDER[k + 1] as usize] = arr[1] as i16;
result[JPEG_ZIGZAG_ORDER[k + 2] as usize] = arr[2] as i16;
result[JPEG_ZIGZAG_ORDER[k + 3] as usize] = arr[3] as i16;
result[JPEG_ZIGZAG_ORDER[k + 4] as usize] = arr[4] as i16;
result[JPEG_ZIGZAG_ORDER[k + 5] as usize] = arr[5] as i16;
result[JPEG_ZIGZAG_ORDER[k + 6] as usize] = arr[6] as i16;
result[JPEG_ZIGZAG_ORDER[k + 7] as usize] = arr[7] as i16;
}
result
}
#[magetypes(v3, neon, wasm128, scalar)]
#[inline(always)]
fn mage_quantize_block(
token: Token,
block: &Block8x8f,
mul_rows: &[[f32; 8]; 8],
zero_bias: &ZeroBiasSimd,
aq_strength: f32,
) -> [i16; 64] {
#[allow(non_camel_case_types)]
type f32x8 = GenericF32x8<Token>;
#[allow(non_camel_case_types)]
type i32x8 = GenericI32x8<Token>;
let aq_m = f32x8::splat(token, aq_strength);
let zero_i32 = i32x8::zero(token);
let mut result = [0i16; 64];
for row in 0..8 {
let block_m = f32x8::from_array(token, block.rows[row]);
let mul_m = f32x8::from_array(token, mul_rows[row]);
let offset_m = f32x8::from_array(token, zero_bias.offset_rows[row]);
let bias_mul_m = f32x8::from_array(token, zero_bias.mul_rows[row]);
let qval = block_m * mul_m;
let threshold = bias_mul_m.mul_add(aq_m, offset_m);
let abs_qval = qval.abs();
let mask = abs_qval.simd_ge(threshold);
let rounded = qval.to_i32_round();
let mask_i32 = mask.bitcast_to_i32();
let blended = i32x8::blend(mask_i32, rounded, zero_i32);
let arr = blended.to_array();
let k = row * 8;
result[k] = arr[0] as i16;
result[k + 1] = arr[1] as i16;
result[k + 2] = arr[2] as i16;
result[k + 3] = arr[3] as i16;
result[k + 4] = arr[4] as i16;
result[k + 5] = arr[5] as i16;
result[k + 6] = arr[6] as i16;
result[k + 7] = arr[7] as i16;
}
result
}
#[inline(always)]
fn fast_round_i32(v: f32) -> i32 {
v.round() as i32
}
#[inline]
fn quantize_block_zigzag(
mul_rows: &[[f32; 8]; 8],
block: &Block8x8f,
zero_bias: &ZeroBiasSimd,
aq_strength: f32,
) -> [i16; 64] {
incant!(mage_quantize_block_zigzag(
block,
mul_rows,
zero_bias,
aq_strength
))
}
#[inline]
fn quantize_block(
mul_rows: &[[f32; 8]; 8],
block: &Block8x8f,
zero_bias: &ZeroBiasSimd,
aq_strength: f32,
) -> [i16; 64] {
incant!(mage_quantize_block(block, mul_rows, zero_bias, aq_strength))
}
#[cfg(test)]
mod tests {
use super::*;
fn quantize_test_data() -> (Block8x8f, QuantTableSimd, ZeroBiasSimd, f32) {
let mut coeffs = [0.0f32; 64];
for i in 0..64 {
let row = i / 8;
let col = i % 8;
let freq = (row + col) as f32;
coeffs[i] = (100.0 - freq * 8.0) * if i % 3 == 0 { -1.0 } else { 1.0 };
}
let block = Block8x8f::from_array(&coeffs);
let mut qvals = [1u16; 64];
for i in 0..64 {
qvals[i] = ((i as u16 / 4) + 2).min(255);
}
let quant = QuantTableSimd::from_values(&qvals);
let mut bias_params = crate::quant::ZeroBiasParams {
offset: [0.0; 64],
mul: [0.0; 64],
};
for i in 0..64 {
bias_params.offset[i] = 0.5;
bias_params.mul[i] = 0.15;
}
let zero_bias = ZeroBiasSimd::from_params(&bias_params);
(block, quant, zero_bias, 1.0)
}
#[cfg(target_arch = "x86_64")]
#[test]
fn test_quantize_zigzag_dispatch_parity() {
use archmage::testing::{CompileTimePolicy, for_each_token_permutation};
let (block, quant, zero_bias, aq) = quantize_test_data();
let reference = quantize_block_zigzag(&quant.mul_rows, &block, &zero_bias, aq);
let report = for_each_token_permutation(CompileTimePolicy::Warn, |perm| {
let result = quantize_block_zigzag(&quant.mul_rows, &block, &zero_bias, aq);
assert_eq!(
result, reference,
"quantize_block_zigzag mismatch at permutation: {perm}"
);
});
eprintln!("quantize_zigzag: {report}");
assert!(
report.permutations_run >= 2,
"expected at least 2 permutations"
);
}
#[cfg(target_arch = "x86_64")]
#[test]
fn test_quantize_natural_dispatch_parity() {
use archmage::testing::{CompileTimePolicy, for_each_token_permutation};
let (block, quant, zero_bias, aq) = quantize_test_data();
let reference = quantize_block(&quant.mul_rows, &block, &zero_bias, aq);
let report = for_each_token_permutation(CompileTimePolicy::Warn, |perm| {
let result = quantize_block(&quant.mul_rows, &block, &zero_bias, aq);
assert_eq!(
result, reference,
"quantize_block mismatch at permutation: {perm}"
);
});
eprintln!("quantize_natural: {report}");
assert!(
report.permutations_run >= 2,
"expected at least 2 permutations"
);
}
#[cfg(target_arch = "x86_64")]
#[test]
fn test_quantize_api_dispatch_parity() {
use archmage::testing::{CompileTimePolicy, for_each_token_permutation};
let (block, quant, zero_bias, aq) = quantize_test_data();
let ref_zigzag = quant.quantize_with_zero_bias_zigzag(&block, &zero_bias, aq);
let ref_natural = quant.quantize_with_zero_bias(&block, &zero_bias, aq);
let report = for_each_token_permutation(CompileTimePolicy::Warn, |perm| {
let zigzag = quant.quantize_with_zero_bias_zigzag(&block, &zero_bias, aq);
let natural = quant.quantize_with_zero_bias(&block, &zero_bias, aq);
assert_eq!(zigzag, ref_zigzag, "zigzag API mismatch at: {perm}");
assert_eq!(natural, ref_natural, "natural API mismatch at: {perm}");
});
eprintln!("quantize_api: {report}");
}
#[test]
fn test_block8x8f_roundtrip() {
let mut arr = [0.0f32; 64];
for i in 0..64 {
arr[i] = i as f32 * 1.5;
}
let block = Block8x8f::from_array(&arr);
let result = block.to_array();
for i in 0..64 {
assert!((arr[i] - result[i]).abs() < 1e-6);
}
}
#[test]
fn test_block8x8f_get_set() {
let mut block = Block8x8f::ZERO;
block.set(3, 5, 42.0);
assert!((block.get(3, 5) - 42.0).abs() < 1e-6);
}
#[test]
fn test_block8x8f_scale() {
let mut arr = [0.0f32; 64];
for i in 0..64 {
arr[i] = i as f32;
}
let block = Block8x8f::from_array(&arr);
let scaled = block.scale(2.0);
for i in 0..64 {
let row = i / 8;
let col = i % 8;
assert!((scaled.get(row, col) - (i as f32 * 2.0)).abs() < 1e-6);
}
}
#[test]
fn test_quant_table_simd() {
let mut values = [1u16; 64];
for i in 0..64 {
values[i] = (i + 1) as u16;
}
let quant = QuantTableSimd::from_values(&values);
for row in 0..8 {
for col in 0..8 {
let expected = 8.0 / (row * 8 + col + 1) as f32;
assert!((quant.mul_rows[row][col] - expected).abs() < 1e-6);
}
}
}
#[test]
fn test_quantize_simple() {
let mut arr = [0.0f32; 64];
for i in 0..64 {
arr[i] = (i + 1) as f32 * 10.0; }
let block = Block8x8f::from_array(&arr);
let mut values = [1u16; 64];
for i in 0..64 {
values[i] = (i + 1) as u16;
}
let quant = QuantTableSimd::from_values(&values);
let result = quant.quantize(&block);
let arr = result.to_i16_array();
for i in 0..64 {
assert_eq!(arr[i], 80, "Mismatch at index {}", i);
}
}
}