#![forbid(unsafe_code)]
#![allow(dead_code)]
#![allow(clippy::similar_names)]
#![allow(clippy::cast_precision_loss)]
#![allow(clippy::cast_sign_loss)]
#![allow(clippy::cast_possible_truncation)]
#![allow(clippy::trivially_copy_pass_by_ref)]
use std::ops::{Add, Neg, Sub};
pub const MV_MAX: i32 = 16383 * 8;
pub const MV_MIN: i32 = -16384 * 8;
pub const DEFAULT_SEARCH_RANGE: i32 = 64;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default, Hash)]
#[repr(u8)]
pub enum MvPrecision {
FullPel = 0,
HalfPel = 1,
#[default]
QuarterPel = 2,
EighthPel = 3,
}
impl MvPrecision {
#[must_use]
pub const fn fractional_bits(self) -> u8 {
match self {
Self::FullPel => 0,
Self::HalfPel => 1,
Self::QuarterPel => 2,
Self::EighthPel => 3,
}
}
#[must_use]
pub const fn scale(self) -> i32 {
1 << self.fractional_bits()
}
#[must_use]
pub const fn frac_mask(self) -> i32 {
self.scale() - 1
}
#[must_use]
pub const fn convert(self, value: i32, target: Self) -> i32 {
let src_bits = self.fractional_bits() as i32;
let dst_bits = target.fractional_bits() as i32;
let shift = dst_bits - src_bits;
if shift > 0 {
value << shift
} else {
value >> (-shift)
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default, Hash)]
pub struct MotionVector {
pub dx: i32,
pub dy: i32,
}
impl MotionVector {
#[must_use]
pub const fn zero() -> Self {
Self { dx: 0, dy: 0 }
}
#[must_use]
pub const fn new(dx: i32, dy: i32) -> Self {
Self { dx, dy }
}
#[must_use]
pub const fn from_full_pel(dx: i32, dy: i32) -> Self {
Self {
dx: dx << 3,
dy: dy << 3,
}
}
#[must_use]
pub const fn from_precision(dx: i32, dy: i32, precision: MvPrecision) -> Self {
let shift = 3 - precision.fractional_bits() as i32;
Self {
dx: dx << shift,
dy: dy << shift,
}
}
#[must_use]
pub const fn is_zero(&self) -> bool {
self.dx == 0 && self.dy == 0
}
#[must_use]
pub const fn full_pel_x(&self) -> i32 {
self.dx >> 3
}
#[must_use]
pub const fn full_pel_y(&self) -> i32 {
self.dy >> 3
}
#[must_use]
pub const fn frac_x(&self) -> i32 {
self.dx & 7
}
#[must_use]
pub const fn frac_y(&self) -> i32 {
self.dy & 7
}
#[must_use]
pub const fn half_pel_x(&self) -> i32 {
(self.dx >> 2) & 1
}
#[must_use]
pub const fn half_pel_y(&self) -> i32 {
(self.dy >> 2) & 1
}
#[must_use]
pub const fn quarter_pel_x(&self) -> i32 {
(self.dx >> 1) & 3
}
#[must_use]
pub const fn quarter_pel_y(&self) -> i32 {
(self.dy >> 1) & 3
}
#[must_use]
pub const fn to_precision(&self, precision: MvPrecision) -> Self {
let shift = 3 - precision.fractional_bits() as i32;
let mask = !((1 << shift) - 1);
Self {
dx: self.dx & mask,
dy: self.dy & mask,
}
}
#[must_use]
pub const fn round_to_precision(&self, precision: MvPrecision) -> Self {
let shift = 3 - precision.fractional_bits() as i32;
let round = 1 << (shift - 1);
if shift > 0 {
Self {
dx: ((self.dx + round) >> shift) << shift,
dy: ((self.dy + round) >> shift) << shift,
}
} else {
*self
}
}
#[must_use]
pub fn clamp(&self) -> Self {
Self {
dx: self.dx.clamp(MV_MIN, MV_MAX),
dy: self.dy.clamp(MV_MIN, MV_MAX),
}
}
#[must_use]
pub fn clamp_to_range(&self, range: &SearchRange) -> Self {
Self {
dx: self.dx.clamp(-range.horizontal << 3, range.horizontal << 3),
dy: self.dy.clamp(-range.vertical << 3, range.vertical << 3),
}
}
#[must_use]
pub const fn magnitude_squared(&self) -> i64 {
(self.dx as i64) * (self.dx as i64) + (self.dy as i64) * (self.dy as i64)
}
#[must_use]
pub const fn l1_norm(&self) -> i32 {
self.dx.abs() + self.dy.abs()
}
#[must_use]
pub fn linf_norm(&self) -> i32 {
self.dx.abs().max(self.dy.abs())
}
#[must_use]
#[allow(clippy::cast_possible_truncation)]
pub fn scale(&self, num: i32, den: i32) -> Self {
if den == 0 {
return *self;
}
Self {
dx: ((i64::from(self.dx) * i64::from(num)) / i64::from(den)) as i32,
dy: ((i64::from(self.dy) * i64::from(num)) / i64::from(den)) as i32,
}
}
}
impl Add for MotionVector {
type Output = Self;
fn add(self, other: Self) -> Self {
Self {
dx: self.dx.saturating_add(other.dx),
dy: self.dy.saturating_add(other.dy),
}
}
}
impl Sub for MotionVector {
type Output = Self;
fn sub(self, other: Self) -> Self {
Self {
dx: self.dx.saturating_sub(other.dx),
dy: self.dy.saturating_sub(other.dy),
}
}
}
impl Neg for MotionVector {
type Output = Self;
fn neg(self) -> Self {
Self {
dx: self.dx.saturating_neg(),
dy: self.dy.saturating_neg(),
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct SearchRange {
pub horizontal: i32,
pub vertical: i32,
}
impl Default for SearchRange {
fn default() -> Self {
Self::new(DEFAULT_SEARCH_RANGE, DEFAULT_SEARCH_RANGE)
}
}
impl SearchRange {
#[must_use]
pub const fn new(horizontal: i32, vertical: i32) -> Self {
Self {
horizontal,
vertical,
}
}
#[must_use]
pub const fn symmetric(range: i32) -> Self {
Self::new(range, range)
}
#[must_use]
pub const fn num_positions(&self) -> u64 {
let w = (2 * self.horizontal + 1) as u64;
let h = (2 * self.vertical + 1) as u64;
w * h
}
#[must_use]
pub const fn contains(&self, dx: i32, dy: i32) -> bool {
dx >= -self.horizontal
&& dx <= self.horizontal
&& dy >= -self.vertical
&& dy <= self.vertical
}
#[must_use]
pub const fn scale(&self, factor: i32) -> Self {
Self {
horizontal: self.horizontal * factor,
vertical: self.vertical * factor,
}
}
#[must_use]
pub const fn reduce(&self, factor: i32) -> Self {
if factor == 0 {
*self
} else {
Self {
horizontal: self.horizontal / factor,
vertical: self.vertical / factor,
}
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct BlockMatch {
pub mv: MotionVector,
pub sad: u32,
pub cost: u32,
}
impl Default for BlockMatch {
fn default() -> Self {
Self::worst()
}
}
impl BlockMatch {
#[must_use]
pub const fn new(mv: MotionVector, sad: u32, cost: u32) -> Self {
Self { mv, sad, cost }
}
#[must_use]
pub const fn zero_mv(sad: u32) -> Self {
Self {
mv: MotionVector::zero(),
sad,
cost: sad,
}
}
#[must_use]
pub const fn worst() -> Self {
Self {
mv: MotionVector::zero(),
sad: u32::MAX,
cost: u32::MAX,
}
}
#[must_use]
pub const fn is_better_than(&self, other: &Self) -> bool {
self.cost < other.cost
}
pub fn update_if_better(&mut self, other: &Self) {
if other.is_better_than(self) {
*self = *other;
}
}
}
#[derive(Clone, Copy, Debug)]
pub struct MvCost {
pub lambda: f32,
pub mv_weight: f32,
pub ref_mv: MotionVector,
}
impl Default for MvCost {
fn default() -> Self {
Self::new(1.0)
}
}
impl MvCost {
#[must_use]
pub const fn new(lambda: f32) -> Self {
Self {
lambda,
mv_weight: 1.0,
ref_mv: MotionVector::zero(),
}
}
#[must_use]
pub const fn with_ref_mv(lambda: f32, ref_mv: MotionVector) -> Self {
Self {
lambda,
mv_weight: 1.0,
ref_mv,
}
}
#[must_use]
pub fn estimate_bits(&self, mv: &MotionVector) -> f32 {
let diff = *mv - self.ref_mv;
let dx_bits = Self::component_bits(diff.dx);
let dy_bits = Self::component_bits(diff.dy);
(dx_bits + dy_bits) * self.mv_weight
}
#[must_use]
fn component_bits(value: i32) -> f32 {
if value == 0 {
return 1.0;
}
let abs_val = value.unsigned_abs();
let log2_approx = 32 - abs_val.leading_zeros();
(2 * log2_approx + 2) as f32
}
#[must_use]
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
pub fn rd_cost(&self, mv: &MotionVector, sad: u32) -> u32 {
let bits = self.estimate_bits(mv);
let rate_cost = (bits * self.lambda) as u32;
sad.saturating_add(rate_cost)
}
pub fn set_ref_mv(&mut self, ref_mv: MotionVector) {
self.ref_mv = ref_mv;
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default, Hash)]
#[repr(u8)]
pub enum BlockSize {
Block4x4 = 0,
Block4x8 = 1,
Block8x4 = 2,
#[default]
Block8x8 = 3,
Block8x16 = 4,
Block16x8 = 5,
Block16x16 = 6,
Block16x32 = 7,
Block32x16 = 8,
Block32x32 = 9,
Block32x64 = 10,
Block64x32 = 11,
Block64x64 = 12,
Block64x128 = 13,
Block128x64 = 14,
Block128x128 = 15,
}
impl BlockSize {
#[must_use]
pub const fn width(&self) -> usize {
match self {
Self::Block4x4 | Self::Block4x8 => 4,
Self::Block8x4 | Self::Block8x8 | Self::Block8x16 => 8,
Self::Block16x8 | Self::Block16x16 | Self::Block16x32 => 16,
Self::Block32x16 | Self::Block32x32 | Self::Block32x64 => 32,
Self::Block64x32 | Self::Block64x64 | Self::Block64x128 => 64,
Self::Block128x64 | Self::Block128x128 => 128,
}
}
#[must_use]
pub const fn height(&self) -> usize {
match self {
Self::Block4x4 | Self::Block8x4 => 4,
Self::Block4x8 | Self::Block8x8 | Self::Block16x8 => 8,
Self::Block8x16 | Self::Block16x16 | Self::Block32x16 => 16,
Self::Block16x32 | Self::Block32x32 | Self::Block64x32 => 32,
Self::Block32x64 | Self::Block64x64 | Self::Block128x64 => 64,
Self::Block64x128 | Self::Block128x128 => 128,
}
}
#[must_use]
pub const fn num_pixels(&self) -> usize {
self.width() * self.height()
}
#[must_use]
pub const fn is_square(&self) -> bool {
matches!(
self,
Self::Block4x4
| Self::Block8x8
| Self::Block16x16
| Self::Block32x32
| Self::Block64x64
| Self::Block128x128
)
}
#[must_use]
pub const fn width_log2(&self) -> u8 {
match self.width() {
4 => 2,
8 => 3,
16 => 4,
32 => 5,
64 => 6,
128 => 7,
_ => 0,
}
}
#[must_use]
pub const fn height_log2(&self) -> u8 {
match self.height() {
4 => 2,
8 => 3,
16 => 4,
32 => 5,
64 => 6,
128 => 7,
_ => 0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mv_precision() {
assert_eq!(MvPrecision::FullPel.fractional_bits(), 0);
assert_eq!(MvPrecision::HalfPel.fractional_bits(), 1);
assert_eq!(MvPrecision::QuarterPel.fractional_bits(), 2);
assert_eq!(MvPrecision::EighthPel.fractional_bits(), 3);
assert_eq!(MvPrecision::FullPel.scale(), 1);
assert_eq!(MvPrecision::QuarterPel.scale(), 4);
assert_eq!(MvPrecision::EighthPel.scale(), 8);
}
#[test]
fn test_mv_precision_convert() {
assert_eq!(MvPrecision::FullPel.convert(2, MvPrecision::QuarterPel), 8);
assert_eq!(MvPrecision::QuarterPel.convert(8, MvPrecision::FullPel), 2);
}
#[test]
fn test_motion_vector_creation() {
let mv = MotionVector::new(16, -24);
assert_eq!(mv.dx, 16);
assert_eq!(mv.dy, -24);
let mv_fp = MotionVector::from_full_pel(2, -3);
assert_eq!(mv_fp.dx, 16);
assert_eq!(mv_fp.dy, -24);
}
#[test]
fn test_motion_vector_components() {
let mv = MotionVector::new(27, -19);
assert_eq!(mv.full_pel_x(), 3);
assert_eq!(mv.full_pel_y(), -3); assert_eq!(mv.frac_x(), 3);
assert_eq!(mv.frac_y(), -19 & 7);
}
#[test]
fn test_motion_vector_zero() {
let mv = MotionVector::zero();
assert!(mv.is_zero());
assert_eq!(mv.magnitude_squared(), 0);
}
#[test]
fn test_motion_vector_arithmetic() {
let mv1 = MotionVector::new(10, 20);
let mv2 = MotionVector::new(5, -10);
let sum = mv1 + mv2;
assert_eq!(sum.dx, 15);
assert_eq!(sum.dy, 10);
let diff = mv1 - mv2;
assert_eq!(diff.dx, 5);
assert_eq!(diff.dy, 30);
let neg = -mv1;
assert_eq!(neg.dx, -10);
assert_eq!(neg.dy, -20);
}
#[test]
fn test_motion_vector_magnitude() {
let mv = MotionVector::new(3, 4);
assert_eq!(mv.magnitude_squared(), 25);
assert_eq!(mv.l1_norm(), 7);
assert_eq!(mv.linf_norm(), 4);
}
#[test]
fn test_motion_vector_precision_conversion() {
let mv = MotionVector::new(27, 19);
let qpel = mv.to_precision(MvPrecision::QuarterPel);
assert_eq!(qpel.dx & 1, 0); assert_eq!(qpel.dy & 1, 0);
let fpel = mv.to_precision(MvPrecision::FullPel);
assert_eq!(fpel.dx & 7, 0); assert_eq!(fpel.dy & 7, 0);
}
#[test]
fn test_search_range() {
let range = SearchRange::symmetric(32);
assert_eq!(range.horizontal, 32);
assert_eq!(range.vertical, 32);
assert!(range.contains(0, 0));
assert!(range.contains(32, 32));
assert!(range.contains(-32, -32));
assert!(!range.contains(33, 0));
}
#[test]
fn test_search_range_positions() {
let range = SearchRange::symmetric(2);
assert_eq!(range.num_positions(), 25);
}
#[test]
fn test_block_match() {
let best = BlockMatch::new(MotionVector::new(8, 16), 100, 120);
let worst = BlockMatch::worst();
assert!(best.is_better_than(&worst));
assert!(!worst.is_better_than(&best));
}
#[test]
fn test_block_match_update() {
let mut current = BlockMatch::worst();
let better = BlockMatch::new(MotionVector::new(8, 16), 100, 120);
current.update_if_better(&better);
assert_eq!(current.sad, 100);
}
#[test]
fn test_mv_cost() {
let cost = MvCost::new(1.0);
let mv = MotionVector::new(16, 16);
let bits = cost.estimate_bits(&mv);
assert!(bits > 0.0);
let rd = cost.rd_cost(&mv, 100);
assert!(rd >= 100);
}
#[test]
fn test_mv_cost_with_ref() {
let ref_mv = MotionVector::new(16, 16);
let cost = MvCost::with_ref_mv(1.0, ref_mv);
let same_bits = cost.estimate_bits(&ref_mv);
let diff_mv = MotionVector::new(32, 32);
let diff_bits = cost.estimate_bits(&diff_mv);
assert!(same_bits < diff_bits);
}
#[test]
fn test_block_size() {
assert_eq!(BlockSize::Block8x8.width(), 8);
assert_eq!(BlockSize::Block8x8.height(), 8);
assert_eq!(BlockSize::Block8x8.num_pixels(), 64);
assert!(BlockSize::Block8x8.is_square());
assert_eq!(BlockSize::Block16x8.width(), 16);
assert_eq!(BlockSize::Block16x8.height(), 8);
assert!(!BlockSize::Block16x8.is_square());
}
#[test]
fn test_block_size_log2() {
assert_eq!(BlockSize::Block4x4.width_log2(), 2);
assert_eq!(BlockSize::Block8x8.width_log2(), 3);
assert_eq!(BlockSize::Block16x16.width_log2(), 4);
assert_eq!(BlockSize::Block64x64.width_log2(), 6);
}
}