use super::inject::{
apply_coeff_sign_overrides, apply_coeff_suffix_lsb_overrides, MvdSlot,
};
use super::orchestrate::ResidualPathKind;
use super::{BinKind, BitInjector};
pub trait StegoMbHook: Send + std::fmt::Debug {
fn on_residual_block(
&mut self,
frame_idx: u32,
mb_addr: u32,
scan_coeffs: &mut [i32],
start_idx: usize,
end_idx: usize,
path_kind: ResidualPathKind,
);
fn on_mvd_slot(
&mut self,
frame_idx: u32,
mb_addr: u32,
slot: &mut MvdSlot,
);
fn take_cover_if_logger(&mut self) -> Option<super::orchestrate::GopCover> {
None
}
fn take_mvd_meta_if_logger(&mut self) -> Vec<MvdPositionMeta> {
Vec::new()
}
fn take_counts_if_counter(&mut self) -> Option<[usize; 4]> {
None
}
fn begin_mvd_for_mb(&mut self) {}
fn commit_mvd_for_mb(&mut self) {}
fn rollback_mvd_for_mb(&mut self) {}
fn mvd_sign_override(
&mut self,
_frame_idx: u32,
_mb_addr: u32,
_slot: &MvdSlot,
) -> Option<u8> {
None
}
}
pub struct InjectionHook<I: BitInjector> {
injector: I,
mvd_msl_safe_gate: Option<std::collections::HashSet<super::hook::PositionKey>>,
}
impl<I: BitInjector> InjectionHook<I> {
pub fn new(injector: I) -> Self {
Self {
injector,
mvd_msl_safe_gate: None,
}
}
pub fn set_mvd_msl_safe_gate(
&mut self,
keys: std::collections::HashSet<super::hook::PositionKey>,
) {
self.mvd_msl_safe_gate = Some(keys);
}
pub fn into_injector(self) -> I {
self.injector
}
}
impl<I: BitInjector> std::fmt::Debug for InjectionHook<I> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("InjectionHook").finish_non_exhaustive()
}
}
impl<I: BitInjector + Send> StegoMbHook for InjectionHook<I> {
fn on_residual_block(
&mut self,
frame_idx: u32,
mb_addr: u32,
scan_coeffs: &mut [i32],
start_idx: usize,
end_idx: usize,
path_kind: ResidualPathKind,
) {
apply_coeff_sign_overrides(
scan_coeffs, start_idx, end_idx, frame_idx, mb_addr,
|ci| path_kind.path(ci, BinKind::Sign),
&mut self.injector,
);
apply_coeff_suffix_lsb_overrides(
scan_coeffs, start_idx, end_idx, frame_idx, mb_addr,
|ci| path_kind.path(ci, BinKind::SuffixLsb),
&mut self.injector,
);
}
fn on_mvd_slot(
&mut self,
frame_idx: u32,
mb_addr: u32,
slot: &mut MvdSlot,
) {
if slot.value == 0 {
return;
}
let Some(gate) = self.mvd_msl_safe_gate.as_ref() else {
return;
};
use super::hook::{EmbedDomain, PositionKey, SyntaxPath, BinKind};
let path = SyntaxPath::Mvd {
list: slot.list,
partition: slot.partition,
axis: slot.axis,
kind: BinKind::SuffixLsb,
};
let key = PositionKey::new(frame_idx, mb_addr, EmbedDomain::MvdSuffixLsb, path);
if !gate.contains(&key) {
return;
}
let Some(plan_bit) = self.injector.override_bit(key) else {
return;
};
let abs = slot.value.unsigned_abs();
let cur_lsb = (abs & 1) as u8;
if cur_lsb == plan_bit {
return;
}
let new_abs: u32 = if cur_lsb == 1 && plan_bit == 0 {
if abs >= 10 { abs - 1 } else { abs + 1 }
} else {
if abs > 10 { abs - 1 } else { abs + 1 }
};
let signed = new_abs as i32;
slot.value = if slot.value < 0 { -signed } else { signed };
}
fn mvd_sign_override(
&mut self,
frame_idx: u32,
mb_addr: u32,
slot: &MvdSlot,
) -> Option<u8> {
if slot.value == 0 {
return None;
}
use super::hook::{EmbedDomain, PositionKey, SyntaxPath, BinKind};
let path = SyntaxPath::Mvd {
list: slot.list,
partition: slot.partition,
axis: slot.axis,
kind: BinKind::Sign,
};
let key = PositionKey::new(frame_idx, mb_addr, EmbedDomain::MvdSignBypass, path);
self.injector.override_bit(key)
}
}
#[derive(Copy, Clone, Debug, Default)]
pub struct MvdPositionMeta {
pub magnitude: u32,
pub mb_addr: u32,
pub frame_idx: u32,
pub partition: u8,
pub axis: u8,
}
#[derive(Debug)]
pub struct PositionLoggerHook {
cover: super::orchestrate::GopCover,
mvd_savepoint: Option<(usize, usize)>,
mvd_meta: Vec<MvdPositionMeta>,
mvd_meta_savepoint: Option<usize>,
}
impl Default for PositionLoggerHook {
fn default() -> Self {
Self::new()
}
}
impl PositionLoggerHook {
pub fn new() -> Self {
Self {
cover: super::orchestrate::GopCover::default(),
mvd_savepoint: None,
mvd_meta: Vec::new(),
mvd_meta_savepoint: None,
}
}
pub fn take_cover(&mut self) -> super::orchestrate::GopCover {
std::mem::take(&mut self.cover)
}
pub fn take_mvd_meta(&mut self) -> Vec<MvdPositionMeta> {
std::mem::take(&mut self.mvd_meta)
}
}
pub struct InjectAndLogHook<I: BitInjector> {
inject: InjectionHook<I>,
logger: PositionLoggerHook,
}
impl<I: BitInjector> InjectAndLogHook<I> {
pub fn new(injector: I) -> Self {
Self {
inject: InjectionHook::new(injector),
logger: PositionLoggerHook::new(),
}
}
}
impl<I: BitInjector> std::fmt::Debug for InjectAndLogHook<I> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("InjectAndLogHook").finish_non_exhaustive()
}
}
impl<I: BitInjector + Send> StegoMbHook for InjectAndLogHook<I> {
fn on_residual_block(
&mut self,
frame_idx: u32,
mb_addr: u32,
scan_coeffs: &mut [i32],
start_idx: usize,
end_idx: usize,
path_kind: ResidualPathKind,
) {
self.logger.on_residual_block(
frame_idx, mb_addr, scan_coeffs, start_idx, end_idx, path_kind,
);
}
fn on_mvd_slot(
&mut self,
frame_idx: u32,
mb_addr: u32,
slot: &mut super::inject::MvdSlot,
) {
self.inject.on_mvd_slot(frame_idx, mb_addr, slot);
self.logger.on_mvd_slot(frame_idx, mb_addr, slot);
}
fn mvd_sign_override(
&mut self,
frame_idx: u32,
mb_addr: u32,
slot: &super::inject::MvdSlot,
) -> Option<u8> {
self.inject.mvd_sign_override(frame_idx, mb_addr, slot)
}
fn begin_mvd_for_mb(&mut self) {
self.logger.begin_mvd_for_mb();
}
fn commit_mvd_for_mb(&mut self) {
self.logger.commit_mvd_for_mb();
}
fn rollback_mvd_for_mb(&mut self) {
self.logger.rollback_mvd_for_mb();
}
fn take_cover_if_logger(&mut self) -> Option<super::orchestrate::GopCover> {
Some(self.logger.take_cover())
}
fn take_mvd_meta_if_logger(&mut self) -> Vec<MvdPositionMeta> {
self.logger.take_mvd_meta()
}
}
impl StegoMbHook for PositionLoggerHook {
fn take_cover_if_logger(&mut self) -> Option<super::orchestrate::GopCover> {
Some(self.take_cover())
}
fn take_mvd_meta_if_logger(&mut self) -> Vec<MvdPositionMeta> {
self.take_mvd_meta()
}
fn begin_mvd_for_mb(&mut self) {
self.mvd_savepoint = Some((
self.cover.cover.mvd_sign_bypass.len(),
self.cover.cover.mvd_suffix_lsb.len(),
));
self.mvd_meta_savepoint = Some(self.mvd_meta.len());
}
fn commit_mvd_for_mb(&mut self) {
self.mvd_savepoint = None;
self.mvd_meta_savepoint = None;
}
fn rollback_mvd_for_mb(&mut self) {
if let Some((sign_len, suffix_len)) = self.mvd_savepoint.take() {
self.cover.cover.mvd_sign_bypass.truncate(sign_len);
self.cover.cover.mvd_suffix_lsb.truncate(suffix_len);
self.cover.costs.mvd_sign_bypass.truncate(sign_len);
self.cover.costs.mvd_suffix_lsb.truncate(suffix_len);
}
if let Some(meta_len) = self.mvd_meta_savepoint.take() {
self.mvd_meta.truncate(meta_len);
}
}
fn on_residual_block(
&mut self,
frame_idx: u32,
mb_addr: u32,
scan_coeffs: &mut [i32],
start_idx: usize,
end_idx: usize,
path_kind: ResidualPathKind,
) {
use super::cost_model::{coeff_sign_cost_vec, coeff_suffix_lsb_cost_vec, PositionCostCtx};
use super::{
enumerate_coeff_sign_positions, enumerate_coeff_suffix_lsb_positions,
extract_coeff_sign_bits, extract_coeff_suffix_lsb_bits,
};
let ctx = PositionCostCtx::new(frame_idx, mb_addr);
let positions = enumerate_coeff_sign_positions(
scan_coeffs, start_idx, end_idx, frame_idx, mb_addr,
|ci| path_kind.path(ci, BinKind::Sign),
);
let bits = extract_coeff_sign_bits(scan_coeffs, start_idx, end_idx);
let costs = coeff_sign_cost_vec(scan_coeffs, start_idx, end_idx, &ctx);
for ((p, b), c) in positions.iter().zip(bits.iter()).zip(costs.iter()) {
self.cover.cover.coeff_sign_bypass.push(*b, *p);
self.cover.costs.coeff_sign_bypass.push(*c);
}
let positions = enumerate_coeff_suffix_lsb_positions(
scan_coeffs, start_idx, end_idx, frame_idx, mb_addr,
|ci| path_kind.path(ci, BinKind::SuffixLsb),
);
let bits = extract_coeff_suffix_lsb_bits(scan_coeffs, start_idx, end_idx);
let costs = coeff_suffix_lsb_cost_vec(scan_coeffs, start_idx, end_idx, &ctx);
for ((p, b), c) in positions.iter().zip(bits.iter()).zip(costs.iter()) {
self.cover.cover.coeff_suffix_lsb.push(*b, *p);
self.cover.costs.coeff_suffix_lsb.push(*c);
}
}
fn on_mvd_slot(
&mut self,
frame_idx: u32,
mb_addr: u32,
slot: &mut MvdSlot,
) {
use super::cost_model::{mvd_sign_cost_vec, mvd_suffix_lsb_cost_vec, PositionCostCtx};
use super::{
enumerate_mvd_sign_positions, enumerate_mvd_suffix_lsb_positions,
extract_mvd_sign_bits, extract_mvd_suffix_lsb_bits,
};
let single = [*slot];
let ctx = PositionCostCtx::new(frame_idx, mb_addr);
let positions = enumerate_mvd_sign_positions(&single, frame_idx, mb_addr);
let bits = extract_mvd_sign_bits(&single);
let costs = mvd_sign_cost_vec(&single, &ctx);
let pre_sign_len = self.cover.cover.mvd_sign_bypass.len();
for ((p, b), c) in positions.iter().zip(bits.iter()).zip(costs.iter()) {
self.cover.cover.mvd_sign_bypass.push(*b, *p);
self.cover.costs.mvd_sign_bypass.push(*c);
}
let pushed = self.cover.cover.mvd_sign_bypass.len() - pre_sign_len;
if pushed > 0 {
use super::Axis;
self.mvd_meta.push(MvdPositionMeta {
magnitude: slot.value.unsigned_abs(),
mb_addr,
frame_idx,
partition: slot.partition,
axis: match slot.axis { Axis::X => 0, Axis::Y => 1 },
});
}
let positions = enumerate_mvd_suffix_lsb_positions(&single, frame_idx, mb_addr);
let bits = extract_mvd_suffix_lsb_bits(&single);
let costs = mvd_suffix_lsb_cost_vec(&single, &ctx);
for ((p, b), c) in positions.iter().zip(bits.iter()).zip(costs.iter()) {
self.cover.cover.mvd_suffix_lsb.push(*b, *p);
self.cover.costs.mvd_suffix_lsb.push(*c);
}
}
}
#[derive(Debug, Default)]
pub struct PositionCountingHook {
coeff_sign: usize,
coeff_suffix: usize,
mvd_sign: usize,
mvd_suffix: usize,
mvd_savepoint: Option<(usize, usize)>,
}
impl PositionCountingHook {
pub fn new() -> Self {
Self::default()
}
pub fn snapshot(&self) -> [usize; 4] {
[self.coeff_sign, self.coeff_suffix, self.mvd_sign, self.mvd_suffix]
}
}
impl StegoMbHook for PositionCountingHook {
fn on_residual_block(
&mut self,
frame_idx: u32,
mb_addr: u32,
scan_coeffs: &mut [i32],
start_idx: usize,
end_idx: usize,
path_kind: ResidualPathKind,
) {
use super::{
enumerate_coeff_sign_positions, enumerate_coeff_suffix_lsb_positions,
};
let positions = enumerate_coeff_sign_positions(
scan_coeffs, start_idx, end_idx, frame_idx, mb_addr,
|ci| path_kind.path(ci, BinKind::Sign),
);
self.coeff_sign += positions.len();
let positions = enumerate_coeff_suffix_lsb_positions(
scan_coeffs, start_idx, end_idx, frame_idx, mb_addr,
|ci| path_kind.path(ci, BinKind::SuffixLsb),
);
self.coeff_suffix += positions.len();
}
fn on_mvd_slot(
&mut self,
frame_idx: u32,
mb_addr: u32,
slot: &mut MvdSlot,
) {
use super::{
enumerate_mvd_sign_positions, enumerate_mvd_suffix_lsb_positions,
};
let single = [*slot];
let positions = enumerate_mvd_sign_positions(&single, frame_idx, mb_addr);
self.mvd_sign += positions.len();
let positions = enumerate_mvd_suffix_lsb_positions(&single, frame_idx, mb_addr);
self.mvd_suffix += positions.len();
}
fn begin_mvd_for_mb(&mut self) {
self.mvd_savepoint = Some((self.mvd_sign, self.mvd_suffix));
}
fn commit_mvd_for_mb(&mut self) {
self.mvd_savepoint = None;
}
fn rollback_mvd_for_mb(&mut self) {
if let Some((s, l)) = self.mvd_savepoint.take() {
self.mvd_sign = s;
self.mvd_suffix = l;
}
}
fn take_counts_if_counter(&mut self) -> Option<[usize; 4]> {
Some(self.snapshot())
}
}
#[cfg(test)]
mod tests {
use super::*;
use super::super::Axis;
#[test]
fn position_logger_residual_block_collects_per_domain() {
let mut hook = PositionLoggerHook::new();
let mut scan = vec![0i32; 16];
scan[0] = 5; scan[3] = -7; scan[6] = 20; hook.on_residual_block(
0, 0, &mut scan, 0, 15,
ResidualPathKind::Luma4x4 { block_idx: 0 },
);
let cover = hook.take_cover();
assert_eq!(cover.cover.coeff_sign_bypass.len(), 3);
assert_eq!(cover.cover.coeff_suffix_lsb.len(), 1);
}
#[test]
fn position_logger_mvd_slot_collects_per_domain() {
let mut hook = PositionLoggerHook::new();
let mut slot = MvdSlot { list: 0, partition: 0, axis: Axis::X, value: 5 };
hook.on_mvd_slot(0, 0, &mut slot);
let cover = hook.take_cover();
assert_eq!(cover.cover.mvd_sign_bypass.len(), 1);
assert_eq!(cover.cover.mvd_suffix_lsb.len(), 0);
let mut hook = PositionLoggerHook::new();
let mut slot = MvdSlot { list: 0, partition: 0, axis: Axis::X, value: 15 };
hook.on_mvd_slot(0, 0, &mut slot);
let cover = hook.take_cover();
assert_eq!(cover.cover.mvd_sign_bypass.len(), 1);
assert_eq!(cover.cover.mvd_suffix_lsb.len(), 1);
}
#[test]
fn position_logger_zero_inputs_emit_no_positions() {
let mut hook = PositionLoggerHook::new();
let mut scan = vec![0i32; 16];
hook.on_residual_block(
0, 0, &mut scan, 0, 15,
ResidualPathKind::Luma4x4 { block_idx: 0 },
);
let mut slot = MvdSlot { list: 0, partition: 0, axis: Axis::X, value: 0 };
hook.on_mvd_slot(0, 0, &mut slot);
let cover = hook.take_cover();
assert_eq!(cover.cover.coeff_sign_bypass.len(), 0);
assert_eq!(cover.cover.coeff_suffix_lsb.len(), 0);
assert_eq!(cover.cover.mvd_sign_bypass.len(), 0);
assert_eq!(cover.cover.mvd_suffix_lsb.len(), 0);
}
struct ForceBit(u8);
impl BitInjector for ForceBit {
fn override_bit(&mut self, _key: super::super::PositionKey) -> Option<u8> {
Some(self.0)
}
}
#[test]
fn injection_hook_residual_block_flips_signs() {
let mut hook = InjectionHook::new(ForceBit(1)); let mut scan = vec![0i32; 16];
scan[0] = 5; scan[3] = 7; scan[6] = -2;
hook.on_residual_block(
0, 0, &mut scan, 0, 15,
ResidualPathKind::Luma4x4 { block_idx: 0 },
);
assert_eq!(scan[0], -5);
assert_eq!(scan[3], -7);
assert_eq!(scan[6], -2);
}
#[test]
fn injection_hook_residual_block_flips_suffix_lsb() {
let mut hook = InjectionHook::new(ForceBit(0));
let mut scan = vec![0i32; 16];
scan[0] = 20;
hook.on_residual_block(
0, 0, &mut scan, 0, 15,
ResidualPathKind::Luma4x4 { block_idx: 0 },
);
assert!(scan[0] > 0, "sign flipped to positive");
assert_eq!(scan[0].unsigned_abs(), 19, "suffix LSB flip ±1");
}
#[test]
fn injection_hook_mvd_slot_magnitude_lsb_flip() {
use super::super::hook::{EmbedDomain, PositionKey, SyntaxPath, BinKind};
let test_key = || PositionKey::new(
0, 0, EmbedDomain::MvdSuffixLsb,
SyntaxPath::Mvd { list: 0, partition: 0, axis: Axis::X, kind: BinKind::SuffixLsb },
);
let install_gate = |hook: &mut InjectionHook<ForceBit>| {
let mut set = std::collections::HashSet::new();
set.insert(test_key());
hook.set_mvd_msl_safe_gate(set);
};
let mut hook = InjectionHook::new(ForceBit(0));
install_gate(&mut hook);
let mut slot = MvdSlot { list: 0, partition: 0, axis: Axis::X, value: -5 };
hook.on_mvd_slot(0, 0, &mut slot);
assert_eq!(slot.value, -6,
"abs<10 + cur_lsb!=plan_bit → +1 magnitude flip (preserve sign)");
let mut hook = InjectionHook::new(ForceBit(1));
install_gate(&mut hook);
let mut slot = MvdSlot { list: 0, partition: 0, axis: Axis::X, value: -12 };
hook.on_mvd_slot(0, 0, &mut slot);
assert_eq!(slot.value, -11,
"abs>10 + cur_lsb!=plan_bit → -1 magnitude flip (preserve sign)");
let mut hook = InjectionHook::new(ForceBit(1));
install_gate(&mut hook);
let mut slot = MvdSlot { list: 0, partition: 0, axis: Axis::X, value: 15 };
hook.on_mvd_slot(0, 0, &mut slot);
assert_eq!(slot.value, 15, "cur_lsb==plan_bit → no mutation");
let mut hook = InjectionHook::new(ForceBit(1));
install_gate(&mut hook);
let mut slot = MvdSlot { list: 0, partition: 0, axis: Axis::X, value: 0 };
hook.on_mvd_slot(0, 0, &mut slot);
assert_eq!(slot.value, 0, "zero MVD stays zero");
let mut hook = InjectionHook::new(ForceBit(0));
let mut slot = MvdSlot { list: 0, partition: 0, axis: Axis::X, value: -5 };
hook.on_mvd_slot(0, 0, &mut slot);
assert_eq!(slot.value, -5,
"no gate installed → on_mvd_slot is no-op (matches pre-d.4 production)");
}
#[test]
fn injection_hook_mvd_sign_override_returns_planned_bit() {
let mut hook = InjectionHook::new(ForceBit(0));
let slot = MvdSlot { list: 0, partition: 0, axis: Axis::X, value: -5 };
let r = hook.mvd_sign_override(0, 0, &slot);
assert_eq!(r, Some(0),
"mvd_sign_override must return the planned bit value");
}
#[test]
fn injection_hook_mvd_sign_override_zero_returns_none() {
let mut hook = InjectionHook::new(ForceBit(1));
let slot = MvdSlot { list: 0, partition: 0, axis: Axis::X, value: 0 };
let r = hook.mvd_sign_override(0, 0, &slot);
assert_eq!(r, None,
"mvd=0 has no sign bypass bin in spec; override is a no-op");
}
#[test]
fn position_logger_does_not_mutate_inputs() {
let mut hook = PositionLoggerHook::new();
let mut scan = vec![0i32; 16];
scan[0] = 5; scan[3] = -7;
let original = scan.clone();
hook.on_residual_block(
0, 0, &mut scan, 0, 15,
ResidualPathKind::Luma4x4 { block_idx: 0 },
);
assert_eq!(scan, original, "position logger must not mutate inputs");
}
}