use super::hook::{BinKind, BitInjector, EmbedDomain, PositionKey, SyntaxPath};
pub fn enumerate_coeff_sign_positions(
scan_coeffs: &[i32],
start_idx: usize,
end_idx: usize,
frame_idx: u32,
mb_addr: u32,
mut path_for_coeff: impl FnMut(u8) -> SyntaxPath,
) -> Vec<PositionKey> {
let mut sig: Vec<usize> = (start_idx..=end_idx)
.filter(|&i| scan_coeffs[i] != 0)
.collect();
sig.reverse();
sig.into_iter()
.map(|i| {
let path = with_sign_kind(path_for_coeff(i as u8));
PositionKey::new(frame_idx, mb_addr, EmbedDomain::CoeffSignBypass, path)
})
.collect()
}
pub fn apply_coeff_sign_overrides(
scan_coeffs: &mut [i32],
start_idx: usize,
end_idx: usize,
frame_idx: u32,
mb_addr: u32,
mut path_for_coeff: impl FnMut(u8) -> SyntaxPath,
injector: &mut dyn BitInjector,
) -> usize {
let mut count = 0usize;
let mut sig: Vec<usize> = (start_idx..=end_idx)
.filter(|&i| scan_coeffs[i] != 0)
.collect();
sig.reverse();
for i in sig {
let path = with_sign_kind(path_for_coeff(i as u8));
let key = PositionKey::new(frame_idx, mb_addr, EmbedDomain::CoeffSignBypass, path);
if let Some(bit) = injector.override_bit(key) {
let want_negative = bit == 1;
let is_negative = scan_coeffs[i] < 0;
if want_negative != is_negative {
scan_coeffs[i] = -scan_coeffs[i];
count += 1;
}
}
}
count
}
pub fn extract_coeff_sign_bits(
scan_coeffs: &[i32],
start_idx: usize,
end_idx: usize,
) -> Vec<u8> {
let mut sig: Vec<usize> = (start_idx..=end_idx)
.filter(|&i| scan_coeffs[i] != 0)
.collect();
sig.reverse();
sig.into_iter()
.map(|i| if scan_coeffs[i] < 0 { 1 } else { 0 })
.collect()
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub struct MvdSlot {
pub list: u8,
pub partition: u8,
pub axis: super::Axis,
pub value: i32,
}
pub fn enumerate_mvd_sign_positions(
slots: &[MvdSlot],
frame_idx: u32,
mb_addr: u32,
) -> Vec<PositionKey> {
slots
.iter()
.filter(|s| s.value != 0)
.map(|s| {
let path = SyntaxPath::Mvd {
list: s.list,
partition: s.partition,
axis: s.axis,
kind: BinKind::Sign,
};
PositionKey::new(frame_idx, mb_addr, EmbedDomain::MvdSignBypass, path)
})
.collect()
}
pub fn apply_mvd_sign_overrides(
slots: &mut [MvdSlot],
frame_idx: u32,
mb_addr: u32,
injector: &mut dyn BitInjector,
) -> usize {
let mut count = 0usize;
for s in slots.iter_mut() {
if s.value == 0 {
continue;
}
let path = SyntaxPath::Mvd {
list: s.list,
partition: s.partition,
axis: s.axis,
kind: BinKind::Sign,
};
let key = PositionKey::new(frame_idx, mb_addr, EmbedDomain::MvdSignBypass, path);
if let Some(bit) = injector.override_bit(key) {
let want_negative = bit == 1;
let is_negative = s.value < 0;
if want_negative != is_negative {
s.value = -s.value;
count += 1;
}
}
}
count
}
pub fn extract_mvd_sign_bits(slots: &[MvdSlot]) -> Vec<u8> {
slots
.iter()
.filter(|s| s.value != 0)
.map(|s| if s.value < 0 { 1 } else { 0 })
.collect()
}
#[derive(Default, Debug, Clone)]
pub struct DomainCover {
pub coeff_sign_bypass: DomainBits,
pub coeff_suffix_lsb: DomainBits,
pub mvd_sign_bypass: DomainBits,
pub mvd_suffix_lsb: DomainBits,
}
#[derive(Default, Debug, Clone)]
pub struct DomainBits {
pub bits: Vec<u8>,
pub positions: Vec<PositionKey>,
}
impl DomainBits {
pub fn len(&self) -> usize {
debug_assert_eq!(self.bits.len(), self.positions.len());
self.bits.len()
}
pub fn is_empty(&self) -> bool {
self.bits.is_empty()
}
pub fn push(&mut self, bit: u8, pos: PositionKey) {
debug_assert!(bit <= 1);
self.bits.push(bit);
self.positions.push(pos);
}
pub fn extend(&mut self, other: DomainBits) {
self.bits.extend(other.bits);
self.positions.extend(other.positions);
}
pub fn truncate(&mut self, new_len: usize) {
self.bits.truncate(new_len);
self.positions.truncate(new_len);
}
}
impl DomainCover {
pub fn new() -> Self {
Self::default()
}
pub fn total_len(&self) -> usize {
self.coeff_sign_bypass.len()
+ self.coeff_suffix_lsb.len()
+ self.mvd_sign_bypass.len()
+ self.mvd_suffix_lsb.len()
}
pub fn capacity(&self) -> super::GopCapacity {
super::GopCapacity {
coeff_sign_bypass: self.coeff_sign_bypass.len(),
coeff_suffix_lsb: self.coeff_suffix_lsb.len(),
mvd_sign_bypass: self.mvd_sign_bypass.len(),
mvd_suffix_lsb: self.mvd_suffix_lsb.len(),
}
}
pub fn for_domain_mut(&mut self, domain: EmbedDomain) -> &mut DomainBits {
match domain {
EmbedDomain::CoeffSignBypass => &mut self.coeff_sign_bypass,
EmbedDomain::CoeffSuffixLsb => &mut self.coeff_suffix_lsb,
EmbedDomain::MvdSignBypass => &mut self.mvd_sign_bypass,
EmbedDomain::MvdSuffixLsb => &mut self.mvd_suffix_lsb,
}
}
pub fn extend_from(&mut self, other: DomainCover) {
self.coeff_sign_bypass.extend(other.coeff_sign_bypass);
self.coeff_suffix_lsb.extend(other.coeff_suffix_lsb);
self.mvd_sign_bypass.extend(other.mvd_sign_bypass);
self.mvd_suffix_lsb.extend(other.mvd_suffix_lsb);
}
}
const COEFF_SUFFIX_LSB_THRESHOLD: u32 = 16;
#[inline]
fn suffix_lsb_bit_for_magnitude(abs: u32) -> u8 {
((abs & 1) ^ 1) as u8
}
#[inline]
fn flipped_magnitude(abs: u32, threshold: u32) -> u32 {
if abs == threshold { abs + 1 } else { abs - 1 }
}
pub fn enumerate_coeff_suffix_lsb_positions(
scan_coeffs: &[i32],
start_idx: usize,
end_idx: usize,
frame_idx: u32,
mb_addr: u32,
mut path_for_coeff: impl FnMut(u8) -> SyntaxPath,
) -> Vec<PositionKey> {
let mut sig: Vec<usize> = (start_idx..=end_idx)
.filter(|&i| scan_coeffs[i].unsigned_abs() >= COEFF_SUFFIX_LSB_THRESHOLD)
.collect();
sig.reverse();
sig.into_iter()
.map(|i| {
let path = with_suffix_lsb_kind(path_for_coeff(i as u8));
PositionKey::new(frame_idx, mb_addr, EmbedDomain::CoeffSuffixLsb, path)
})
.collect()
}
pub fn apply_coeff_suffix_lsb_overrides(
scan_coeffs: &mut [i32],
start_idx: usize,
end_idx: usize,
frame_idx: u32,
mb_addr: u32,
mut path_for_coeff: impl FnMut(u8) -> SyntaxPath,
injector: &mut dyn BitInjector,
) -> usize {
let mut count = 0usize;
let mut sig: Vec<usize> = (start_idx..=end_idx)
.filter(|&i| scan_coeffs[i].unsigned_abs() >= COEFF_SUFFIX_LSB_THRESHOLD)
.collect();
sig.reverse();
for i in sig {
let path = with_suffix_lsb_kind(path_for_coeff(i as u8));
let key = PositionKey::new(frame_idx, mb_addr, EmbedDomain::CoeffSuffixLsb, path);
if let Some(target_bit) = injector.override_bit(key) {
let abs = scan_coeffs[i].unsigned_abs();
let cover_bit = suffix_lsb_bit_for_magnitude(abs);
if target_bit != cover_bit {
let new_abs = flipped_magnitude(abs, COEFF_SUFFIX_LSB_THRESHOLD);
scan_coeffs[i] = if scan_coeffs[i] < 0 {
-(new_abs as i32)
} else {
new_abs as i32
};
count += 1;
}
}
}
count
}
pub fn extract_coeff_suffix_lsb_bits(
scan_coeffs: &[i32],
start_idx: usize,
end_idx: usize,
) -> Vec<u8> {
let mut sig: Vec<usize> = (start_idx..=end_idx)
.filter(|&i| scan_coeffs[i].unsigned_abs() >= COEFF_SUFFIX_LSB_THRESHOLD)
.collect();
sig.reverse();
sig.into_iter()
.map(|i| suffix_lsb_bit_for_magnitude(scan_coeffs[i].unsigned_abs()))
.collect()
}
const MVD_SUFFIX_LSB_THRESHOLD: u32 = 9;
pub fn enumerate_mvd_suffix_lsb_positions(
slots: &[MvdSlot],
frame_idx: u32,
mb_addr: u32,
) -> Vec<PositionKey> {
slots
.iter()
.filter(|s| s.value.unsigned_abs() >= MVD_SUFFIX_LSB_THRESHOLD)
.map(|s| {
let path = SyntaxPath::Mvd {
list: s.list,
partition: s.partition,
axis: s.axis,
kind: BinKind::SuffixLsb,
};
PositionKey::new(frame_idx, mb_addr, EmbedDomain::MvdSuffixLsb, path)
})
.collect()
}
pub fn apply_mvd_suffix_lsb_overrides(
slots: &mut [MvdSlot],
frame_idx: u32,
mb_addr: u32,
injector: &mut dyn BitInjector,
) -> usize {
let mut count = 0usize;
for s in slots.iter_mut() {
let abs = s.value.unsigned_abs();
if abs < MVD_SUFFIX_LSB_THRESHOLD {
continue;
}
let path = SyntaxPath::Mvd {
list: s.list,
partition: s.partition,
axis: s.axis,
kind: BinKind::SuffixLsb,
};
let key = PositionKey::new(frame_idx, mb_addr, EmbedDomain::MvdSuffixLsb, path);
if let Some(target_bit) = injector.override_bit(key) {
let cover_bit = suffix_lsb_bit_for_magnitude(abs);
if target_bit != cover_bit {
let new_abs = flipped_magnitude(abs, MVD_SUFFIX_LSB_THRESHOLD);
s.value = if s.value < 0 {
-(new_abs as i32)
} else {
new_abs as i32
};
count += 1;
}
}
}
count
}
pub fn extract_mvd_suffix_lsb_bits(slots: &[MvdSlot]) -> Vec<u8> {
slots
.iter()
.filter(|s| s.value.unsigned_abs() >= MVD_SUFFIX_LSB_THRESHOLD)
.map(|s| suffix_lsb_bit_for_magnitude(s.value.unsigned_abs()))
.collect()
}
fn with_suffix_lsb_kind(path: SyntaxPath) -> SyntaxPath {
match path {
SyntaxPath::Luma4x4 { block_idx, coeff_idx, .. } => SyntaxPath::Luma4x4 {
block_idx, coeff_idx, kind: BinKind::SuffixLsb,
},
SyntaxPath::Luma8x8 { block_idx, coeff_idx, .. } => SyntaxPath::Luma8x8 {
block_idx, coeff_idx, kind: BinKind::SuffixLsb,
},
SyntaxPath::ChromaAc { plane, block_idx, coeff_idx, .. } => SyntaxPath::ChromaAc {
plane, block_idx, coeff_idx, kind: BinKind::SuffixLsb,
},
SyntaxPath::ChromaDc { plane, coeff_idx, .. } => SyntaxPath::ChromaDc {
plane, coeff_idx, kind: BinKind::SuffixLsb,
},
SyntaxPath::LumaDcIntra16x16 { coeff_idx, .. } => SyntaxPath::LumaDcIntra16x16 {
coeff_idx, kind: BinKind::SuffixLsb,
},
SyntaxPath::Mvd { list, partition, axis, .. } => SyntaxPath::Mvd {
list, partition, axis, kind: BinKind::SuffixLsb,
},
}
}
fn with_sign_kind(path: SyntaxPath) -> SyntaxPath {
match path {
SyntaxPath::Luma4x4 { block_idx, coeff_idx, .. } => SyntaxPath::Luma4x4 {
block_idx, coeff_idx, kind: BinKind::Sign,
},
SyntaxPath::Luma8x8 { block_idx, coeff_idx, .. } => SyntaxPath::Luma8x8 {
block_idx, coeff_idx, kind: BinKind::Sign,
},
SyntaxPath::ChromaAc { plane, block_idx, coeff_idx, .. } => SyntaxPath::ChromaAc {
plane, block_idx, coeff_idx, kind: BinKind::Sign,
},
SyntaxPath::ChromaDc { plane, coeff_idx, .. } => SyntaxPath::ChromaDc {
plane, coeff_idx, kind: BinKind::Sign,
},
SyntaxPath::LumaDcIntra16x16 { coeff_idx, .. } => SyntaxPath::LumaDcIntra16x16 {
coeff_idx, kind: BinKind::Sign,
},
SyntaxPath::Mvd { list, partition, axis, .. } => SyntaxPath::Mvd {
list, partition, axis, kind: BinKind::Sign,
},
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codec::h264::stego::PositionRecorder;
fn luma4x4_path(coeff_idx: u8) -> SyntaxPath {
SyntaxPath::Luma4x4 { block_idx: 0, coeff_idx, kind: BinKind::Sign }
}
#[test]
fn enumerate_empty_block() {
let scan = vec![0i32; 16];
let positions = enumerate_coeff_sign_positions(&scan, 0, 15, 0, 0, luma4x4_path);
assert!(positions.is_empty());
}
#[test]
fn enumerate_single_coeff() {
let mut scan = vec![0i32; 16];
scan[5] = 3;
let positions = enumerate_coeff_sign_positions(&scan, 0, 15, 7, 100, luma4x4_path);
assert_eq!(positions.len(), 1);
assert_eq!(positions[0].frame_idx(), 7);
assert_eq!(positions[0].mb_addr(), 100);
assert_eq!(positions[0].domain(), EmbedDomain::CoeffSignBypass);
match positions[0].syntax_path() {
SyntaxPath::Luma4x4 { block_idx: 0, coeff_idx: 5, kind: BinKind::Sign } => (),
other => panic!("wrong path {other:?}"),
}
}
#[test]
fn enumerate_reverse_scan_order() {
let mut scan = vec![0i32; 16];
scan[0] = 1;
scan[3] = 2;
scan[7] = 3;
let positions = enumerate_coeff_sign_positions(&scan, 0, 15, 0, 0, luma4x4_path);
let coeff_idxs: Vec<u8> = positions
.iter()
.map(|k| match k.syntax_path() {
SyntaxPath::Luma4x4 { coeff_idx, .. } => coeff_idx,
_ => panic!(),
})
.collect();
assert_eq!(coeff_idxs, vec![7, 3, 0]);
}
struct ConstantBitInjector(u8);
impl BitInjector for ConstantBitInjector {
fn override_bit(&mut self, _key: PositionKey) -> Option<u8> {
Some(self.0)
}
}
#[test]
fn apply_overrides_force_all_positive() {
let mut scan = vec![0i32; 16];
scan[0] = 5;
scan[2] = -3;
scan[5] = -7;
scan[10] = 4;
let mut inj = ConstantBitInjector(0); let count = apply_coeff_sign_overrides(
&mut scan, 0, 15, 0, 0, luma4x4_path, &mut inj,
);
assert_eq!(count, 2);
assert_eq!(scan[0], 5);
assert_eq!(scan[2], 3); assert_eq!(scan[5], 7); assert_eq!(scan[10], 4);
}
#[test]
fn apply_overrides_force_all_negative() {
let mut scan = vec![0i32; 16];
scan[0] = 5;
scan[2] = -3;
scan[5] = -7;
scan[10] = 4;
let mut inj = ConstantBitInjector(1); let count = apply_coeff_sign_overrides(
&mut scan, 0, 15, 0, 0, luma4x4_path, &mut inj,
);
assert_eq!(count, 2);
assert_eq!(scan[0], -5); assert_eq!(scan[2], -3);
assert_eq!(scan[5], -7);
assert_eq!(scan[10], -4); }
struct PlanInjector {
plan: std::collections::HashMap<PositionKey, u8>,
}
impl BitInjector for PlanInjector {
fn override_bit(&mut self, key: PositionKey) -> Option<u8> {
self.plan.get(&key).copied()
}
}
#[test]
fn apply_overrides_with_explicit_plan() {
let target_key = PositionKey::new(
0, 0, EmbedDomain::CoeffSignBypass,
SyntaxPath::Luma4x4 { block_idx: 0, coeff_idx: 5, kind: BinKind::Sign },
);
let mut plan = std::collections::HashMap::new();
plan.insert(target_key, 1u8);
let mut inj = PlanInjector { plan };
let mut scan = vec![0i32; 16];
scan[2] = 3;
scan[5] = 7;
scan[10] = -4;
let count = apply_coeff_sign_overrides(
&mut scan, 0, 15, 0, 0, luma4x4_path, &mut inj,
);
assert_eq!(count, 1);
assert_eq!(scan[2], 3);
assert_eq!(scan[5], -7);
assert_eq!(scan[10], -4);
}
#[test]
fn extract_sign_bits_matches_enumerate_order() {
let mut scan = vec![0i32; 16];
scan[0] = -3; scan[5] = 7; scan[10] = -1; let bits = extract_coeff_sign_bits(&scan, 0, 15);
assert_eq!(bits, vec![1, 0, 1]);
}
#[test]
fn mvd_enumerate_skips_zero_values() {
let slots = vec![
MvdSlot { list: 0, partition: 0, axis: super::super::Axis::X, value: 5 },
MvdSlot { list: 0, partition: 0, axis: super::super::Axis::Y, value: 0 },
MvdSlot { list: 0, partition: 1, axis: super::super::Axis::X, value: -3 },
MvdSlot { list: 0, partition: 1, axis: super::super::Axis::Y, value: 0 },
];
let positions = enumerate_mvd_sign_positions(&slots, 5, 100);
assert_eq!(positions.len(), 2);
assert_eq!(positions[0].frame_idx(), 5);
assert_eq!(positions[0].mb_addr(), 100);
for k in &positions {
assert_eq!(k.domain(), EmbedDomain::MvdSignBypass);
}
}
#[test]
fn mvd_apply_overrides_force_negative() {
let mut slots = vec![
MvdSlot { list: 0, partition: 0, axis: super::super::Axis::X, value: 5 },
MvdSlot { list: 0, partition: 0, axis: super::super::Axis::Y, value: 0 },
MvdSlot { list: 0, partition: 1, axis: super::super::Axis::X, value: -3 },
];
let mut inj = ConstantBitInjector(1);
let count = apply_mvd_sign_overrides(&mut slots, 0, 0, &mut inj);
assert_eq!(count, 1, "only the positive slot 0 should flip");
assert_eq!(slots[0].value, -5);
assert_eq!(slots[1].value, 0); assert_eq!(slots[2].value, -3);
}
#[test]
fn mvd_extract_skips_zero_values() {
let slots = vec![
MvdSlot { list: 0, partition: 0, axis: super::super::Axis::X, value: 5 },
MvdSlot { list: 0, partition: 0, axis: super::super::Axis::Y, value: 0 },
MvdSlot { list: 0, partition: 1, axis: super::super::Axis::X, value: -3 },
];
let bits = extract_mvd_sign_bits(&slots);
assert_eq!(bits, vec![0, 1]);
}
#[test]
fn domain_cover_capacity_and_total() {
let mut cover = DomainCover::new();
cover.coeff_sign_bypass.push(
0,
PositionKey::new(
0, 0, EmbedDomain::CoeffSignBypass,
SyntaxPath::Luma4x4 { block_idx: 0, coeff_idx: 0, kind: BinKind::Sign },
),
);
cover.coeff_sign_bypass.push(
1,
PositionKey::new(
0, 0, EmbedDomain::CoeffSignBypass,
SyntaxPath::Luma4x4 { block_idx: 0, coeff_idx: 1, kind: BinKind::Sign },
),
);
cover.mvd_sign_bypass.push(
0,
PositionKey::new(
0, 0, EmbedDomain::MvdSignBypass,
SyntaxPath::Mvd { list: 0, partition: 0, axis: super::super::Axis::X,
kind: BinKind::Sign },
),
);
let cap = cover.capacity();
assert_eq!(cap.coeff_sign_bypass, 2);
assert_eq!(cap.mvd_sign_bypass, 1);
assert_eq!(cap.coeff_suffix_lsb, 0);
assert_eq!(cap.mvd_suffix_lsb, 0);
assert_eq!(cover.total_len(), 3);
}
#[test]
fn domain_cover_for_domain_mut_dispatches_correctly() {
let mut cover = DomainCover::new();
let dummy_key = PositionKey::new(
0, 0, EmbedDomain::MvdSuffixLsb,
SyntaxPath::Mvd { list: 0, partition: 0, axis: super::super::Axis::X,
kind: BinKind::SuffixLsb },
);
cover.for_domain_mut(EmbedDomain::MvdSuffixLsb).push(1, dummy_key);
assert_eq!(cover.mvd_suffix_lsb.len(), 1);
assert_eq!(cover.coeff_sign_bypass.len(), 0);
assert_eq!(cover.mvd_sign_bypass.len(), 0);
}
#[test]
fn coeff_suffix_lsb_below_threshold_not_eligible() {
let mut scan = vec![0i32; 16];
scan[0] = 15;
let positions = enumerate_coeff_suffix_lsb_positions(
&scan, 0, 15, 0, 0, luma4x4_path,
);
assert!(positions.is_empty());
}
#[test]
fn coeff_suffix_lsb_threshold_eligible() {
let mut scan = vec![0i32; 16];
scan[0] = 16;
let positions = enumerate_coeff_suffix_lsb_positions(
&scan, 0, 15, 0, 0, luma4x4_path,
);
assert_eq!(positions.len(), 1);
let bits = extract_coeff_suffix_lsb_bits(&scan, 0, 15);
assert_eq!(bits, vec![1]);
}
#[test]
fn coeff_suffix_lsb_threshold_flip_must_go_up() {
let mut scan = vec![0i32; 16];
scan[0] = 16;
let mut inj = ConstantBitInjector(0); let count = apply_coeff_suffix_lsb_overrides(
&mut scan, 0, 15, 0, 0, luma4x4_path, &mut inj,
);
assert_eq!(count, 1);
assert_eq!(scan[0], 17, "must go +1 to stay above threshold");
}
#[test]
fn coeff_suffix_lsb_above_threshold_flip_goes_down() {
let mut scan = vec![0i32; 16];
scan[0] = 20;
let mut inj = ConstantBitInjector(0);
let count = apply_coeff_suffix_lsb_overrides(
&mut scan, 0, 15, 0, 0, luma4x4_path, &mut inj,
);
assert_eq!(count, 1);
assert_eq!(scan[0], 19, "should go -1 since |20| > threshold");
}
#[test]
fn coeff_suffix_lsb_negative_sign_preserved() {
let mut scan = vec![0i32; 16];
scan[0] = -20;
let mut inj = ConstantBitInjector(0);
apply_coeff_suffix_lsb_overrides(
&mut scan, 0, 15, 0, 0, luma4x4_path, &mut inj,
);
assert_eq!(scan[0], -19, "sign preserved across magnitude flip");
}
#[test]
fn coeff_suffix_lsb_no_flip_when_cover_matches_target() {
let mut scan = vec![0i32; 16];
scan[0] = 17; let mut inj = ConstantBitInjector(0); let count = apply_coeff_suffix_lsb_overrides(
&mut scan, 0, 15, 0, 0, luma4x4_path, &mut inj,
);
assert_eq!(count, 0);
assert_eq!(scan[0], 17);
}
#[test]
fn mvd_suffix_lsb_below_threshold_not_eligible() {
let slots = vec![
MvdSlot { list: 0, partition: 0, axis: super::super::Axis::X, value: 8 },
];
let positions = enumerate_mvd_suffix_lsb_positions(&slots, 0, 0);
assert!(positions.is_empty());
}
#[test]
fn mvd_suffix_lsb_threshold_eligible() {
let slots = vec![
MvdSlot { list: 0, partition: 0, axis: super::super::Axis::X, value: 9 },
MvdSlot { list: 0, partition: 0, axis: super::super::Axis::Y, value: -10 },
];
let positions = enumerate_mvd_suffix_lsb_positions(&slots, 0, 0);
assert_eq!(positions.len(), 2);
let bits = extract_mvd_suffix_lsb_bits(&slots);
assert_eq!(bits, vec![0, 1]);
}
#[test]
fn mvd_suffix_lsb_threshold_flip_goes_up() {
let mut slots = vec![
MvdSlot { list: 0, partition: 0, axis: super::super::Axis::X, value: 9 },
];
let mut inj = ConstantBitInjector(1); let count = apply_mvd_suffix_lsb_overrides(&mut slots, 0, 0, &mut inj);
assert_eq!(count, 1);
assert_eq!(slots[0].value, 10);
}
struct Lcg(u32);
impl Lcg {
fn next_u32(&mut self) -> u32 {
self.0 = self.0.wrapping_mul(1664525).wrapping_add(1013904223);
self.0
}
fn next_bit(&mut self) -> u8 {
(self.next_u32() & 1) as u8
}
fn next_bounded(&mut self, n: u32) -> u32 {
self.next_u32() % n.max(1)
}
fn next_signed(&mut self, max_abs: i32) -> i32 {
let mag = (self.next_u32() % (max_abs as u32 + 1)) as i32;
if self.next_bit() == 0 { mag } else { -mag }
}
}
#[test]
fn coeff_sign_inject_random_roundtrip_property() {
let mut lcg = Lcg(0x1234_5678);
for _trial in 0..32 {
let mut scan = vec![0i32; 16];
let nonzero_count = lcg.next_bounded(9) as usize;
for _ in 0..nonzero_count {
let pos = lcg.next_bounded(16) as usize;
scan[pos] = lcg.next_signed(50);
if scan[pos] == 0 {
scan[pos] = 1; }
}
let positions = enumerate_coeff_sign_positions(
&scan, 0, 15, 0, 0, luma4x4_path,
);
let plan: Vec<u8> = (0..positions.len())
.map(|_| lcg.next_bit())
.collect();
let plan_map: std::collections::HashMap<PositionKey, u8> =
positions.iter().zip(plan.iter()).map(|(&k, &b)| (k, b)).collect();
struct PlanInjector(std::collections::HashMap<PositionKey, u8>);
impl BitInjector for PlanInjector {
fn override_bit(&mut self, key: PositionKey) -> Option<u8> {
self.0.get(&key).copied()
}
}
let mut injector = PlanInjector(plan_map);
let original_magnitudes: Vec<u32> =
scan.iter().map(|c| c.unsigned_abs()).collect();
apply_coeff_sign_overrides(
&mut scan, 0, 15, 0, 0, luma4x4_path, &mut injector,
);
for (i, &orig) in original_magnitudes.iter().enumerate() {
assert_eq!(
scan[i].unsigned_abs(),
orig,
"magnitude shifted at scan[{i}]",
);
}
let extracted = extract_coeff_sign_bits(&scan, 0, 15);
assert_eq!(
extracted, plan,
"extracted bits != plan after random injection",
);
}
}
#[test]
fn mvd_sign_inject_random_roundtrip_property() {
let mut lcg = Lcg(0xDEAD_BEEF);
for _trial in 0..32 {
let slot_count = (lcg.next_bounded(4) + 1) as usize;
let mut slots = Vec::with_capacity(slot_count);
for i in 0..slot_count {
let value = if lcg.next_bounded(4) == 0 {
0i32
} else {
lcg.next_signed(20)
};
slots.push(MvdSlot {
list: 0,
partition: i as u8,
axis: super::super::Axis::X,
value,
});
}
let positions = enumerate_mvd_sign_positions(&slots, 0, 0);
let plan: Vec<u8> = (0..positions.len())
.map(|_| lcg.next_bit())
.collect();
let plan_map: std::collections::HashMap<PositionKey, u8> =
positions.iter().zip(plan.iter()).map(|(&k, &b)| (k, b)).collect();
struct PlanInjector(std::collections::HashMap<PositionKey, u8>);
impl BitInjector for PlanInjector {
fn override_bit(&mut self, key: PositionKey) -> Option<u8> {
self.0.get(&key).copied()
}
}
let mut injector = PlanInjector(plan_map);
let orig_magnitudes: Vec<u32> = slots.iter().map(|s| s.value.unsigned_abs()).collect();
apply_mvd_sign_overrides(&mut slots, 0, 0, &mut injector);
for (s, &orig) in slots.iter().zip(orig_magnitudes.iter()) {
assert_eq!(s.value.unsigned_abs(), orig);
}
let extracted = extract_mvd_sign_bits(&slots);
assert_eq!(extracted, plan);
}
}
#[test]
fn position_count_equals_extracted_bit_count() {
let mut scan = vec![0i32; 16];
for i in 0..16 {
scan[i] = if i % 3 == 0 { (i as i32) - 4 } else { 0 };
}
let positions = enumerate_coeff_sign_positions(&scan, 0, 15, 0, 0, luma4x4_path);
let bits = extract_coeff_sign_bits(&scan, 0, 15);
assert_eq!(positions.len(), bits.len());
let recorder = {
let mut r = PositionRecorder::new();
for pos in &positions {
use crate::codec::h264::stego::PositionLogger;
r.register(*pos);
}
r
};
assert_eq!(recorder.positions, positions);
}
}