Skip to main content

oximedia_codec/simd/av1/
intra.rs

1//! AV1 intra prediction SIMD operations.
2//!
3//! Implements various intra prediction modes used in AV1 encoding/decoding.
4
5use crate::simd::traits::SimdOps;
6use crate::simd::types::U8x16;
7
8/// AV1 intra prediction modes.
9#[derive(Clone, Copy, Debug, PartialEq, Eq)]
10pub enum IntraMode {
11    /// DC prediction (average of neighbors).
12    Dc,
13    /// Horizontal prediction.
14    Horizontal,
15    /// Vertical prediction.
16    Vertical,
17    /// Diagonal down-left prediction.
18    DiagonalDownLeft,
19    /// Diagonal down-right prediction.
20    DiagonalDownRight,
21    /// Vertical right prediction.
22    VerticalRight,
23    /// Horizontal down prediction.
24    HorizontalDown,
25    /// Vertical left prediction.
26    VerticalLeft,
27    /// Horizontal up prediction.
28    HorizontalUp,
29    /// True motion (paeth) prediction.
30    Paeth,
31    /// Smooth prediction.
32    Smooth,
33    /// Smooth vertical prediction.
34    SmoothV,
35    /// Smooth horizontal prediction.
36    SmoothH,
37}
38
39/// AV1 intra prediction SIMD operations.
40pub struct IntraPredSimd<S> {
41    #[allow(dead_code)]
42    simd: S,
43}
44
45impl<S: SimdOps> IntraPredSimd<S> {
46    /// Create a new intra prediction SIMD instance.
47    #[inline]
48    pub const fn new(simd: S) -> Self {
49        Self { simd }
50    }
51
52    /// Perform intra prediction for a 4x4 block.
53    ///
54    /// # Arguments
55    /// * `mode` - Prediction mode
56    /// * `top` - Top reference pixels (4 pixels)
57    /// * `left` - Left reference pixels (4 pixels)
58    /// * `top_left` - Top-left corner pixel
59    /// * `dst` - Destination buffer
60    /// * `stride` - Destination stride
61    pub fn predict_4x4(
62        &self,
63        mode: IntraMode,
64        top: &[u8],
65        left: &[u8],
66        top_left: u8,
67        dst: &mut [u8],
68        stride: usize,
69    ) {
70        match mode {
71            IntraMode::Dc => self.predict_dc_4x4(top, left, dst, stride),
72            IntraMode::Horizontal => self.predict_h_4x4(left, dst, stride),
73            IntraMode::Vertical => self.predict_v_4x4(top, dst, stride),
74            IntraMode::Paeth => self.predict_paeth_4x4(top, left, top_left, dst, stride),
75            IntraMode::Smooth => self.predict_smooth_4x4(top, left, top_left, dst, stride),
76            IntraMode::SmoothV => self.predict_smooth_v_4x4(top, left, dst, stride),
77            IntraMode::SmoothH => self.predict_smooth_h_4x4(top, left, dst, stride),
78            _ => self.predict_dc_4x4(top, left, dst, stride), // Default to DC
79        }
80    }
81
82    /// Perform intra prediction for an 8x8 block.
83    pub fn predict_8x8(
84        &self,
85        mode: IntraMode,
86        top: &[u8],
87        left: &[u8],
88        top_left: u8,
89        dst: &mut [u8],
90        stride: usize,
91    ) {
92        match mode {
93            IntraMode::Dc => self.predict_dc_8x8(top, left, dst, stride),
94            IntraMode::Horizontal => self.predict_h_8x8(left, dst, stride),
95            IntraMode::Vertical => self.predict_v_8x8(top, dst, stride),
96            IntraMode::Paeth => self.predict_paeth_8x8(top, left, top_left, dst, stride),
97            IntraMode::Smooth => self.predict_smooth_8x8(top, left, top_left, dst, stride),
98            _ => self.predict_dc_8x8(top, left, dst, stride),
99        }
100    }
101
102    // ========================================================================
103    // 4x4 Prediction Modes
104    // ========================================================================
105
106    /// DC prediction for 4x4 block.
107    fn predict_dc_4x4(&self, top: &[u8], left: &[u8], dst: &mut [u8], stride: usize) {
108        // Calculate DC value as average of top and left
109        let mut sum = 0u32;
110        for i in 0..4 {
111            if i < top.len() {
112                sum += u32::from(top[i]);
113            }
114            if i < left.len() {
115                sum += u32::from(left[i]);
116            }
117        }
118        let dc = ((sum + 4) / 8) as u8;
119
120        // Fill block with DC value
121        for y in 0..4 {
122            let offset = y * stride;
123            if dst.len() >= offset + 4 {
124                for x in 0..4 {
125                    dst[offset + x] = dc;
126                }
127            }
128        }
129    }
130
131    /// Horizontal prediction for 4x4 block.
132    fn predict_h_4x4(&self, left: &[u8], dst: &mut [u8], stride: usize) {
133        for y in 0..4 {
134            let offset = y * stride;
135            if dst.len() >= offset + 4 && y < left.len() {
136                let pixel = left[y];
137                for x in 0..4 {
138                    dst[offset + x] = pixel;
139                }
140            }
141        }
142    }
143
144    /// Vertical prediction for 4x4 block.
145    fn predict_v_4x4(&self, top: &[u8], dst: &mut [u8], stride: usize) {
146        if top.len() < 4 {
147            return;
148        }
149
150        for y in 0..4 {
151            let offset = y * stride;
152            if dst.len() >= offset + 4 {
153                dst[offset..offset + 4].copy_from_slice(&top[..4]);
154            }
155        }
156    }
157
158    /// Paeth (gradient) prediction for 4x4 block.
159    fn predict_paeth_4x4(
160        &self,
161        top: &[u8],
162        left: &[u8],
163        top_left: u8,
164        dst: &mut [u8],
165        stride: usize,
166    ) {
167        for y in 0..4 {
168            for x in 0..4 {
169                let offset = y * stride + x;
170                if offset >= dst.len() || y >= left.len() || x >= top.len() {
171                    continue;
172                }
173
174                let t = top[x];
175                let l = left[y];
176                let tl = top_left;
177
178                dst[offset] = self.paeth_predictor(l, t, tl);
179            }
180        }
181    }
182
183    /// Smooth prediction for 4x4 block.
184    fn predict_smooth_4x4(
185        &self,
186        top: &[u8],
187        left: &[u8],
188        _top_left: u8,
189        dst: &mut [u8],
190        stride: usize,
191    ) {
192        // Smooth prediction blends horizontal and vertical predictions
193        for y in 0..4 {
194            for x in 0..4 {
195                let offset = y * stride + x;
196                if offset >= dst.len() || y >= left.len() || x >= top.len() {
197                    continue;
198                }
199
200                let h_weight = ((4 - x) * 64 / 4) as u32;
201                let v_weight = ((4 - y) * 64 / 4) as u32;
202                let h_pred = u32::from(left[y]) * h_weight;
203                let v_pred = u32::from(top[x]) * v_weight;
204
205                let pred = (h_pred + v_pred + 64) / 128;
206                dst[offset] = pred as u8;
207            }
208        }
209    }
210
211    /// Smooth vertical prediction for 4x4 block.
212    fn predict_smooth_v_4x4(&self, top: &[u8], left: &[u8], dst: &mut [u8], stride: usize) {
213        if top.len() < 4 || left.len() < 4 {
214            return;
215        }
216
217        let bottom = left[3]; // Bottom-most left pixel
218
219        for y in 0..4 {
220            let weight = ((4 - y) * 64 / 4) as u32;
221            for x in 0..4 {
222                let offset = y * stride + x;
223                if offset >= dst.len() || x >= top.len() {
224                    continue;
225                }
226
227                let pred =
228                    (u32::from(top[x]) * weight + u32::from(bottom) * (64 - weight) + 32) / 64;
229                dst[offset] = pred as u8;
230            }
231        }
232    }
233
234    /// Smooth horizontal prediction for 4x4 block.
235    fn predict_smooth_h_4x4(&self, top: &[u8], left: &[u8], dst: &mut [u8], stride: usize) {
236        if top.len() < 4 || left.len() < 4 {
237            return;
238        }
239
240        let right = top[3]; // Right-most top pixel
241
242        for y in 0..4 {
243            for x in 0..4 {
244                let offset = y * stride + x;
245                if offset >= dst.len() || y >= left.len() {
246                    continue;
247                }
248
249                let weight = ((4 - x) * 64 / 4) as u32;
250                let pred =
251                    (u32::from(left[y]) * weight + u32::from(right) * (64 - weight) + 32) / 64;
252                dst[offset] = pred as u8;
253            }
254        }
255    }
256
257    // ========================================================================
258    // 8x8 Prediction Modes
259    // ========================================================================
260
261    /// DC prediction for 8x8 block (SIMD accelerated).
262    fn predict_dc_8x8(&self, top: &[u8], left: &[u8], dst: &mut [u8], stride: usize) {
263        // Calculate DC value
264        let mut sum = 0u32;
265        for i in 0..8 {
266            if i < top.len() {
267                sum += u32::from(top[i]);
268            }
269            if i < left.len() {
270                sum += u32::from(left[i]);
271            }
272        }
273        let dc = ((sum + 8) / 16) as u8;
274
275        // Fill block using SIMD
276        let dc_vec = U8x16::splat(dc);
277        let dc_array = dc_vec.to_array();
278        for y in 0..8 {
279            let offset = y * stride;
280            if dst.len() >= offset + 8 {
281                dst[offset..offset + 8].copy_from_slice(&dc_array[..8]);
282            }
283        }
284    }
285
286    /// Horizontal prediction for 8x8 block.
287    fn predict_h_8x8(&self, left: &[u8], dst: &mut [u8], stride: usize) {
288        for y in 0..8 {
289            let offset = y * stride;
290            if dst.len() >= offset + 8 && y < left.len() {
291                let pixel_vec = U8x16::splat(left[y]);
292                let pixel_array = pixel_vec.to_array();
293                dst[offset..offset + 8].copy_from_slice(&pixel_array[..8]);
294            }
295        }
296    }
297
298    /// Vertical prediction for 8x8 block.
299    fn predict_v_8x8(&self, top: &[u8], dst: &mut [u8], stride: usize) {
300        if top.len() < 8 {
301            return;
302        }
303
304        for y in 0..8 {
305            let offset = y * stride;
306            if dst.len() >= offset + 8 {
307                dst[offset..offset + 8].copy_from_slice(&top[..8]);
308            }
309        }
310    }
311
312    /// Paeth prediction for 8x8 block.
313    fn predict_paeth_8x8(
314        &self,
315        top: &[u8],
316        left: &[u8],
317        top_left: u8,
318        dst: &mut [u8],
319        stride: usize,
320    ) {
321        for y in 0..8 {
322            for x in 0..8 {
323                let offset = y * stride + x;
324                if offset >= dst.len() || y >= left.len() || x >= top.len() {
325                    continue;
326                }
327
328                let t = top[x];
329                let l = left[y];
330                let tl = top_left;
331
332                dst[offset] = self.paeth_predictor(l, t, tl);
333            }
334        }
335    }
336
337    /// Smooth prediction for 8x8 block.
338    fn predict_smooth_8x8(
339        &self,
340        top: &[u8],
341        left: &[u8],
342        _top_left: u8,
343        dst: &mut [u8],
344        stride: usize,
345    ) {
346        for y in 0..8 {
347            for x in 0..8 {
348                let offset = y * stride + x;
349                if offset >= dst.len() || y >= left.len() || x >= top.len() {
350                    continue;
351                }
352
353                let h_weight = ((8 - x) * 64 / 8) as u32;
354                let v_weight = ((8 - y) * 64 / 8) as u32;
355                let h_pred = u32::from(left[y]) * h_weight;
356                let v_pred = u32::from(top[x]) * v_weight;
357
358                let pred = (h_pred + v_pred + 64) / 128;
359                dst[offset] = pred as u8;
360            }
361        }
362    }
363
364    // ========================================================================
365    // Helper Functions
366    // ========================================================================
367
368    /// Paeth predictor (gradient prediction).
369    ///
370    /// Selects the neighbor (left, top, or top-left) that is closest
371    /// to the gradient prediction.
372    fn paeth_predictor(&self, left: u8, top: u8, top_left: u8) -> u8 {
373        let l = i32::from(left);
374        let t = i32::from(top);
375        let tl = i32::from(top_left);
376
377        let base = l + t - tl;
378        let dist_l = (base - l).abs();
379        let dist_t = (base - t).abs();
380        let dist_tl = (base - tl).abs();
381
382        if dist_l <= dist_t && dist_l <= dist_tl {
383            left
384        } else if dist_t <= dist_tl {
385            top
386        } else {
387            top_left
388        }
389    }
390}