Skip to main content

oximedia_codec/motion/
subpel.rs

1//! Sub-pixel motion estimation refinement.
2//!
3//! This module provides:
4//! - Half-pel and quarter-pel interpolation filters
5//! - Sub-pixel refinement search
6//! - SATD (Sum of Absolute Transformed Differences) computation
7//! - Hadamard transform for SATD
8//!
9//! Sub-pixel motion estimation significantly improves prediction quality
10//! at the cost of additional computation for interpolation.
11
12#![forbid(unsafe_code)]
13#![allow(dead_code)]
14#![allow(clippy::too_many_arguments)]
15#![allow(clippy::cast_possible_truncation)]
16#![allow(clippy::cast_sign_loss)]
17#![allow(clippy::cast_possible_wrap)]
18#![allow(clippy::needless_range_loop)]
19#![allow(clippy::similar_names)]
20#![allow(clippy::must_use_candidate)]
21#![allow(clippy::items_after_statements)]
22#![allow(clippy::bool_to_int_with_if)]
23#![allow(clippy::unnecessary_cast)]
24#![allow(clippy::let_and_return)]
25#![allow(clippy::redundant_closure_for_method_calls)]
26#![allow(clippy::trivially_copy_pass_by_ref)]
27#![allow(clippy::unused_self)]
28
29use super::types::{BlockMatch, BlockSize, MotionVector, MvCost, MvPrecision};
30
31/// 6-tap filter coefficients for half-pel interpolation.
32/// Used in H.264/AVC and similar codecs.
33pub const HALF_PEL_FILTER_6TAP: [i16; 6] = [1, -5, 20, 20, -5, 1];
34
35/// 8-tap filter coefficients for quarter-pel interpolation.
36pub const QUARTER_PEL_FILTER_8TAP: [i16; 8] = [-1, 4, -10, 58, 17, -5, 1, 0];
37
38/// Bilinear filter for simple half-pel.
39pub const BILINEAR_HALF: [i16; 2] = [1, 1];
40
41/// Configuration for sub-pixel refinement.
42#[derive(Clone, Debug)]
43pub struct SubpelConfig {
44    /// Target precision.
45    pub precision: MvPrecision,
46    /// Use SATD instead of SAD.
47    pub use_satd: bool,
48    /// MV cost for RD optimization.
49    pub mv_cost: MvCost,
50    /// Filter type for half-pel.
51    pub half_pel_filter: HalfPelFilter,
52    /// Filter type for quarter-pel.
53    pub quarter_pel_filter: QuarterPelFilter,
54}
55
56impl Default for SubpelConfig {
57    fn default() -> Self {
58        Self {
59            precision: MvPrecision::QuarterPel,
60            use_satd: true,
61            mv_cost: MvCost::default(),
62            half_pel_filter: HalfPelFilter::Sixtap,
63            quarter_pel_filter: QuarterPelFilter::Bilinear,
64        }
65    }
66}
67
68/// Half-pel interpolation filter type.
69#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
70pub enum HalfPelFilter {
71    /// Bilinear (2-tap) filter.
72    Bilinear,
73    /// 6-tap filter (H.264 style).
74    #[default]
75    Sixtap,
76}
77
78/// Quarter-pel interpolation filter type.
79#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
80pub enum QuarterPelFilter {
81    /// Bilinear interpolation.
82    #[default]
83    Bilinear,
84    /// 8-tap filter.
85    Eighttap,
86}
87
88/// Half-pel interpolation at a single position.
89#[derive(Clone, Copy, Debug, Default)]
90pub struct HalfPelInterpolator;
91
92impl HalfPelInterpolator {
93    /// Creates a new interpolator.
94    #[must_use]
95    pub const fn new() -> Self {
96        Self
97    }
98
99    /// Interpolates a half-pel position using bilinear filter.
100    #[must_use]
101    pub fn bilinear(a: u8, b: u8) -> u8 {
102        (u16::from(a) + u16::from(b)).div_ceil(2) as u8
103    }
104
105    /// Interpolates a half-pel position using 6-tap filter.
106    ///
107    /// Input: 6 samples centered on the interpolation position.
108    #[must_use]
109    pub fn sixtap(samples: &[u8; 6]) -> u8 {
110        let mut sum: i32 = 0;
111        for (i, &coef) in HALF_PEL_FILTER_6TAP.iter().enumerate() {
112            sum += i32::from(coef) * i32::from(samples[i]);
113        }
114        // Round and normalize (divide by 32)
115        ((sum + 16) >> 5).clamp(0, 255) as u8
116    }
117
118    /// Interpolates horizontal half-pel for a row.
119    pub fn interpolate_h(src: &[u8], stride: usize, dst: &mut [u8], width: usize, height: usize) {
120        for y in 0..height {
121            let row_offset = y * stride;
122            for x in 0..width {
123                let src_x = row_offset + x;
124                if src_x + 1 < src.len() {
125                    dst[y * width + x] = Self::bilinear(src[src_x], src[src_x + 1]);
126                }
127            }
128        }
129    }
130
131    /// Interpolates vertical half-pel for a column.
132    pub fn interpolate_v(src: &[u8], stride: usize, dst: &mut [u8], width: usize, height: usize) {
133        for y in 0..height {
134            for x in 0..width {
135                let src_idx = y * stride + x;
136                let src_idx_next = (y + 1) * stride + x;
137                if src_idx_next < src.len() {
138                    dst[y * width + x] = Self::bilinear(src[src_idx], src[src_idx_next]);
139                }
140            }
141        }
142    }
143
144    /// Interpolates diagonal (HV) half-pel.
145    pub fn interpolate_hv(src: &[u8], stride: usize, dst: &mut [u8], width: usize, height: usize) {
146        for y in 0..height {
147            for x in 0..width {
148                let p00 = src.get(y * stride + x).copied().unwrap_or(0);
149                let p01 = src.get(y * stride + x + 1).copied().unwrap_or(0);
150                let p10 = src.get((y + 1) * stride + x).copied().unwrap_or(0);
151                let p11 = src.get((y + 1) * stride + x + 1).copied().unwrap_or(0);
152
153                let sum = u16::from(p00) + u16::from(p01) + u16::from(p10) + u16::from(p11);
154                dst[y * width + x] = ((sum + 2) / 4) as u8;
155            }
156        }
157    }
158}
159
160/// Quarter-pel interpolation.
161#[derive(Clone, Copy, Debug, Default)]
162pub struct QuarterPelInterpolator;
163
164impl QuarterPelInterpolator {
165    /// Creates a new interpolator.
166    #[must_use]
167    pub const fn new() -> Self {
168        Self
169    }
170
171    /// Bilinear interpolation for quarter-pel.
172    #[must_use]
173    pub fn bilinear(a: u8, b: u8, weight_a: u8, weight_b: u8) -> u8 {
174        let wa = u16::from(weight_a);
175        let wb = u16::from(weight_b);
176        let total = wa + wb;
177        ((u16::from(a) * wa + u16::from(b) * wb + total / 2) / total) as u8
178    }
179
180    /// Interpolates quarter-pel from full-pel and half-pel samples.
181    #[must_use]
182    pub fn interpolate_qpel(full: u8, half: u8) -> u8 {
183        Self::bilinear(full, half, 1, 1)
184    }
185}
186
187/// Hadamard transform for SATD computation.
188#[derive(Clone, Copy, Debug, Default)]
189pub struct HadamardTransform;
190
191impl HadamardTransform {
192    /// 4x4 Hadamard transform (in-place).
193    pub fn hadamard_4x4(block: &mut [[i16; 4]; 4]) {
194        // Horizontal transform
195        for row in block.iter_mut() {
196            let a = row[0] + row[1];
197            let b = row[2] + row[3];
198            let c = row[0] - row[1];
199            let d = row[2] - row[3];
200
201            row[0] = a + b;
202            row[1] = c + d;
203            row[2] = a - b;
204            row[3] = c - d;
205        }
206
207        // Vertical transform
208        for col in 0..4 {
209            let a = block[0][col] + block[1][col];
210            let b = block[2][col] + block[3][col];
211            let c = block[0][col] - block[1][col];
212            let d = block[2][col] - block[3][col];
213
214            block[0][col] = a + b;
215            block[1][col] = c + d;
216            block[2][col] = a - b;
217            block[3][col] = c - d;
218        }
219    }
220
221    /// 8x8 Hadamard transform using two 4x4 transforms.
222    pub fn hadamard_8x8(block: &mut [[i16; 8]; 8]) {
223        // Process as four 4x4 blocks
224        let mut sub = [[0i16; 4]; 4];
225
226        // Top-left 4x4
227        for i in 0..4 {
228            for j in 0..4 {
229                sub[i][j] = block[i][j];
230            }
231        }
232        Self::hadamard_4x4(&mut sub);
233        for i in 0..4 {
234            for j in 0..4 {
235                block[i][j] = sub[i][j];
236            }
237        }
238
239        // Top-right 4x4
240        for i in 0..4 {
241            for j in 0..4 {
242                sub[i][j] = block[i][j + 4];
243            }
244        }
245        Self::hadamard_4x4(&mut sub);
246        for i in 0..4 {
247            for j in 0..4 {
248                block[i][j + 4] = sub[i][j];
249            }
250        }
251
252        // Bottom-left 4x4
253        for i in 0..4 {
254            for j in 0..4 {
255                sub[i][j] = block[i + 4][j];
256            }
257        }
258        Self::hadamard_4x4(&mut sub);
259        for i in 0..4 {
260            for j in 0..4 {
261                block[i + 4][j] = sub[i][j];
262            }
263        }
264
265        // Bottom-right 4x4
266        for i in 0..4 {
267            for j in 0..4 {
268                sub[i][j] = block[i + 4][j + 4];
269            }
270        }
271        Self::hadamard_4x4(&mut sub);
272        for i in 0..4 {
273            for j in 0..4 {
274                block[i + 4][j + 4] = sub[i][j];
275            }
276        }
277    }
278}
279
280/// SATD (Sum of Absolute Transformed Differences) calculator.
281#[derive(Clone, Copy, Debug, Default)]
282pub struct SatdCalculator;
283
284impl SatdCalculator {
285    /// Creates a new calculator.
286    #[must_use]
287    pub const fn new() -> Self {
288        Self
289    }
290
291    /// Calculates SATD for a 4x4 block.
292    #[must_use]
293    pub fn satd_4x4(src: &[u8], src_stride: usize, ref_block: &[u8], ref_stride: usize) -> u32 {
294        // Calculate differences
295        let mut diff = [[0i16; 4]; 4];
296        for row in 0..4 {
297            let src_offset = row * src_stride;
298            let ref_offset = row * ref_stride;
299            for col in 0..4 {
300                if src_offset + col < src.len() && ref_offset + col < ref_block.len() {
301                    diff[row][col] =
302                        i16::from(src[src_offset + col]) - i16::from(ref_block[ref_offset + col]);
303                }
304            }
305        }
306
307        // Apply Hadamard transform
308        HadamardTransform::hadamard_4x4(&mut diff);
309
310        // Sum absolute values
311        let mut sum = 0u32;
312        for row in &diff {
313            for &val in row {
314                sum += u32::from(val.unsigned_abs());
315            }
316        }
317
318        // Normalize (divide by 2 as Hadamard doubles values)
319        (sum + 1) >> 1
320    }
321
322    /// Calculates SATD for an 8x8 block.
323    #[must_use]
324    pub fn satd_8x8(src: &[u8], src_stride: usize, ref_block: &[u8], ref_stride: usize) -> u32 {
325        let mut total = 0u32;
326
327        // Process as four 4x4 blocks
328        for block_row in 0..2 {
329            for block_col in 0..2 {
330                let src_offset = block_row * 4 * src_stride + block_col * 4;
331                let ref_offset = block_row * 4 * ref_stride + block_col * 4;
332
333                if src_offset < src.len() && ref_offset < ref_block.len() {
334                    total += Self::satd_4x4(
335                        &src[src_offset..],
336                        src_stride,
337                        &ref_block[ref_offset..],
338                        ref_stride,
339                    );
340                }
341            }
342        }
343
344        total
345    }
346
347    /// Calculates SATD for a 16x16 block.
348    #[must_use]
349    pub fn satd_16x16(src: &[u8], src_stride: usize, ref_block: &[u8], ref_stride: usize) -> u32 {
350        let mut total = 0u32;
351
352        // Process as four 8x8 blocks
353        for block_row in 0..2 {
354            for block_col in 0..2 {
355                let src_offset = block_row * 8 * src_stride + block_col * 8;
356                let ref_offset = block_row * 8 * ref_stride + block_col * 8;
357
358                if src_offset < src.len() && ref_offset < ref_block.len() {
359                    total += Self::satd_8x8(
360                        &src[src_offset..],
361                        src_stride,
362                        &ref_block[ref_offset..],
363                        ref_stride,
364                    );
365                }
366            }
367        }
368
369        total
370    }
371
372    /// Calculates SATD for arbitrary block size.
373    #[must_use]
374    pub fn satd(
375        src: &[u8],
376        src_stride: usize,
377        ref_block: &[u8],
378        ref_stride: usize,
379        block_size: BlockSize,
380    ) -> u32 {
381        match block_size {
382            BlockSize::Block4x4 => Self::satd_4x4(src, src_stride, ref_block, ref_stride),
383            BlockSize::Block8x8 => Self::satd_8x8(src, src_stride, ref_block, ref_stride),
384            BlockSize::Block16x16 => Self::satd_16x16(src, src_stride, ref_block, ref_stride),
385            _ => {
386                // For other sizes, use 4x4 SATD blocks
387                let width = block_size.width();
388                let height = block_size.height();
389                let mut total = 0u32;
390
391                for by in (0..height).step_by(4) {
392                    for bx in (0..width).step_by(4) {
393                        let src_offset = by * src_stride + bx;
394                        let ref_offset = by * ref_stride + bx;
395
396                        if src_offset < src.len() && ref_offset < ref_block.len() {
397                            total += Self::satd_4x4(
398                                &src[src_offset..],
399                                src_stride,
400                                &ref_block[ref_offset..],
401                                ref_stride,
402                            );
403                        }
404                    }
405                }
406
407                total
408            }
409        }
410    }
411}
412
413/// Sub-pixel refinement search.
414#[derive(Clone, Debug)]
415pub struct SubpelRefiner {
416    /// Configuration.
417    config: SubpelConfig,
418    /// Interpolation buffer.
419    interp_buffer: Vec<u8>,
420    /// SATD calculator.
421    satd: SatdCalculator,
422}
423
424impl Default for SubpelRefiner {
425    fn default() -> Self {
426        Self::new()
427    }
428}
429
430impl SubpelRefiner {
431    /// Half-pel search pattern offsets (in 1/8 pel units).
432    const HALF_PEL_PATTERN: [(i32, i32); 8] = [
433        (0, -4),  // Top
434        (-4, 0),  // Left
435        (4, 0),   // Right
436        (0, 4),   // Bottom
437        (-4, -4), // Top-left
438        (4, -4),  // Top-right
439        (-4, 4),  // Bottom-left
440        (4, 4),   // Bottom-right
441    ];
442
443    /// Quarter-pel search pattern offsets (in 1/8 pel units).
444    const QUARTER_PEL_PATTERN: [(i32, i32); 8] = [
445        (0, -2),  // Top
446        (-2, 0),  // Left
447        (2, 0),   // Right
448        (0, 2),   // Bottom
449        (-2, -2), // Top-left
450        (2, -2),  // Top-right
451        (-2, 2),  // Bottom-left
452        (2, 2),   // Bottom-right
453    ];
454
455    /// Creates a new sub-pixel refiner.
456    #[must_use]
457    pub fn new() -> Self {
458        Self {
459            config: SubpelConfig::default(),
460            interp_buffer: vec![0u8; 256 * 256],
461            satd: SatdCalculator::new(),
462        }
463    }
464
465    /// Sets the configuration.
466    #[must_use]
467    pub fn with_config(mut self, config: SubpelConfig) -> Self {
468        self.config = config;
469        self
470    }
471
472    /// Refines a full-pel motion vector to sub-pixel precision.
473    pub fn refine(
474        &mut self,
475        src: &[u8],
476        src_stride: usize,
477        reference: &[u8],
478        ref_stride: usize,
479        block_size: BlockSize,
480        ref_x: usize,
481        ref_y: usize,
482        ref_width: usize,
483        ref_height: usize,
484        mv: MotionVector,
485    ) -> BlockMatch {
486        let mut best_mv = mv;
487        let mut best_cost = self.calculate_cost(
488            src, src_stride, reference, ref_stride, block_size, ref_x, ref_y, ref_width,
489            ref_height, &mv,
490        );
491
492        // Half-pel refinement
493        if self.config.precision as u8 >= MvPrecision::HalfPel as u8 {
494            let (new_mv, new_cost) = self.search_subpel(
495                src,
496                src_stride,
497                reference,
498                ref_stride,
499                block_size,
500                ref_x,
501                ref_y,
502                ref_width,
503                ref_height,
504                best_mv,
505                &Self::HALF_PEL_PATTERN,
506            );
507            if new_cost < best_cost {
508                best_mv = new_mv;
509                best_cost = new_cost;
510            }
511        }
512
513        // Quarter-pel refinement
514        if self.config.precision as u8 >= MvPrecision::QuarterPel as u8 {
515            let (new_mv, new_cost) = self.search_subpel(
516                src,
517                src_stride,
518                reference,
519                ref_stride,
520                block_size,
521                ref_x,
522                ref_y,
523                ref_width,
524                ref_height,
525                best_mv,
526                &Self::QUARTER_PEL_PATTERN,
527            );
528            if new_cost < best_cost {
529                best_mv = new_mv;
530                best_cost = new_cost;
531            }
532        }
533
534        // Calculate final SAD for the best MV
535        let sad = self.calculate_distortion(
536            src, src_stride, reference, ref_stride, block_size, ref_x, ref_y, ref_width,
537            ref_height, &best_mv,
538        );
539
540        BlockMatch::new(best_mv, sad, best_cost)
541    }
542
543    /// Searches sub-pixel positions around a center.
544    fn search_subpel(
545        &mut self,
546        src: &[u8],
547        src_stride: usize,
548        reference: &[u8],
549        ref_stride: usize,
550        block_size: BlockSize,
551        ref_x: usize,
552        ref_y: usize,
553        ref_width: usize,
554        ref_height: usize,
555        center: MotionVector,
556        pattern: &[(i32, i32)],
557    ) -> (MotionVector, u32) {
558        let mut best_mv = center;
559        let mut best_cost = self.calculate_cost(
560            src, src_stride, reference, ref_stride, block_size, ref_x, ref_y, ref_width,
561            ref_height, &center,
562        );
563
564        for &(dx, dy) in pattern {
565            let candidate = MotionVector::new(center.dx + dx, center.dy + dy);
566
567            let cost = self.calculate_cost(
568                src, src_stride, reference, ref_stride, block_size, ref_x, ref_y, ref_width,
569                ref_height, &candidate,
570            );
571
572            if cost < best_cost {
573                best_mv = candidate;
574                best_cost = cost;
575            }
576        }
577
578        (best_mv, best_cost)
579    }
580
581    /// Calculates RD cost for a motion vector.
582    fn calculate_cost(
583        &mut self,
584        src: &[u8],
585        src_stride: usize,
586        reference: &[u8],
587        ref_stride: usize,
588        block_size: BlockSize,
589        ref_x: usize,
590        ref_y: usize,
591        ref_width: usize,
592        ref_height: usize,
593        mv: &MotionVector,
594    ) -> u32 {
595        let distortion = self.calculate_distortion(
596            src, src_stride, reference, ref_stride, block_size, ref_x, ref_y, ref_width,
597            ref_height, mv,
598        );
599
600        self.config.mv_cost.rd_cost(mv, distortion)
601    }
602
603    /// Calculates distortion (SAD or SATD) for a motion vector.
604    #[allow(clippy::too_many_arguments)]
605    fn calculate_distortion(
606        &mut self,
607        src: &[u8],
608        src_stride: usize,
609        reference: &[u8],
610        ref_stride: usize,
611        block_size: BlockSize,
612        ref_x: usize,
613        ref_y: usize,
614        ref_width: usize,
615        ref_height: usize,
616        mv: &MotionVector,
617    ) -> u32 {
618        let width = block_size.width();
619        let height = block_size.height();
620
621        // Get interpolated reference block
622        if !self.interpolate_block(
623            reference, ref_stride, ref_x, ref_y, ref_width, ref_height, mv, width, height,
624        ) {
625            return u32::MAX;
626        }
627
628        // Calculate distortion
629        if self.config.use_satd {
630            SatdCalculator::satd(src, src_stride, &self.interp_buffer, width, block_size)
631        } else {
632            self.calculate_sad(src, src_stride, width, height)
633        }
634    }
635
636    /// Interpolates a block at sub-pixel position.
637    #[allow(clippy::too_many_arguments)]
638    fn interpolate_block(
639        &mut self,
640        reference: &[u8],
641        ref_stride: usize,
642        ref_x: usize,
643        ref_y: usize,
644        ref_width: usize,
645        ref_height: usize,
646        mv: &MotionVector,
647        width: usize,
648        height: usize,
649    ) -> bool {
650        let full_x = ref_x as i32 + mv.full_pel_x();
651        let full_y = ref_y as i32 + mv.full_pel_y();
652        let frac_x = mv.frac_x();
653        let frac_y = mv.frac_y();
654
655        // Bounds check
656        if full_x < 0 || full_y < 0 {
657            return false;
658        }
659        let full_x = full_x as usize;
660        let full_y = full_y as usize;
661
662        if full_x + width > ref_width || full_y + height > ref_height {
663            return false;
664        }
665
666        // No sub-pixel interpolation needed
667        if frac_x == 0 && frac_y == 0 {
668            for row in 0..height {
669                let src_offset = (full_y + row) * ref_stride + full_x;
670                let dst_offset = row * width;
671                if src_offset + width <= reference.len() {
672                    self.interp_buffer[dst_offset..dst_offset + width]
673                        .copy_from_slice(&reference[src_offset..src_offset + width]);
674                }
675            }
676            return true;
677        }
678
679        // Half-pel interpolation
680        let hx = usize::from(frac_x >= 4);
681        let hy = usize::from(frac_y >= 4);
682
683        for row in 0..height {
684            for col in 0..width {
685                let x0 = full_x + col;
686                let y0 = full_y + row;
687                let x1 = (x0 + hx).min(ref_width - 1);
688                let y1 = (y0 + hy).min(ref_height - 1);
689
690                let p00 = reference[y0 * ref_stride + x0];
691                let p01 = reference[y0 * ref_stride + x1];
692                let p10 = reference[y1 * ref_stride + x0];
693                let p11 = reference[y1 * ref_stride + x1];
694
695                // Bilinear interpolation weights
696                let wx = (frac_x & 3) as u16;
697                let wy = (frac_y & 3) as u16;
698                let wx_inv = 4 - wx;
699                let wy_inv = 4 - wy;
700
701                let val = (u16::from(p00) * wx_inv * wy_inv
702                    + u16::from(p01) * wx * wy_inv
703                    + u16::from(p10) * wx_inv * wy
704                    + u16::from(p11) * wx * wy
705                    + 8)
706                    / 16;
707
708                self.interp_buffer[row * width + col] = val as u8;
709            }
710        }
711
712        true
713    }
714
715    /// Calculates SAD from interpolation buffer.
716    fn calculate_sad(&self, src: &[u8], src_stride: usize, width: usize, height: usize) -> u32 {
717        let mut sad = 0u32;
718        for row in 0..height {
719            let src_offset = row * src_stride;
720            let ref_offset = row * width;
721            for col in 0..width {
722                if src_offset + col < src.len() {
723                    let diff = i32::from(src[src_offset + col])
724                        - i32::from(self.interp_buffer[ref_offset + col]);
725                    sad += diff.unsigned_abs();
726                }
727            }
728        }
729        sad
730    }
731}
732
733/// Sub-pixel search patterns.
734pub struct SubpelPatterns;
735
736impl SubpelPatterns {
737    /// Square pattern for exhaustive sub-pel search.
738    pub const SQUARE_9: [(i32, i32); 9] = [
739        (0, 0),
740        (-1, -1),
741        (0, -1),
742        (1, -1),
743        (-1, 0),
744        (1, 0),
745        (-1, 1),
746        (0, 1),
747        (1, 1),
748    ];
749
750    /// Diamond pattern for fast sub-pel search.
751    pub const DIAMOND_5: [(i32, i32); 5] = [(0, 0), (0, -1), (-1, 0), (1, 0), (0, 1)];
752
753    /// Extended pattern for thorough search.
754    pub const EXTENDED_25: [(i32, i32); 25] = [
755        (0, 0),
756        (-1, -1),
757        (0, -1),
758        (1, -1),
759        (-1, 0),
760        (1, 0),
761        (-1, 1),
762        (0, 1),
763        (1, 1),
764        (-2, -2),
765        (-1, -2),
766        (0, -2),
767        (1, -2),
768        (2, -2),
769        (-2, -1),
770        (2, -1),
771        (-2, 0),
772        (2, 0),
773        (-2, 1),
774        (2, 1),
775        (-2, 2),
776        (-1, 2),
777        (0, 2),
778        (1, 2),
779        (2, 2),
780    ];
781}
782
783#[cfg(test)]
784mod tests {
785    use super::*;
786
787    #[test]
788    fn test_half_pel_bilinear() {
789        assert_eq!(HalfPelInterpolator::bilinear(100, 200), 150);
790        assert_eq!(HalfPelInterpolator::bilinear(0, 255), 128);
791        assert_eq!(HalfPelInterpolator::bilinear(100, 100), 100);
792    }
793
794    #[test]
795    fn test_half_pel_sixtap() {
796        let samples = [100u8, 100, 100, 100, 100, 100];
797        let result = HalfPelInterpolator::sixtap(&samples);
798        // Constant input should give approximately the same output
799        assert!((result as i32 - 100).abs() < 5);
800    }
801
802    #[test]
803    fn test_quarter_pel_bilinear() {
804        assert_eq!(QuarterPelInterpolator::bilinear(100, 200, 3, 1), 125);
805        assert_eq!(QuarterPelInterpolator::bilinear(100, 200, 1, 3), 175);
806        assert_eq!(QuarterPelInterpolator::bilinear(100, 200, 1, 1), 150);
807    }
808
809    #[test]
810    fn test_hadamard_4x4() {
811        let mut block = [[0i16; 4]; 4];
812        block[0] = [1, 2, 3, 4];
813        block[1] = [5, 6, 7, 8];
814        block[2] = [9, 10, 11, 12];
815        block[3] = [13, 14, 15, 16];
816
817        HadamardTransform::hadamard_4x4(&mut block);
818
819        // DC coefficient should be sum of all values
820        assert_eq!(block[0][0], 136); // Sum of 1..16 = 136
821    }
822
823    #[test]
824    fn test_satd_4x4_identical() {
825        let block = vec![100u8; 16];
826        let satd = SatdCalculator::satd_4x4(&block, 4, &block, 4);
827        assert_eq!(satd, 0);
828    }
829
830    #[test]
831    fn test_satd_4x4_constant_diff() {
832        let src = vec![100u8; 16];
833        let ref_block = vec![110u8; 16];
834        let satd = SatdCalculator::satd_4x4(&src, 4, &ref_block, 4);
835        // SATD of constant difference is special case
836        assert!(satd > 0);
837    }
838
839    #[test]
840    fn test_satd_8x8() {
841        let src = vec![100u8; 64];
842        let ref_block = vec![100u8; 64];
843        let satd = SatdCalculator::satd_8x8(&src, 8, &ref_block, 8);
844        assert_eq!(satd, 0);
845    }
846
847    #[test]
848    fn test_satd_16x16() {
849        let src = vec![100u8; 256];
850        let ref_block = vec![100u8; 256];
851        let satd = SatdCalculator::satd_16x16(&src, 16, &ref_block, 16);
852        assert_eq!(satd, 0);
853    }
854
855    #[test]
856    fn test_subpel_refiner_creation() {
857        let refiner = SubpelRefiner::new();
858        assert_eq!(refiner.config.precision, MvPrecision::QuarterPel);
859        assert!(refiner.config.use_satd);
860    }
861
862    #[test]
863    fn test_subpel_config() {
864        let config = SubpelConfig {
865            precision: MvPrecision::HalfPel,
866            use_satd: false,
867            ..Default::default()
868        };
869        assert_eq!(config.precision, MvPrecision::HalfPel);
870        assert!(!config.use_satd);
871    }
872
873    #[test]
874    fn test_subpel_refiner_no_motion() {
875        let src = vec![100u8; 64];
876        let reference = vec![100u8; 256];
877
878        let mut refiner = SubpelRefiner::new();
879        let mv = MotionVector::zero();
880
881        let result = refiner.refine(
882            &src,
883            8,
884            &reference,
885            16,
886            BlockSize::Block8x8,
887            0,
888            0,
889            16,
890            16,
891            mv,
892        );
893
894        // Perfect match, no motion
895        assert_eq!(result.mv.dx, 0);
896        assert_eq!(result.mv.dy, 0);
897    }
898
899    #[test]
900    fn test_interpolation_full_pel() {
901        let mut refiner = SubpelRefiner::new();
902        let reference = vec![128u8; 256];
903        let mv = MotionVector::zero();
904
905        let success = refiner.interpolate_block(&reference, 16, 0, 0, 16, 16, &mv, 8, 8);
906
907        assert!(success);
908        // Full-pel position should copy exactly
909        assert_eq!(refiner.interp_buffer[0], 128);
910    }
911
912    #[test]
913    fn test_interpolation_half_pel() {
914        let mut refiner = SubpelRefiner::new();
915        let mut reference = vec![100u8; 256];
916        // Create gradient
917        for i in 0..256 {
918            reference[i] = (i % 256) as u8;
919        }
920
921        let mv = MotionVector::new(4, 0); // Half-pel in x
922
923        let success = refiner.interpolate_block(&reference, 16, 0, 0, 16, 16, &mv, 8, 8);
924
925        assert!(success);
926    }
927
928    #[test]
929    fn test_subpel_patterns() {
930        assert_eq!(SubpelPatterns::DIAMOND_5.len(), 5);
931        assert_eq!(SubpelPatterns::SQUARE_9.len(), 9);
932        assert_eq!(SubpelPatterns::EXTENDED_25.len(), 25);
933    }
934
935    #[test]
936    fn test_half_pel_interpolate_h() {
937        let src = vec![100u8, 200u8, 100u8, 200u8, 100u8, 200u8, 100u8, 200u8];
938        let mut dst = vec![0u8; 7];
939
940        HalfPelInterpolator::interpolate_h(&src, 8, &mut dst, 7, 1);
941
942        // Interpolated values should be between neighbors
943        assert_eq!(dst[0], 150);
944        assert_eq!(dst[1], 150);
945    }
946
947    #[test]
948    fn test_satd_block_size() {
949        let src = vec![100u8; 128];
950        let ref_block = vec![100u8; 128];
951
952        let satd = SatdCalculator::satd(&src, 8, &ref_block, 8, BlockSize::Block8x16);
953        assert_eq!(satd, 0);
954    }
955}