rav1e/
predict.rs

1// Copyright (c) 2017-2022, The rav1e contributors. All rights reserved
2//
3// This source code is subject to the terms of the BSD 2 Clause License and
4// the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
5// was not distributed with this source code in the LICENSE file, you can
6// obtain it at www.aomedia.org/license/software. If the Alliance for Open
7// Media Patent License 1.0 was not distributed with this source code in the
8// PATENTS file, you can obtain it at www.aomedia.org/license/patent.
9
10#![allow(non_upper_case_globals)]
11#![allow(non_camel_case_types)]
12#![allow(dead_code)]
13
14use std::mem::MaybeUninit;
15
16cfg_if::cfg_if! {
17  if #[cfg(nasm_x86_64)] {
18    pub use crate::asm::x86::predict::*;
19  } else if #[cfg(asm_neon)] {
20    pub use crate::asm::aarch64::predict::*;
21  } else {
22    pub use self::rust::*;
23  }
24}
25
26use aligned_vec::{avec, ABox};
27
28use crate::context::{TileBlockOffset, MAX_SB_SIZE_LOG2, MAX_TX_SIZE};
29use crate::cpu_features::CpuFeatureLevel;
30use crate::encoder::FrameInvariants;
31use crate::frame::*;
32use crate::mc::*;
33use crate::partition::*;
34use crate::tiling::*;
35use crate::transform::*;
36use crate::util::*;
37
38pub const ANGLE_STEP: i8 = 3;
39
40// TODO: Review the order of this list.
41// The order impacts compression efficiency.
42pub static RAV1E_INTRA_MODES: &[PredictionMode] = &[
43  PredictionMode::DC_PRED,
44  PredictionMode::H_PRED,
45  PredictionMode::V_PRED,
46  PredictionMode::SMOOTH_PRED,
47  PredictionMode::SMOOTH_H_PRED,
48  PredictionMode::SMOOTH_V_PRED,
49  PredictionMode::PAETH_PRED,
50  PredictionMode::D45_PRED,
51  PredictionMode::D135_PRED,
52  PredictionMode::D113_PRED,
53  PredictionMode::D157_PRED,
54  PredictionMode::D203_PRED,
55  PredictionMode::D67_PRED,
56];
57
58pub static RAV1E_INTER_MODES_MINIMAL: &[PredictionMode] =
59  &[PredictionMode::NEARESTMV];
60
61pub static RAV1E_INTER_COMPOUND_MODES: &[PredictionMode] = &[
62  PredictionMode::GLOBAL_GLOBALMV,
63  PredictionMode::NEAREST_NEARESTMV,
64  PredictionMode::NEW_NEWMV,
65  PredictionMode::NEAREST_NEWMV,
66  PredictionMode::NEW_NEARESTMV,
67  PredictionMode::NEAR_NEAR0MV,
68  PredictionMode::NEAR_NEAR1MV,
69  PredictionMode::NEAR_NEAR2MV,
70];
71
72// There are more modes than in the spec because every allowed
73// drl index for NEAR modes is considered its own mode.
74#[derive(Copy, Clone, Debug, PartialEq, PartialOrd, Eq, Ord, Default)]
75pub enum PredictionMode {
76  #[default]
77  DC_PRED, // Average of above and left pixels
78  V_PRED,      // Vertical
79  H_PRED,      // Horizontal
80  D45_PRED,    // Directional 45  degree
81  D135_PRED,   // Directional 135 degree
82  D113_PRED,   // Directional 113 degree
83  D157_PRED,   // Directional 157 degree
84  D203_PRED,   // Directional 203 degree
85  D67_PRED,    // Directional 67  degree
86  SMOOTH_PRED, // Combination of horizontal and vertical interpolation
87  SMOOTH_V_PRED,
88  SMOOTH_H_PRED,
89  PAETH_PRED,
90  UV_CFL_PRED,
91  NEARESTMV,
92  NEAR0MV,
93  NEAR1MV,
94  NEAR2MV,
95  GLOBALMV,
96  NEWMV,
97  // Compound ref compound modes
98  NEAREST_NEARESTMV,
99  NEAR_NEAR0MV,
100  NEAR_NEAR1MV,
101  NEAR_NEAR2MV,
102  NEAREST_NEWMV,
103  NEW_NEARESTMV,
104  NEAR_NEW0MV,
105  NEAR_NEW1MV,
106  NEAR_NEW2MV,
107  NEW_NEAR0MV,
108  NEW_NEAR1MV,
109  NEW_NEAR2MV,
110  GLOBAL_GLOBALMV,
111  NEW_NEWMV,
112}
113
114// This is a higher number than in the spec and cannot be used
115// for bitstream writing purposes.
116pub const PREDICTION_MODES: usize = 34;
117
118#[derive(Copy, Clone, Debug)]
119pub enum PredictionVariant {
120  NONE,
121  LEFT,
122  TOP,
123  BOTH,
124}
125
126impl PredictionVariant {
127  #[inline]
128  const fn new(x: usize, y: usize) -> Self {
129    match (x, y) {
130      (0, 0) => PredictionVariant::NONE,
131      (_, 0) => PredictionVariant::LEFT,
132      (0, _) => PredictionVariant::TOP,
133      _ => PredictionVariant::BOTH,
134    }
135  }
136}
137
138pub const fn intra_mode_to_angle(mode: PredictionMode) -> isize {
139  match mode {
140    PredictionMode::V_PRED => 90,
141    PredictionMode::H_PRED => 180,
142    PredictionMode::D45_PRED => 45,
143    PredictionMode::D135_PRED => 135,
144    PredictionMode::D113_PRED => 113,
145    PredictionMode::D157_PRED => 157,
146    PredictionMode::D203_PRED => 203,
147    PredictionMode::D67_PRED => 67,
148    _ => 0,
149  }
150}
151
152impl PredictionMode {
153  #[inline]
154  pub fn is_compound(self) -> bool {
155    self >= PredictionMode::NEAREST_NEARESTMV
156  }
157  #[inline]
158  pub fn has_nearmv(self) -> bool {
159    self == PredictionMode::NEAR0MV
160      || self == PredictionMode::NEAR1MV
161      || self == PredictionMode::NEAR2MV
162      || self == PredictionMode::NEAR_NEAR0MV
163      || self == PredictionMode::NEAR_NEAR1MV
164      || self == PredictionMode::NEAR_NEAR2MV
165      || self == PredictionMode::NEAR_NEW0MV
166      || self == PredictionMode::NEAR_NEW1MV
167      || self == PredictionMode::NEAR_NEW2MV
168      || self == PredictionMode::NEW_NEAR0MV
169      || self == PredictionMode::NEW_NEAR1MV
170      || self == PredictionMode::NEW_NEAR2MV
171  }
172  #[inline]
173  pub fn has_newmv(self) -> bool {
174    self == PredictionMode::NEWMV
175      || self == PredictionMode::NEW_NEWMV
176      || self == PredictionMode::NEAREST_NEWMV
177      || self == PredictionMode::NEW_NEARESTMV
178      || self == PredictionMode::NEAR_NEW0MV
179      || self == PredictionMode::NEAR_NEW1MV
180      || self == PredictionMode::NEAR_NEW2MV
181      || self == PredictionMode::NEW_NEAR0MV
182      || self == PredictionMode::NEW_NEAR1MV
183      || self == PredictionMode::NEW_NEAR2MV
184  }
185  #[inline]
186  pub fn ref_mv_idx(self) -> usize {
187    if self == PredictionMode::NEAR0MV
188      || self == PredictionMode::NEAR1MV
189      || self == PredictionMode::NEAR2MV
190    {
191      self as usize - PredictionMode::NEAR0MV as usize + 1
192    } else if self == PredictionMode::NEAR_NEAR0MV
193      || self == PredictionMode::NEAR_NEAR1MV
194      || self == PredictionMode::NEAR_NEAR2MV
195    {
196      self as usize - PredictionMode::NEAR_NEAR0MV as usize + 1
197    } else {
198      1
199    }
200  }
201
202  /// # Panics
203  ///
204  /// - If called on an inter `PredictionMode`
205  pub fn predict_intra<T: Pixel>(
206    self, tile_rect: TileRect, dst: &mut PlaneRegionMut<'_, T>,
207    tx_size: TxSize, bit_depth: usize, ac: &[i16], intra_param: IntraParam,
208    ief_params: Option<IntraEdgeFilterParameters>, edge_buf: &IntraEdge<T>,
209    cpu: CpuFeatureLevel,
210  ) {
211    assert!(self.is_intra());
212    let &Rect { x: frame_x, y: frame_y, .. } = dst.rect();
213    debug_assert!(frame_x >= 0 && frame_y >= 0);
214    // x and y are expressed relative to the tile
215    let x = frame_x as usize - tile_rect.x;
216    let y = frame_y as usize - tile_rect.y;
217
218    let variant = PredictionVariant::new(x, y);
219
220    let alpha = match intra_param {
221      IntraParam::Alpha(val) => val,
222      _ => 0,
223    };
224    let angle_delta = match intra_param {
225      IntraParam::AngleDelta(val) => val,
226      _ => 0,
227    };
228
229    let mode = match self {
230      PredictionMode::PAETH_PRED => match variant {
231        PredictionVariant::NONE => PredictionMode::DC_PRED,
232        PredictionVariant::TOP => PredictionMode::V_PRED,
233        PredictionVariant::LEFT => PredictionMode::H_PRED,
234        PredictionVariant::BOTH => PredictionMode::PAETH_PRED,
235      },
236      PredictionMode::UV_CFL_PRED if alpha == 0 => PredictionMode::DC_PRED,
237      _ => self,
238    };
239
240    let angle = match mode {
241      PredictionMode::UV_CFL_PRED => alpha as isize,
242      _ => intra_mode_to_angle(mode) + (angle_delta * ANGLE_STEP) as isize,
243    };
244
245    dispatch_predict_intra::<T>(
246      mode, variant, dst, tx_size, bit_depth, ac, angle, ief_params, edge_buf,
247      cpu,
248    );
249  }
250
251  #[inline]
252  pub fn is_intra(self) -> bool {
253    self < PredictionMode::NEARESTMV
254  }
255
256  #[inline]
257  pub fn is_cfl(self) -> bool {
258    self == PredictionMode::UV_CFL_PRED
259  }
260
261  #[inline]
262  pub fn is_directional(self) -> bool {
263    self >= PredictionMode::V_PRED && self <= PredictionMode::D67_PRED
264  }
265
266  #[inline(always)]
267  pub const fn angle_delta_count(self) -> i8 {
268    match self {
269      PredictionMode::V_PRED
270      | PredictionMode::H_PRED
271      | PredictionMode::D45_PRED
272      | PredictionMode::D135_PRED
273      | PredictionMode::D113_PRED
274      | PredictionMode::D157_PRED
275      | PredictionMode::D203_PRED
276      | PredictionMode::D67_PRED => 7,
277      _ => 1,
278    }
279  }
280
281  // Used by inter prediction to extract the fractional component of a mv and
282  // obtain the correct PlaneSlice to operate on.
283  #[inline]
284  fn get_mv_params<T: Pixel>(
285    rec_plane: &Plane<T>, po: PlaneOffset, mv: MotionVector,
286  ) -> (i32, i32, PlaneSlice<T>) {
287    let &PlaneConfig { xdec, ydec, .. } = &rec_plane.cfg;
288    let row_offset = mv.row as i32 >> (3 + ydec);
289    let col_offset = mv.col as i32 >> (3 + xdec);
290    let row_frac = ((mv.row as i32) << (1 - ydec)) & 0xf;
291    let col_frac = ((mv.col as i32) << (1 - xdec)) & 0xf;
292    let qo = PlaneOffset {
293      x: po.x + col_offset as isize - 3,
294      y: po.y + row_offset as isize - 3,
295    };
296    (row_frac, col_frac, rec_plane.slice(qo).clamp().subslice(3, 3))
297  }
298
299  /// Inter prediction with a single reference (i.e. not compound mode)
300  ///
301  /// # Panics
302  ///
303  /// - If called on an intra `PredictionMode`
304  pub fn predict_inter_single<T: Pixel>(
305    self, fi: &FrameInvariants<T>, tile_rect: TileRect, p: usize,
306    po: PlaneOffset, dst: &mut PlaneRegionMut<'_, T>, width: usize,
307    height: usize, ref_frame: RefType, mv: MotionVector,
308  ) {
309    assert!(!self.is_intra());
310    let frame_po = tile_rect.to_frame_plane_offset(po);
311
312    let mode = fi.default_filter;
313
314    if let Some(ref rec) =
315      fi.rec_buffer.frames[fi.ref_frames[ref_frame.to_index()] as usize]
316    {
317      let (row_frac, col_frac, src) =
318        PredictionMode::get_mv_params(&rec.frame.planes[p], frame_po, mv);
319      put_8tap(
320        dst,
321        src,
322        width,
323        height,
324        col_frac,
325        row_frac,
326        mode,
327        mode,
328        fi.sequence.bit_depth,
329        fi.cpu_feature_level,
330      );
331    }
332  }
333
334  /// Inter prediction with two references.
335  ///
336  /// # Panics
337  ///
338  /// - If called on an intra `PredictionMode`
339  pub fn predict_inter_compound<T: Pixel>(
340    self, fi: &FrameInvariants<T>, tile_rect: TileRect, p: usize,
341    po: PlaneOffset, dst: &mut PlaneRegionMut<'_, T>, width: usize,
342    height: usize, ref_frames: [RefType; 2], mvs: [MotionVector; 2],
343    buffer: &mut InterCompoundBuffers,
344  ) {
345    assert!(!self.is_intra());
346    let frame_po = tile_rect.to_frame_plane_offset(po);
347
348    let mode = fi.default_filter;
349
350    for i in 0..2 {
351      if let Some(ref rec) =
352        fi.rec_buffer.frames[fi.ref_frames[ref_frames[i].to_index()] as usize]
353      {
354        let (row_frac, col_frac, src) = PredictionMode::get_mv_params(
355          &rec.frame.planes[p],
356          frame_po,
357          mvs[i],
358        );
359        prep_8tap(
360          buffer.get_buffer_mut(i),
361          src,
362          width,
363          height,
364          col_frac,
365          row_frac,
366          mode,
367          mode,
368          fi.sequence.bit_depth,
369          fi.cpu_feature_level,
370        );
371      }
372    }
373    mc_avg(
374      dst,
375      buffer.get_buffer(0),
376      buffer.get_buffer(1),
377      width,
378      height,
379      fi.sequence.bit_depth,
380      fi.cpu_feature_level,
381    );
382  }
383
384  /// Inter prediction that determines whether compound mode is being used based
385  /// on the second [`RefType`] in [`ref_frames`].
386  pub fn predict_inter<T: Pixel>(
387    self, fi: &FrameInvariants<T>, tile_rect: TileRect, p: usize,
388    po: PlaneOffset, dst: &mut PlaneRegionMut<'_, T>, width: usize,
389    height: usize, ref_frames: [RefType; 2], mvs: [MotionVector; 2],
390    compound_buffer: &mut InterCompoundBuffers,
391  ) {
392    let is_compound = ref_frames[1] != RefType::INTRA_FRAME
393      && ref_frames[1] != RefType::NONE_FRAME;
394
395    if !is_compound {
396      self.predict_inter_single(
397        fi,
398        tile_rect,
399        p,
400        po,
401        dst,
402        width,
403        height,
404        ref_frames[0],
405        mvs[0],
406      )
407    } else {
408      self.predict_inter_compound(
409        fi,
410        tile_rect,
411        p,
412        po,
413        dst,
414        width,
415        height,
416        ref_frames,
417        mvs,
418        compound_buffer,
419      );
420    }
421  }
422}
423
424/// A pair of buffers holding the interpolation of two references. Use for
425/// compound inter prediction.
426#[derive(Debug)]
427pub struct InterCompoundBuffers {
428  data: ABox<[i16]>,
429}
430
431impl InterCompoundBuffers {
432  // Size of one of the two buffers used.
433  const BUFFER_SIZE: usize = 1 << (2 * MAX_SB_SIZE_LOG2);
434
435  /// Get the buffer for eith
436  #[inline]
437  fn get_buffer_mut(&mut self, i: usize) -> &mut [i16] {
438    match i {
439      0 => &mut self.data[0..Self::BUFFER_SIZE],
440      1 => &mut self.data[Self::BUFFER_SIZE..2 * Self::BUFFER_SIZE],
441      _ => panic!(),
442    }
443  }
444
445  #[inline]
446  fn get_buffer(&self, i: usize) -> &[i16] {
447    match i {
448      0 => &self.data[0..Self::BUFFER_SIZE],
449      1 => &self.data[Self::BUFFER_SIZE..2 * Self::BUFFER_SIZE],
450      _ => panic!(),
451    }
452  }
453}
454
455impl Default for InterCompoundBuffers {
456  fn default() -> Self {
457    Self { data: avec![0; 2 * Self::BUFFER_SIZE].into_boxed_slice() }
458  }
459}
460
461#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd)]
462pub enum InterIntraMode {
463  II_DC_PRED,
464  II_V_PRED,
465  II_H_PRED,
466  II_SMOOTH_PRED,
467  INTERINTRA_MODES,
468}
469
470#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd)]
471pub enum CompoundType {
472  COMPOUND_AVERAGE,
473  COMPOUND_WEDGE,
474  COMPOUND_DIFFWTD,
475  COMPOUND_TYPES,
476}
477
478#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd)]
479pub enum MotionMode {
480  SIMPLE_TRANSLATION,
481  OBMC_CAUSAL,   // 2-sided OBMC
482  WARPED_CAUSAL, // 2-sided WARPED
483  MOTION_MODES,
484}
485
486#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd)]
487pub enum PaletteSize {
488  TWO_COLORS,
489  THREE_COLORS,
490  FOUR_COLORS,
491  FIVE_COLORS,
492  SIX_COLORS,
493  SEVEN_COLORS,
494  EIGHT_COLORS,
495  PALETTE_SIZES,
496}
497
498#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd)]
499pub enum PaletteColor {
500  PALETTE_COLOR_ONE,
501  PALETTE_COLOR_TWO,
502  PALETTE_COLOR_THREE,
503  PALETTE_COLOR_FOUR,
504  PALETTE_COLOR_FIVE,
505  PALETTE_COLOR_SIX,
506  PALETTE_COLOR_SEVEN,
507  PALETTE_COLOR_EIGHT,
508  PALETTE_COLORS,
509}
510
511#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd)]
512pub enum FilterIntraMode {
513  FILTER_DC_PRED,
514  FILTER_V_PRED,
515  FILTER_H_PRED,
516  FILTER_D157_PRED,
517  FILTER_PAETH_PRED,
518  FILTER_INTRA_MODES,
519}
520
521#[derive(Copy, Clone, Debug)]
522pub enum IntraParam {
523  AngleDelta(i8),
524  Alpha(i16),
525  None,
526}
527
528#[derive(Debug, Clone, Copy, Default)]
529pub struct AngleDelta {
530  pub y: i8,
531  pub uv: i8,
532}
533
534#[derive(Copy, Clone, Default)]
535pub struct IntraEdgeFilterParameters {
536  pub plane: usize,
537  pub above_ref_frame_types: Option<[RefType; 2]>,
538  pub left_ref_frame_types: Option<[RefType; 2]>,
539  pub above_mode: Option<PredictionMode>,
540  pub left_mode: Option<PredictionMode>,
541}
542
543impl IntraEdgeFilterParameters {
544  pub fn new(
545    plane: usize, above_ctx: Option<CodedBlockInfo>,
546    left_ctx: Option<CodedBlockInfo>,
547  ) -> Self {
548    IntraEdgeFilterParameters {
549      plane,
550      above_mode: match above_ctx {
551        Some(bi) => match plane {
552          0 => bi.luma_mode,
553          _ => bi.chroma_mode,
554        }
555        .into(),
556        None => None,
557      },
558      left_mode: match left_ctx {
559        Some(bi) => match plane {
560          0 => bi.luma_mode,
561          _ => bi.chroma_mode,
562        }
563        .into(),
564        None => None,
565      },
566      above_ref_frame_types: above_ctx.map(|bi| bi.reference_types),
567      left_ref_frame_types: left_ctx.map(|bi| bi.reference_types),
568    }
569  }
570
571  /// # Panics
572  ///
573  /// - If the appropriate ref frame types are not set on `self`
574  pub fn use_smooth_filter(self) -> bool {
575    let above_smooth = match self.above_mode {
576      Some(PredictionMode::SMOOTH_PRED)
577      | Some(PredictionMode::SMOOTH_V_PRED)
578      | Some(PredictionMode::SMOOTH_H_PRED) => {
579        self.plane == 0
580          || self.above_ref_frame_types.unwrap()[0] == RefType::INTRA_FRAME
581      }
582      _ => false,
583    };
584
585    let left_smooth = match self.left_mode {
586      Some(PredictionMode::SMOOTH_PRED)
587      | Some(PredictionMode::SMOOTH_V_PRED)
588      | Some(PredictionMode::SMOOTH_H_PRED) => {
589        self.plane == 0
590          || self.left_ref_frame_types.unwrap()[0] == RefType::INTRA_FRAME
591      }
592      _ => false,
593    };
594
595    above_smooth || left_smooth
596  }
597}
598
599// Weights are quadratic from '1' to '1 / block_size', scaled by 2^sm_weight_log2_scale.
600const sm_weight_log2_scale: u8 = 8;
601
602// Smooth predictor weights
603#[rustfmt::skip]
604static sm_weight_arrays: [u8; 2 * MAX_TX_SIZE] = [
605    // Unused, because we always offset by bs, which is at least 2.
606    0, 0,
607    // bs = 2
608    255, 128,
609    // bs = 4
610    255, 149, 85, 64,
611    // bs = 8
612    255, 197, 146, 105, 73, 50, 37, 32,
613    // bs = 16
614    255, 225, 196, 170, 145, 123, 102, 84, 68, 54, 43, 33, 26, 20, 17, 16,
615    // bs = 32
616    255, 240, 225, 210, 196, 182, 169, 157, 145, 133, 122, 111, 101, 92, 83, 74,
617    66, 59, 52, 45, 39, 34, 29, 25, 21, 17, 14, 12, 10, 9, 8, 8,
618    // bs = 64
619    255, 248, 240, 233, 225, 218, 210, 203, 196, 189, 182, 176, 169, 163, 156,
620    150, 144, 138, 133, 127, 121, 116, 111, 106, 101, 96, 91, 86, 82, 77, 73, 69,
621    65, 61, 57, 54, 50, 47, 44, 41, 38, 35, 32, 29, 27, 25, 22, 20, 18, 16, 15,
622    13, 12, 10, 9, 8, 7, 6, 6, 5, 5, 4, 4, 4,
623];
624
625#[inline(always)]
626const fn get_scaled_luma_q0(alpha_q3: i16, ac_pred_q3: i16) -> i32 {
627  let scaled_luma_q6 = (alpha_q3 as i32) * (ac_pred_q3 as i32);
628  let abs_scaled_luma_q0 = (scaled_luma_q6.abs() + 32) >> 6;
629  if scaled_luma_q6 < 0 {
630    -abs_scaled_luma_q0
631  } else {
632    abs_scaled_luma_q0
633  }
634}
635
636/// # Returns
637///
638/// Initialized luma AC coefficients
639///
640/// # Panics
641///
642/// - If the block size is invalid for subsampling
643///
644pub fn luma_ac<'ac, T: Pixel>(
645  ac: &'ac mut [MaybeUninit<i16>], ts: &mut TileStateMut<'_, T>,
646  tile_bo: TileBlockOffset, bsize: BlockSize, tx_size: TxSize,
647  fi: &FrameInvariants<T>,
648) -> &'ac mut [i16] {
649  use crate::context::MI_SIZE_LOG2;
650
651  let PlaneConfig { xdec, ydec, .. } = ts.input.planes[1].cfg;
652  let plane_bsize = bsize.subsampled_size(xdec, ydec).unwrap();
653
654  // ensure ac has the right length, so there aren't any uninitialized elements at the end
655  let ac = &mut ac[..plane_bsize.area()];
656
657  let bo = if bsize.is_sub8x8(xdec, ydec) {
658    let offset = bsize.sub8x8_offset(xdec, ydec);
659    tile_bo.with_offset(offset.0, offset.1)
660  } else {
661    tile_bo
662  };
663  let rec = &ts.rec.planes[0];
664  let luma = &rec.subregion(Area::BlockStartingAt { bo: bo.0 });
665  let frame_bo = ts.to_frame_block_offset(bo);
666
667  let frame_clipped_bw: usize =
668    ((fi.w_in_b - frame_bo.0.x) << MI_SIZE_LOG2).min(bsize.width());
669  let frame_clipped_bh: usize =
670    ((fi.h_in_b - frame_bo.0.y) << MI_SIZE_LOG2).min(bsize.height());
671
672  // Similar to 'MaxLumaW' and 'MaxLumaH' stated in https://aomediacodec.github.io/av1-spec/#transform-block-semantics
673  let max_luma_w = if bsize.width() > BlockSize::BLOCK_8X8.width() {
674    let txw_log2 = tx_size.width_log2();
675    ((frame_clipped_bw + (1 << txw_log2) - 1) >> txw_log2) << txw_log2
676  } else {
677    bsize.width()
678  };
679  let max_luma_h = if bsize.height() > BlockSize::BLOCK_8X8.height() {
680    let txh_log2 = tx_size.height_log2();
681    ((frame_clipped_bh + (1 << txh_log2) - 1) >> txh_log2) << txh_log2
682  } else {
683    bsize.height()
684  };
685
686  let w_pad = (bsize.width() - max_luma_w) >> (2 + xdec);
687  let h_pad = (bsize.height() - max_luma_h) >> (2 + ydec);
688  let cpu = fi.cpu_feature_level;
689
690  (match (xdec, ydec) {
691    (0, 0) => pred_cfl_ac::<T, 0, 0>,
692    (1, 0) => pred_cfl_ac::<T, 1, 0>,
693    (_, _) => pred_cfl_ac::<T, 1, 1>,
694  })(ac, luma, plane_bsize, w_pad, h_pad, cpu);
695
696  // SAFETY: it relies on individual pred_cfl_ac implementations to initialize the ac
697  unsafe { slice_assume_init_mut(ac) }
698}
699
700pub(crate) mod rust {
701  use super::*;
702  use std::mem::size_of;
703
704  #[inline(always)]
705  pub fn dispatch_predict_intra<T: Pixel>(
706    mode: PredictionMode, variant: PredictionVariant,
707    dst: &mut PlaneRegionMut<'_, T>, tx_size: TxSize, bit_depth: usize,
708    ac: &[i16], angle: isize, ief_params: Option<IntraEdgeFilterParameters>,
709    edge_buf: &IntraEdge<T>, _cpu: CpuFeatureLevel,
710  ) {
711    let width = tx_size.width();
712    let height = tx_size.height();
713
714    // left pixels are ordered from bottom to top and right-aligned
715    let (left, top_left, above) = edge_buf.as_slices();
716
717    let above_slice = above;
718    let left_slice = &left[left.len().saturating_sub(height)..];
719    let left_and_left_below_slice =
720      &left[left.len().saturating_sub(width + height)..];
721
722    match mode {
723      PredictionMode::DC_PRED => {
724        (match variant {
725          PredictionVariant::NONE => pred_dc_128,
726          PredictionVariant::LEFT => pred_dc_left,
727          PredictionVariant::TOP => pred_dc_top,
728          PredictionVariant::BOTH => pred_dc,
729        })(dst, above_slice, left_slice, width, height, bit_depth)
730      }
731      PredictionMode::V_PRED if angle == 90 => {
732        pred_v(dst, above_slice, width, height)
733      }
734      PredictionMode::H_PRED if angle == 180 => {
735        pred_h(dst, left_slice, width, height)
736      }
737      PredictionMode::H_PRED
738      | PredictionMode::V_PRED
739      | PredictionMode::D45_PRED
740      | PredictionMode::D135_PRED
741      | PredictionMode::D113_PRED
742      | PredictionMode::D157_PRED
743      | PredictionMode::D203_PRED
744      | PredictionMode::D67_PRED => pred_directional(
745        dst,
746        above_slice,
747        left_and_left_below_slice,
748        top_left,
749        angle as usize,
750        width,
751        height,
752        bit_depth,
753        ief_params,
754      ),
755      PredictionMode::SMOOTH_PRED => {
756        pred_smooth(dst, above_slice, left_slice, width, height)
757      }
758      PredictionMode::SMOOTH_V_PRED => {
759        pred_smooth_v(dst, above_slice, left_slice, width, height)
760      }
761      PredictionMode::SMOOTH_H_PRED => {
762        pred_smooth_h(dst, above_slice, left_slice, width, height)
763      }
764      PredictionMode::PAETH_PRED => {
765        pred_paeth(dst, above_slice, left_slice, top_left[0], width, height)
766      }
767      PredictionMode::UV_CFL_PRED => (match variant {
768        PredictionVariant::NONE => pred_cfl_128,
769        PredictionVariant::LEFT => pred_cfl_left,
770        PredictionVariant::TOP => pred_cfl_top,
771        PredictionVariant::BOTH => pred_cfl,
772      })(
773        dst,
774        ac,
775        angle as i16,
776        above_slice,
777        left_slice,
778        width,
779        height,
780        bit_depth,
781      ),
782      _ => unimplemented!(),
783    }
784  }
785
786  pub(crate) fn pred_dc<T: Pixel>(
787    output: &mut PlaneRegionMut<'_, T>, above: &[T], left: &[T], width: usize,
788    height: usize, _bit_depth: usize,
789  ) {
790    let edges = left[..height].iter().chain(above[..width].iter());
791    let len = (width + height) as u32;
792    let avg = (edges.fold(0u32, |acc, &v| {
793      let v: u32 = v.into();
794      v + acc
795    }) + (len >> 1))
796      / len;
797    let avg = T::cast_from(avg);
798
799    for line in output.rows_iter_mut().take(height) {
800      line[..width].fill(avg);
801    }
802  }
803
804  pub(crate) fn pred_dc_128<T: Pixel>(
805    output: &mut PlaneRegionMut<'_, T>, _above: &[T], _left: &[T],
806    width: usize, height: usize, bit_depth: usize,
807  ) {
808    let v = T::cast_from(128u32 << (bit_depth - 8));
809    for line in output.rows_iter_mut().take(height) {
810      line[..width].fill(v);
811    }
812  }
813
814  pub(crate) fn pred_dc_left<T: Pixel>(
815    output: &mut PlaneRegionMut<'_, T>, _above: &[T], left: &[T],
816    width: usize, height: usize, _bit_depth: usize,
817  ) {
818    let sum = left[..].iter().fold(0u32, |acc, &v| {
819      let v: u32 = v.into();
820      v + acc
821    });
822    let avg = T::cast_from((sum + (height >> 1) as u32) / height as u32);
823    for line in output.rows_iter_mut().take(height) {
824      line[..width].fill(avg);
825    }
826  }
827
828  pub(crate) fn pred_dc_top<T: Pixel>(
829    output: &mut PlaneRegionMut<'_, T>, above: &[T], _left: &[T],
830    width: usize, height: usize, _bit_depth: usize,
831  ) {
832    let sum = above[..width].iter().fold(0u32, |acc, &v| {
833      let v: u32 = v.into();
834      v + acc
835    });
836    let avg = T::cast_from((sum + (width >> 1) as u32) / width as u32);
837    for line in output.rows_iter_mut().take(height) {
838      line[..width].fill(avg);
839    }
840  }
841
842  pub(crate) fn pred_h<T: Pixel>(
843    output: &mut PlaneRegionMut<'_, T>, left: &[T], width: usize,
844    height: usize,
845  ) {
846    for (line, l) in output.rows_iter_mut().zip(left[..height].iter().rev()) {
847      line[..width].fill(*l);
848    }
849  }
850
851  pub(crate) fn pred_v<T: Pixel>(
852    output: &mut PlaneRegionMut<'_, T>, above: &[T], width: usize,
853    height: usize,
854  ) {
855    for line in output.rows_iter_mut().take(height) {
856      line[..width].copy_from_slice(&above[..width])
857    }
858  }
859
860  pub(crate) fn pred_paeth<T: Pixel>(
861    output: &mut PlaneRegionMut<'_, T>, above: &[T], left: &[T],
862    above_left: T, width: usize, height: usize,
863  ) {
864    for r in 0..height {
865      let row = &mut output[r];
866      for c in 0..width {
867        // Top-left pixel is fixed in libaom
868        let raw_top_left: i32 = above_left.into();
869        let raw_left: i32 = left[height - 1 - r].into();
870        let raw_top: i32 = above[c].into();
871
872        let p_base = raw_top + raw_left - raw_top_left;
873        let p_left = (p_base - raw_left).abs();
874        let p_top = (p_base - raw_top).abs();
875        let p_top_left = (p_base - raw_top_left).abs();
876
877        // Return nearest to base of left, top and top_left
878        if p_left <= p_top && p_left <= p_top_left {
879          row[c] = T::cast_from(raw_left);
880        } else if p_top <= p_top_left {
881          row[c] = T::cast_from(raw_top);
882        } else {
883          row[c] = T::cast_from(raw_top_left);
884        }
885      }
886    }
887  }
888
889  pub(crate) fn pred_smooth<T: Pixel>(
890    output: &mut PlaneRegionMut<'_, T>, above: &[T], left: &[T], width: usize,
891    height: usize,
892  ) {
893    let below_pred = left[0]; // estimated by bottom-left pixel
894    let right_pred = above[width - 1]; // estimated by top-right pixel
895    let sm_weights_w = &sm_weight_arrays[width..];
896    let sm_weights_h = &sm_weight_arrays[height..];
897
898    let log2_scale = 1 + sm_weight_log2_scale;
899    let scale = 1_u16 << sm_weight_log2_scale;
900
901    // Weights sanity checks
902    assert!((sm_weights_w[0] as u16) < scale);
903    assert!((sm_weights_h[0] as u16) < scale);
904    assert!((scale - sm_weights_w[width - 1] as u16) < scale);
905    assert!((scale - sm_weights_h[height - 1] as u16) < scale);
906    // ensures no overflow when calculating predictor
907    assert!(log2_scale as usize + size_of::<T>() < 31);
908
909    for r in 0..height {
910      let row = &mut output[r];
911      for c in 0..width {
912        let pixels = [above[c], below_pred, left[height - 1 - r], right_pred];
913
914        let weights = [
915          sm_weights_h[r] as u16,
916          scale - sm_weights_h[r] as u16,
917          sm_weights_w[c] as u16,
918          scale - sm_weights_w[c] as u16,
919        ];
920
921        assert!(
922          scale >= (sm_weights_h[r] as u16)
923            && scale >= (sm_weights_w[c] as u16)
924        );
925
926        // Sum up weighted pixels
927        let mut this_pred: u32 = weights
928          .iter()
929          .zip(pixels.iter())
930          .map(|(w, p)| {
931            let p: u32 = (*p).into();
932            (*w as u32) * p
933          })
934          .sum();
935        this_pred = (this_pred + (1 << (log2_scale - 1))) >> log2_scale;
936
937        row[c] = T::cast_from(this_pred);
938      }
939    }
940  }
941
942  pub(crate) fn pred_smooth_h<T: Pixel>(
943    output: &mut PlaneRegionMut<'_, T>, above: &[T], left: &[T], width: usize,
944    height: usize,
945  ) {
946    let right_pred = above[width - 1]; // estimated by top-right pixel
947    let sm_weights = &sm_weight_arrays[width..];
948
949    let log2_scale = sm_weight_log2_scale;
950    let scale = 1_u16 << sm_weight_log2_scale;
951
952    // Weights sanity checks
953    assert!((sm_weights[0] as u16) < scale);
954    assert!((scale - sm_weights[width - 1] as u16) < scale);
955    // ensures no overflow when calculating predictor
956    assert!(log2_scale as usize + size_of::<T>() < 31);
957
958    for r in 0..height {
959      let row = &mut output[r];
960      for c in 0..width {
961        let pixels = [left[height - 1 - r], right_pred];
962        let weights = [sm_weights[c] as u16, scale - sm_weights[c] as u16];
963
964        assert!(scale >= sm_weights[c] as u16);
965
966        let mut this_pred: u32 = weights
967          .iter()
968          .zip(pixels.iter())
969          .map(|(w, p)| {
970            let p: u32 = (*p).into();
971            (*w as u32) * p
972          })
973          .sum();
974        this_pred = (this_pred + (1 << (log2_scale - 1))) >> log2_scale;
975
976        row[c] = T::cast_from(this_pred);
977      }
978    }
979  }
980
981  pub(crate) fn pred_smooth_v<T: Pixel>(
982    output: &mut PlaneRegionMut<'_, T>, above: &[T], left: &[T], width: usize,
983    height: usize,
984  ) {
985    let below_pred = left[0]; // estimated by bottom-left pixel
986    let sm_weights = &sm_weight_arrays[height..];
987
988    let log2_scale = sm_weight_log2_scale;
989    let scale = 1_u16 << sm_weight_log2_scale;
990
991    // Weights sanity checks
992    assert!((sm_weights[0] as u16) < scale);
993    assert!((scale - sm_weights[height - 1] as u16) < scale);
994    // ensures no overflow when calculating predictor
995    assert!(log2_scale as usize + size_of::<T>() < 31);
996
997    for r in 0..height {
998      let row = &mut output[r];
999      for c in 0..width {
1000        let pixels = [above[c], below_pred];
1001        let weights = [sm_weights[r] as u16, scale - sm_weights[r] as u16];
1002
1003        assert!(scale >= sm_weights[r] as u16);
1004
1005        let mut this_pred: u32 = weights
1006          .iter()
1007          .zip(pixels.iter())
1008          .map(|(w, p)| {
1009            let p: u32 = (*p).into();
1010            (*w as u32) * p
1011          })
1012          .sum();
1013        this_pred = (this_pred + (1 << (log2_scale - 1))) >> log2_scale;
1014
1015        row[c] = T::cast_from(this_pred);
1016      }
1017    }
1018  }
1019
1020  pub(crate) fn pred_cfl_ac<T: Pixel, const XDEC: usize, const YDEC: usize>(
1021    ac: &mut [MaybeUninit<i16>], luma: &PlaneRegion<'_, T>,
1022    plane_bsize: BlockSize, w_pad: usize, h_pad: usize, _cpu: CpuFeatureLevel,
1023  ) {
1024    let max_luma_w = (plane_bsize.width() - w_pad * 4) << XDEC;
1025    let max_luma_h = (plane_bsize.height() - h_pad * 4) << YDEC;
1026    let max_luma_x: usize = max_luma_w.max(8) - (1 << XDEC);
1027    let max_luma_y: usize = max_luma_h.max(8) - (1 << YDEC);
1028    let mut sum: i32 = 0;
1029
1030    let ac = &mut ac[..plane_bsize.area()];
1031
1032    for (sub_y, ac_rows) in
1033      ac.chunks_exact_mut(plane_bsize.width()).enumerate()
1034    {
1035      for (sub_x, ac_item) in ac_rows.iter_mut().enumerate() {
1036        // Refer to https://aomediacodec.github.io/av1-spec/#predict-chroma-from-luma-process
1037        let luma_y = sub_y << YDEC;
1038        let luma_x = sub_x << XDEC;
1039        let y = luma_y.min(max_luma_y);
1040        let x = luma_x.min(max_luma_x);
1041        let mut sample: i16 = i16::cast_from(luma[y][x]);
1042        if XDEC != 0 {
1043          sample += i16::cast_from(luma[y][x + 1]);
1044        }
1045        if YDEC != 0 {
1046          debug_assert!(XDEC != 0);
1047          sample += i16::cast_from(luma[y + 1][x])
1048            + i16::cast_from(luma[y + 1][x + 1]);
1049        }
1050        sample <<= 3 - XDEC - YDEC;
1051        ac_item.write(sample);
1052        sum += sample as i32;
1053      }
1054    }
1055    // SAFETY: the loop above has initialized all items
1056    let ac = unsafe { slice_assume_init_mut(ac) };
1057    let shift = plane_bsize.width_log2() + plane_bsize.height_log2();
1058    let average = ((sum + (1 << (shift - 1))) >> shift) as i16;
1059
1060    for val in ac {
1061      *val -= average;
1062    }
1063  }
1064
1065  pub(crate) fn pred_cfl_inner<T: Pixel>(
1066    output: &mut PlaneRegionMut<'_, T>, ac: &[i16], alpha: i16, width: usize,
1067    height: usize, bit_depth: usize,
1068  ) {
1069    if alpha == 0 {
1070      return;
1071    }
1072    debug_assert!(ac.len() >= width * height);
1073    assert!(output.plane_cfg.stride >= width);
1074    assert!(output.rows_iter().len() >= height);
1075
1076    let sample_max = (1 << bit_depth) - 1;
1077    let avg: i32 = output[0][0].into();
1078
1079    for (line, luma) in
1080      output.rows_iter_mut().zip(ac.chunks_exact(width)).take(height)
1081    {
1082      for (v, &l) in line[..width].iter_mut().zip(luma[..width].iter()) {
1083        *v = T::cast_from(
1084          (avg + get_scaled_luma_q0(alpha, l)).clamp(0, sample_max),
1085        );
1086      }
1087    }
1088  }
1089
1090  pub(crate) fn pred_cfl<T: Pixel>(
1091    output: &mut PlaneRegionMut<'_, T>, ac: &[i16], alpha: i16, above: &[T],
1092    left: &[T], width: usize, height: usize, bit_depth: usize,
1093  ) {
1094    pred_dc(output, above, left, width, height, bit_depth);
1095    pred_cfl_inner(output, ac, alpha, width, height, bit_depth);
1096  }
1097
1098  pub(crate) fn pred_cfl_128<T: Pixel>(
1099    output: &mut PlaneRegionMut<'_, T>, ac: &[i16], alpha: i16, above: &[T],
1100    left: &[T], width: usize, height: usize, bit_depth: usize,
1101  ) {
1102    pred_dc_128(output, above, left, width, height, bit_depth);
1103    pred_cfl_inner(output, ac, alpha, width, height, bit_depth);
1104  }
1105
1106  pub(crate) fn pred_cfl_left<T: Pixel>(
1107    output: &mut PlaneRegionMut<'_, T>, ac: &[i16], alpha: i16, above: &[T],
1108    left: &[T], width: usize, height: usize, bit_depth: usize,
1109  ) {
1110    pred_dc_left(output, above, left, width, height, bit_depth);
1111    pred_cfl_inner(output, ac, alpha, width, height, bit_depth);
1112  }
1113
1114  pub(crate) fn pred_cfl_top<T: Pixel>(
1115    output: &mut PlaneRegionMut<'_, T>, ac: &[i16], alpha: i16, above: &[T],
1116    left: &[T], width: usize, height: usize, bit_depth: usize,
1117  ) {
1118    pred_dc_top(output, above, left, width, height, bit_depth);
1119    pred_cfl_inner(output, ac, alpha, width, height, bit_depth);
1120  }
1121
1122  #[allow(clippy::collapsible_if)]
1123  #[allow(clippy::collapsible_else_if)]
1124  #[allow(clippy::needless_return)]
1125  pub(crate) const fn select_ief_strength(
1126    width: usize, height: usize, smooth_filter: bool, angle_delta: isize,
1127  ) -> u8 {
1128    let block_wh = width + height;
1129    let abs_delta = angle_delta.unsigned_abs();
1130
1131    if smooth_filter {
1132      if block_wh <= 8 {
1133        if abs_delta >= 64 {
1134          return 2;
1135        }
1136        if abs_delta >= 40 {
1137          return 1;
1138        }
1139      } else if block_wh <= 16 {
1140        if abs_delta >= 48 {
1141          return 2;
1142        }
1143        if abs_delta >= 20 {
1144          return 1;
1145        }
1146      } else if block_wh <= 24 {
1147        if abs_delta >= 4 {
1148          return 3;
1149        }
1150      } else {
1151        return 3;
1152      }
1153    } else {
1154      if block_wh <= 8 {
1155        if abs_delta >= 56 {
1156          return 1;
1157        }
1158      } else if block_wh <= 16 {
1159        if abs_delta >= 40 {
1160          return 1;
1161        }
1162      } else if block_wh <= 24 {
1163        if abs_delta >= 32 {
1164          return 3;
1165        }
1166        if abs_delta >= 16 {
1167          return 2;
1168        }
1169        if abs_delta >= 8 {
1170          return 1;
1171        }
1172      } else if block_wh <= 32 {
1173        if abs_delta >= 32 {
1174          return 3;
1175        }
1176        if abs_delta >= 4 {
1177          return 2;
1178        }
1179        return 1;
1180      } else {
1181        return 3;
1182      }
1183    }
1184
1185    return 0;
1186  }
1187
1188  pub(crate) const fn select_ief_upsample(
1189    width: usize, height: usize, smooth_filter: bool, angle_delta: isize,
1190  ) -> bool {
1191    let block_wh = width + height;
1192    let abs_delta = angle_delta.unsigned_abs();
1193
1194    if abs_delta == 0 || abs_delta >= 40 {
1195      false
1196    } else if smooth_filter {
1197      block_wh <= 8
1198    } else {
1199      block_wh <= 16
1200    }
1201  }
1202
1203  pub(crate) fn filter_edge<T: Pixel>(
1204    size: usize, strength: u8, edge: &mut [T],
1205  ) {
1206    const INTRA_EDGE_KERNEL: [[u32; 5]; 3] =
1207      [[0, 4, 8, 4, 0], [0, 5, 6, 5, 0], [2, 4, 4, 4, 2]];
1208
1209    if strength == 0 {
1210      return;
1211    }
1212
1213    // Copy the edge buffer to avoid predicting from
1214    // just-filtered samples.
1215    let mut edge_filtered = [MaybeUninit::<T>::uninit(); MAX_TX_SIZE * 4 + 1];
1216    let edge_filtered =
1217      init_slice_repeat_mut(&mut edge_filtered[..edge.len()], T::zero());
1218    edge_filtered.copy_from_slice(&edge[..edge.len()]);
1219
1220    for i in 1..size {
1221      let mut s = 0;
1222
1223      for j in 0..INTRA_EDGE_KERNEL[0].len() {
1224        let k = (i + j).saturating_sub(2).min(size - 1);
1225        s += INTRA_EDGE_KERNEL[(strength - 1) as usize][j]
1226          * edge[k].to_u32().unwrap();
1227      }
1228
1229      edge_filtered[i] = T::cast_from((s + 8) >> 4);
1230    }
1231    edge.copy_from_slice(edge_filtered);
1232  }
1233
1234  pub(crate) fn upsample_edge<T: Pixel>(
1235    size: usize, edge: &mut [T], bit_depth: usize,
1236  ) {
1237    // The input edge should be valid in the -1..size range,
1238    // where the -1 index is the top-left edge pixel. Since
1239    // negative indices are unsafe in Rust, the caller is
1240    // expected to globally offset it by 1, which makes the
1241    // input range 0..=size.
1242    let mut dup = [MaybeUninit::<T>::uninit(); MAX_TX_SIZE];
1243    let dup = init_slice_repeat_mut(&mut dup[..size + 3], T::zero());
1244    dup[0] = edge[0];
1245    dup[1..=size + 1].copy_from_slice(&edge[0..=size]);
1246    dup[size + 2] = edge[size];
1247
1248    // Past here the edge is being filtered, and its
1249    // effective range is shifted from -1..size to
1250    // -2..2*size-1. Again, because this is safe Rust,
1251    // we cannot use negative indices, and the actual range
1252    // will be 0..=2*size. The caller is expected to adjust
1253    // its indices on receipt of the filtered edge.
1254    edge[0] = dup[0];
1255
1256    for i in 0..size {
1257      let mut s = -dup[i].to_i32().unwrap()
1258        + (9 * dup[i + 1].to_i32().unwrap())
1259        + (9 * dup[i + 2].to_i32().unwrap())
1260        - dup[i + 3].to_i32().unwrap();
1261      s = ((s + 8) / 16).clamp(0, (1 << bit_depth) - 1);
1262
1263      edge[2 * i + 1] = T::cast_from(s);
1264      edge[2 * i + 2] = dup[i + 2];
1265    }
1266  }
1267
1268  pub(crate) const fn dr_intra_derivative(p_angle: usize) -> usize {
1269    match p_angle {
1270      3 => 1023,
1271      6 => 547,
1272      9 => 372,
1273      14 => 273,
1274      17 => 215,
1275      20 => 178,
1276      23 => 151,
1277      26 => 132,
1278      29 => 116,
1279      32 => 102,
1280      36 => 90,
1281      39 => 80,
1282      42 => 71,
1283      45 => 64,
1284      48 => 57,
1285      51 => 51,
1286      54 => 45,
1287      58 => 40,
1288      61 => 35,
1289      64 => 31,
1290      67 => 27,
1291      70 => 23,
1292      73 => 19,
1293      76 => 15,
1294      81 => 11,
1295      84 => 7,
1296      87 => 3,
1297      _ => 0,
1298    }
1299  }
1300
1301  pub(crate) fn pred_directional<T: Pixel>(
1302    output: &mut PlaneRegionMut<'_, T>, above: &[T], left: &[T],
1303    top_left: &[T], p_angle: usize, width: usize, height: usize,
1304    bit_depth: usize, ief_params: Option<IntraEdgeFilterParameters>,
1305  ) {
1306    let sample_max = (1 << bit_depth) - 1;
1307
1308    let max_x = output.plane_cfg.width as isize - 1;
1309    let max_y = output.plane_cfg.height as isize - 1;
1310
1311    let mut upsample_above = false;
1312    let mut upsample_left = false;
1313
1314    let mut above_edge: &[T] = above;
1315    let mut left_edge: &[T] = left;
1316    let top_left_edge: T = top_left[0];
1317
1318    let enable_edge_filter = ief_params.is_some();
1319
1320    // Initialize above and left edge buffers of the largest possible needed size if upsampled
1321    // The first value is the top left pixel, also mutable and indexed at -1 in the spec
1322    let mut above_filtered = [MaybeUninit::<T>::uninit(); MAX_TX_SIZE * 4 + 1];
1323    let above_filtered = init_slice_repeat_mut(
1324      &mut above_filtered[..=(width + height) * 2],
1325      T::zero(),
1326    );
1327    let mut left_filtered = [MaybeUninit::<T>::uninit(); MAX_TX_SIZE * 4 + 1];
1328    let left_filtered = init_slice_repeat_mut(
1329      &mut left_filtered[..=(width + height) * 2],
1330      T::zero(),
1331    );
1332
1333    if enable_edge_filter {
1334      let above_len = above.len().min(above_filtered.len() - 1);
1335      let left_len = left.len().min(left_filtered.len() - 1);
1336      above_filtered[1..=above_len].clone_from_slice(&above[..above_len]);
1337      for i in 1..=left_len {
1338        left_filtered[i] = left[left.len() - i];
1339      }
1340
1341      let smooth_filter = ief_params.unwrap().use_smooth_filter();
1342
1343      if p_angle != 90 && p_angle != 180 {
1344        above_filtered[0] = top_left_edge;
1345        left_filtered[0] = top_left_edge;
1346
1347        let num_px = (
1348          width.min((max_x - output.rect().x + 1).try_into().unwrap())
1349            + if p_angle < 90 { height } else { 0 }
1350            + 1, // above
1351          height.min((max_y - output.rect().y + 1).try_into().unwrap())
1352            + if p_angle > 180 { width } else { 0 }
1353            + 1, // left
1354        );
1355
1356        let filter_strength = select_ief_strength(
1357          width,
1358          height,
1359          smooth_filter,
1360          p_angle as isize - 90,
1361        );
1362        filter_edge(num_px.0, filter_strength, above_filtered);
1363        let filter_strength = select_ief_strength(
1364          width,
1365          height,
1366          smooth_filter,
1367          p_angle as isize - 180,
1368        );
1369        filter_edge(num_px.1, filter_strength, left_filtered);
1370      }
1371
1372      let num_px = (
1373        width + if p_angle < 90 { height } else { 0 }, // above
1374        height + if p_angle > 180 { width } else { 0 }, // left
1375      );
1376
1377      upsample_above = select_ief_upsample(
1378        width,
1379        height,
1380        smooth_filter,
1381        p_angle as isize - 90,
1382      );
1383      if upsample_above {
1384        upsample_edge(num_px.0, above_filtered, bit_depth);
1385      }
1386      upsample_left = select_ief_upsample(
1387        width,
1388        height,
1389        smooth_filter,
1390        p_angle as isize - 180,
1391      );
1392      if upsample_left {
1393        upsample_edge(num_px.1, left_filtered, bit_depth);
1394      }
1395
1396      left_filtered.reverse();
1397      above_edge = above_filtered;
1398      left_edge = left_filtered;
1399    }
1400
1401    let dx = if p_angle < 90 {
1402      dr_intra_derivative(p_angle)
1403    } else if p_angle > 90 && p_angle < 180 {
1404      dr_intra_derivative(180 - p_angle)
1405    } else {
1406      0 // undefined
1407    };
1408
1409    let dy = if p_angle > 90 && p_angle < 180 {
1410      dr_intra_derivative(p_angle - 90)
1411    } else if p_angle > 180 {
1412      dr_intra_derivative(270 - p_angle)
1413    } else {
1414      0 // undefined
1415    };
1416
1417    // edge buffer index offsets applied due to the fact
1418    // that we cannot safely use negative indices in Rust
1419    let upsample_above = upsample_above as usize;
1420    let upsample_left = upsample_left as usize;
1421    let offset_above = (enable_edge_filter as usize) << upsample_above;
1422    let offset_left = (enable_edge_filter as usize) << upsample_left;
1423
1424    if p_angle < 90 {
1425      for i in 0..height {
1426        let row = &mut output[i];
1427        for j in 0..width {
1428          let idx = (i + 1) * dx;
1429          let base = (idx >> (6 - upsample_above)) + (j << upsample_above);
1430          let shift = (((idx << upsample_above) >> 1) & 31) as i32;
1431          let max_base_x = (height + width - 1) << upsample_above;
1432          let v = (if base < max_base_x {
1433            let a: i32 = above_edge[base + offset_above].into();
1434            let b: i32 = above_edge[base + 1 + offset_above].into();
1435            round_shift(a * (32 - shift) + b * shift, 5)
1436          } else {
1437            let c: i32 = above_edge[max_base_x + offset_above].into();
1438            c
1439          })
1440          .clamp(0, sample_max);
1441          row[j] = T::cast_from(v);
1442        }
1443      }
1444    } else if p_angle > 90 && p_angle < 180 {
1445      for i in 0..height {
1446        let row = &mut output[i];
1447        for j in 0..width {
1448          let idx = (j << 6) as isize - ((i + 1) * dx) as isize;
1449          let base = idx >> (6 - upsample_above);
1450          if base >= -(1 << upsample_above) {
1451            let shift = (((idx << upsample_above) >> 1) & 31) as i32;
1452            let a: i32 = if !enable_edge_filter && base < 0 {
1453              top_left_edge
1454            } else {
1455              above_edge[(base + offset_above as isize) as usize]
1456            }
1457            .into();
1458            let b: i32 =
1459              above_edge[(base + 1 + offset_above as isize) as usize].into();
1460            let v = round_shift(a * (32 - shift) + b * shift, 5)
1461              .clamp(0, sample_max);
1462            row[j] = T::cast_from(v);
1463          } else {
1464            let idx = (i << 6) as isize - ((j + 1) * dy) as isize;
1465            let base = idx >> (6 - upsample_left);
1466            let shift = (((idx << upsample_left) >> 1) & 31) as i32;
1467            let l = left_edge.len() - 1;
1468            let a: i32 = if !enable_edge_filter && base < 0 {
1469              top_left_edge
1470            } else if (base + offset_left as isize) == -2 {
1471              left_edge[0]
1472            } else {
1473              left_edge[l - (base + offset_left as isize) as usize]
1474            }
1475            .into();
1476            let b: i32 = if (base + offset_left as isize) == -2 {
1477              left_edge[1]
1478            } else {
1479              left_edge[l - (base + offset_left as isize + 1) as usize]
1480            }
1481            .into();
1482            let v = round_shift(a * (32 - shift) + b * shift, 5)
1483              .clamp(0, sample_max);
1484            row[j] = T::cast_from(v);
1485          }
1486        }
1487      }
1488    } else if p_angle > 180 {
1489      for i in 0..height {
1490        let row = &mut output[i];
1491        for j in 0..width {
1492          let idx = (j + 1) * dy;
1493          let base = (idx >> (6 - upsample_left)) + (i << upsample_left);
1494          let shift = (((idx << upsample_left) >> 1) & 31) as i32;
1495          let l = left_edge.len() - 1;
1496          let a: i32 = left_edge[l.saturating_sub(base + offset_left)].into();
1497          let b: i32 =
1498            left_edge[l.saturating_sub(base + offset_left + 1)].into();
1499          let v =
1500            round_shift(a * (32 - shift) + b * shift, 5).clamp(0, sample_max);
1501          row[j] = T::cast_from(v);
1502        }
1503      }
1504    }
1505  }
1506}
1507
1508#[cfg(test)]
1509mod test {
1510  use super::*;
1511  use crate::predict::rust::*;
1512  use num_traits::*;
1513
1514  #[test]
1515  fn pred_matches_u8() {
1516    let edge_buf =
1517      Aligned::from_fn(|i| (i + 32).saturating_sub(MAX_TX_SIZE * 2).as_());
1518    let (all_left, top_left, above) = IntraEdge::mock(&edge_buf).as_slices();
1519    let left = &all_left[all_left.len() - 4..];
1520
1521    let mut output = Plane::from_slice(&[0u8; 4 * 4], 4);
1522
1523    pred_dc(&mut output.as_region_mut(), above, left, 4, 4, 8);
1524    assert_eq!(&output.data[..], [32u8; 16]);
1525
1526    pred_dc_top(&mut output.as_region_mut(), above, left, 4, 4, 8);
1527    assert_eq!(&output.data[..], [35u8; 16]);
1528
1529    pred_dc_left(&mut output.as_region_mut(), above, left, 4, 4, 8);
1530    assert_eq!(&output.data[..], [30u8; 16]);
1531
1532    pred_dc_128(&mut output.as_region_mut(), above, left, 4, 4, 8);
1533    assert_eq!(&output.data[..], [128u8; 16]);
1534
1535    pred_v(&mut output.as_region_mut(), above, 4, 4);
1536    assert_eq!(
1537      &output.data[..],
1538      [33, 34, 35, 36, 33, 34, 35, 36, 33, 34, 35, 36, 33, 34, 35, 36]
1539    );
1540
1541    pred_h(&mut output.as_region_mut(), left, 4, 4);
1542    assert_eq!(
1543      &output.data[..],
1544      [31, 31, 31, 31, 30, 30, 30, 30, 29, 29, 29, 29, 28, 28, 28, 28]
1545    );
1546
1547    pred_paeth(&mut output.as_region_mut(), above, left, top_left[0], 4, 4);
1548    assert_eq!(
1549      &output.data[..],
1550      [32, 34, 35, 36, 30, 32, 32, 36, 29, 32, 32, 32, 28, 28, 32, 32]
1551    );
1552
1553    pred_smooth(&mut output.as_region_mut(), above, left, 4, 4);
1554    assert_eq!(
1555      &output.data[..],
1556      [32, 34, 35, 35, 30, 32, 33, 34, 29, 31, 32, 32, 29, 30, 32, 32]
1557    );
1558
1559    pred_smooth_h(&mut output.as_region_mut(), above, left, 4, 4);
1560    assert_eq!(
1561      &output.data[..],
1562      [31, 33, 34, 35, 30, 33, 34, 35, 29, 32, 34, 34, 28, 31, 33, 34]
1563    );
1564
1565    pred_smooth_v(&mut output.as_region_mut(), above, left, 4, 4);
1566    assert_eq!(
1567      &output.data[..],
1568      [33, 34, 35, 36, 31, 31, 32, 33, 30, 30, 30, 31, 29, 30, 30, 30]
1569    );
1570
1571    let left = &all_left[all_left.len() - 8..];
1572    let angles = [
1573      3, 6, 9, 14, 17, 20, 23, 26, 29, 32, 36, 39, 42, 45, 48, 51, 54, 58, 61,
1574      64, 67, 70, 73, 76, 81, 84, 87,
1575    ];
1576    let expected = [
1577      [40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40],
1578      [40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40],
1579      [39, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40],
1580      [37, 38, 39, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40],
1581      [36, 37, 38, 39, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40],
1582      [36, 37, 38, 39, 39, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40],
1583      [35, 36, 37, 38, 38, 39, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40],
1584      [35, 36, 37, 38, 37, 38, 39, 40, 39, 40, 40, 40, 40, 40, 40, 40],
1585      [35, 36, 37, 38, 37, 38, 39, 40, 38, 39, 40, 40, 40, 40, 40, 40],
1586      [35, 36, 37, 38, 36, 37, 38, 39, 38, 39, 40, 40, 39, 40, 40, 40],
1587      [34, 35, 36, 37, 36, 37, 38, 39, 37, 38, 39, 40, 39, 40, 40, 40],
1588      [34, 35, 36, 37, 36, 37, 38, 39, 37, 38, 39, 40, 38, 39, 40, 40],
1589      [34, 35, 36, 37, 35, 36, 37, 38, 36, 37, 38, 39, 37, 38, 39, 40],
1590      [34, 35, 36, 37, 35, 36, 37, 38, 36, 37, 38, 39, 37, 38, 39, 40],
1591      [34, 35, 36, 37, 35, 36, 37, 38, 36, 37, 38, 39, 37, 38, 39, 40],
1592      [34, 35, 36, 37, 35, 36, 37, 38, 35, 36, 37, 38, 36, 37, 38, 39],
1593      [34, 35, 36, 37, 34, 35, 36, 37, 35, 36, 37, 38, 36, 37, 38, 39],
1594      [34, 35, 36, 37, 34, 35, 36, 37, 35, 36, 37, 38, 36, 37, 38, 39],
1595      [34, 35, 36, 37, 34, 35, 36, 37, 35, 36, 37, 38, 35, 36, 37, 38],
1596      [33, 34, 35, 36, 34, 35, 36, 37, 34, 35, 36, 37, 35, 36, 37, 38],
1597      [33, 34, 35, 36, 34, 35, 36, 37, 34, 35, 36, 37, 35, 36, 37, 38],
1598      [33, 34, 35, 36, 34, 35, 36, 37, 34, 35, 36, 37, 34, 35, 36, 37],
1599      [33, 34, 35, 36, 34, 35, 36, 37, 34, 35, 36, 37, 34, 35, 36, 37],
1600      [33, 34, 35, 36, 33, 34, 35, 36, 34, 35, 36, 37, 34, 35, 36, 37],
1601      [33, 34, 35, 36, 33, 34, 35, 36, 34, 35, 36, 37, 34, 35, 36, 37],
1602      [33, 34, 35, 36, 33, 34, 35, 36, 33, 34, 35, 36, 33, 34, 35, 36],
1603      [33, 34, 35, 36, 33, 34, 35, 36, 33, 34, 35, 36, 33, 34, 35, 36],
1604    ];
1605    for (&angle, expected) in angles.iter().zip(expected.iter()) {
1606      pred_directional(
1607        &mut output.as_region_mut(),
1608        above,
1609        left,
1610        top_left,
1611        angle,
1612        4,
1613        4,
1614        8,
1615        None,
1616      );
1617      assert_eq!(&output.data[..], expected);
1618    }
1619  }
1620
1621  #[test]
1622  fn pred_max() {
1623    let max12bit = 4096 - 1;
1624    let above = [max12bit; 32];
1625    let left = [max12bit; 32];
1626
1627    let mut o = Plane::from_slice(&vec![0u16; 32 * 32], 32);
1628
1629    pred_dc(&mut o.as_region_mut(), &above[..4], &left[..4], 4, 4, 16);
1630
1631    for l in o.data.chunks(32).take(4) {
1632      for v in l[..4].iter() {
1633        assert_eq!(*v, max12bit);
1634      }
1635    }
1636
1637    pred_h(&mut o.as_region_mut(), &left[..4], 4, 4);
1638
1639    for l in o.data.chunks(32).take(4) {
1640      for v in l[..4].iter() {
1641        assert_eq!(*v, max12bit);
1642      }
1643    }
1644
1645    pred_v(&mut o.as_region_mut(), &above[..4], 4, 4);
1646
1647    for l in o.data.chunks(32).take(4) {
1648      for v in l[..4].iter() {
1649        assert_eq!(*v, max12bit);
1650      }
1651    }
1652
1653    let above_left = max12bit;
1654
1655    pred_paeth(
1656      &mut o.as_region_mut(),
1657      &above[..4],
1658      &left[..4],
1659      above_left,
1660      4,
1661      4,
1662    );
1663
1664    for l in o.data.chunks(32).take(4) {
1665      for v in l[..4].iter() {
1666        assert_eq!(*v, max12bit);
1667      }
1668    }
1669
1670    pred_smooth(&mut o.as_region_mut(), &above[..4], &left[..4], 4, 4);
1671
1672    for l in o.data.chunks(32).take(4) {
1673      for v in l[..4].iter() {
1674        assert_eq!(*v, max12bit);
1675      }
1676    }
1677
1678    pred_smooth_h(&mut o.as_region_mut(), &above[..4], &left[..4], 4, 4);
1679
1680    for l in o.data.chunks(32).take(4) {
1681      for v in l[..4].iter() {
1682        assert_eq!(*v, max12bit);
1683      }
1684    }
1685
1686    pred_smooth_v(&mut o.as_region_mut(), &above[..4], &left[..4], 4, 4);
1687
1688    for l in o.data.chunks(32).take(4) {
1689      for v in l[..4].iter() {
1690        assert_eq!(*v, max12bit);
1691      }
1692    }
1693  }
1694}