#![forbid(unsafe_code)]
#![allow(dead_code)]
#![allow(clippy::cast_possible_truncation)]
#![allow(clippy::cast_sign_loss)]
#![allow(clippy::cast_possible_wrap)]
#![allow(clippy::too_many_arguments)]
#![allow(clippy::similar_names)]
#![allow(clippy::needless_range_loop)]
#![allow(clippy::doc_markdown)]
#![allow(clippy::manual_div_ceil)]
use super::inter::{InterMode, MAX_MV_REF_CANDIDATES};
use super::mv::{MotionVector, MvRefType, MV_MAX, MV_MIN};
use super::partition::BlockSize;
pub const MAX_REF_MV_STACK_SIZE: usize = 8;
pub const NUM_SPATIAL_NEIGHBORS: usize = 8;
pub const SPATIAL_WEIGHT: u8 = 2;
pub const TEMPORAL_WEIGHT: u8 = 1;
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub struct MvRefCandidate {
pub mv: MotionVector,
pub ref_frame: MvRefType,
pub weight: u8,
}
impl MvRefCandidate {
#[must_use]
pub const fn new(mv: MotionVector, ref_frame: MvRefType, weight: u8) -> Self {
Self {
mv,
ref_frame,
weight,
}
}
#[must_use]
pub const fn zero(ref_frame: MvRefType) -> Self {
Self {
mv: MotionVector::zero(),
ref_frame,
weight: 0,
}
}
#[must_use]
pub const fn same_mv(&self, other: &Self) -> bool {
self.mv.row == other.mv.row && self.mv.col == other.mv.col
}
#[must_use]
pub const fn same_ref(&self, other: &Self) -> bool {
self.ref_frame as u8 == other.ref_frame as u8
}
}
#[derive(Clone, Debug, Default)]
pub struct MvRefStack {
candidates: [MvRefCandidate; MAX_REF_MV_STACK_SIZE],
count: usize,
ref_frame: MvRefType,
}
impl MvRefStack {
#[must_use]
pub const fn new(ref_frame: MvRefType) -> Self {
Self {
candidates: [MvRefCandidate::zero(MvRefType::Intra); MAX_REF_MV_STACK_SIZE],
count: 0,
ref_frame,
}
}
pub fn clear(&mut self) {
self.count = 0;
}
#[must_use]
pub const fn len(&self) -> usize {
self.count
}
#[must_use]
pub const fn is_empty(&self) -> bool {
self.count == 0
}
#[must_use]
pub const fn is_full(&self) -> bool {
self.count >= MAX_REF_MV_STACK_SIZE
}
pub fn add(&mut self, candidate: MvRefCandidate) {
for i in 0..self.count {
if self.candidates[i].same_mv(&candidate) && self.candidates[i].same_ref(&candidate) {
self.candidates[i].weight =
self.candidates[i].weight.saturating_add(candidate.weight);
return;
}
}
if self.count < MAX_REF_MV_STACK_SIZE {
self.candidates[self.count] = candidate;
self.count += 1;
}
}
pub fn sort_by_weight(&mut self) {
for i in 1..self.count {
let key = self.candidates[i];
let mut j = i;
while j > 0 && self.candidates[j - 1].weight < key.weight {
self.candidates[j] = self.candidates[j - 1];
j -= 1;
}
self.candidates[j] = key;
}
}
#[must_use]
pub const fn get(&self, index: usize) -> Option<&MvRefCandidate> {
if index < self.count {
Some(&self.candidates[index])
} else {
None
}
}
#[must_use]
pub fn nearest_mv(&self) -> MotionVector {
if self.count > 0 {
self.candidates[0].mv
} else {
MotionVector::zero()
}
}
#[must_use]
pub fn near_mv(&self) -> MotionVector {
if self.count > 1 {
self.candidates[1].mv
} else if self.count > 0 {
self.candidates[0].mv
} else {
MotionVector::zero()
}
}
#[must_use]
pub fn best_ref_mvs(&self) -> [MotionVector; MAX_MV_REF_CANDIDATES] {
[self.nearest_mv(), self.near_mv()]
}
}
#[derive(Clone, Copy, Debug)]
pub struct NeighborPosition {
pub row_offset: i32,
pub col_offset: i32,
}
impl NeighborPosition {
#[must_use]
pub const fn new(row_offset: i32, col_offset: i32) -> Self {
Self {
row_offset,
col_offset,
}
}
}
pub static SPATIAL_NEIGHBORS: [NeighborPosition; NUM_SPATIAL_NEIGHBORS] = [
NeighborPosition {
row_offset: 0,
col_offset: -1,
}, NeighborPosition {
row_offset: -1,
col_offset: 0,
}, NeighborPosition {
row_offset: -1,
col_offset: -1,
}, NeighborPosition {
row_offset: 0,
col_offset: -2,
}, NeighborPosition {
row_offset: -2,
col_offset: 0,
}, NeighborPosition {
row_offset: -1,
col_offset: 1,
}, NeighborPosition {
row_offset: -2,
col_offset: -1,
}, NeighborPosition {
row_offset: -1,
col_offset: -2,
}, ];
#[derive(Clone, Copy, Debug, Default)]
pub struct BlockModeInfo {
pub ref_frames: [MvRefType; 2],
pub mvs: [MotionVector; 2],
pub mode: InterMode,
pub block_size: BlockSize,
pub is_inter: bool,
pub is_compound: bool,
}
impl BlockModeInfo {
#[must_use]
pub const fn intra() -> Self {
Self {
ref_frames: [MvRefType::Intra, MvRefType::Intra],
mvs: [MotionVector::zero(), MotionVector::zero()],
mode: InterMode::NearestMv,
block_size: BlockSize::Block4x4,
is_inter: false,
is_compound: false,
}
}
#[must_use]
pub const fn inter_single(ref_frame: MvRefType, mv: MotionVector, mode: InterMode) -> Self {
Self {
ref_frames: [ref_frame, MvRefType::Intra],
mvs: [mv, MotionVector::zero()],
mode,
block_size: BlockSize::Block4x4,
is_inter: true,
is_compound: false,
}
}
#[must_use]
pub const fn inter_compound(
ref0: MvRefType,
ref1: MvRefType,
mv0: MotionVector,
mv1: MotionVector,
) -> Self {
Self {
ref_frames: [ref0, ref1],
mvs: [mv0, mv1],
mode: InterMode::NearestMv,
block_size: BlockSize::Block4x4,
is_inter: true,
is_compound: true,
}
}
#[must_use]
pub const fn mv_for_ref(&self, ref_frame: MvRefType) -> Option<MotionVector> {
if self.ref_frames[0] as u8 == ref_frame as u8 {
Some(self.mvs[0])
} else if self.ref_frames[1] as u8 == ref_frame as u8 {
Some(self.mvs[1])
} else {
None
}
}
}
#[derive(Clone, Debug)]
pub struct ModeInfoGrid {
data: Vec<BlockModeInfo>,
mi_cols: usize,
mi_rows: usize,
}
impl Default for ModeInfoGrid {
fn default() -> Self {
Self::new()
}
}
impl ModeInfoGrid {
#[must_use]
pub const fn new() -> Self {
Self {
data: Vec::new(),
mi_cols: 0,
mi_rows: 0,
}
}
pub fn allocate(&mut self, width: u32, height: u32) {
self.mi_cols = ((width as usize) + 3) / 4;
self.mi_rows = ((height as usize) + 3) / 4;
self.data
.resize(self.mi_cols * self.mi_rows, BlockModeInfo::intra());
}
pub fn clear(&mut self) {
self.data.fill(BlockModeInfo::intra());
}
#[must_use]
pub fn get(&self, mi_row: usize, mi_col: usize) -> Option<&BlockModeInfo> {
if mi_row < self.mi_rows && mi_col < self.mi_cols {
self.data.get(mi_row * self.mi_cols + mi_col)
} else {
None
}
}
pub fn set(&mut self, mi_row: usize, mi_col: usize, info: BlockModeInfo) {
if mi_row < self.mi_rows && mi_col < self.mi_cols {
self.data[mi_row * self.mi_cols + mi_col] = info;
}
}
pub fn fill_block(
&mut self,
mi_row: usize,
mi_col: usize,
block_size: BlockSize,
info: BlockModeInfo,
) {
let mi_width = block_size.width_mi();
let mi_height = block_size.height_mi();
for row in mi_row..mi_row + mi_height {
for col in mi_col..mi_col + mi_width {
self.set(row, col, info);
}
}
}
#[must_use]
pub const fn mi_cols(&self) -> usize {
self.mi_cols
}
#[must_use]
pub const fn mi_rows(&self) -> usize {
self.mi_rows
}
}
#[derive(Clone, Debug)]
pub struct MvRefContext {
pub mode_info: ModeInfoGrid,
pub mi_row: usize,
pub mi_col: usize,
pub block_size: BlockSize,
pub ref_frame: MvRefType,
}
impl Default for MvRefContext {
fn default() -> Self {
Self::new()
}
}
impl MvRefContext {
#[must_use]
pub const fn new() -> Self {
Self {
mode_info: ModeInfoGrid::new(),
mi_row: 0,
mi_col: 0,
block_size: BlockSize::Block4x4,
ref_frame: MvRefType::Last,
}
}
pub fn set_position(&mut self, mi_row: usize, mi_col: usize, block_size: BlockSize) {
self.mi_row = mi_row;
self.mi_col = mi_col;
self.block_size = block_size;
}
#[must_use]
pub fn is_valid_neighbor(&self, neighbor: &NeighborPosition) -> bool {
let row = self.mi_row as i32 + neighbor.row_offset;
let col = self.mi_col as i32 + neighbor.col_offset;
row >= 0
&& col >= 0
&& (row as usize) < self.mode_info.mi_rows()
&& (col as usize) < self.mode_info.mi_cols()
}
#[must_use]
pub fn get_neighbor(&self, neighbor: &NeighborPosition) -> Option<&BlockModeInfo> {
let row = self.mi_row as i32 + neighbor.row_offset;
let col = self.mi_col as i32 + neighbor.col_offset;
if row >= 0 && col >= 0 {
self.mode_info.get(row as usize, col as usize)
} else {
None
}
}
}
pub fn find_mv_refs(ctx: &MvRefContext, ref_frame: MvRefType, stack: &mut MvRefStack) {
stack.clear();
for neighbor in &SPATIAL_NEIGHBORS {
if !ctx.is_valid_neighbor(neighbor) {
continue;
}
if let Some(info) = ctx.get_neighbor(neighbor) {
if !info.is_inter {
continue;
}
if let Some(mv) = info.mv_for_ref(ref_frame) {
let candidate = MvRefCandidate::new(mv, ref_frame, SPATIAL_WEIGHT);
stack.add(candidate);
if stack.len() >= MAX_MV_REF_CANDIDATES {
break;
}
}
}
}
stack.sort_by_weight();
}
#[must_use]
pub fn find_best_ref_mvs(ctx: &MvRefContext, ref_frame: MvRefType) -> [MotionVector; 2] {
let mut stack = MvRefStack::new(ref_frame);
find_mv_refs(ctx, ref_frame, &mut stack);
stack.best_ref_mvs()
}
#[must_use]
#[allow(clippy::too_many_arguments)]
pub fn clamp_mv(
mv: MotionVector,
mi_row: usize,
mi_col: usize,
block_size: BlockSize,
frame_width: usize,
frame_height: usize,
) -> MotionVector {
let block_x = (mi_col * 4) << 3;
let block_y = (mi_row * 4) << 3;
let max_x = ((frame_width - block_size.width()) << 3) as i16;
let max_y = ((frame_height - block_size.height()) << 3) as i16;
let min_x = -(block_x as i16) - 128;
let min_y = -(block_y as i16) - 128;
MotionVector::new(
mv.row.clamp(min_y.max(MV_MIN), max_y.min(MV_MAX)),
mv.col.clamp(min_x.max(MV_MIN), max_x.min(MV_MAX)),
)
}
#[must_use]
pub fn round_mv(mv: MotionVector, allow_hp: bool) -> MotionVector {
if allow_hp {
mv
} else {
let round_row = if mv.row < 0 {
(mv.row - 1) & !1
} else {
(mv.row + 1) & !1
};
let round_col = if mv.col < 0 {
(mv.col - 1) & !1
} else {
(mv.col + 1) & !1
};
MotionVector::new(round_row, round_col)
}
}
#[derive(Clone, Copy, Debug, Default)]
pub struct MvPredContext {
pub nearest_count: u8,
pub near_count: u8,
pub zero_count: u8,
pub new_count: u8,
pub inter_count: u8,
}
impl MvPredContext {
#[must_use]
pub const fn new() -> Self {
Self {
nearest_count: 0,
near_count: 0,
zero_count: 0,
new_count: 0,
inter_count: 0,
}
}
pub fn add_neighbor(&mut self, info: &BlockModeInfo) {
if info.is_inter {
self.inter_count = self.inter_count.saturating_add(1);
match info.mode {
InterMode::NearestMv => self.nearest_count = self.nearest_count.saturating_add(1),
InterMode::NearMv => self.near_count = self.near_count.saturating_add(1),
InterMode::ZeroMv => self.zero_count = self.zero_count.saturating_add(1),
InterMode::NewMv => self.new_count = self.new_count.saturating_add(1),
}
}
}
#[must_use]
pub const fn mode_context(&self) -> usize {
match self.inter_count {
0 => 0,
1 => {
if self.new_count > 0 {
3
} else {
1
}
}
_ => {
if self.new_count > 1 {
5
} else if self.new_count > 0 {
4
} else {
2
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mv_ref_candidate() {
let mv = MotionVector::new(10, 20);
let candidate = MvRefCandidate::new(mv, MvRefType::Last, 5);
assert_eq!(candidate.mv.row, 10);
assert_eq!(candidate.mv.col, 20);
assert_eq!(candidate.weight, 5);
}
#[test]
fn test_mv_ref_candidate_same_mv() {
let c1 = MvRefCandidate::new(MotionVector::new(10, 20), MvRefType::Last, 1);
let c2 = MvRefCandidate::new(MotionVector::new(10, 20), MvRefType::Last, 2);
let c3 = MvRefCandidate::new(MotionVector::new(10, 30), MvRefType::Last, 1);
assert!(c1.same_mv(&c2));
assert!(!c1.same_mv(&c3));
}
#[test]
fn test_mv_ref_stack_new() {
let stack = MvRefStack::new(MvRefType::Last);
assert!(stack.is_empty());
assert!(!stack.is_full());
assert_eq!(stack.len(), 0);
}
#[test]
fn test_mv_ref_stack_add() {
let mut stack = MvRefStack::new(MvRefType::Last);
let c1 = MvRefCandidate::new(MotionVector::new(10, 20), MvRefType::Last, 1);
stack.add(c1);
assert_eq!(stack.len(), 1);
assert_eq!(stack.nearest_mv(), MotionVector::new(10, 20));
}
#[test]
fn test_mv_ref_stack_add_duplicate() {
let mut stack = MvRefStack::new(MvRefType::Last);
let c1 = MvRefCandidate::new(MotionVector::new(10, 20), MvRefType::Last, 1);
let c2 = MvRefCandidate::new(MotionVector::new(10, 20), MvRefType::Last, 2);
stack.add(c1);
stack.add(c2);
assert_eq!(stack.len(), 1);
assert_eq!(stack.get(0).expect("get should return value").weight, 3);
}
#[test]
fn test_mv_ref_stack_sort() {
let mut stack = MvRefStack::new(MvRefType::Last);
stack.add(MvRefCandidate::new(
MotionVector::new(10, 20),
MvRefType::Last,
1,
));
stack.add(MvRefCandidate::new(
MotionVector::new(30, 40),
MvRefType::Last,
3,
));
stack.add(MvRefCandidate::new(
MotionVector::new(50, 60),
MvRefType::Last,
2,
));
stack.sort_by_weight();
assert_eq!(stack.get(0).expect("get should return value").weight, 3);
assert_eq!(stack.get(1).expect("get should return value").weight, 2);
assert_eq!(stack.get(2).expect("get should return value").weight, 1);
}
#[test]
fn test_mv_ref_stack_best_ref_mvs() {
let mut stack = MvRefStack::new(MvRefType::Last);
stack.add(MvRefCandidate::new(
MotionVector::new(10, 20),
MvRefType::Last,
2,
));
stack.add(MvRefCandidate::new(
MotionVector::new(30, 40),
MvRefType::Last,
1,
));
let [nearest, near] = stack.best_ref_mvs();
assert_eq!(nearest, MotionVector::new(10, 20));
assert_eq!(near, MotionVector::new(30, 40));
}
#[test]
fn test_block_mode_info_intra() {
let info = BlockModeInfo::intra();
assert!(!info.is_inter);
assert!(!info.is_compound);
}
#[test]
fn test_block_mode_info_inter_single() {
let info = BlockModeInfo::inter_single(
MvRefType::Last,
MotionVector::new(10, 20),
InterMode::NearestMv,
);
assert!(info.is_inter);
assert!(!info.is_compound);
assert_eq!(
info.mv_for_ref(MvRefType::Last),
Some(MotionVector::new(10, 20))
);
assert_eq!(info.mv_for_ref(MvRefType::Golden), None);
}
#[test]
fn test_block_mode_info_inter_compound() {
let info = BlockModeInfo::inter_compound(
MvRefType::Last,
MvRefType::Golden,
MotionVector::new(10, 20),
MotionVector::new(30, 40),
);
assert!(info.is_inter);
assert!(info.is_compound);
assert_eq!(
info.mv_for_ref(MvRefType::Last),
Some(MotionVector::new(10, 20))
);
assert_eq!(
info.mv_for_ref(MvRefType::Golden),
Some(MotionVector::new(30, 40))
);
}
#[test]
fn test_mode_info_grid() {
let mut grid = ModeInfoGrid::new();
grid.allocate(64, 64);
assert_eq!(grid.mi_cols(), 16);
assert_eq!(grid.mi_rows(), 16);
let info = BlockModeInfo::inter_single(
MvRefType::Last,
MotionVector::new(10, 20),
InterMode::NearestMv,
);
grid.set(0, 0, info);
let retrieved = grid.get(0, 0).expect("get should return value");
assert!(retrieved.is_inter);
}
#[test]
fn test_mode_info_grid_fill_block() {
let mut grid = ModeInfoGrid::new();
grid.allocate(64, 64);
let info = BlockModeInfo::inter_single(
MvRefType::Last,
MotionVector::new(10, 20),
InterMode::NearestMv,
);
grid.fill_block(0, 0, BlockSize::Block8x8, info);
assert!(grid.get(0, 0).expect("get should return value").is_inter);
assert!(grid.get(0, 1).expect("get should return value").is_inter);
assert!(grid.get(1, 0).expect("get should return value").is_inter);
assert!(grid.get(1, 1).expect("get should return value").is_inter);
}
#[test]
fn test_clamp_mv() {
let mv = MotionVector::new(1000, 2000);
let clamped = clamp_mv(mv, 0, 0, BlockSize::Block8x8, 64, 64);
assert!(clamped.row <= (56 << 3) as i16); assert!(clamped.col <= (56 << 3) as i16);
}
#[test]
fn test_round_mv() {
let mv = MotionVector::new(5, 7);
let hp = round_mv(mv, true);
assert_eq!(hp.row, 5);
assert_eq!(hp.col, 7);
let qp = round_mv(mv, false);
assert_eq!(qp.row & 1, 0);
assert_eq!(qp.col & 1, 0);
}
#[test]
fn test_mv_pred_context() {
let mut ctx = MvPredContext::new();
let nearest_info = BlockModeInfo::inter_single(
MvRefType::Last,
MotionVector::zero(),
InterMode::NearestMv,
);
ctx.add_neighbor(&nearest_info);
assert_eq!(ctx.inter_count, 1);
assert_eq!(ctx.nearest_count, 1);
assert_eq!(ctx.mode_context(), 1);
}
#[test]
fn test_find_mv_refs() {
let mut ctx = MvRefContext::new();
ctx.mode_info.allocate(64, 64);
let neighbor_info = BlockModeInfo::inter_single(
MvRefType::Last,
MotionVector::new(10, 20),
InterMode::NearestMv,
);
ctx.mode_info.set(0, 0, neighbor_info);
ctx.set_position(0, 1, BlockSize::Block8x8);
let mut stack = MvRefStack::new(MvRefType::Last);
find_mv_refs(&ctx, MvRefType::Last, &mut stack);
assert!(!stack.is_empty());
assert_eq!(stack.nearest_mv(), MotionVector::new(10, 20));
}
#[test]
fn test_find_best_ref_mvs() {
let mut ctx = MvRefContext::new();
ctx.mode_info.allocate(64, 64);
ctx.set_position(8, 8, BlockSize::Block8x8);
let [nearest, near] = find_best_ref_mvs(&ctx, MvRefType::Last);
assert_eq!(nearest, MotionVector::zero());
assert_eq!(near, MotionVector::zero());
}
}