pub fn apply_gain_row_scalar(
sdr: &[[f32; 3]],
gainmap: &[u8],
lut: &[f32; 256],
output: &mut [[f32; 3]],
) {
assert_eq!(sdr.len(), output.len());
assert_eq!(sdr.len(), gainmap.len());
for (i, (sdr_px, out_px)) in sdr.iter().zip(output.iter_mut()).enumerate() {
let g = lut[gainmap[i] as usize];
out_px[0] = sdr_px[0] * g;
out_px[1] = sdr_px[1] * g;
out_px[2] = sdr_px[2] * g;
}
}
#[cfg(feature = "simd")]
use archmage::prelude::*;
#[cfg(feature = "simd")]
use magetypes::simd::generic::f32x8 as GenericF32x8;
#[cfg(feature = "simd")]
#[magetypes(v3, neon, wasm128, scalar)]
fn apply_gain_inner(
token: Token,
sdr: &[[f32; 3]],
gainmap: &[u8],
lut: &[f32; 256],
output: &mut [[f32; 3]],
) {
#[allow(non_camel_case_types)]
type f32x8 = GenericF32x8<Token>;
const LANES: usize = 8;
assert_eq!(sdr.len(), output.len());
assert_eq!(sdr.len(), gainmap.len());
let chunks = sdr.len() / LANES;
for chunk_idx in 0..chunks {
let base = chunk_idx * LANES;
let gains: [f32; LANES] = core::array::from_fn(|i| lut[gainmap[base + i] as usize]);
let g = f32x8::from_array(token, gains);
let r: [f32; LANES] = core::array::from_fn(|i| sdr[base + i][0]);
let r_v = f32x8::from_array(token, r);
let g_ch: [f32; LANES] = core::array::from_fn(|i| sdr[base + i][1]);
let g_v = f32x8::from_array(token, g_ch);
let b: [f32; LANES] = core::array::from_fn(|i| sdr[base + i][2]);
let b_v = f32x8::from_array(token, b);
let r_out = r_v * g;
let g_out = g_v * g;
let b_out = b_v * g;
let r_arr = r_out.to_array();
let g_arr = g_out.to_array();
let b_arr = b_out.to_array();
for i in 0..LANES {
output[base + i] = [r_arr[i], g_arr[i], b_arr[i]];
}
}
let remainder_start = chunks * LANES;
for i in remainder_start..sdr.len() {
let g_val = lut[gainmap[i] as usize];
output[i][0] = sdr[i][0] * g_val;
output[i][1] = sdr[i][1] * g_val;
output[i][2] = sdr[i][2] * g_val;
}
}
#[cfg(feature = "simd")]
pub fn apply_gain_row_simd(
sdr: &[[f32; 3]],
gainmap: &[u8],
lut: &[f32; 256],
output: &mut [[f32; 3]],
) {
incant!(
apply_gain_inner(sdr, gainmap, lut, output),
[v3, neon, wasm128, scalar]
);
}
#[cfg(test)]
mod tests {
extern crate std;
use std::vec;
#[cfg(feature = "simd")]
use std::vec::Vec;
use super::*;
fn build_test_lut(min_gain: f32, max_gain: f32) -> [f32; 256] {
let mut lut = [0.0f32; 256];
for (i, entry) in lut.iter_mut().enumerate() {
*entry = min_gain + (max_gain - min_gain) * (i as f32 / 255.0);
}
lut
}
#[test]
fn test_scalar_basic() {
let sdr = vec![[0.5f32, 0.25, 0.75], [1.0, 0.0, 0.5]];
let gainmap = vec![128u8, 255];
let lut = build_test_lut(1.0, 4.0);
let mut output = vec![[0.0f32; 3]; 2];
apply_gain_row_scalar(&sdr, &gainmap, &lut, &mut output);
let g0 = lut[128];
assert!(
(output[0][0] - 0.5 * g0).abs() < 1e-6,
"R0: {}",
output[0][0]
);
assert!(
(output[0][1] - 0.25 * g0).abs() < 1e-6,
"G0: {}",
output[0][1]
);
assert!(
(output[0][2] - 0.75 * g0).abs() < 1e-6,
"B0: {}",
output[0][2]
);
let g1 = lut[255];
assert!(
(output[1][0] - 1.0 * g1).abs() < 1e-6,
"R1: {}",
output[1][0]
);
assert!(
(output[1][1] - 0.0 * g1).abs() < 1e-6,
"G1: {}",
output[1][1]
);
assert!(
(output[1][2] - 0.5 * g1).abs() < 1e-6,
"B1: {}",
output[1][2]
);
}
#[cfg(feature = "simd")]
#[test]
fn test_simd_matches_scalar() {
let pixel_count = 256;
let sdr: Vec<[f32; 3]> = (0..pixel_count)
.map(|i| {
let v = i as f32 / 255.0;
[v, v * 0.5, 1.0 - v]
})
.collect();
let gainmap: Vec<u8> = (0..pixel_count).map(|i| i as u8).collect();
let lut = build_test_lut(0.5, 8.0);
let mut scalar_output = vec![[0.0f32; 3]; pixel_count];
let mut simd_output = vec![[0.0f32; 3]; pixel_count];
apply_gain_row_scalar(&sdr, &gainmap, &lut, &mut scalar_output);
apply_gain_row_simd(&sdr, &gainmap, &lut, &mut simd_output);
for i in 0..pixel_count {
for ch in 0..3 {
assert!(
(scalar_output[i][ch] - simd_output[i][ch]).abs() < 1e-6,
"Mismatch at pixel {} channel {}: scalar={}, simd={}",
i,
ch,
scalar_output[i][ch],
simd_output[i][ch],
);
}
}
}
#[cfg(feature = "simd")]
#[test]
fn test_simd_non_aligned_length() {
for width in [1, 3, 7, 9, 13, 15, 17, 31, 33] {
let sdr: Vec<[f32; 3]> = (0..width)
.map(|i| {
let v = (i as f32 * 7.0) % 1.0;
[v, v, v]
})
.collect();
let gainmap: Vec<u8> = (0..width).map(|i| ((i * 13) % 256) as u8).collect();
let lut = build_test_lut(1.0, 4.0);
let mut scalar_output = vec![[0.0f32; 3]; width];
let mut simd_output = vec![[0.0f32; 3]; width];
apply_gain_row_scalar(&sdr, &gainmap, &lut, &mut scalar_output);
apply_gain_row_simd(&sdr, &gainmap, &lut, &mut simd_output);
for i in 0..width {
for ch in 0..3 {
assert!(
(scalar_output[i][ch] - simd_output[i][ch]).abs() < 1e-6,
"width={}, pixel={}, ch={}: scalar={}, simd={}",
width,
i,
ch,
scalar_output[i][ch],
simd_output[i][ch],
);
}
}
}
}
#[cfg(feature = "simd")]
#[test]
fn test_simd_empty() {
let sdr: &[[f32; 3]] = &[];
let gainmap: &[u8] = &[];
let lut = build_test_lut(1.0, 4.0);
let mut output: Vec<[f32; 3]> = vec![];
apply_gain_row_simd(sdr, gainmap, &lut, &mut output);
assert!(output.is_empty());
}
#[cfg(feature = "simd")]
#[test]
fn test_simd_single_pixel() {
let sdr = vec![[0.8f32, 0.4, 0.2]];
let gainmap = vec![200u8];
let lut = build_test_lut(1.0, 4.0);
let mut scalar_output = vec![[0.0f32; 3]; 1];
let mut simd_output = vec![[0.0f32; 3]; 1];
apply_gain_row_scalar(&sdr, &gainmap, &lut, &mut scalar_output);
apply_gain_row_simd(&sdr, &gainmap, &lut, &mut simd_output);
for ch in 0..3 {
assert!(
(scalar_output[0][ch] - simd_output[0][ch]).abs() < 1e-6,
"ch={}: scalar={}, simd={}",
ch,
scalar_output[0][ch],
simd_output[0][ch],
);
}
}
#[cfg(feature = "simd")]
#[test]
fn test_simd_gain_endpoints() {
let min_gain = 0.5f32;
let max_gain = 8.0f32;
let lut = build_test_lut(min_gain, max_gain);
let sdr = vec![[1.0f32; 3]; 2];
let gainmap = vec![0u8, 255];
let mut output = vec![[0.0f32; 3]; 2];
apply_gain_row_simd(&sdr, &gainmap, &lut, &mut output);
for (ch, val) in output[0].iter().enumerate() {
assert!(
(val - min_gain).abs() < 1e-6,
"byte 0 ch={}: expected {}, got {}",
ch,
min_gain,
val,
);
}
for (ch, val) in output[1].iter().enumerate() {
assert!(
(val - max_gain).abs() < 1e-6,
"byte 255 ch={}: expected {}, got {}",
ch,
max_gain,
val,
);
}
}
}