#![forbid(unsafe_code)]
#![allow(dead_code)]
#![allow(clippy::doc_markdown)]
#![allow(clippy::too_many_arguments)]
#![allow(clippy::cast_possible_truncation)]
#![allow(clippy::cast_sign_loss)]
#![allow(clippy::cast_possible_wrap)]
#![allow(clippy::similar_names)]
#![allow(clippy::module_name_repetitions)]
#![allow(clippy::struct_excessive_bools)]
use super::block::{BlockModeInfo, InterMode, IntraMode as Av1IntraMode};
use crate::error::{CodecError, CodecResult};
use crate::frame::VideoFrame;
use crate::intra::{
BitDepth, BlockDimensions, DcPredictor, DirectionalPredictor, HorizontalPredictor, IntraMode,
IntraPredContext, IntraPredictor, PaethPredictor, SmoothHPredictor, SmoothPredictor,
SmoothVPredictor, VerticalPredictor,
};
pub const SUBPEL_BITS: u8 = 3;
pub const SUBPEL_SCALE: i32 = 1 << SUBPEL_BITS;
pub const INTERP_TAPS: usize = 8;
pub const MAX_OBMC_SIZE: usize = 128;
pub const WARP_PARAMS: usize = 6;
#[derive(Debug)]
pub struct PredictionEngine {
current_frame: Option<VideoFrame>,
reference_frames: Vec<Option<VideoFrame>>,
bit_depth: u8,
intra_context: IntraPredContext,
mc_buffer: Vec<u16>,
}
impl PredictionEngine {
pub fn new(width: u32, height: u32, bit_depth: u8) -> Self {
let intra_bd = match bit_depth {
8 => BitDepth::Bits8,
10 => BitDepth::Bits10,
12 => BitDepth::Bits12,
_ => BitDepth::Bits8,
};
Self {
current_frame: None,
reference_frames: vec![None; 8],
bit_depth,
intra_context: IntraPredContext::new(width as usize, height as usize, intra_bd),
mc_buffer: vec![0; MAX_OBMC_SIZE * MAX_OBMC_SIZE],
}
}
pub fn predict_block(
&mut self,
mode_info: &BlockModeInfo,
x: u32,
y: u32,
plane: u8,
dst: &mut [u16],
stride: usize,
) -> CodecResult<()> {
if mode_info.is_inter {
self.predict_inter(mode_info, x, y, plane, dst, stride)
} else {
self.predict_intra(mode_info, x, y, plane, dst, stride)
}
}
fn predict_intra(
&mut self,
mode_info: &BlockModeInfo,
_x: u32,
_y: u32,
plane: u8,
dst: &mut [u16],
stride: usize,
) -> CodecResult<()> {
let bsize = mode_info.block_size;
let width = bsize.width() as usize;
let height = bsize.height() as usize;
let mode = if plane == 0 {
mode_info.intra_mode
} else {
mode_info.uv_mode
};
let intra_mode = self.map_intra_mode(mode);
if mode.is_directional() && plane == 0 {
let _angle_delta = mode_info.angle_delta[0];
}
self.apply_intra_mode(intra_mode, mode, width, height, dst, stride)?;
if mode_info.filter_intra_mode > 0 && plane == 0 {
self.apply_filter_intra(dst, width, height, stride, mode_info.filter_intra_mode)?;
}
Ok(())
}
fn map_intra_mode(&self, mode: Av1IntraMode) -> IntraMode {
match mode {
Av1IntraMode::DcPred => IntraMode::Dc,
Av1IntraMode::VPred => IntraMode::Vertical,
Av1IntraMode::HPred => IntraMode::Horizontal,
Av1IntraMode::D45Pred => IntraMode::D45,
Av1IntraMode::D135Pred => IntraMode::D135,
Av1IntraMode::D113Pred => IntraMode::D113,
Av1IntraMode::D157Pred => IntraMode::D157,
Av1IntraMode::D203Pred => IntraMode::D203,
Av1IntraMode::D67Pred => IntraMode::D67,
Av1IntraMode::SmoothPred => IntraMode::Smooth,
Av1IntraMode::SmoothVPred => IntraMode::SmoothV,
Av1IntraMode::SmoothHPred => IntraMode::SmoothH,
Av1IntraMode::PaethPred => IntraMode::Paeth,
}
}
fn apply_intra_mode(
&self,
intra_mode: IntraMode,
_av1_mode: Av1IntraMode,
width: usize,
height: usize,
dst: &mut [u16],
stride: usize,
) -> CodecResult<()> {
let block_dims = BlockDimensions::new(width, height);
let bit_depth = self.intra_context.bit_depth();
match intra_mode {
IntraMode::Dc => {
let predictor = DcPredictor::new(bit_depth);
predictor.predict(&self.intra_context, dst, stride, block_dims);
}
IntraMode::Vertical => {
let predictor = VerticalPredictor::new();
predictor.predict(&self.intra_context, dst, stride, block_dims);
}
IntraMode::Horizontal => {
let predictor = HorizontalPredictor::new();
predictor.predict(&self.intra_context, dst, stride, block_dims);
}
IntraMode::Smooth => {
let predictor = SmoothPredictor::new();
predictor.predict(&self.intra_context, dst, stride, block_dims);
}
IntraMode::SmoothV => {
let predictor = SmoothVPredictor::new();
predictor.predict(&self.intra_context, dst, stride, block_dims);
}
IntraMode::SmoothH => {
let predictor = SmoothHPredictor::new();
predictor.predict(&self.intra_context, dst, stride, block_dims);
}
IntraMode::Paeth => {
let predictor = PaethPredictor::new();
predictor.predict(&self.intra_context, dst, stride, block_dims);
}
IntraMode::D45
| IntraMode::D135
| IntraMode::D113
| IntraMode::D157
| IntraMode::D203
| IntraMode::D67 => {
let angle = self.intra_mode_to_angle(intra_mode);
let bit_depth = self.intra_context.bit_depth();
let predictor = DirectionalPredictor::new(angle, bit_depth);
predictor.predict(&self.intra_context, dst, stride, block_dims);
}
IntraMode::FilterIntra => {
let predictor = DcPredictor::new(bit_depth);
predictor.predict(&self.intra_context, dst, stride, block_dims);
}
}
Ok(())
}
fn intra_mode_to_angle(&self, mode: IntraMode) -> u16 {
match mode {
IntraMode::D45 => 45,
IntraMode::D67 => 67,
IntraMode::D113 => 113,
IntraMode::D135 => 135,
IntraMode::D157 => 157,
IntraMode::D203 => 203,
_ => 0,
}
}
fn apply_angle_delta(&mut self, _ctx: &mut IntraPredContext, _delta: i8) {
}
fn apply_filter_intra(
&self,
dst: &mut [u16],
width: usize,
height: usize,
stride: usize,
_mode: u8,
) -> CodecResult<()> {
for y in 0..height {
for x in 0..width {
let idx = y * stride + x;
if idx < dst.len() {
dst[idx] = self.apply_filter_tap(dst, x, y, width, height, stride);
}
}
}
Ok(())
}
fn apply_filter_tap(
&self,
src: &[u16],
x: usize,
y: usize,
width: usize,
height: usize,
stride: usize,
) -> u16 {
let mut sum = 0u32;
let mut count = 0u32;
for dy in -1i32..=1 {
for dx in -1i32..=1 {
let nx = (x as i32 + dx) as usize;
let ny = (y as i32 + dy) as usize;
if nx < width && ny < height {
let idx = ny * stride + nx;
if idx < src.len() {
sum += u32::from(src[idx]);
count += 1;
}
}
}
}
sum.checked_div(count)
.map_or(src[y * stride + x], |v| v as u16)
}
fn predict_inter(
&mut self,
mode_info: &BlockModeInfo,
x: u32,
y: u32,
plane: u8,
dst: &mut [u16],
stride: usize,
) -> CodecResult<()> {
let bsize = mode_info.block_size;
let width = bsize.width() as usize;
let height = bsize.height() as usize;
if mode_info.is_compound() {
self.predict_compound(mode_info, x, y, plane, dst, stride, width, height)
} else {
self.predict_single_ref(mode_info, x, y, plane, dst, stride, width, height)
}
}
fn predict_single_ref(
&mut self,
mode_info: &BlockModeInfo,
x: u32,
y: u32,
plane: u8,
dst: &mut [u16],
stride: usize,
width: usize,
height: usize,
) -> CodecResult<()> {
let ref_idx = mode_info.ref_frames[0];
if ref_idx < 0 || ref_idx >= self.reference_frames.len() as i8 {
return Err(CodecError::InvalidBitstream(
"Invalid reference frame".to_string(),
));
}
let ref_frame = &self.reference_frames[ref_idx as usize];
if ref_frame.is_none() {
return Err(CodecError::InvalidBitstream(
"Reference frame not available".to_string(),
));
}
let mv = self.get_motion_vector(mode_info, 0);
let ref_frame_inner = ref_frame.as_ref().ok_or_else(|| {
CodecError::InvalidBitstream("Reference frame not available".to_string())
})?;
self.motion_compensate(
ref_frame_inner,
x,
y,
mv,
plane,
dst,
stride,
width,
height,
mode_info.interp_filter[0],
)?;
if mode_info.motion_mode == 1 {
self.apply_obmc(mode_info, x, y, plane, dst, stride, width, height)?;
}
if mode_info.motion_mode == 2 {
self.apply_warped_motion(mode_info, x, y, plane, dst, stride, width, height)?;
}
Ok(())
}
fn predict_compound(
&mut self,
mode_info: &BlockModeInfo,
x: u32,
y: u32,
plane: u8,
dst: &mut [u16],
stride: usize,
width: usize,
height: usize,
) -> CodecResult<()> {
let mut pred0 = vec![0u16; width * height];
let mut pred1 = vec![0u16; width * height];
if mode_info.ref_frames[0] >= 0 {
let mv0 = self.get_motion_vector(mode_info, 0);
let ref0 = &self.reference_frames[mode_info.ref_frames[0] as usize];
if let Some(ref frame) = ref0 {
self.motion_compensate(
frame,
x,
y,
mv0,
plane,
&mut pred0,
width,
width,
height,
mode_info.interp_filter[0],
)?;
}
}
if mode_info.ref_frames[1] >= 0 {
let mv1 = self.get_motion_vector(mode_info, 1);
let ref1 = &self.reference_frames[mode_info.ref_frames[1] as usize];
if let Some(ref frame) = ref1 {
self.motion_compensate(
frame,
x,
y,
mv1,
plane,
&mut pred1,
width,
width,
height,
mode_info.interp_filter[1],
)?;
}
}
self.blend_compound_predictions(
&pred0,
&pred1,
dst,
stride,
width,
height,
mode_info.compound_type,
);
Ok(())
}
fn get_motion_vector(&self, mode_info: &BlockModeInfo, ref_idx: usize) -> [i32; 2] {
match mode_info.inter_mode {
InterMode::NewMv => [
i32::from(mode_info.mv[ref_idx][0]),
i32::from(mode_info.mv[ref_idx][1]),
],
InterMode::NearestMv | InterMode::NearMv => {
[0, 0]
}
InterMode::GlobalMv => {
[0, 0]
}
}
}
#[allow(clippy::too_many_lines)]
fn motion_compensate(
&self,
ref_frame: &VideoFrame,
x: u32,
y: u32,
mv: [i32; 2],
plane: u8,
dst: &mut [u16],
stride: usize,
width: usize,
height: usize,
_interp_filter: u8,
) -> CodecResult<()> {
let ref_x = (x as i32 * SUBPEL_SCALE) + mv[1];
let ref_y = (y as i32 * SUBPEL_SCALE) + mv[0];
let int_x = ref_x >> SUBPEL_BITS;
let int_y = ref_y >> SUBPEL_BITS;
let frac_x = (ref_x & (SUBPEL_SCALE - 1)) as usize;
let frac_y = (ref_y & (SUBPEL_SCALE - 1)) as usize;
let plane_idx = plane as usize;
let (ref_data, ref_stride) = if plane_idx < ref_frame.planes.len() {
(
&ref_frame.planes[plane_idx].data[..],
ref_frame.planes[plane_idx].stride,
)
} else {
return Err(CodecError::Internal("Invalid plane index".to_string()));
};
if frac_x == 0 && frac_y == 0 {
self.copy_block(
ref_data, ref_stride, int_x, int_y, dst, stride, width, height,
);
} else if frac_y == 0 {
self.interp_horizontal(
ref_data, ref_stride, int_x, int_y, frac_x, dst, stride, width, height,
);
} else if frac_x == 0 {
self.interp_vertical(
ref_data, ref_stride, int_x, int_y, frac_y, dst, stride, width, height,
);
} else {
self.interp_2d(
ref_data, ref_stride, int_x, int_y, frac_x, frac_y, dst, stride, width, height,
);
}
Ok(())
}
fn copy_block(
&self,
src: &[u8],
src_stride: usize,
x: i32,
y: i32,
dst: &mut [u16],
dst_stride: usize,
width: usize,
height: usize,
) {
for row in 0..height {
let src_y = (y + row as i32).max(0) as usize;
let src_start = src_y * src_stride + x.max(0) as usize;
for col in 0..width {
if src_start + col < src.len() {
let dst_idx = row * dst_stride + col;
if dst_idx < dst.len() {
dst[dst_idx] = u16::from(src[src_start + col]);
}
}
}
}
}
fn interp_horizontal(
&self,
src: &[u8],
src_stride: usize,
x: i32,
y: i32,
frac: usize,
dst: &mut [u16],
dst_stride: usize,
width: usize,
height: usize,
) {
let filter = self.get_interp_filter(frac);
for row in 0..height {
let src_y = (y + row as i32).max(0) as usize;
for col in 0..width {
let mut sum = 0i32;
for tap in 0..INTERP_TAPS {
let src_x =
(x + col as i32 + tap as i32 - INTERP_TAPS as i32 / 2).max(0) as usize;
let src_idx = src_y * src_stride + src_x;
if src_idx < src.len() {
sum += i32::from(src[src_idx]) * filter[tap];
}
}
let dst_idx = row * dst_stride + col;
if dst_idx < dst.len() {
dst[dst_idx] = ((sum + 64) >> 7).clamp(0, (1 << self.bit_depth) - 1) as u16;
}
}
}
}
fn interp_vertical(
&self,
src: &[u8],
src_stride: usize,
x: i32,
y: i32,
frac: usize,
dst: &mut [u16],
dst_stride: usize,
width: usize,
height: usize,
) {
let filter = self.get_interp_filter(frac);
for row in 0..height {
for col in 0..width {
let mut sum = 0i32;
for tap in 0..INTERP_TAPS {
let src_y =
(y + row as i32 + tap as i32 - INTERP_TAPS as i32 / 2).max(0) as usize;
let src_x = (x + col as i32).max(0) as usize;
let src_idx = src_y * src_stride + src_x;
if src_idx < src.len() {
sum += i32::from(src[src_idx]) * filter[tap];
}
}
let dst_idx = row * dst_stride + col;
if dst_idx < dst.len() {
dst[dst_idx] = ((sum + 64) >> 7).clamp(0, (1 << self.bit_depth) - 1) as u16;
}
}
}
}
fn interp_2d(
&self,
src: &[u8],
src_stride: usize,
x: i32,
y: i32,
frac_x: usize,
frac_y: usize,
dst: &mut [u16],
dst_stride: usize,
width: usize,
height: usize,
) {
let mut temp = vec![0u16; (width + INTERP_TAPS) * (height + INTERP_TAPS)];
let temp_stride = width + INTERP_TAPS;
self.interp_horizontal(
src,
src_stride,
x,
y - INTERP_TAPS as i32 / 2,
frac_x,
&mut temp,
temp_stride,
width,
height + INTERP_TAPS,
);
let filter_y = self.get_interp_filter(frac_y);
for row in 0..height {
for col in 0..width {
let mut sum = 0i32;
for tap in 0..INTERP_TAPS {
let temp_idx = (row + tap) * temp_stride + col;
if temp_idx < temp.len() {
sum += i32::from(temp[temp_idx]) * filter_y[tap];
}
}
let dst_idx = row * dst_stride + col;
if dst_idx < dst.len() {
dst[dst_idx] = ((sum + 64) >> 7).clamp(0, (1 << self.bit_depth) - 1) as u16;
}
}
}
}
fn get_interp_filter(&self, frac: usize) -> [i32; INTERP_TAPS] {
match frac {
0 => [0, 0, 0, 128, 0, 0, 0, 0],
1 => [0, 0, 16, 112, 16, 0, 0, 0],
2 => [0, 0, 32, 96, 32, 0, 0, 0],
3 => [0, 0, 48, 80, 48, 0, 0, 0],
4 => [0, 0, 64, 64, 64, 0, 0, 0],
5 => [0, 0, 48, 80, 48, 0, 0, 0],
6 => [0, 0, 32, 96, 32, 0, 0, 0],
7 => [0, 0, 16, 112, 16, 0, 0, 0],
_ => [0, 0, 0, 128, 0, 0, 0, 0],
}
}
fn blend_compound_predictions(
&self,
pred0: &[u16],
pred1: &[u16],
dst: &mut [u16],
stride: usize,
width: usize,
height: usize,
compound_type: u8,
) {
for row in 0..height {
for col in 0..width {
let src_idx = row * width + col;
let dst_idx = row * stride + col;
if src_idx < pred0.len() && src_idx < pred1.len() && dst_idx < dst.len() {
let blended = if compound_type == 0 {
(u32::from(pred0[src_idx]) + u32::from(pred1[src_idx]) + 1) >> 1
} else {
u32::from(pred0[src_idx])
};
dst[dst_idx] = blended as u16;
}
}
}
}
fn apply_obmc(
&mut self,
_mode_info: &BlockModeInfo,
_x: u32,
_y: u32,
_plane: u8,
dst: &mut [u16],
stride: usize,
width: usize,
height: usize,
) -> CodecResult<()> {
self.smooth_boundaries(dst, stride, width, height);
Ok(())
}
fn smooth_boundaries(&self, dst: &mut [u16], stride: usize, width: usize, height: usize) {
for row in 0..height {
for col in 0..width {
if row == 0 || col == 0 || row == height - 1 || col == width - 1 {
let idx = row * stride + col;
if idx < dst.len() {
let current = dst[idx];
dst[idx] = ((u32::from(current) * 3 + 2) >> 2) as u16;
}
}
}
}
}
fn apply_warped_motion(
&mut self,
_mode_info: &BlockModeInfo,
_x: u32,
_y: u32,
_plane: u8,
dst: &mut [u16],
stride: usize,
width: usize,
height: usize,
) -> CodecResult<()> {
for row in 0..height {
for col in 0..width {
let idx = row * stride + col;
if idx < dst.len() {
dst[idx] = dst[idx];
}
}
}
Ok(())
}
pub fn set_reference_frame(&mut self, idx: usize, frame: VideoFrame) {
if idx < self.reference_frames.len() {
self.reference_frames[idx] = Some(frame);
}
}
pub fn set_current_frame(&mut self, frame: VideoFrame) {
self.current_frame = Some(frame);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::frame::{FrameType, VideoFrame};
use oximedia_core::{PixelFormat, Rational, Timestamp};
fn create_test_frame(width: u32, height: u32) -> VideoFrame {
let mut frame = VideoFrame::new(PixelFormat::Yuv420p, width, height);
frame.allocate();
frame.frame_type = FrameType::Key;
frame.timestamp = Timestamp::new(0, Rational::new(1, 30));
frame
}
fn create_test_mode_info() -> BlockModeInfo {
BlockModeInfo::new()
}
#[test]
fn test_prediction_engine_creation() {
let engine = PredictionEngine::new(1920, 1080, 8);
assert_eq!(engine.bit_depth, 8);
}
#[test]
fn test_map_intra_mode() {
let engine = PredictionEngine::new(64, 64, 8);
assert_eq!(engine.map_intra_mode(Av1IntraMode::DcPred), IntraMode::Dc);
assert_eq!(
engine.map_intra_mode(Av1IntraMode::VPred),
IntraMode::Vertical
);
assert_eq!(
engine.map_intra_mode(Av1IntraMode::HPred),
IntraMode::Horizontal
);
}
#[test]
fn test_predict_intra() {
let mut engine = PredictionEngine::new(64, 64, 8);
let frame = create_test_frame(64, 64);
engine.set_current_frame(frame);
let mode_info = create_test_mode_info();
let mut dst = vec![0u16; 16 * 16];
let result = engine.predict_intra(&mode_info, 0, 0, 0, &mut dst, 16);
assert!(result.is_ok());
}
#[test]
fn test_get_interp_filter() {
let engine = PredictionEngine::new(64, 64, 8);
let filter = engine.get_interp_filter(0);
assert_eq!(filter[3], 128);
let filter_half = engine.get_interp_filter(4);
assert!(filter_half[3] > 0);
}
#[test]
fn test_set_reference_frame() {
let mut engine = PredictionEngine::new(64, 64, 8);
let frame = create_test_frame(64, 64);
engine.set_reference_frame(0, frame);
assert!(engine.reference_frames[0].is_some());
}
#[test]
fn test_constants() {
assert_eq!(SUBPEL_BITS, 3);
assert_eq!(SUBPEL_SCALE, 8);
assert_eq!(INTERP_TAPS, 8);
assert_eq!(WARP_PARAMS, 6);
}
}