1use alloc::vec;
13use alloc::vec::Vec;
14
15const CHANNEL_WEIGHTS: [f32; 3] = [12.339_445, 1.0, 0.2];
17
18#[inline]
25pub fn compute_block_l2_errors(
26 original: [&[f32]; 3],
27 reconstructed: [&[f32]; 3],
28 mask1x1: &[f32],
29 xsize_blocks: usize,
30 ysize_blocks: usize,
31) -> Vec<f32> {
32 let padded_width = xsize_blocks * 8;
33 let nblocks = xsize_blocks * ysize_blocks;
34
35 #[cfg(target_arch = "x86_64")]
36 {
37 use archmage::SimdToken;
38 if let Some(token) = archmage::X64V3Token::summon() {
39 return compute_block_l2_errors_avx2(
40 token,
41 original,
42 reconstructed,
43 mask1x1,
44 xsize_blocks,
45 ysize_blocks,
46 padded_width,
47 nblocks,
48 );
49 }
50 }
51
52 #[cfg(target_arch = "aarch64")]
53 {
54 use archmage::SimdToken;
55 if let Some(token) = archmage::NeonToken::summon() {
56 return compute_block_l2_errors_neon(
57 token,
58 original,
59 reconstructed,
60 mask1x1,
61 xsize_blocks,
62 ysize_blocks,
63 padded_width,
64 nblocks,
65 );
66 }
67 }
68
69 compute_block_l2_errors_scalar(
70 original,
71 reconstructed,
72 mask1x1,
73 xsize_blocks,
74 ysize_blocks,
75 padded_width,
76 nblocks,
77 )
78}
79
80#[inline]
81pub fn compute_block_l2_errors_scalar(
82 original: [&[f32]; 3],
83 reconstructed: [&[f32]; 3],
84 mask1x1: &[f32],
85 xsize_blocks: usize,
86 ysize_blocks: usize,
87 padded_width: usize,
88 nblocks: usize,
89) -> Vec<f32> {
90 let mut errors = vec![0.0f32; nblocks];
91
92 for by in 0..ysize_blocks {
93 for bx in 0..xsize_blocks {
94 let block_idx = by * xsize_blocks + bx;
95 let mut total_err = 0.0f32;
96
97 for py in 0..8 {
98 for px in 0..8 {
99 let y = by * 8 + py;
100 let x = bx * 8 + px;
101 let pixel_idx = y * padded_width + x;
102 let mask = mask1x1[pixel_idx];
103 let mask_sq = mask * mask;
104
105 for c in 0..3 {
106 let diff = original[c][pixel_idx] - reconstructed[c][pixel_idx];
107 total_err += CHANNEL_WEIGHTS[c] * mask_sq * diff * diff;
108 }
109 }
110 }
111
112 errors[block_idx] = total_err;
113 }
114 }
115
116 errors
117}
118
119#[cfg(target_arch = "x86_64")]
120#[inline]
121#[archmage::arcane]
122#[allow(clippy::too_many_arguments)]
123pub fn compute_block_l2_errors_avx2(
124 token: archmage::X64V3Token,
125 original: [&[f32]; 3],
126 reconstructed: [&[f32]; 3],
127 mask1x1: &[f32],
128 xsize_blocks: usize,
129 ysize_blocks: usize,
130 padded_width: usize,
131 nblocks: usize,
132) -> Vec<f32> {
133 use magetypes::simd::f32x8;
134
135 let w_x = f32x8::splat(token, CHANNEL_WEIGHTS[0]);
136 let w_b = f32x8::splat(token, CHANNEL_WEIGHTS[2]);
138
139 let mut errors = vec![0.0f32; nblocks];
140
141 for by in 0..ysize_blocks {
142 for bx in 0..xsize_blocks {
143 let block_idx = by * xsize_blocks + bx;
144 let mut acc = f32x8::zero(token);
145
146 for py in 0..8 {
147 let row_start = (by * 8 + py) * padded_width + bx * 8;
148
149 let mask_v = f32x8::from_slice(token, &mask1x1[row_start..]);
151 let mask_sq = mask_v * mask_v;
152
153 let orig_x = f32x8::from_slice(token, &original[0][row_start..]);
155 let recon_x = f32x8::from_slice(token, &reconstructed[0][row_start..]);
156 let diff_x = orig_x - recon_x;
157 acc += w_x * mask_sq * diff_x * diff_x;
158
159 let orig_y = f32x8::from_slice(token, &original[1][row_start..]);
161 let recon_y = f32x8::from_slice(token, &reconstructed[1][row_start..]);
162 let diff_y = orig_y - recon_y;
163 acc += mask_sq * diff_y * diff_y;
165
166 let orig_b = f32x8::from_slice(token, &original[2][row_start..]);
168 let recon_b = f32x8::from_slice(token, &reconstructed[2][row_start..]);
169 let diff_b = orig_b - recon_b;
170 acc += w_b * mask_sq * diff_b * diff_b;
171 }
172
173 errors[block_idx] = acc.reduce_add();
175 }
176 }
177
178 errors
179}
180
181#[cfg(target_arch = "aarch64")]
186#[inline]
187#[archmage::arcane]
188#[allow(clippy::too_many_arguments)]
189pub fn compute_block_l2_errors_neon(
190 token: archmage::NeonToken,
191 original: [&[f32]; 3],
192 reconstructed: [&[f32]; 3],
193 mask1x1: &[f32],
194 xsize_blocks: usize,
195 ysize_blocks: usize,
196 padded_width: usize,
197 nblocks: usize,
198) -> Vec<f32> {
199 use magetypes::simd::f32x4;
200
201 let w_x = f32x4::splat(token, CHANNEL_WEIGHTS[0]);
202 let w_b = f32x4::splat(token, CHANNEL_WEIGHTS[2]);
203
204 let mut errors = vec![0.0f32; nblocks];
205
206 for by in 0..ysize_blocks {
207 for bx in 0..xsize_blocks {
208 let block_idx = by * xsize_blocks + bx;
209 let mut acc = f32x4::zero(token);
210
211 for py in 0..8 {
212 let row_start = (by * 8 + py) * padded_width + bx * 8;
213
214 for half in 0..2usize {
216 let off = row_start + half * 4;
217
218 let mask_v = f32x4::from_slice(token, &mask1x1[off..]);
219 let mask_sq = mask_v * mask_v;
220
221 let orig_x = f32x4::from_slice(token, &original[0][off..]);
222 let recon_x = f32x4::from_slice(token, &reconstructed[0][off..]);
223 let diff_x = orig_x - recon_x;
224 acc += w_x * mask_sq * diff_x * diff_x;
225
226 let orig_y = f32x4::from_slice(token, &original[1][off..]);
227 let recon_y = f32x4::from_slice(token, &reconstructed[1][off..]);
228 let diff_y = orig_y - recon_y;
229 acc += mask_sq * diff_y * diff_y;
230
231 let orig_b = f32x4::from_slice(token, &original[2][off..]);
232 let recon_b = f32x4::from_slice(token, &reconstructed[2][off..]);
233 let diff_b = orig_b - recon_b;
234 acc += w_b * mask_sq * diff_b * diff_b;
235 }
236 }
237
238 errors[block_idx] = acc.reduce_add();
239 }
240 }
241
242 errors
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248 use alloc::vec;
249
250 #[test]
251 fn test_block_l2_errors_uniform() {
252 let xsize_blocks = 2;
253 let ysize_blocks = 2;
254 let padded_width = xsize_blocks * 8;
255 let n = padded_width * ysize_blocks * 8;
256
257 let original = [vec![1.0f32; n], vec![1.0f32; n], vec![1.0f32; n]];
259 let reconstructed = [vec![0.0f32; n], vec![0.0f32; n], vec![0.0f32; n]];
260 let mask = vec![1.0f32; n];
261
262 let errors = compute_block_l2_errors(
263 [&original[0], &original[1], &original[2]],
264 [&reconstructed[0], &reconstructed[1], &reconstructed[2]],
265 &mask,
266 xsize_blocks,
267 ysize_blocks,
268 );
269
270 let expected = 64.0 * (CHANNEL_WEIGHTS[0] + CHANNEL_WEIGHTS[1] + CHANNEL_WEIGHTS[2]);
274 for (i, &err) in errors.iter().enumerate() {
275 assert!(
276 (err - expected).abs() < 0.1,
277 "Block {} error {} != expected {}",
278 i,
279 err,
280 expected
281 );
282 }
283 }
284
285 #[test]
286 fn test_block_l2_errors_matches_scalar() {
287 let xsize_blocks = 4;
288 let ysize_blocks = 4;
289 let padded_width = xsize_blocks * 8;
290 let n = padded_width * ysize_blocks * 8;
291
292 let mut orig0 = vec![0.0f32; n];
294 let mut orig1 = vec![0.0f32; n];
295 let mut orig2 = vec![0.0f32; n];
296 let mut recon0 = vec![0.0f32; n];
297 let mut recon1 = vec![0.0f32; n];
298 let mut recon2 = vec![0.0f32; n];
299 let mut mask = vec![0.0f32; n];
300
301 for i in 0..n {
302 let f = i as f32;
303 orig0[i] = (f * 0.013).sin() * 0.5;
304 orig1[i] = (f * 0.017).cos() * 0.8;
305 orig2[i] = (f * 0.023).sin() * 0.3;
306 recon0[i] = orig0[i] + (f * 0.031).sin() * 0.1;
307 recon1[i] = orig1[i] + (f * 0.037).cos() * 0.05;
308 recon2[i] = orig2[i] + (f * 0.041).sin() * 0.02;
309 mask[i] = 0.5 + (f * 0.007).sin().abs() * 0.5;
310 }
311
312 let simd_result = compute_block_l2_errors(
313 [&orig0, &orig1, &orig2],
314 [&recon0, &recon1, &recon2],
315 &mask,
316 xsize_blocks,
317 ysize_blocks,
318 );
319
320 let scalar_result = compute_block_l2_errors_scalar(
321 [&orig0, &orig1, &orig2],
322 [&recon0, &recon1, &recon2],
323 &mask,
324 xsize_blocks,
325 ysize_blocks,
326 padded_width,
327 xsize_blocks * ysize_blocks,
328 );
329
330 assert_eq!(simd_result.len(), scalar_result.len());
331 for (i, (&s, &sc)) in simd_result.iter().zip(scalar_result.iter()).enumerate() {
332 let rel_err = if sc.abs() > 1e-10 {
333 ((s - sc) / sc).abs()
334 } else {
335 (s - sc).abs()
336 };
337 assert!(
338 rel_err < 1e-5,
339 "Block {} SIMD {} vs scalar {} rel_err {}",
340 i,
341 s,
342 sc,
343 rel_err
344 );
345 }
346 }
347}