Skip to main content

oximedia_codec/simd/av1/
motion_comp.rs

1//! AV1 motion compensation SIMD operations.
2//!
3//! Implements interpolation filters and motion compensation for AV1.
4
5use crate::simd::traits::SimdOps;
6use crate::simd::types::{I16x8, I32x4, U8x16};
7
8/// AV1 motion compensation SIMD operations.
9pub struct MotionCompSimd<S> {
10    simd: S,
11}
12
13impl<S: SimdOps> MotionCompSimd<S> {
14    /// Create a new motion compensation SIMD instance.
15    #[inline]
16    pub const fn new(simd: S) -> Self {
17        Self { simd }
18    }
19
20    /// Copy a block without interpolation (integer-pel motion).
21    pub fn copy_block(
22        &self,
23        src: &[u8],
24        src_stride: usize,
25        dst: &mut [u8],
26        dst_stride: usize,
27        width: usize,
28        height: usize,
29    ) {
30        for y in 0..height {
31            let src_offset = y * src_stride;
32            let dst_offset = y * dst_stride;
33
34            if src.len() >= src_offset + width && dst.len() >= dst_offset + width {
35                dst[dst_offset..dst_offset + width]
36                    .copy_from_slice(&src[src_offset..src_offset + width]);
37            }
38        }
39    }
40
41    /// Horizontal 8-tap interpolation filter.
42    ///
43    /// Applies an 8-tap filter horizontally for sub-pixel motion compensation.
44    pub fn filter_h_8tap(
45        &self,
46        src: &[u8],
47        src_stride: usize,
48        dst: &mut [u8],
49        dst_stride: usize,
50        coeffs: &[i16; 8],
51        width: usize,
52        height: usize,
53    ) {
54        for y in 0..height {
55            for x in 0..width {
56                let src_offset = y * src_stride + x;
57                let dst_offset = y * dst_stride + x;
58
59                if src.len() < src_offset + 8 || dst_offset >= dst.len() {
60                    continue;
61                }
62
63                // Load 8 source pixels
64                let mut pixels = I16x8::zero();
65                for i in 0..8 {
66                    if src_offset + i < src.len() {
67                        pixels[i] = i16::from(src[src_offset + i]);
68                    }
69                }
70
71                // Load filter coefficients
72                let filter = I16x8::from_array(*coeffs);
73
74                // Multiply and accumulate
75                let products = self.simd.mul_i16x8(pixels, filter);
76                let sum = self.simd.horizontal_sum_i16x8(products);
77
78                // Round and shift
79                let result = (sum + 64) >> 7;
80                dst[dst_offset] = result.clamp(0, 255) as u8;
81            }
82        }
83    }
84
85    /// Vertical 8-tap interpolation filter.
86    pub fn filter_v_8tap(
87        &self,
88        src: &[u8],
89        src_stride: usize,
90        dst: &mut [u8],
91        dst_stride: usize,
92        coeffs: &[i16; 8],
93        width: usize,
94        height: usize,
95    ) {
96        for y in 0..height {
97            for x in 0..width {
98                let dst_offset = y * dst_stride + x;
99
100                if dst_offset >= dst.len() {
101                    continue;
102                }
103
104                // Load 8 vertical pixels
105                let mut pixels = I16x8::zero();
106                for i in 0..8 {
107                    let src_offset = (y + i) * src_stride + x;
108                    if src_offset < src.len() {
109                        pixels[i] = i16::from(src[src_offset]);
110                    }
111                }
112
113                // Load filter coefficients
114                let filter = I16x8::from_array(*coeffs);
115
116                // Multiply and accumulate
117                let products = self.simd.mul_i16x8(pixels, filter);
118                let sum = self.simd.horizontal_sum_i16x8(products);
119
120                // Round and shift
121                let result = (sum + 64) >> 7;
122                dst[dst_offset] = result.clamp(0, 255) as u8;
123            }
124        }
125    }
126
127    /// 2D 8-tap interpolation (both horizontal and vertical).
128    #[allow(clippy::too_many_arguments)]
129    pub fn filter_2d_8tap(
130        &self,
131        src: &[u8],
132        src_stride: usize,
133        dst: &mut [u8],
134        dst_stride: usize,
135        h_coeffs: &[i16; 8],
136        v_coeffs: &[i16; 8],
137        width: usize,
138        height: usize,
139    ) {
140        // Allocate temporary buffer for horizontal filtering
141        let temp_size = (height + 7) * width;
142        let mut temp = vec![0i16; temp_size];
143
144        // Horizontal filtering to temp buffer
145        for y in 0..height + 7 {
146            for x in 0..width {
147                let src_offset = y * src_stride + x;
148                let temp_offset = y * width + x;
149
150                if temp_offset >= temp.len() {
151                    continue;
152                }
153
154                // Load 8 horizontal pixels
155                let mut pixels = I16x8::zero();
156                for i in 0..8 {
157                    if src_offset + i < src.len() {
158                        pixels[i] = i16::from(src[src_offset + i]);
159                    }
160                }
161
162                // Apply horizontal filter
163                let filter = I16x8::from_array(*h_coeffs);
164                let products = self.simd.mul_i16x8(pixels, filter);
165                let sum = self.simd.horizontal_sum_i16x8(products);
166
167                temp[temp_offset] = ((sum + 64) >> 7) as i16;
168            }
169        }
170
171        // Vertical filtering from temp to dst
172        for y in 0..height {
173            for x in 0..width {
174                let dst_offset = y * dst_stride + x;
175
176                if dst_offset >= dst.len() {
177                    continue;
178                }
179
180                // Load 8 vertical pixels from temp
181                let mut pixels = I16x8::zero();
182                for i in 0..8 {
183                    let temp_offset = (y + i) * width + x;
184                    if temp_offset < temp.len() {
185                        pixels[i] = temp[temp_offset];
186                    }
187                }
188
189                // Apply vertical filter
190                let filter = I16x8::from_array(*v_coeffs);
191                let products = self.simd.mul_i16x8(pixels, filter);
192                let sum = self.simd.horizontal_sum_i16x8(products);
193
194                // Round and shift
195                let result = (sum + 64) >> 7;
196                dst[dst_offset] = result.clamp(0, 255) as u8;
197            }
198        }
199    }
200
201    /// Bilinear interpolation (simple 2-tap filter).
202    pub fn bilinear_h(
203        &self,
204        src: &[u8],
205        src_stride: usize,
206        dst: &mut [u8],
207        dst_stride: usize,
208        fraction: u8,
209        width: usize,
210        height: usize,
211    ) {
212        let w1 = fraction;
213        let w0 = 64 - w1;
214
215        for y in 0..height {
216            for x in 0..width {
217                let src_offset = y * src_stride + x;
218                let dst_offset = y * dst_stride + x;
219
220                if src_offset + 1 >= src.len() || dst_offset >= dst.len() {
221                    continue;
222                }
223
224                let p0 = u32::from(src[src_offset]);
225                let p1 = u32::from(src[src_offset + 1]);
226
227                let result = (p0 * u32::from(w0) + p1 * u32::from(w1) + 32) / 64;
228                dst[dst_offset] = result as u8;
229            }
230        }
231    }
232
233    /// Bilinear vertical interpolation.
234    pub fn bilinear_v(
235        &self,
236        src: &[u8],
237        src_stride: usize,
238        dst: &mut [u8],
239        dst_stride: usize,
240        fraction: u8,
241        width: usize,
242        height: usize,
243    ) {
244        let w1 = fraction;
245        let w0 = 64 - w1;
246
247        for y in 0..height {
248            for x in 0..width {
249                let src_offset = y * src_stride + x;
250                let dst_offset = y * dst_stride + x;
251
252                if src_offset + src_stride >= src.len() || dst_offset >= dst.len() {
253                    continue;
254                }
255
256                let p0 = u32::from(src[src_offset]);
257                let p1 = u32::from(src[src_offset + src_stride]);
258
259                let result = (p0 * u32::from(w0) + p1 * u32::from(w1) + 32) / 64;
260                dst[dst_offset] = result as u8;
261            }
262        }
263    }
264
265    /// Average two blocks for bi-directional prediction.
266    pub fn average_blocks(
267        &self,
268        src1: &[u8],
269        src2: &[u8],
270        dst: &mut [u8],
271        width: usize,
272        height: usize,
273        stride: usize,
274    ) {
275        for y in 0..height {
276            let offset = y * stride;
277
278            // Process 16 pixels at a time using SIMD
279            let chunks = width / 16;
280            for i in 0..chunks {
281                let pos = offset + i * 16;
282
283                if src1.len() < pos + 16 || src2.len() < pos + 16 || dst.len() < pos + 16 {
284                    continue;
285                }
286
287                let mut v1 = U8x16::zero();
288                let mut v2 = U8x16::zero();
289                v1.copy_from_slice(&src1[pos..pos + 16]);
290                v2.copy_from_slice(&src2[pos..pos + 16]);
291
292                let avg = self.simd.avg_u8x16(v1, v2);
293                let avg_array = avg.to_array();
294                dst[pos..pos + 16].copy_from_slice(&avg_array);
295            }
296
297            // Handle remaining pixels
298            for x in (chunks * 16)..width {
299                let pos = offset + x;
300                if src1.len() > pos && src2.len() > pos && dst.len() > pos {
301                    dst[pos] = ((u16::from(src1[pos]) + u16::from(src2[pos]) + 1) / 2) as u8;
302                }
303            }
304        }
305    }
306
307    /// Weighted prediction (combine two blocks with weights).
308    #[allow(clippy::too_many_arguments)]
309    pub fn weighted_pred(
310        &self,
311        src1: &[u8],
312        src2: &[u8],
313        dst: &mut [u8],
314        weight1: u8,
315        weight2: u8,
316        width: usize,
317        height: usize,
318        stride: usize,
319    ) {
320        let total_weight = u32::from(weight1) + u32::from(weight2);
321
322        for y in 0..height {
323            for x in 0..width {
324                let offset = y * stride + x;
325
326                if src1.len() <= offset || src2.len() <= offset || dst.len() <= offset {
327                    continue;
328                }
329
330                let p1 = u32::from(src1[offset]) * u32::from(weight1);
331                let p2 = u32::from(src2[offset]) * u32::from(weight2);
332
333                let result = (p1 + p2 + total_weight / 2) / total_weight;
334                dst[offset] = result.clamp(0, 255) as u8;
335            }
336        }
337    }
338
339    /// OBMC (Overlapped Block Motion Compensation) blending.
340    #[allow(clippy::too_many_arguments)]
341    pub fn obmc_blend(
342        &self,
343        pred: &[u8],
344        obmc: &[u8],
345        dst: &mut [u8],
346        width: usize,
347        height: usize,
348        stride: usize,
349        weights: &[u8],
350    ) {
351        for y in 0..height {
352            for x in 0..width {
353                let offset = y * stride + x;
354                let weight_idx = (y * width + x).min(weights.len().saturating_sub(1));
355
356                if pred.len() <= offset || obmc.len() <= offset || dst.len() <= offset {
357                    continue;
358                }
359
360                let w = u32::from(weights[weight_idx]);
361                let p1 = u32::from(pred[offset]) * w;
362                let p2 = u32::from(obmc[offset]) * (64 - w);
363
364                let result = (p1 + p2 + 32) / 64;
365                dst[offset] = result as u8;
366            }
367        }
368    }
369
370    /// SIMD-optimized horizontal filtering for 4-pixel wide blocks.
371    #[allow(dead_code)]
372    fn filter_h_4_simd(&self, src: &[u8], coeffs: &[i16; 8]) -> [u8; 4] {
373        let mut pixels = I16x8::zero();
374        for i in 0..8.min(src.len()) {
375            pixels[i] = i16::from(src[i]);
376        }
377
378        let filter = I16x8::from_array(*coeffs);
379        let result = self.simd.pmaddwd(pixels, filter);
380
381        let sum = self.simd.horizontal_sum_i32x4(result);
382        let final_val = (sum + 64) >> 7;
383
384        [
385            final_val.clamp(0, 255) as u8,
386            final_val.clamp(0, 255) as u8,
387            final_val.clamp(0, 255) as u8,
388            final_val.clamp(0, 255) as u8,
389        ]
390    }
391}
392
393/// Standard AV1 8-tap interpolation filter coefficients.
394pub mod filter_coeffs {
395    /// Regular filter (smooth).
396    pub const REGULAR: [i16; 8] = [-1, 3, -7, 127, 8, -3, 1, 0];
397
398    /// Sharp filter (preserves edges).
399    pub const SHARP: [i16; 8] = [-1, 3, -8, 127, 8, -2, 1, 0];
400
401    /// Smooth filter (reduces high frequencies).
402    pub const SMOOTH: [i16; 8] = [-2, 6, -13, 120, 13, -6, 2, 0];
403
404    /// Bilinear filter.
405    pub const BILINEAR: [i16; 8] = [0, 0, 0, 128, 0, 0, 0, 0];
406}