#[inline]
pub fn pixel_domain_loss(
pixel_error: &[f32],
mask: &[f32],
mask_row_base: usize,
mask_stride: usize,
mask_offset: f32,
block_width: usize,
block_height: usize,
) -> f64 {
debug_assert!(
block_width.is_multiple_of(8),
"block_width must be multiple of 8"
);
debug_assert!(pixel_error.len() >= block_width * block_height);
#[cfg(target_arch = "x86_64")]
{
use archmage::SimdToken;
if let Some(token) = archmage::X64V3Token::summon() {
return pixel_domain_loss_avx2(
token,
pixel_error,
mask,
mask_row_base,
mask_stride,
mask_offset,
block_width,
block_height,
);
}
}
#[cfg(target_arch = "aarch64")]
{
use archmage::SimdToken;
if let Some(token) = archmage::NeonToken::summon() {
return pixel_domain_loss_neon(
token,
pixel_error,
mask,
mask_row_base,
mask_stride,
mask_offset,
block_width,
block_height,
);
}
}
#[cfg(target_arch = "wasm32")]
{
use archmage::SimdToken;
if let Some(token) = archmage::Wasm128Token::summon() {
return pixel_domain_loss_wasm128(
token,
pixel_error,
mask,
mask_row_base,
mask_stride,
mask_offset,
block_width,
block_height,
);
}
}
pixel_domain_loss_scalar(
pixel_error,
mask,
mask_row_base,
mask_stride,
mask_offset,
block_width,
block_height,
)
}
#[inline]
pub fn pixel_domain_loss_scalar(
pixel_error: &[f32],
mask: &[f32],
mask_row_base: usize,
mask_stride: usize,
mask_offset: f32,
block_width: usize,
block_height: usize,
) -> f64 {
let mut channel_loss = 0.0f64;
for py in 0..block_height {
let mask_row_start = mask_row_base + py * mask_stride;
let error_row_start = py * block_width;
for px in 0..block_width {
let mask_val = mask[mask_row_start + px];
let error_val = pixel_error[error_row_start + px];
let masked = (mask_val + mask_offset) * error_val;
let m2 = (masked * masked) as f64;
let m4 = m2 * m2;
let m8 = m4 * m4;
channel_loss += m8;
}
}
channel_loss
}
#[cfg(target_arch = "x86_64")]
#[inline]
#[archmage::arcane]
#[allow(clippy::too_many_arguments)]
pub fn pixel_domain_loss_avx2(
token: archmage::X64V3Token,
pixel_error: &[f32],
mask: &[f32],
mask_row_base: usize,
mask_stride: usize,
mask_offset: f32,
block_width: usize,
block_height: usize,
) -> f64 {
use core::arch::x86_64::*;
use magetypes::simd::{f32x8, f64x4};
let offset_v = f32x8::splat(token, mask_offset);
let mut acc_lo = f64x4::zero(token);
let mut acc_hi = f64x4::zero(token);
for py in 0..block_height {
let mask_row_start = mask_row_base + py * mask_stride;
let error_row_start = py * block_width;
let mask_row = &mask[mask_row_start..mask_row_start + block_width];
let error_row = &pixel_error[error_row_start..error_row_start + block_width];
for (mask_chunk, error_chunk) in mask_row.chunks_exact(8).zip(error_row.chunks_exact(8)) {
let mask_v = f32x8::from_slice(token, mask_chunk);
let error_v = f32x8::from_slice(token, error_chunk);
let masked_v = (mask_v + offset_v) * error_v;
let m2_v = masked_v * masked_v;
let m2_lo_128 = _mm256_castps256_ps128(m2_v.raw());
let m2_lo = f64x4::from_m256d(token, _mm256_cvtps_pd(m2_lo_128));
let m2_hi_128 = _mm256_extractf128_ps::<1>(m2_v.raw());
let m2_hi = f64x4::from_m256d(token, _mm256_cvtps_pd(m2_hi_128));
let m4_lo = m2_lo * m2_lo;
let m4_hi = m2_hi * m2_hi;
let m8_lo = m4_lo * m4_lo;
let m8_hi = m4_hi * m4_hi;
acc_lo += m8_lo;
acc_hi += m8_hi;
}
}
acc_lo.reduce_add() + acc_hi.reduce_add()
}
#[cfg(target_arch = "aarch64")]
#[inline]
#[archmage::arcane]
#[allow(clippy::too_many_arguments)]
pub fn pixel_domain_loss_neon(
token: archmage::NeonToken,
pixel_error: &[f32],
mask: &[f32],
mask_row_base: usize,
mask_stride: usize,
mask_offset: f32,
block_width: usize,
block_height: usize,
) -> f64 {
use core::arch::aarch64::*;
use magetypes::simd::{f32x4, f64x2};
let offset_v = f32x4::splat(token, mask_offset);
let mut acc_lo = f64x2::zero(token);
let mut acc_hi = f64x2::zero(token);
for py in 0..block_height {
let mask_row_start = mask_row_base + py * mask_stride;
let error_row_start = py * block_width;
let mask_row = &mask[mask_row_start..mask_row_start + block_width];
let error_row = &pixel_error[error_row_start..error_row_start + block_width];
let mut px = 0;
while px < block_width {
let mask_v = f32x4::from_slice(token, &mask_row[px..]);
let error_v = f32x4::from_slice(token, &error_row[px..]);
let masked_v = (mask_v + offset_v) * error_v;
let m2_v = masked_v * masked_v;
let m2_raw = m2_v.raw();
let m2_lo = f64x2::from_float64x2_t(token, vcvt_f64_f32(vget_low_f32(m2_raw)));
let m2_hi = f64x2::from_float64x2_t(token, vcvt_high_f64_f32(m2_raw));
let m4_lo = m2_lo * m2_lo;
let m4_hi = m2_hi * m2_hi;
let m8_lo = m4_lo * m4_lo;
let m8_hi = m4_hi * m4_hi;
acc_lo += m8_lo;
acc_hi += m8_hi;
px += 4;
}
}
acc_lo.reduce_add() + acc_hi.reduce_add()
}
#[cfg(target_arch = "wasm32")]
#[inline]
#[archmage::arcane]
#[allow(clippy::too_many_arguments)]
pub fn pixel_domain_loss_wasm128(
token: archmage::Wasm128Token,
pixel_error: &[f32],
mask: &[f32],
mask_row_base: usize,
mask_stride: usize,
mask_offset: f32,
block_width: usize,
block_height: usize,
) -> f64 {
use core::arch::wasm32::*;
use magetypes::simd::{f32x4, f64x2};
let offset_v = f32x4::splat(token, mask_offset);
let mut acc_lo = f64x2::zero(token);
let mut acc_hi = f64x2::zero(token);
for py in 0..block_height {
let mask_row_start = mask_row_base + py * mask_stride;
let error_row_start = py * block_width;
let mask_row = &mask[mask_row_start..mask_row_start + block_width];
let error_row = &pixel_error[error_row_start..error_row_start + block_width];
let mut px = 0;
while px < block_width {
let mask_v = f32x4::from_slice(token, &mask_row[px..]);
let error_v = f32x4::from_slice(token, &error_row[px..]);
let masked_v = (mask_v + offset_v) * error_v;
let m2_v = masked_v * masked_v;
let m2_raw = m2_v.raw();
let m2_lo = f64x2::from_v128(token, f64x2_promote_low_f32x4(m2_raw));
let high_shuffled = i32x4_shuffle::<2, 3, 0, 1>(m2_raw, m2_raw);
let m2_hi = f64x2::from_v128(token, f64x2_promote_low_f32x4(high_shuffled));
let m4_lo = m2_lo * m2_lo;
let m4_hi = m2_hi * m2_hi;
let m8_lo = m4_lo * m4_lo;
let m8_hi = m4_hi * m4_hi;
acc_lo += m8_lo;
acc_hi += m8_hi;
px += 4;
}
}
acc_lo.reduce_add() + acc_hi.reduce_add()
}
#[cfg(test)]
mod tests {
use super::*;
extern crate std;
use alloc::vec;
#[test]
fn test_pixel_domain_loss_uniform() {
let block_width = 8;
let block_height = 8;
let mask_stride = 16;
let pixel_error = vec![1.0f32; block_width * block_height];
let mask = vec![0.5f32; mask_stride * 16];
let mask_offset = 0.5f32;
let result = pixel_domain_loss(
&pixel_error,
&mask,
0,
mask_stride,
mask_offset,
block_width,
block_height,
);
assert!(
(result - 64.0).abs() < 1e-6,
"Expected 64.0, got {}",
result
);
}
#[test]
fn test_pixel_domain_loss_matches_scalar() {
let block_width = 16;
let block_height = 8;
let mask_stride = 32;
let mut pixel_error = vec![0.0f32; block_width * block_height];
let mut mask = vec![0.0f32; mask_stride * 16];
for (i, val) in pixel_error.iter_mut().enumerate() {
*val = (i as f32 * 0.1 + 0.5) * if i % 3 == 0 { -1.0 } else { 1.0 };
}
for (i, val) in mask.iter_mut().enumerate() {
*val = (i as f32 * 0.01 + 0.3).sin().abs();
}
let mask_offset = 0.7f32;
let scalar_result = pixel_domain_loss_scalar(
&pixel_error,
&mask,
0,
mask_stride,
mask_offset,
block_width,
block_height,
);
let report = archmage::testing::for_each_token_permutation(
archmage::testing::CompileTimePolicy::Warn,
|perm| {
let simd_result = pixel_domain_loss(
&pixel_error,
&mask,
0,
mask_stride,
mask_offset,
block_width,
block_height,
);
let rel_err = ((simd_result - scalar_result) / scalar_result.max(1e-20)).abs();
assert!(
rel_err < 1e-6,
"SIMD ({}) vs scalar ({}) relative error {} too large [{perm}]",
simd_result,
scalar_result,
rel_err
);
},
);
std::eprintln!("{report}");
}
#[test]
fn test_pixel_domain_loss_16x16() {
let block_width = 16;
let block_height = 16;
let mask_stride = 32;
let pixel_error = vec![0.5f32; block_width * block_height];
let mask = vec![1.0f32; mask_stride * 32];
let mask_offset = 0.0f32;
let result = pixel_domain_loss(
&pixel_error,
&mask,
0,
mask_stride,
mask_offset,
block_width,
block_height,
);
assert!((result - 1.0).abs() < 1e-6, "Expected 1.0, got {}", result);
}
}