use super::motion_estimation::MotionVector;
use super::partition_state::EncoderMvGrid;
#[cfg(test)]
pub(crate) static B_FORCE_MODE_ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
#[derive(Debug, Clone)]
pub enum BMbDecision {
Skip,
Direct16x16,
L0_16x16 { mv: MotionVector },
L1_16x16 { mv: MotionVector },
Bi_16x16 { mv_l0: MotionVector, mv_l1: MotionVector },
Partitioned {
mb_type: u8,
parts: [super::b_partitioned::BPartitionMv; 2],
},
#[allow(dead_code)]
B8x8 {},
}
pub fn mb_decision_b(
grid: &EncoderMvGrid,
mb_x: usize,
mb_y: usize,
frame_num: u32,
mb_addr: u32,
) -> BMbDecision {
if let Some(forced) = forced_b_mode_from_env() {
return forced;
}
let bucket = mb_decision_bucket(frame_num, mb_addr);
match bucket {
0..=49 => BMbDecision::Skip,
50..=84 => BMbDecision::Direct16x16,
85..=90 => {
let mv = predict_b_partition_mv_l0(grid, mb_x, mb_y);
BMbDecision::L0_16x16 { mv }
}
91..=95 => {
let mv = predict_b_partition_mv_l1(grid, mb_x, mb_y);
BMbDecision::L1_16x16 { mv }
}
_ => {
let mv_l0 = predict_b_partition_mv_l0(grid, mb_x, mb_y);
let mv_l1 = predict_b_partition_mv_l1(grid, mb_x, mb_y);
BMbDecision::Bi_16x16 { mv_l0, mv_l1 }
}
}
}
fn mb_decision_bucket(frame_num: u32, mb_addr: u32) -> u32 {
let mut x = (frame_num as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15);
x = x.wrapping_add(mb_addr as u64);
x = x.wrapping_mul(0xBF58_476D_1CE4_E5B9);
x ^= x >> 30;
(x % 100) as u32
}
fn predict_b_partition_mv_l0(grid: &EncoderMvGrid, mb_x: usize, mb_y: usize) -> MotionVector {
super::partition_state::predict_mv_for_mb_partition(
grid, mb_x * 4, mb_y * 4,
4, 4,
0, 0,
)
}
fn predict_b_partition_mv_l1(grid: &EncoderMvGrid, mb_x: usize, mb_y: usize) -> MotionVector {
super::b_direct_predictor::predict_mv_for_partition_l1_pub(
grid, mb_x * 4, mb_y * 4, 0,
)
}
fn forced_b_mode_from_env() -> Option<BMbDecision> {
let var = std::env::var("PHASM_B_FORCE_MODE").ok()?;
let zero = MotionVector { mv_x: 0, mv_y: 0 };
match var.to_ascii_lowercase().as_str() {
"skip" => Some(BMbDecision::Skip),
"direct" => Some(BMbDecision::Direct16x16),
"l0_16x16" => Some(BMbDecision::L0_16x16 { mv: zero }),
"l1_16x16" => Some(BMbDecision::L1_16x16 { mv: zero }),
"bi_16x16" => Some(BMbDecision::Bi_16x16 {
mv_l0: zero,
mv_l1: zero,
}),
s if s.starts_with("partitioned_") => {
let mb_type: u8 = s["partitioned_".len()..].parse().ok()?;
forced_partitioned_decision(mb_type)
}
_ => None,
}
}
fn forced_partitioned_decision(mb_type: u8) -> Option<BMbDecision> {
use super::b_partitioned::{partitioned_b_meta, BListUse, BPartitionMv};
let zero = MotionVector { mv_x: 0, mv_y: 0 };
let meta = partitioned_b_meta(mb_type as u32)?;
let mv_for = |usage: BListUse| -> BPartitionMv {
let (mv_l0, mv_l1) = match usage {
BListUse::L0 => (Some(zero), None),
BListUse::L1 => (None, Some(zero)),
BListUse::Bi => (Some(zero), Some(zero)),
};
BPartitionMv { mv_l0, mv_l1 }
};
Some(BMbDecision::Partitioned {
mb_type,
parts: [mv_for(meta.part0), mv_for(meta.part1)],
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn distribution_match_x264_medium_buckets() {
let _lock = B_FORCE_MODE_ENV_LOCK.lock().expect("lock not poisoned");
unsafe { std::env::remove_var("PHASM_B_FORCE_MODE"); }
let grid = EncoderMvGrid::new(2, 2);
let mut counts = [0u32; 5]; let n = 10_000u32;
for mb_addr in 0..n {
let d = mb_decision_b(&grid, 0, 0, 0, mb_addr);
let idx = match d {
BMbDecision::Skip => 0,
BMbDecision::Direct16x16 => 1,
BMbDecision::L0_16x16 { .. } => 2,
BMbDecision::L1_16x16 { .. } => 3,
BMbDecision::Bi_16x16 { .. } => 4,
_ => panic!("unexpected variant from §6E-A6.1 stub"),
};
counts[idx] += 1;
}
let pct = |c: u32| (c as f32 / n as f32) * 100.0;
let (skip, direct, l0, l1, bi) =
(pct(counts[0]), pct(counts[1]), pct(counts[2]), pct(counts[3]), pct(counts[4]));
eprintln!(
"§6E-A6.1 mode mix: skip={skip:.1}% direct={direct:.1}% \
L0={l0:.1}% L1={l1:.1}% Bi={bi:.1}%"
);
assert!((skip - 50.0).abs() < 3.0, "skip {skip:.1}%");
assert!((direct - 35.0).abs() < 3.0, "direct {direct:.1}%");
assert!((l0 - 6.0).abs() < 3.0, "L0 {l0:.1}%");
assert!((l1 - 5.0).abs() < 3.0, "L1 {l1:.1}%");
assert!((bi - 4.0).abs() < 3.0, "Bi {bi:.1}%");
}
#[test]
fn deterministic_output_for_same_input() {
let grid = EncoderMvGrid::new(2, 2);
for mb_addr in 0..256 {
let a = mb_decision_b(&grid, 0, 0, 7, mb_addr);
let b = mb_decision_b(&grid, 0, 0, 7, mb_addr);
let disc = |d: &BMbDecision| std::mem::discriminant(d);
assert_eq!(disc(&a), disc(&b),
"non-deterministic mb_decision at mb={mb_addr}");
}
}
}