use crate::theora::intra_pred::{select_best_mode, IntraPredContext, IntraPredMode};
use crate::theora::motion::{motion_estimation_diamond, MotionVector};
use crate::theora::tables::CodingMode;
use crate::theora::transform::{copy_block, Block8x8};
#[derive(Debug, Clone)]
pub struct BlockDecision {
pub mode: CodingMode,
pub mv: Option<MotionVector>,
pub intra_mode: Option<IntraPredMode>,
pub cost: f64,
}
pub struct BlockDecisionEngine {
lambda: f32,
me_range: i16,
subpel_me: bool,
rdo_enabled: bool,
}
impl BlockDecisionEngine {
#[must_use]
pub const fn new(lambda: f32, me_range: i16) -> Self {
Self {
lambda,
me_range,
subpel_me: true,
rdo_enabled: true,
}
}
#[must_use]
#[allow(clippy::too_many_arguments)]
pub fn decide_block_mode(
&self,
current: &[u8; 64],
reference: Option<&[u8]>,
ref_stride: usize,
block_x: usize,
block_y: usize,
intra_ctx: &IntraPredContext,
is_keyframe: bool,
) -> BlockDecision {
let Some(reference) = reference else {
return self.decide_intra_mode(current, intra_ctx);
};
if is_keyframe {
return self.decide_intra_mode(current, intra_ctx);
}
let intra_decision = self.decide_intra_mode(current, intra_ctx);
let inter_decision =
self.decide_inter_mode(current, reference, ref_stride, block_x, block_y);
if intra_decision.cost < inter_decision.cost {
intra_decision
} else {
inter_decision
}
}
fn decide_intra_mode(&self, current: &[u8; 64], ctx: &IntraPredContext) -> BlockDecision {
let (best_mode, sad) = select_best_mode(current, ctx);
BlockDecision {
mode: CodingMode::Intra,
mv: None,
intra_mode: Some(best_mode),
cost: f64::from(sad),
}
}
fn decide_inter_mode(
&self,
current: &[u8; 64],
reference: &[u8],
ref_stride: usize,
block_x: usize,
block_y: usize,
) -> BlockDecision {
let (mv, sad) = motion_estimation_diamond(
current,
reference,
ref_stride,
block_x,
block_y,
self.me_range,
);
let skip_cost = if mv.is_zero() {
f64::from(sad)
} else {
f64::MAX
};
let mv_bits = self.estimate_mv_bits(&mv);
let inter_cost = f64::from(sad) + f64::from(self.lambda * mv_bits);
if skip_cost < inter_cost && skip_cost < 100.0 {
BlockDecision {
mode: CodingMode::InterNoMv,
mv: Some(MotionVector::new(0, 0)),
intra_mode: None,
cost: skip_cost,
}
} else {
BlockDecision {
mode: CodingMode::InterMv,
mv: Some(mv),
intra_mode: None,
cost: inter_cost,
}
}
}
fn estimate_mv_bits(&self, mv: &MotionVector) -> f32 {
let x_bits = self.estimate_component_bits(mv.x);
let y_bits = self.estimate_component_bits(mv.y);
x_bits + y_bits
}
fn estimate_component_bits(&self, value: i16) -> f32 {
if value == 0 {
1.0
} else {
let abs_val = value.abs();
let magnitude_bits = 16 - abs_val.leading_zeros();
(magnitude_bits * 2 + 1) as f32 }
}
}
pub struct FastModeDecision {
intra_bias: f32,
skip_threshold: u32,
}
impl FastModeDecision {
#[must_use]
pub const fn new(intra_bias: f32, skip_threshold: u32) -> Self {
Self {
intra_bias,
skip_threshold,
}
}
#[must_use]
pub fn decide_fast(
&self,
current: &[u8; 64],
reference: Option<&[u8]>,
ref_stride: usize,
block_x: usize,
block_y: usize,
) -> CodingMode {
if let Some(reference) = reference {
let mut skip_block = [0u8; 64];
copy_block(reference, ref_stride, block_x, block_y, &mut skip_block);
let skip_sad = calculate_sad(current, &skip_block);
if skip_sad < self.skip_threshold {
return CodingMode::InterNoMv;
}
let (mv, inter_sad) =
motion_estimation_diamond(current, reference, ref_stride, block_x, block_y, 4);
let biased_inter_sad = inter_sad as f32;
let intra_sad_estimate = calculate_intra_sad_estimate(current);
let biased_intra_sad = intra_sad_estimate as f32 * self.intra_bias;
if biased_inter_sad < biased_intra_sad {
if mv.is_zero() {
CodingMode::InterNoMv
} else {
CodingMode::InterMv
}
} else {
CodingMode::Intra
}
} else {
CodingMode::Intra
}
}
}
fn calculate_sad(block1: &[u8; 64], block2: &[u8; 64]) -> u32 {
let mut sad = 0u32;
for i in 0..64 {
sad += (i32::from(block1[i]) - i32::from(block2[i])).unsigned_abs();
}
sad
}
fn calculate_intra_sad_estimate(block: &[u8; 64]) -> u32 {
let mut sum = 0u32;
for &pixel in block.iter() {
sum += u32::from(pixel);
}
let dc = sum / 64;
let mut sad = 0u32;
for &pixel in block.iter() {
sad += (i32::from(pixel) - dc as i32).unsigned_abs();
}
sad
}
pub struct SubblockDecision {
variance_threshold: f32,
adaptive: bool,
}
impl SubblockDecision {
#[must_use]
pub const fn new(variance_threshold: f32, adaptive: bool) -> Self {
Self {
variance_threshold,
adaptive,
}
}
#[must_use]
pub fn should_split(&self, block: &[u8; 256]) -> bool {
if !self.adaptive {
return false; }
let variances = [
self.calculate_subblock_variance(block, 0, 0),
self.calculate_subblock_variance(block, 8, 0),
self.calculate_subblock_variance(block, 0, 8),
self.calculate_subblock_variance(block, 8, 8),
];
let max_var = variances.iter().copied().fold(f32::MIN, f32::max);
let min_var = variances.iter().copied().fold(f32::MAX, f32::min);
(max_var - min_var) > self.variance_threshold
}
fn calculate_subblock_variance(&self, block: &[u8; 256], x: usize, y: usize) -> f32 {
let mut sum = 0u32;
let mut sum_sq = 0u32;
for dy in 0..8 {
for dx in 0..8 {
let pixel = u32::from(block[(y + dy) * 16 + x + dx]);
sum += pixel;
sum_sq += pixel * pixel;
}
}
let mean = sum / 64;
let variance = (sum_sq / 64) - (mean * mean);
variance as f32
}
}
pub struct MergeDecision {
sad_threshold: u32,
}
impl MergeDecision {
#[must_use]
pub const fn new(sad_threshold: u32) -> Self {
Self { sad_threshold }
}
#[must_use]
pub fn should_merge(&self, block1: &[u8; 64], block2: &[u8; 64]) -> bool {
let sad = calculate_sad(block1, block2);
sad < self.sad_threshold
}
}
pub struct EarlyTermination {
threshold: u32,
enabled: bool,
}
impl EarlyTermination {
#[must_use]
pub const fn new(threshold: u32, enabled: bool) -> Self {
Self { threshold, enabled }
}
#[must_use]
pub fn should_terminate(&self, current_sad: u32) -> bool {
self.enabled && current_sad < self.threshold
}
}
pub struct BlockComplexity {
activity_threshold: f32,
}
impl BlockComplexity {
#[must_use]
pub const fn new(activity_threshold: f32) -> Self {
Self { activity_threshold }
}
#[must_use]
pub fn spatial_activity(&self, block: &[u8; 64]) -> f32 {
let mut activity = 0u32;
for y in 0..8 {
for x in 0..7 {
let diff = (i16::from(block[y * 8 + x + 1]) - i16::from(block[y * 8 + x])).abs();
activity += diff as u32;
}
}
for y in 0..7 {
for x in 0..8 {
let diff = (i16::from(block[(y + 1) * 8 + x]) - i16::from(block[y * 8 + x])).abs();
activity += diff as u32;
}
}
activity as f32 / 112.0 }
#[must_use]
pub fn is_homogeneous(&self, block: &[u8; 64]) -> bool {
self.spatial_activity(block) < self.activity_threshold
}
#[must_use]
pub fn temporal_activity(&self, current: &[u8; 64], reference: &[u8; 64]) -> f32 {
let sad = calculate_sad(current, reference);
sad as f32 / 64.0
}
}
#[derive(Debug, Clone, Default)]
pub struct ModeStats {
pub intra_count: u32,
pub inter_count: u32,
pub skip_count: u32,
pub total_cost: f64,
}
impl ModeStats {
#[must_use]
pub const fn new() -> Self {
Self {
intra_count: 0,
inter_count: 0,
skip_count: 0,
total_cost: 0.0,
}
}
pub fn update(&mut self, decision: &BlockDecision) {
match decision.mode {
CodingMode::Intra => self.intra_count += 1,
CodingMode::InterMv | CodingMode::InterGoldenMv => self.inter_count += 1,
CodingMode::InterNoMv | CodingMode::InterGoldenNoMv | CodingMode::NotCoded => {
self.skip_count += 1
}
_ => {}
}
self.total_cost += decision.cost;
}
#[must_use]
pub fn average_cost(&self) -> f64 {
let total_blocks = self.intra_count + self.inter_count + self.skip_count;
if total_blocks > 0 {
self.total_cost / f64::from(total_blocks)
} else {
0.0
}
}
pub fn reset(&mut self) {
*self = Self::new();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_block_decision_engine() {
let engine = BlockDecisionEngine::new(1.0, 16);
let current = [128u8; 64];
let ctx = IntraPredContext::default();
let decision = engine.decide_intra_mode(¤t, &ctx);
assert_eq!(decision.mode, CodingMode::Intra);
}
#[test]
fn test_fast_mode_decision() {
let fast = FastModeDecision::new(1.2, 100);
let current = [128u8; 64];
let mode = fast.decide_fast(¤t, None, 0, 0, 0);
assert_eq!(mode, CodingMode::Intra);
}
#[test]
fn test_sad_calculation() {
let block1 = [100u8; 64];
let block2 = [110u8; 64];
let sad = calculate_sad(&block1, &block2);
assert_eq!(sad, 64 * 10);
}
#[test]
fn test_subblock_decision() {
let decision = SubblockDecision::new(100.0, true);
let block = [128u8; 256];
assert!(!decision.should_split(&block)); }
#[test]
fn test_merge_decision() {
let merge = MergeDecision::new(100);
let block1 = [128u8; 64];
let block2 = [129u8; 64];
assert!(merge.should_merge(&block1, &block2));
let block3 = [130u8; 64];
assert!(!merge.should_merge(&block1, &block3));
}
#[test]
fn test_early_termination() {
let early = EarlyTermination::new(50, true);
assert!(early.should_terminate(30));
assert!(!early.should_terminate(100));
}
#[test]
fn test_block_complexity() {
let analyzer = BlockComplexity::new(10.0);
let block = [128u8; 64];
let activity = analyzer.spatial_activity(&block);
assert_eq!(activity, 0.0); assert!(analyzer.is_homogeneous(&block));
}
#[test]
fn test_mode_stats() {
let mut stats = ModeStats::new();
let decision = BlockDecision {
mode: CodingMode::Intra,
mv: None,
intra_mode: None,
cost: 100.0,
};
stats.update(&decision);
assert_eq!(stats.intra_count, 1);
assert_eq!(stats.average_cost(), 100.0);
}
}