Skip to main content

oximedia_codec/simd/
sad.rs

1//! Sum of Absolute Differences (SAD) operations.
2//!
3//! SAD is fundamental to motion estimation in video codecs. It measures
4//! the similarity between blocks of pixels, with lower values indicating
5//! better matches.
6//!
7//! This module provides optimized SAD calculations for common block sizes:
8//! - 4x4 (used in H.264/AV1 for small partitions)
9//! - 8x8 (common block size)
10//! - 16x16 (macroblock size)
11//! - 32x32 (used in HEVC/AV1)
12//!
13//! All functions are designed to map efficiently to SIMD instructions.
14
15#![forbid(unsafe_code)]
16// Allow loop indexing for SIMD-like element-wise operations
17#![allow(clippy::needless_range_loop)]
18
19use super::scalar::ScalarFallback;
20use super::traits::SimdOps;
21use super::types::U8x16;
22
23/// SAD operations using SIMD.
24pub struct SadOps<S: SimdOps> {
25    simd: S,
26}
27
28impl<S: SimdOps + Default> Default for SadOps<S> {
29    fn default() -> Self {
30        Self::new(S::default())
31    }
32}
33
34impl<S: SimdOps> SadOps<S> {
35    /// Create a new SAD operations instance.
36    #[inline]
37    #[must_use]
38    pub const fn new(simd: S) -> Self {
39        Self { simd }
40    }
41
42    /// Get the underlying SIMD implementation.
43    #[inline]
44    #[must_use]
45    pub const fn simd(&self) -> &S {
46        &self.simd
47    }
48
49    /// Calculate SAD for a 4x4 block.
50    ///
51    /// # Arguments
52    /// * `src` - Source block data (row-major, stride = `src_stride`)
53    /// * `src_stride` - Stride between source rows
54    /// * `ref_block` - Reference block data (row-major, stride = `ref_stride`)
55    /// * `ref_stride` - Stride between reference rows
56    ///
57    /// # Returns
58    /// Sum of absolute differences for all 16 pixels.
59    #[inline]
60    pub fn sad_4x4(
61        &self,
62        src: &[u8],
63        src_stride: usize,
64        ref_block: &[u8],
65        ref_stride: usize,
66    ) -> u32 {
67        let mut sum = 0u32;
68
69        for row in 0..4 {
70            let src_offset = row * src_stride;
71            let ref_offset = row * ref_stride;
72
73            if src_offset + 4 <= src.len() && ref_offset + 4 <= ref_block.len() {
74                for col in 0..4 {
75                    let diff =
76                        i32::from(src[src_offset + col]) - i32::from(ref_block[ref_offset + col]);
77                    sum += diff.unsigned_abs();
78                }
79            }
80        }
81
82        sum
83    }
84
85    /// Calculate SAD for an 8x8 block.
86    #[inline]
87    pub fn sad_8x8(
88        &self,
89        src: &[u8],
90        src_stride: usize,
91        ref_block: &[u8],
92        ref_stride: usize,
93    ) -> u32 {
94        let mut sum = 0u32;
95
96        for row in 0..8 {
97            let src_offset = row * src_stride;
98            let ref_offset = row * ref_stride;
99
100            if src_offset + 8 <= src.len() && ref_offset + 8 <= ref_block.len() {
101                sum += self.simd.sad_8(
102                    &src[src_offset..src_offset + 8],
103                    &ref_block[ref_offset..ref_offset + 8],
104                );
105            }
106        }
107
108        sum
109    }
110
111    /// Calculate SAD for a 16x16 block.
112    #[inline]
113    pub fn sad_16x16(
114        &self,
115        src: &[u8],
116        src_stride: usize,
117        ref_block: &[u8],
118        ref_stride: usize,
119    ) -> u32 {
120        let mut sum = 0u32;
121
122        for row in 0..16 {
123            let src_offset = row * src_stride;
124            let ref_offset = row * ref_stride;
125
126            if src_offset + 16 <= src.len() && ref_offset + 16 <= ref_block.len() {
127                let src_row = U8x16::from_array(
128                    src[src_offset..src_offset + 16]
129                        .try_into()
130                        .unwrap_or([0; 16]),
131                );
132                let ref_row = U8x16::from_array(
133                    ref_block[ref_offset..ref_offset + 16]
134                        .try_into()
135                        .unwrap_or([0; 16]),
136                );
137                sum += self.simd.sad_u8x16(src_row, ref_row);
138            }
139        }
140
141        sum
142    }
143
144    /// Calculate SAD for a 32x32 block.
145    #[inline]
146    pub fn sad_32x32(
147        &self,
148        src: &[u8],
149        src_stride: usize,
150        ref_block: &[u8],
151        ref_stride: usize,
152    ) -> u32 {
153        let mut sum = 0u32;
154
155        for row in 0..32 {
156            let src_offset = row * src_stride;
157            let ref_offset = row * ref_stride;
158
159            if src_offset + 32 <= src.len() && ref_offset + 32 <= ref_block.len() {
160                // Process as two 16-byte chunks
161                for chunk in 0..2 {
162                    let chunk_offset = chunk * 16;
163                    let src_row = U8x16::from_array(
164                        src[src_offset + chunk_offset..src_offset + chunk_offset + 16]
165                            .try_into()
166                            .unwrap_or([0; 16]),
167                    );
168                    let ref_row = U8x16::from_array(
169                        ref_block[ref_offset + chunk_offset..ref_offset + chunk_offset + 16]
170                            .try_into()
171                            .unwrap_or([0; 16]),
172                    );
173                    sum += self.simd.sad_u8x16(src_row, ref_row);
174                }
175            }
176        }
177
178        sum
179    }
180
181    /// Calculate SAD for an arbitrary block size.
182    ///
183    /// Less efficient than size-specific functions but more flexible.
184    #[allow(dead_code)]
185    pub fn sad_nxn(
186        &self,
187        src: &[u8],
188        src_stride: usize,
189        ref_block: &[u8],
190        ref_stride: usize,
191        width: usize,
192        height: usize,
193    ) -> u32 {
194        let mut sum = 0u32;
195
196        for row in 0..height {
197            let src_offset = row * src_stride;
198            let ref_offset = row * ref_stride;
199
200            if src_offset + width <= src.len() && ref_offset + width <= ref_block.len() {
201                // Process 16-byte chunks
202                let mut col = 0;
203                while col + 16 <= width {
204                    let src_chunk = U8x16::from_array(
205                        src[src_offset + col..src_offset + col + 16]
206                            .try_into()
207                            .unwrap_or([0; 16]),
208                    );
209                    let ref_chunk = U8x16::from_array(
210                        ref_block[ref_offset + col..ref_offset + col + 16]
211                            .try_into()
212                            .unwrap_or([0; 16]),
213                    );
214                    sum += self.simd.sad_u8x16(src_chunk, ref_chunk);
215                    col += 16;
216                }
217
218                // Process remaining bytes
219                while col < width {
220                    let diff =
221                        i32::from(src[src_offset + col]) - i32::from(ref_block[ref_offset + col]);
222                    sum += diff.unsigned_abs();
223                    col += 1;
224                }
225            }
226        }
227
228        sum
229    }
230
231    /// Calculate SATD (Sum of Absolute Transformed Differences) for 4x4 block.
232    ///
233    /// SATD applies a Hadamard transform before summing, providing a better
234    /// cost metric for rate-distortion optimization.
235    #[allow(dead_code)]
236    pub fn satd_4x4(
237        &self,
238        src: &[u8],
239        src_stride: usize,
240        ref_block: &[u8],
241        ref_stride: usize,
242    ) -> u32 {
243        // Calculate differences
244        let mut diff = [[0i16; 4]; 4];
245        for row in 0..4 {
246            let src_offset = row * src_stride;
247            let ref_offset = row * ref_stride;
248            for col in 0..4 {
249                if src_offset + col < src.len() && ref_offset + col < ref_block.len() {
250                    diff[row][col] =
251                        i16::from(src[src_offset + col]) - i16::from(ref_block[ref_offset + col]);
252                }
253            }
254        }
255
256        // Horizontal Hadamard
257        let mut tmp = [[0i16; 4]; 4];
258        for row in 0..4 {
259            let a = diff[row][0] + diff[row][1];
260            let b = diff[row][2] + diff[row][3];
261            let c = diff[row][0] - diff[row][1];
262            let d = diff[row][2] - diff[row][3];
263
264            tmp[row][0] = a + b;
265            tmp[row][1] = c + d;
266            tmp[row][2] = a - b;
267            tmp[row][3] = c - d;
268        }
269
270        // Vertical Hadamard
271        let mut result = [[0i16; 4]; 4];
272        for col in 0..4 {
273            let a = tmp[0][col] + tmp[1][col];
274            let b = tmp[2][col] + tmp[3][col];
275            let c = tmp[0][col] - tmp[1][col];
276            let d = tmp[2][col] - tmp[3][col];
277
278            result[0][col] = a + b;
279            result[1][col] = c + d;
280            result[2][col] = a - b;
281            result[3][col] = c - d;
282        }
283
284        // Sum absolute values
285        let mut sum = 0u32;
286        for row in 0..4 {
287            for col in 0..4 {
288                sum += u32::from(result[row][col].unsigned_abs());
289            }
290        }
291
292        // Normalize (divide by 2 as Hadamard doubles values)
293        (sum + 1) >> 1
294    }
295}
296
297/// Create a SAD operations instance with scalar fallback.
298#[inline]
299#[must_use]
300pub fn sad_ops() -> SadOps<ScalarFallback> {
301    SadOps::new(ScalarFallback::new())
302}
303
304/// Calculate SAD for multiple candidate positions (motion search).
305///
306/// Returns the index of the best (lowest SAD) position.
307#[allow(dead_code, clippy::cast_sign_loss)]
308#[must_use]
309pub fn find_best_match_4x4(
310    src: &[u8],
311    src_stride: usize,
312    ref_frame: &[u8],
313    ref_stride: usize,
314    candidates: &[(i32, i32)],
315    ref_width: usize,
316    ref_height: usize,
317) -> Option<(usize, u32)> {
318    let ops = sad_ops();
319    let mut best_idx = None;
320    let mut best_sad = u32::MAX;
321
322    for (idx, &(dx, dy)) in candidates.iter().enumerate() {
323        // Check bounds
324        if dx < 0 || dy < 0 {
325            continue;
326        }
327        let x = dx as usize;
328        let y = dy as usize;
329
330        if x + 4 > ref_width || y + 4 > ref_height {
331            continue;
332        }
333
334        let ref_offset = y * ref_stride + x;
335        if ref_offset + 3 * ref_stride + 4 > ref_frame.len() {
336            continue;
337        }
338
339        let sad = ops.sad_4x4(src, src_stride, &ref_frame[ref_offset..], ref_stride);
340
341        if sad < best_sad {
342            best_sad = sad;
343            best_idx = Some(idx);
344        }
345    }
346
347    best_idx.map(|idx| (idx, best_sad))
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353
354    #[test]
355    fn test_sad_4x4_identical() {
356        let ops = sad_ops();
357
358        let block = [
359            100u8, 110, 120, 130, 105, 115, 125, 135, 110, 120, 130, 140, 115, 125, 135, 145,
360        ];
361
362        let sad = ops.sad_4x4(&block, 4, &block, 4);
363        assert_eq!(sad, 0);
364    }
365
366    #[test]
367    fn test_sad_4x4_constant_diff() {
368        let ops = sad_ops();
369
370        let src = [100u8; 16];
371        let ref_block = [110u8; 16];
372
373        // Each pixel differs by 10, 16 pixels = 160
374        let sad = ops.sad_4x4(&src, 4, &ref_block, 4);
375        assert_eq!(sad, 160);
376    }
377
378    #[test]
379    fn test_sad_8x8_identical() {
380        let ops = sad_ops();
381
382        let block = [128u8; 64];
383        let sad = ops.sad_8x8(&block, 8, &block, 8);
384        assert_eq!(sad, 0);
385    }
386
387    #[test]
388    fn test_sad_8x8_constant_diff() {
389        let ops = sad_ops();
390
391        let src = [100u8; 64];
392        let ref_block = [105u8; 64];
393
394        // Each pixel differs by 5, 64 pixels = 320
395        let sad = ops.sad_8x8(&src, 8, &ref_block, 8);
396        assert_eq!(sad, 320);
397    }
398
399    #[test]
400    fn test_sad_16x16_identical() {
401        let ops = sad_ops();
402
403        let block = [128u8; 256];
404        let sad = ops.sad_16x16(&block, 16, &block, 16);
405        assert_eq!(sad, 0);
406    }
407
408    #[test]
409    fn test_sad_16x16_constant_diff() {
410        let ops = sad_ops();
411
412        let src = [100u8; 256];
413        let ref_block = [102u8; 256];
414
415        // Each pixel differs by 2, 256 pixels = 512
416        let sad = ops.sad_16x16(&src, 16, &ref_block, 16);
417        assert_eq!(sad, 512);
418    }
419
420    #[test]
421    fn test_sad_32x32_identical() {
422        let ops = sad_ops();
423
424        let block = [128u8; 1024];
425        let sad = ops.sad_32x32(&block, 32, &block, 32);
426        assert_eq!(sad, 0);
427    }
428
429    #[test]
430    fn test_sad_with_stride() {
431        let ops = sad_ops();
432
433        // Create a larger buffer with stride > block width
434        let stride = 8;
435        let mut src = [0u8; 32]; // 4 rows * 8 stride
436        let mut ref_block = [0u8; 32];
437
438        for row in 0..4 {
439            for col in 0..4 {
440                src[row * stride + col] = 100;
441                ref_block[row * stride + col] = 110;
442            }
443        }
444
445        let sad = ops.sad_4x4(&src, stride, &ref_block, stride);
446        assert_eq!(sad, 160); // 16 pixels * 10 diff
447    }
448
449    #[test]
450    fn test_satd_4x4_identical() {
451        let ops = sad_ops();
452
453        let block = [128u8; 16];
454        let satd = ops.satd_4x4(&block, 4, &block, 4);
455        assert_eq!(satd, 0);
456    }
457
458    #[test]
459    fn test_satd_4x4_constant_diff() {
460        let ops = sad_ops();
461
462        let src = [100u8; 16];
463        let ref_block = [110u8; 16];
464
465        // SATD of constant difference is special case
466        let satd = ops.satd_4x4(&src, 4, &ref_block, 4);
467        // After Hadamard, DC coefficient captures all energy
468        // Result should be 16 * 10 / 2 = 80 (approximately)
469        assert!(satd > 0);
470    }
471
472    #[test]
473    fn test_find_best_match() {
474        let src = [100u8; 16];
475
476        // Create a reference frame with the matching block at (4, 4)
477        let mut ref_frame = [50u8; 256]; // 16x16
478        for row in 0..4 {
479            for col in 0..4 {
480                ref_frame[(row + 4) * 16 + col + 4] = 100;
481            }
482        }
483
484        let candidates = vec![
485            (0, 0),
486            (4, 0),
487            (0, 4),
488            (4, 4), // This should be the best match
489            (8, 8),
490        ];
491
492        let result = find_best_match_4x4(&src, 4, &ref_frame, 16, &candidates, 16, 16);
493        assert!(result.is_some());
494        let (idx, sad) = result.expect("should succeed");
495        assert_eq!(idx, 3); // (4, 4) is at index 3
496        assert_eq!(sad, 0); // Perfect match
497    }
498
499    #[test]
500    fn test_sad_nxn() {
501        let ops = sad_ops();
502
503        // Test 12x12 block (non-power-of-2)
504        let src = [100u8; 144]; // 12x12
505        let ref_block = [103u8; 144];
506
507        let sad = ops.sad_nxn(&src, 12, &ref_block, 12, 12, 12);
508        // 144 pixels * 3 diff = 432
509        assert_eq!(sad, 432);
510    }
511}