use super::motion_estimation::MotionVector;
use super::partition_state::EncoderMvGrid;
#[cfg(test)]
pub(crate) static B_FORCE_MODE_ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
#[allow(non_camel_case_types)]
#[derive(Debug, Clone)]
pub enum BMbDecision {
Skip,
Direct16x16,
L0_16x16 { mv: MotionVector },
L1_16x16 { mv: MotionVector },
Bi_16x16 { mv_l0: MotionVector, mv_l1: MotionVector },
Partitioned {
mb_type: u8,
parts: [super::b_partitioned::BPartitionMv; 2],
},
B8x8 {
sub_mb_types: [u8; 4],
parts: [super::b_partitioned::BPartitionMv; 4],
},
}
pub fn mb_decision_b(
grid: &EncoderMvGrid,
mb_x: usize,
mb_y: usize,
frame_num: u32,
mb_addr: u32,
) -> BMbDecision {
mb_decision_b_with_mvs(grid, mb_x, mb_y, frame_num, mb_addr, None)
}
pub fn mb_decision_b_with_mvs(
grid: &EncoderMvGrid,
mb_x: usize,
mb_y: usize,
frame_num: u32,
mb_addr: u32,
me_mvs: Option<(MotionVector, MotionVector)>,
) -> BMbDecision {
if let Some(forced) = forced_b_mode_from_env() {
return forced;
}
let bucket = mb_decision_bucket(frame_num, mb_addr);
let _ = me_mvs; match bucket {
0..=49 => BMbDecision::Skip,
_ => BMbDecision::Direct16x16,
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct BRdoConfig {
pub enable_rdo: bool,
pub enable_residual: bool,
}
impl BRdoConfig {
pub const SAFE: Self = Self { enable_rdo: false, enable_residual: false };
pub const PRODUCTION_VISUAL: Self = Self {
enable_rdo: true,
enable_residual: true,
};
}
pub fn b_rdo_enabled_with(config: BRdoConfig) -> bool {
match std::env::var("PHASM_B_RDO") {
Ok(v) if v == "1" => true,
Ok(v) if v == "0" => false,
_ => config.enable_rdo,
}
}
pub fn b_rdo_enabled() -> bool {
b_rdo_enabled_with(BRdoConfig::SAFE)
}
fn skip_cbp_is_zero(
direct: &super::b_direct_predictor::BDirectSpatialResult,
src_y: &[[u8; 16]; 16],
l0_ref: &super::reference_buffer::ReconFrame,
l1_ref: &super::reference_buffer::ReconFrame,
mb_x: usize,
mb_y: usize,
mb_qp: u8,
) -> bool {
use super::motion_compensation::{apply_luma_mv_block, apply_luma_mv_block_bipred};
use super::quantization::{forward_quantize_4x4, trellis_quantize_4x4, QuantParams, QuantSlice};
use super::transform::forward_dct_4x4;
use crate::codec::h264::macroblock::BLOCK_INDEX_TO_POS;
let mut pred = [[0u8; 16]; 16];
let pred_flat = pred.as_flattened_mut();
let mb_px_x = (mb_x * 16) as u32;
let mb_px_y = (mb_y * 16) as u32;
match (direct.uses_l0(), direct.uses_l1()) {
(true, true) => apply_luma_mv_block_bipred(
l0_ref, direct.mv_l0, l1_ref, direct.mv_l1,
mb_px_x, mb_px_y, 16, 16, pred_flat, 16,
),
(true, false) => apply_luma_mv_block(
l0_ref, mb_px_x, mb_px_y, 16, 16, direct.mv_l0, pred_flat, 16,
),
(false, true) => apply_luma_mv_block(
l1_ref, mb_px_x, mb_px_y, 16, 16, direct.mv_l1, pred_flat, 16,
),
(false, false) => return false, }
let inter = QuantParams { qp: mb_qp, slice: QuantSlice::Inter };
for k in 0..16 {
let (bx, by) = BLOCK_INDEX_TO_POS[k];
let sby = by as usize;
let sbx = bx as usize;
let mut sub_res = [[0i32; 4]; 4];
for dy in 0..4 {
for dx in 0..4 {
sub_res[dy][dx] = src_y[sby * 4 + dy][sbx * 4 + dx] as i32
- pred[sby * 4 + dy][sbx * 4 + dx] as i32;
}
}
let coeffs = forward_dct_4x4(&sub_res);
let levels = trellis_quantize_4x4(&coeffs, inter, true)
.unwrap_or_else(|_| forward_quantize_4x4(&coeffs, inter));
for row in &levels {
for &v in row {
if v != 0 {
return false;
}
}
}
}
true
}
#[allow(clippy::too_many_arguments)]
pub fn mb_decision_b_rdo(
grid: &EncoderMvGrid,
mb_x: usize,
mb_y: usize,
src_y: &[[u8; 16]; 16],
l0_ref: &super::reference_buffer::ReconFrame,
l1_ref: &super::reference_buffer::ReconFrame,
mb_qp: u8,
me_mvs: (MotionVector, MotionVector),
) -> BMbDecision {
use super::rdo_b::{evaluate_b_mb_rdo, BMbCandidate};
let direct = super::b_direct_predictor::derive_b_direct_spatial_with_col(
grid, mb_x, mb_y, l1_ref.motion_grid.as_ref(),
);
if skip_cbp_is_zero(&direct, src_y, l0_ref, l1_ref, mb_x, mb_y, mb_qp) {
return BMbDecision::Skip;
}
let skip_or_direct = BMbCandidate::SkipOrDirect {
mv_l0: direct.mv_l0,
mv_l1: direct.mv_l1,
uses_l0: direct.uses_l0(),
uses_l1: direct.uses_l1(),
};
let l0 = BMbCandidate::L0_16x16 { mv_l0: me_mvs.0 };
let l1 = BMbCandidate::L1_16x16 { mv_l1: me_mvs.1 };
let bi = BMbCandidate::Bi_16x16 { mv_l0: me_mvs.0, mv_l1: me_mvs.1 };
let r_skip_or_direct = evaluate_b_mb_rdo(&skip_or_direct, src_y, l0_ref, l1_ref, mb_x, mb_y, mb_qp);
let r_l0 = evaluate_b_mb_rdo(&l0, src_y, l0_ref, l1_ref, mb_x, mb_y, mb_qp);
let r_l1 = evaluate_b_mb_rdo(&l1, src_y, l0_ref, l1_ref, mb_x, mb_y, mb_qp);
let r_bi = evaluate_b_mb_rdo(&bi, src_y, l0_ref, l1_ref, mb_x, mb_y, mb_qp);
let best_single_satd = r_l0.satd.min(r_l1.satd);
let bi_passes_threshold = true;
let _ = best_single_satd;
if std::env::var_os("PHASM_B_RDO_LOG_BI").is_some() {
eprintln!(
"B-MB ({mb_x},{mb_y}) qp={mb_qp} l0_satd={} l1_satd={} bi_satd={} ratio={:.3} pass={}",
r_l0.satd, r_l1.satd, r_bi.satd,
(r_bi.satd as f64) / (best_single_satd.max(1) as f64),
bi_passes_threshold,
);
}
if std::env::var_os("PHASM_B_RDO_TRACE").is_some() {
use std::sync::atomic::{AtomicUsize, Ordering};
static DIRECT_CNT: AtomicUsize = AtomicUsize::new(0);
static L0_CNT: AtomicUsize = AtomicUsize::new(0);
static L1_CNT: AtomicUsize = AtomicUsize::new(0);
static BI_CNT: AtomicUsize = AtomicUsize::new(0);
let dec_label;
let mut best_for_trace = (r_skip_or_direct.cost, 1_u8);
if r_l0.cost < best_for_trace.0 { best_for_trace = (r_l0.cost, 2); }
if r_l1.cost < best_for_trace.0 { best_for_trace = (r_l1.cost, 3); }
if bi_passes_threshold && r_bi.cost < best_for_trace.0 { best_for_trace = (r_bi.cost, 4); }
match best_for_trace.1 {
1 => { DIRECT_CNT.fetch_add(1, Ordering::Relaxed); dec_label = "Direct"; }
2 => { L0_CNT.fetch_add(1, Ordering::Relaxed); dec_label = "L0"; }
3 => { L1_CNT.fetch_add(1, Ordering::Relaxed); dec_label = "L1"; }
4 => { BI_CNT.fetch_add(1, Ordering::Relaxed); dec_label = "Bi"; }
_ => unreachable!(),
}
let total = DIRECT_CNT.load(Ordering::Relaxed)
+ L0_CNT.load(Ordering::Relaxed) + L1_CNT.load(Ordering::Relaxed)
+ BI_CNT.load(Ordering::Relaxed);
if total % 1000 == 0 {
eprintln!(
"B-RDO trace (RDO-only) [{total}]: Direct={} L0={} L1={} Bi={} (this={dec_label})",
DIRECT_CNT.load(Ordering::Relaxed),
L0_CNT.load(Ordering::Relaxed), L1_CNT.load(Ordering::Relaxed),
BI_CNT.load(Ordering::Relaxed),
);
}
}
let mut best = (r_skip_or_direct.cost, 1_u8); if r_l0.cost < best.0 { best = (r_l0.cost, 2); }
if r_l1.cost < best.0 { best = (r_l1.cost, 3); }
if bi_passes_threshold && r_bi.cost < best.0 { best = (r_bi.cost, 4); }
match best.1 {
1 => BMbDecision::Direct16x16,
2 => BMbDecision::L0_16x16 { mv: me_mvs.0 },
3 => BMbDecision::L1_16x16 { mv: me_mvs.1 },
4 => BMbDecision::Bi_16x16 { mv_l0: me_mvs.0, mv_l1: me_mvs.1 },
_ => unreachable!(),
}
}
fn mb_decision_bucket(frame_num: u32, mb_addr: u32) -> u32 {
let mut x = (frame_num as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15);
x = x.wrapping_add(mb_addr as u64);
x = x.wrapping_mul(0xBF58_476D_1CE4_E5B9);
x ^= x >> 30;
(x % 100) as u32
}
fn predict_b_partition_mv_l0(grid: &EncoderMvGrid, mb_x: usize, mb_y: usize) -> MotionVector {
super::partition_state::predict_mv_for_mb_partition(
grid, mb_x * 4, mb_y * 4,
4, 4,
0, 0,
)
}
pub fn predict_b_partition_mv_l0_pub(grid: &EncoderMvGrid, mb_x: usize, mb_y: usize) -> MotionVector {
predict_b_partition_mv_l0(grid, mb_x, mb_y)
}
fn predict_b_partition_mv_l1(grid: &EncoderMvGrid, mb_x: usize, mb_y: usize) -> MotionVector {
super::b_direct_predictor::predict_mv_for_partition_l1_pub(
grid, mb_x * 4, mb_y * 4, 0,
)
}
pub fn predict_b_partition_mv_l1_pub(grid: &EncoderMvGrid, mb_x: usize, mb_y: usize) -> MotionVector {
predict_b_partition_mv_l1(grid, mb_x, mb_y)
}
pub fn forced_b_mode_from_env() -> Option<BMbDecision> {
let var = std::env::var("PHASM_B_FORCE_MODE").ok()?;
let zero = MotionVector { mv_x: 0, mv_y: 0 };
let forced_mv = std::env::var("PHASM_B_FORCE_MV")
.ok()
.and_then(|s| {
let mut parts = s.splitn(2, ',');
let x: i16 = parts.next()?.trim().parse().ok()?;
let y: i16 = parts.next()?.trim().parse().ok()?;
Some(MotionVector { mv_x: x, mv_y: y })
})
.unwrap_or(zero);
let forced_mv_l1 = std::env::var("PHASM_B_FORCE_MV_L1")
.ok()
.and_then(|s| {
let mut parts = s.splitn(2, ',');
let x: i16 = parts.next()?.trim().parse().ok()?;
let y: i16 = parts.next()?.trim().parse().ok()?;
Some(MotionVector { mv_x: x, mv_y: y })
})
.unwrap_or(forced_mv);
match var.to_ascii_lowercase().as_str() {
"skip" => Some(BMbDecision::Skip),
"direct" => Some(BMbDecision::Direct16x16),
"l0_16x16" => Some(BMbDecision::L0_16x16 { mv: forced_mv }),
"l1_16x16" => Some(BMbDecision::L1_16x16 { mv: forced_mv }),
"bi_16x16" => Some(BMbDecision::Bi_16x16 {
mv_l0: forced_mv,
mv_l1: forced_mv_l1,
}),
s if s.starts_with("partitioned_") => {
let mb_type: u8 = s["partitioned_".len()..].parse().ok()?;
forced_partitioned_decision(mb_type, forced_mv, forced_mv_l1)
}
"b_8x8_uniform_direct" => Some(forced_b_8x8_uniform(0, forced_mv, forced_mv_l1)),
"b_8x8_uniform_l0" => Some(forced_b_8x8_uniform(1, forced_mv, forced_mv_l1)),
"b_8x8_uniform_l1" => Some(forced_b_8x8_uniform(2, forced_mv, forced_mv_l1)),
"b_8x8_uniform_bi" => Some(forced_b_8x8_uniform(3, forced_mv, forced_mv_l1)),
"b_8x8_mixed" => Some(forced_b_8x8_mixed(forced_mv, forced_mv_l1)),
_ => None,
}
}
fn forced_b_8x8_uniform(sub: u8, mv_l0: MotionVector, mv_l1: MotionVector) -> BMbDecision {
let part = b_8x8_part_for_subtype(sub, mv_l0, mv_l1);
BMbDecision::B8x8 {
sub_mb_types: [sub; 4],
parts: [part; 4],
}
}
fn forced_b_8x8_mixed(mv_l0: MotionVector, mv_l1: MotionVector) -> BMbDecision {
let sub_mb_types = [0u8, 1, 2, 3];
let parts = [
b_8x8_part_for_subtype(0, mv_l0, mv_l1),
b_8x8_part_for_subtype(1, mv_l0, mv_l1),
b_8x8_part_for_subtype(2, mv_l0, mv_l1),
b_8x8_part_for_subtype(3, mv_l0, mv_l1),
];
BMbDecision::B8x8 { sub_mb_types, parts }
}
fn b_8x8_part_for_subtype(
sub: u8,
mv_l0: MotionVector,
mv_l1: MotionVector,
) -> super::b_partitioned::BPartitionMv {
use super::b_partitioned::BPartitionMv;
match sub {
0 => BPartitionMv { mv_l0: None, mv_l1: None },
1 => BPartitionMv { mv_l0: Some(mv_l0), mv_l1: None },
2 => BPartitionMv { mv_l0: None, mv_l1: Some(mv_l1) },
3 => BPartitionMv { mv_l0: Some(mv_l0), mv_l1: Some(mv_l1) },
_ => {
debug_assert!(false, "B_8x8 sub_mb_type {sub} out of §6E-A6.3 scope");
BPartitionMv { mv_l0: None, mv_l1: None }
}
}
}
fn forced_partitioned_decision(
mb_type: u8,
mv_l0: MotionVector,
mv_l1: MotionVector,
) -> Option<BMbDecision> {
use super::b_partitioned::{partitioned_b_meta, BListUse, BPartitionMv};
let meta = partitioned_b_meta(mb_type as u32)?;
let mv_for = |usage: BListUse| -> BPartitionMv {
let (mv_l0_o, mv_l1_o) = match usage {
BListUse::L0 => (Some(mv_l0), None),
BListUse::L1 => (None, Some(mv_l1)),
BListUse::Bi => (Some(mv_l0), Some(mv_l1)),
};
BPartitionMv { mv_l0: mv_l0_o, mv_l1: mv_l1_o }
};
Some(BMbDecision::Partitioned {
mb_type,
parts: [mv_for(meta.part0), mv_for(meta.part1)],
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn fallback_distribution_skip_direct_only() {
let _lock = B_FORCE_MODE_ENV_LOCK.lock().expect("lock not poisoned");
unsafe { std::env::remove_var("PHASM_B_FORCE_MODE"); }
let grid = EncoderMvGrid::new(2, 2);
let mut counts = [0u32; 5];
let n = 10_000u32;
for mb_addr in 0..n {
let d = mb_decision_b(&grid, 0, 0, 0, mb_addr);
let idx = match d {
BMbDecision::Skip => 0,
BMbDecision::Direct16x16 => 1,
BMbDecision::L0_16x16 { .. } => 2,
BMbDecision::L1_16x16 { .. } => 3,
BMbDecision::Bi_16x16 { .. } => 4,
_ => panic!("unexpected variant from no-RDO fallback"),
};
counts[idx] += 1;
}
let pct = |c: u32| (c as f32 / n as f32) * 100.0;
let (skip, direct, l0, l1, bi) =
(pct(counts[0]), pct(counts[1]), pct(counts[2]), pct(counts[3]), pct(counts[4]));
eprintln!(
"no-RDO fallback mix: skip={skip:.1}% direct={direct:.1}% \
L0={l0:.1}% L1={l1:.1}% Bi={bi:.1}%"
);
assert!((skip - 50.0).abs() < 3.0, "skip {skip:.1}%");
assert!((direct - 50.0).abs() < 3.0, "direct {direct:.1}%");
assert_eq!(counts[2], 0, "L0 must be 0 (RDO-only); got {l0:.1}%");
assert_eq!(counts[3], 0, "L1 must be 0 (RDO-only); got {l1:.1}%");
assert_eq!(counts[4], 0, "Bi must be 0 (RDO-only); got {bi:.1}%");
}
#[test]
fn deterministic_output_for_same_input() {
let grid = EncoderMvGrid::new(2, 2);
for mb_addr in 0..256 {
let a = mb_decision_b(&grid, 0, 0, 7, mb_addr);
let b = mb_decision_b(&grid, 0, 0, 7, mb_addr);
let disc = |d: &BMbDecision| std::mem::discriminant(d);
assert_eq!(disc(&a), disc(&b),
"non-deterministic mb_decision at mb={mb_addr}");
}
}
fn make_recon_b(width: u32, height: u32, y_fill: u8) -> super::super::reference_buffer::ReconFrame {
use super::super::reconstruction::ReconBuffer;
let mut buf = ReconBuffer::new(width, height).unwrap();
for v in buf.y.iter_mut() { *v = y_fill; }
for v in buf.cb.iter_mut() { *v = 128; }
for v in buf.cr.iter_mut() { *v = 128; }
super::super::reference_buffer::ReconFrame::snapshot(&buf)
}
#[test]
fn rdo_picks_skip_when_l0_matches_exactly() {
let mut src = [[0u8; 16]; 16];
for row in &mut src { for px in row { *px = 100; } }
let l0 = make_recon_b(64, 64, 100);
let l1 = make_recon_b(64, 64, 100);
let grid = EncoderMvGrid::new(4, 4);
let zero = MotionVector { mv_x: 0, mv_y: 0 };
let d = mb_decision_b_rdo(&grid, 0, 0, &src, &l0, &l1, 30, (zero, zero));
assert!(matches!(d, BMbDecision::Skip),
"Skip should win when prediction is exact, got {:?}", d);
}
#[test]
fn rdo_picks_l0_when_l0_matches_l1_doesnt() {
let mut src = [[0u8; 16]; 16];
for row in &mut src { for px in row { *px = 100; } }
let l0 = make_recon_b(64, 64, 100);
let l1 = make_recon_b(64, 64, 50);
let grid = EncoderMvGrid::new(4, 4);
let zero = MotionVector { mv_x: 0, mv_y: 0 };
let d = mb_decision_b_rdo(&grid, 0, 0, &src, &l0, &l1, 30, (zero, zero));
assert!(matches!(d, BMbDecision::L0_16x16 { .. }),
"L0_16x16 should win when only L0 matches, got {:?}", d);
}
#[test]
fn rdo_picks_l1_when_l1_matches_l0_doesnt() {
let mut src = [[0u8; 16]; 16];
for row in &mut src { for px in row { *px = 200; } }
let l0 = make_recon_b(64, 64, 50);
let l1 = make_recon_b(64, 64, 200);
let grid = EncoderMvGrid::new(4, 4);
let zero = MotionVector { mv_x: 0, mv_y: 0 };
let d = mb_decision_b_rdo(&grid, 0, 0, &src, &l0, &l1, 30, (zero, zero));
assert!(matches!(d, BMbDecision::L1_16x16 { .. }),
"L1_16x16 should win when only L1 matches, got {:?}", d);
}
#[test]
fn rdo_deterministic() {
let mut src = [[0u8; 16]; 16];
for (y, row) in src.iter_mut().enumerate() {
for (x, px) in row.iter_mut().enumerate() {
*px = ((x * 7 + y * 11) & 0xFF) as u8;
}
}
let l0 = make_recon_b(64, 64, 80);
let l1 = make_recon_b(64, 64, 90);
let grid = EncoderMvGrid::new(4, 4);
let zero = MotionVector { mv_x: 0, mv_y: 0 };
let d_a = mb_decision_b_rdo(&grid, 0, 0, &src, &l0, &l1, 30, (zero, zero));
let d_b = mb_decision_b_rdo(&grid, 0, 0, &src, &l0, &l1, 30, (zero, zero));
assert_eq!(std::mem::discriminant(&d_a), std::mem::discriminant(&d_b));
}
#[test]
fn b_rdo_enabled_reads_env() {
unsafe { std::env::remove_var("PHASM_B_RDO"); }
assert!(!b_rdo_enabled());
unsafe { std::env::set_var("PHASM_B_RDO", "1"); }
assert!(b_rdo_enabled());
unsafe { std::env::set_var("PHASM_B_RDO", "0"); }
assert!(!b_rdo_enabled());
unsafe { std::env::remove_var("PHASM_B_RDO"); }
}
}