use crate::stego::stc::embed::{stc_embed, EmbedResult};
use crate::stego::stc::hhat::generate_hhat;
use super::cost_model::{
coeff_sign_cost_vec, coeff_suffix_lsb_cost_vec, mvd_sign_cost_vec,
mvd_suffix_lsb_cost_vec, PositionCostCtx,
};
use super::keys::CabacStegoMasterKeys;
#[allow(unused_imports)]
use super::keys::DomainSeeds;
use super::{
apply_coeff_sign_overrides, apply_coeff_suffix_lsb_overrides,
apply_mvd_sign_overrides, apply_mvd_suffix_lsb_overrides,
enumerate_coeff_sign_positions, enumerate_coeff_suffix_lsb_positions,
enumerate_mvd_sign_positions, enumerate_mvd_suffix_lsb_positions,
extract_coeff_sign_bits, extract_coeff_suffix_lsb_bits,
extract_mvd_sign_bits, extract_mvd_suffix_lsb_bits,
BinKind, BitInjector, DomainBits, DomainCover, EmbedDomain, MvdSlot, PositionKey,
SyntaxPath,
};
#[derive(Clone, Debug)]
pub struct MbDecision {
pub frame_idx: u32,
pub mb_addr: u32,
pub residual_blocks: Vec<MbResidualBlock>,
pub mvd_slots: Vec<MvdSlot>,
}
#[derive(Clone, Debug)]
pub struct MbResidualBlock {
pub scan_coeffs: Vec<i32>,
pub start_idx: usize,
pub end_idx: usize,
pub ctx_block_cat: u8,
pub path_kind: ResidualPathKind,
}
#[derive(Copy, Clone, Debug)]
pub enum ResidualPathKind {
Luma4x4 { block_idx: u8 },
Luma8x8 { block_idx: u8 },
ChromaAc { plane: u8, block_idx: u8 },
ChromaDc { plane: u8 },
LumaDcIntra16x16,
}
impl ResidualPathKind {
pub fn path(self, coeff_idx: u8, kind: BinKind) -> SyntaxPath {
match self {
ResidualPathKind::Luma4x4 { block_idx } => SyntaxPath::Luma4x4 {
block_idx, coeff_idx, kind,
},
ResidualPathKind::Luma8x8 { block_idx } => SyntaxPath::Luma8x8 {
block_idx, coeff_idx, kind,
},
ResidualPathKind::ChromaAc { plane, block_idx } => SyntaxPath::ChromaAc {
plane, block_idx, coeff_idx, kind,
},
ResidualPathKind::ChromaDc { plane } => SyntaxPath::ChromaDc {
plane, coeff_idx, kind,
},
ResidualPathKind::LumaDcIntra16x16 => SyntaxPath::LumaDcIntra16x16 {
coeff_idx, kind,
},
}
}
}
#[derive(Default, Clone, Debug)]
pub struct GopDecisionCache {
pub mbs: Vec<MbDecision>,
}
impl GopDecisionCache {
pub fn new() -> Self {
Self::default()
}
pub fn push(&mut self, mb: MbDecision) {
self.mbs.push(mb);
}
}
#[derive(Default, Clone, Debug)]
pub struct GopCover {
pub cover: DomainCover,
pub costs: DomainCosts,
}
#[derive(Default, Clone, Debug)]
pub struct DomainCosts {
pub coeff_sign_bypass: Vec<f32>,
pub coeff_suffix_lsb: Vec<f32>,
pub mvd_sign_bypass: Vec<f32>,
pub mvd_suffix_lsb: Vec<f32>,
}
pub fn pass1_collect_cover(cache: &GopDecisionCache) -> GopCover {
let mut out = GopCover::default();
for mb in &cache.mbs {
let cost_ctx = PositionCostCtx::new(mb.frame_idx, mb.mb_addr);
for blk in &mb.residual_blocks {
let positions = enumerate_coeff_sign_positions(
&blk.scan_coeffs,
blk.start_idx,
blk.end_idx,
mb.frame_idx,
mb.mb_addr,
|ci| blk.path_kind.path(ci, BinKind::Sign),
);
let bits = extract_coeff_sign_bits(
&blk.scan_coeffs, blk.start_idx, blk.end_idx,
);
let costs = coeff_sign_cost_vec(
&blk.scan_coeffs, blk.start_idx, blk.end_idx, &cost_ctx,
);
for ((p, b), c) in positions.iter().zip(bits.iter()).zip(costs.iter()) {
out.cover.coeff_sign_bypass.push(*b, *p);
out.costs.coeff_sign_bypass.push(*c);
}
let positions = enumerate_coeff_suffix_lsb_positions(
&blk.scan_coeffs,
blk.start_idx,
blk.end_idx,
mb.frame_idx,
mb.mb_addr,
|ci| blk.path_kind.path(ci, BinKind::SuffixLsb),
);
let bits = extract_coeff_suffix_lsb_bits(
&blk.scan_coeffs, blk.start_idx, blk.end_idx,
);
let costs = coeff_suffix_lsb_cost_vec(
&blk.scan_coeffs, blk.start_idx, blk.end_idx, &cost_ctx,
);
for ((p, b), c) in positions.iter().zip(bits.iter()).zip(costs.iter()) {
out.cover.coeff_suffix_lsb.push(*b, *p);
out.costs.coeff_suffix_lsb.push(*c);
}
}
let positions = enumerate_mvd_sign_positions(&mb.mvd_slots, mb.frame_idx, mb.mb_addr);
let bits = extract_mvd_sign_bits(&mb.mvd_slots);
let costs = mvd_sign_cost_vec(&mb.mvd_slots, &cost_ctx);
for ((p, b), c) in positions.iter().zip(bits.iter()).zip(costs.iter()) {
out.cover.mvd_sign_bypass.push(*b, *p);
out.costs.mvd_sign_bypass.push(*c);
}
let positions = enumerate_mvd_suffix_lsb_positions(&mb.mvd_slots, mb.frame_idx, mb.mb_addr);
let bits = extract_mvd_suffix_lsb_bits(&mb.mvd_slots);
let costs = mvd_suffix_lsb_cost_vec(&mb.mvd_slots, &cost_ctx);
for ((p, b), c) in positions.iter().zip(bits.iter()).zip(costs.iter()) {
out.cover.mvd_suffix_lsb.push(*b, *p);
out.costs.mvd_suffix_lsb.push(*c);
}
}
out
}
#[derive(Default, Clone, Debug)]
pub struct DomainPlan {
pub coeff_sign_bypass: Vec<u8>,
pub coeff_suffix_lsb: Vec<u8>,
pub mvd_sign_bypass: Vec<u8>,
pub mvd_suffix_lsb: Vec<u8>,
pub total_modifications: usize,
pub total_cost: f64,
}
pub fn pass2_stc_plan(
cover: &GopCover,
messages: &DomainMessages,
h: usize,
) -> Option<DomainPlan> {
pass2_stc_plan_internal(cover, messages, h, None)
}
pub fn pass2_stc_plan_with_keys(
cover: &GopCover,
messages: &DomainMessages,
h: usize,
keys: &CabacStegoMasterKeys,
gop_idx: u32,
) -> Option<DomainPlan> {
pass2_stc_plan_internal(cover, messages, h, Some((keys, gop_idx)))
}
fn pass2_stc_plan_internal(
cover: &GopCover,
messages: &DomainMessages,
h: usize,
keys_and_gop: Option<(&CabacStegoMasterKeys, u32)>,
) -> Option<DomainPlan> {
let mut plan = DomainPlan::default();
for (domain, cover_bits, costs, message, plan_slot) in [
(EmbedDomain::CoeffSignBypass, &cover.cover.coeff_sign_bypass,
cover.costs.coeff_sign_bypass.as_slice(), &messages.coeff_sign_bypass,
&mut plan.coeff_sign_bypass),
(EmbedDomain::CoeffSuffixLsb, &cover.cover.coeff_suffix_lsb,
cover.costs.coeff_suffix_lsb.as_slice(), &messages.coeff_suffix_lsb,
&mut plan.coeff_suffix_lsb),
(EmbedDomain::MvdSignBypass, &cover.cover.mvd_sign_bypass,
cover.costs.mvd_sign_bypass.as_slice(), &messages.mvd_sign_bypass,
&mut plan.mvd_sign_bypass),
(EmbedDomain::MvdSuffixLsb, &cover.cover.mvd_suffix_lsb,
cover.costs.mvd_suffix_lsb.as_slice(), &messages.mvd_suffix_lsb,
&mut plan.mvd_suffix_lsb),
] {
let seed = match keys_and_gop {
Some((k, gop)) => k.per_gop_seeds(domain, gop).hhat_seed,
None => [0u8; 32],
};
let r = plan_one_domain_seeded(cover_bits, costs, message, h, &seed)?;
*plan_slot = r.bits;
plan.total_modifications += r.num_modifications;
plan.total_cost += r.total_cost;
}
Some(plan)
}
struct DomainPlanResult {
bits: Vec<u8>,
num_modifications: usize,
total_cost: f64,
}
#[derive(Default, Clone, Debug)]
pub struct DomainMessages {
pub coeff_sign_bypass: Vec<u8>,
pub coeff_suffix_lsb: Vec<u8>,
pub mvd_sign_bypass: Vec<u8>,
pub mvd_suffix_lsb: Vec<u8>,
}
pub fn split_message_per_domain(
message: &[u8],
capacities: &super::GopCapacity,
) -> Option<DomainMessages> {
let m_total = message.len();
if m_total == 0 {
return Some(DomainMessages::default());
}
let n_total = capacities.total();
if n_total < m_total {
return None;
}
let mut m_coeff_sign = (m_total * capacities.coeff_sign_bypass) / n_total;
let mut m_coeff_suffix = (m_total * capacities.coeff_suffix_lsb) / n_total;
let mut m_mvd_sign = (m_total * capacities.mvd_sign_bypass) / n_total;
let mut m_mvd_suffix = (m_total * capacities.mvd_suffix_lsb) / n_total;
let mut leftover = m_total - (m_coeff_sign + m_coeff_suffix + m_mvd_sign + m_mvd_suffix);
while leftover > 0 {
let pick = pick_max_headroom(capacities, &[
(m_coeff_sign, capacities.coeff_sign_bypass),
(m_coeff_suffix, capacities.coeff_suffix_lsb),
(m_mvd_sign, capacities.mvd_sign_bypass),
(m_mvd_suffix, capacities.mvd_suffix_lsb),
]);
match pick {
0 => m_coeff_sign += 1,
1 => m_coeff_suffix += 1,
2 => m_mvd_sign += 1,
3 => m_mvd_suffix += 1,
_ => unreachable!(),
}
leftover -= 1;
}
let bits: Vec<u8> = message
.iter()
.flat_map(|&b| (0..8).rev().map(move |i| (b >> i) & 1))
.collect();
let bit_stream: Vec<u8> = if message.iter().all(|&b| b <= 1) {
message.to_vec()
} else {
bits
};
let total_bits = bit_stream.len();
let m_total = total_bits;
let _ = m_total;
let mut cursor = 0usize;
let take = |start: &mut usize, n: usize| -> Vec<u8> {
let end = (*start + n).min(bit_stream.len());
let slice = bit_stream[*start..end].to_vec();
*start = end;
slice
};
let mut take_mut = |n: usize| -> Vec<u8> {
let end = (cursor + n).min(bit_stream.len());
let s = bit_stream[cursor..end].to_vec();
cursor = end;
s
};
let _ = take;
let coeff_sign_bypass = take_mut(m_coeff_sign);
let coeff_suffix_lsb = take_mut(m_coeff_suffix);
let mvd_sign_bypass = take_mut(m_mvd_sign);
let mvd_suffix_lsb = take_mut(m_mvd_suffix);
Some(DomainMessages {
coeff_sign_bypass,
coeff_suffix_lsb,
mvd_sign_bypass,
mvd_suffix_lsb,
})
}
#[derive(Copy, Clone, Debug)]
pub struct StealthAllocator {
pub w_coeff_sign: f64,
pub w_coeff_suffix: f64,
pub w_mvd_sign: f64,
pub w_mvd_suffix: f64,
pub mvd_drift_budget_frac: f64,
}
impl Default for StealthAllocator {
fn default() -> Self {
Self::v1_default()
}
}
impl StealthAllocator {
pub const fn v1_default() -> Self {
Self {
w_coeff_sign: 0.5,
w_coeff_suffix: 0.8,
w_mvd_sign: 1.0,
w_mvd_suffix: 1.0,
mvd_drift_budget_frac: 0.20,
}
}
}
pub fn stealth_weighted_allocation(
m_total: usize,
capacities: &super::GopCapacity,
allocator: &StealthAllocator,
) -> Option<(usize, usize, usize, usize)> {
if m_total == 0 {
return Some((0, 0, 0, 0));
}
let n_cs = capacities.coeff_sign_bypass as f64;
let n_cl = capacities.coeff_suffix_lsb as f64;
let n_ms = capacities.mvd_sign_bypass as f64;
let n_ml = capacities.mvd_suffix_lsb as f64;
let nw_cs = n_cs * allocator.w_coeff_sign;
let nw_cl = n_cl * allocator.w_coeff_suffix;
let nw_ms = n_ms * allocator.w_mvd_sign;
let nw_ml = n_ml * allocator.w_mvd_suffix;
let nw_sum = nw_cs + nw_cl + nw_ms + nw_ml;
if nw_sum <= 0.0 {
return None;
}
if (capacities.coeff_sign_bypass
+ capacities.coeff_suffix_lsb
+ capacities.mvd_sign_bypass
+ capacities.mvd_suffix_lsb) < m_total
{
return None;
}
let m_total_f = m_total as f64;
let mut m_cs = ((m_total_f * nw_cs) / nw_sum).floor() as usize;
let mut m_cl = ((m_total_f * nw_cl) / nw_sum).floor() as usize;
let mut m_ms = ((m_total_f * nw_ms) / nw_sum).floor() as usize;
let mut m_ml = ((m_total_f * nw_ml) / nw_sum).floor() as usize;
let mvd_share_max = (m_total_f * allocator.mvd_drift_budget_frac).floor() as usize;
let mvd_share = m_ms + m_ml;
if mvd_share > mvd_share_max {
let overflow = mvd_share - mvd_share_max;
let mvd_total = (m_ms + m_ml) as f64;
if mvd_total > 0.0 {
let new_m_ms = ((mvd_share_max as f64) * (m_ms as f64) / mvd_total).floor() as usize;
let new_m_ml = mvd_share_max.saturating_sub(new_m_ms);
m_ms = new_m_ms;
m_ml = new_m_ml;
} else {
m_ms = 0;
m_ml = 0;
}
let coeff_nw = nw_cs + nw_cl;
if coeff_nw > 0.0 {
let extra_cs = ((overflow as f64) * nw_cs / coeff_nw).floor() as usize;
let extra_cl = overflow.saturating_sub(extra_cs);
m_cs = m_cs.saturating_add(extra_cs);
m_cl = m_cl.saturating_add(extra_cl);
}
}
let cap_cs = capacities.coeff_sign_bypass;
let cap_cl = capacities.coeff_suffix_lsb;
let cap_ms = capacities.mvd_sign_bypass;
let cap_ml = capacities.mvd_suffix_lsb;
if m_cs > cap_cs { m_cs = cap_cs; }
if m_cl > cap_cl { m_cl = cap_cl; }
if m_ms > cap_ms { m_ms = cap_ms; }
if m_ml > cap_ml { m_ml = cap_ml; }
let mut remainder = m_total.saturating_sub(m_cs + m_cl + m_ms + m_ml);
while remainder > 0 {
let mvd_share_now = m_ms + m_ml;
let mvd_room = mvd_share_max.saturating_sub(mvd_share_now);
let h_cs = ((cap_cs - m_cs) as f64) * allocator.w_coeff_sign;
let h_cl = ((cap_cl - m_cl) as f64) * allocator.w_coeff_suffix;
let h_ms = if mvd_room > 0 {
((cap_ms - m_ms) as f64) * allocator.w_mvd_sign
} else { -1.0 };
let h_ml = if mvd_room > 0 {
((cap_ml - m_ml) as f64) * allocator.w_mvd_suffix
} else { -1.0 };
let candidates = [(h_cs, 0usize), (h_cl, 1), (h_ms, 2), (h_ml, 3)];
let (best_h, best_idx) = candidates
.iter()
.copied()
.filter(|(h, _)| *h > 0.0)
.max_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or((-1.0, 0));
if best_h <= 0.0 {
return None;
}
match best_idx {
0 => m_cs += 1,
1 => m_cl += 1,
2 => m_ms += 1,
3 => m_ml += 1,
_ => unreachable!(),
}
remainder -= 1;
}
debug_assert_eq!(m_cs + m_cl + m_ms + m_ml, m_total);
Some((m_cs, m_cl, m_ms, m_ml))
}
fn pick_max_headroom(
_caps: &super::GopCapacity,
allocated_capacity_pairs: &[(usize, usize)],
) -> usize {
let mut best_idx = 0;
let mut best_headroom: i64 = i64::MIN;
for (i, &(allocated, capacity)) in allocated_capacity_pairs.iter().enumerate() {
let headroom = capacity as i64 - allocated as i64;
if headroom > best_headroom {
best_headroom = headroom;
best_idx = i;
}
}
best_idx
}
#[cfg(test)]
mod stealth_alloc_tests {
use super::*;
fn caps(cs: usize, cl: usize, ms: usize, ml: usize) -> super::super::GopCapacity {
super::super::GopCapacity {
coeff_sign_bypass: cs,
coeff_suffix_lsb: cl,
mvd_sign_bypass: ms,
mvd_suffix_lsb: ml,
}
}
#[test]
fn weighted_alloc_total_equals_m_total() {
let alloc = StealthAllocator::v1_default();
let c = caps(500, 100, 200, 50);
for &m in &[1usize, 8, 100, 400, 600] {
let (a, b, c2, d) = stealth_weighted_allocation(m, &c, &alloc).unwrap();
assert_eq!(a + b + c2 + d, m,
"allocation must sum to m_total ({m}); got {a}+{b}+{c2}+{d}");
}
}
#[test]
fn weighted_alloc_returns_none_when_drift_budget_blocks_capacity() {
let alloc = StealthAllocator::v1_default();
let c = caps(500, 100, 200, 50);
assert!(stealth_weighted_allocation(800, &c, &alloc).is_none(),
"expected None when drift budget can't accommodate M");
}
#[test]
fn weighted_alloc_respects_mvd_drift_budget() {
let alloc = StealthAllocator::v1_default(); let c = caps(500, 100, 200, 50);
let m = 400usize;
let (_, _, ms, ml) = stealth_weighted_allocation(m, &c, &alloc).unwrap();
let mvd_share = ms + ml;
let cap_share = (m as f64 * alloc.mvd_drift_budget_frac).floor() as usize;
assert!(mvd_share <= cap_share,
"mvd_share {mvd_share} must respect drift budget cap {cap_share}");
}
#[test]
fn weighted_alloc_pushes_to_mvd_under_balanced_caps() {
let alloc = StealthAllocator::v1_default();
let c = caps(100, 100, 100, 100);
let m = 80usize;
let (m_cs, _m_cl, m_ms, m_ml) = stealth_weighted_allocation(m, &c, &alloc).unwrap();
assert!(m_ms + m_ml <= 16, "MVD share capped at drift budget");
assert!(m_cs < (m - m_ms - m_ml),
"coeff_sign should not absorb all coeff bits (lower weight)");
}
#[test]
fn weighted_alloc_returns_none_when_capacity_insufficient() {
let alloc = StealthAllocator::v1_default();
let c = caps(10, 10, 10, 10);
let m = 100usize;
assert!(stealth_weighted_allocation(m, &c, &alloc).is_none());
}
#[test]
fn weighted_alloc_zero_message() {
let alloc = StealthAllocator::v1_default();
let c = caps(100, 100, 100, 100);
assert_eq!(
stealth_weighted_allocation(0, &c, &alloc).unwrap(),
(0, 0, 0, 0)
);
}
}
fn plan_one_domain_seeded(
cover_bits: &DomainBits,
cost: &[f32],
message: &[u8],
h: usize,
seed: &[u8; 32],
) -> Option<DomainPlanResult> {
if cover_bits.is_empty() || message.is_empty() {
return Some(DomainPlanResult {
bits: cover_bits.bits.clone(),
num_modifications: 0,
total_cost: 0.0,
});
}
let n = cover_bits.bits.len();
let m = message.len();
let w = n / m.max(1);
if w == 0 {
return None;
}
let hhat = generate_hhat(h, w, seed);
let result: EmbedResult = stc_embed(
&cover_bits.bits, cost, message, &hhat, h, w,
)?;
Some(DomainPlanResult {
bits: result.stego_bits,
num_modifications: result.num_modifications,
total_cost: result.total_cost,
})
}
pub struct PlanInjector {
plan: std::collections::HashMap<PositionKey, u8>,
}
impl PlanInjector {
pub fn from_plan(cover: &DomainCover, plan: &DomainPlan) -> Self {
let mut map = std::collections::HashMap::new();
Self::extend(&mut map, &cover.coeff_sign_bypass.positions, &plan.coeff_sign_bypass);
Self::extend(&mut map, &cover.coeff_suffix_lsb.positions, &plan.coeff_suffix_lsb);
Self::extend(&mut map, &cover.mvd_sign_bypass.positions, &plan.mvd_sign_bypass);
Self::extend(&mut map, &cover.mvd_suffix_lsb.positions, &plan.mvd_suffix_lsb);
Self { plan: map }
}
fn extend(
map: &mut std::collections::HashMap<PositionKey, u8>,
positions: &[PositionKey],
bits: &[u8],
) {
let n = positions.len().min(bits.len());
for i in 0..n {
map.insert(positions[i], bits[i]);
}
}
pub fn map(&self) -> &std::collections::HashMap<PositionKey, u8> {
&self.plan
}
}
impl BitInjector for PlanInjector {
fn override_bit(&mut self, key: PositionKey) -> Option<u8> {
self.plan.get(&key).copied()
}
}
pub fn pass3_apply_overrides(
cache: &mut GopDecisionCache,
cover: &DomainCover,
plan: &DomainPlan,
) -> usize {
let mut injector = PlanInjector::from_plan(cover, plan);
let mut count = 0usize;
for mb in &mut cache.mbs {
for blk in &mut mb.residual_blocks {
count += apply_coeff_sign_overrides(
&mut blk.scan_coeffs,
blk.start_idx,
blk.end_idx,
mb.frame_idx,
mb.mb_addr,
|ci| blk.path_kind.path(ci, BinKind::Sign),
&mut injector,
);
count += apply_coeff_suffix_lsb_overrides(
&mut blk.scan_coeffs,
blk.start_idx,
blk.end_idx,
mb.frame_idx,
mb.mb_addr,
|ci| blk.path_kind.path(ci, BinKind::SuffixLsb),
&mut injector,
);
}
count += apply_mvd_sign_overrides(
&mut mb.mvd_slots, mb.frame_idx, mb.mb_addr, &mut injector,
);
count += apply_mvd_suffix_lsb_overrides(
&mut mb.mvd_slots, mb.frame_idx, mb.mb_addr, &mut injector,
);
}
count
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codec::h264::stego::Axis;
fn build_synthetic_gop() -> GopDecisionCache {
let mut cache = GopDecisionCache::new();
let mut scan = vec![0i32; 16];
scan[0] = 3;
scan[3] = -7;
scan[6] = 2;
cache.push(MbDecision {
frame_idx: 0, mb_addr: 0,
residual_blocks: vec![MbResidualBlock {
scan_coeffs: scan,
start_idx: 0, end_idx: 15,
ctx_block_cat: 1,
path_kind: ResidualPathKind::Luma4x4 { block_idx: 0 },
}],
mvd_slots: vec![],
});
let mut scan = vec![0i32; 16];
scan[1] = -4;
scan[5] = 1;
cache.push(MbDecision {
frame_idx: 0, mb_addr: 1,
residual_blocks: vec![MbResidualBlock {
scan_coeffs: scan,
start_idx: 0, end_idx: 15,
ctx_block_cat: 1,
path_kind: ResidualPathKind::Luma4x4 { block_idx: 0 },
}],
mvd_slots: vec![
MvdSlot { list: 0, partition: 0, axis: Axis::X, value: 5 },
MvdSlot { list: 0, partition: 0, axis: Axis::Y, value: -3 },
],
});
cache
}
#[test]
fn split_message_proportional_to_capacity() {
use super::super::GopCapacity;
let caps = GopCapacity {
coeff_sign_bypass: 100,
coeff_suffix_lsb: 50,
mvd_sign_bypass: 30,
mvd_suffix_lsb: 20,
};
let msg = vec![0u8, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1];
let split = split_message_per_domain(&msg, &caps).unwrap();
assert_eq!(split.coeff_sign_bypass.len(), 10);
assert_eq!(split.coeff_suffix_lsb.len(), 5);
assert_eq!(split.mvd_sign_bypass.len(), 3);
assert_eq!(split.mvd_suffix_lsb.len(), 2);
let total = split.coeff_sign_bypass.len() + split.coeff_suffix_lsb.len()
+ split.mvd_sign_bypass.len() + split.mvd_suffix_lsb.len();
assert_eq!(total, msg.len());
}
#[test]
fn split_message_too_large_returns_none() {
use super::super::GopCapacity;
let caps = GopCapacity {
coeff_sign_bypass: 5, coeff_suffix_lsb: 0,
mvd_sign_bypass: 0, mvd_suffix_lsb: 0,
};
let msg = vec![0u8; 10];
assert!(split_message_per_domain(&msg, &caps).is_none());
}
#[test]
fn split_message_empty_message_returns_default() {
use super::super::GopCapacity;
let caps = GopCapacity {
coeff_sign_bypass: 100, coeff_suffix_lsb: 50,
mvd_sign_bypass: 30, mvd_suffix_lsb: 20,
};
let split = split_message_per_domain(&[], &caps).unwrap();
assert_eq!(split.coeff_sign_bypass.len(), 0);
assert_eq!(split.coeff_suffix_lsb.len(), 0);
assert_eq!(split.mvd_sign_bypass.len(), 0);
assert_eq!(split.mvd_suffix_lsb.len(), 0);
}
#[test]
fn split_message_leftover_goes_to_largest_headroom() {
use super::super::GopCapacity;
let caps = GopCapacity {
coeff_sign_bypass: 9, coeff_suffix_lsb: 2,
mvd_sign_bypass: 0, mvd_suffix_lsb: 0,
};
let msg = vec![0u8; 5];
let split = split_message_per_domain(&msg, &caps).unwrap();
assert_eq!(split.coeff_sign_bypass.len(), 5);
assert_eq!(split.coeff_suffix_lsb.len(), 0);
}
#[test]
fn split_message_decoder_mirror_recovery() {
use super::super::GopCapacity;
let caps = GopCapacity {
coeff_sign_bypass: 100, coeff_suffix_lsb: 50,
mvd_sign_bypass: 30, mvd_suffix_lsb: 20,
};
let msg: Vec<u8> = (0..20).map(|i| (i & 1) as u8).collect();
let enc_split = split_message_per_domain(&msg, &caps).unwrap();
let dec_split = split_message_per_domain(&msg, &caps).unwrap();
assert_eq!(enc_split.coeff_sign_bypass, dec_split.coeff_sign_bypass);
assert_eq!(enc_split.coeff_suffix_lsb, dec_split.coeff_suffix_lsb);
assert_eq!(enc_split.mvd_sign_bypass, dec_split.mvd_sign_bypass);
assert_eq!(enc_split.mvd_suffix_lsb, dec_split.mvd_suffix_lsb);
}
#[test]
fn pass1_collects_per_domain_cover() {
let cache = build_synthetic_gop();
let cover = pass1_collect_cover(&cache);
assert_eq!(cover.cover.coeff_sign_bypass.len(), 5);
assert_eq!(cover.cover.coeff_suffix_lsb.len(), 0);
assert_eq!(cover.cover.mvd_sign_bypass.len(), 2);
assert_eq!(cover.cover.mvd_suffix_lsb.len(), 0);
assert_eq!(
cover.costs.coeff_sign_bypass.len(),
cover.cover.coeff_sign_bypass.len(),
);
assert_eq!(
cover.costs.mvd_sign_bypass.len(),
cover.cover.mvd_sign_bypass.len(),
);
}
#[test]
fn pass2_empty_message_returns_cover_bits() {
let cache = build_synthetic_gop();
let cover = pass1_collect_cover(&cache);
let messages = DomainMessages::default();
let plan = pass2_stc_plan(&cover, &messages, 7).unwrap();
assert_eq!(plan.coeff_sign_bypass, cover.cover.coeff_sign_bypass.bits);
assert_eq!(plan.mvd_sign_bypass, cover.cover.mvd_sign_bypass.bits);
assert_eq!(plan.total_modifications, 0);
}
#[test]
fn pass2_stc_embeds_message_bits() {
let cache = build_synthetic_gop();
let cover = pass1_collect_cover(&cache);
let messages = DomainMessages {
coeff_sign_bypass: vec![1u8],
..Default::default()
};
let plan = pass2_stc_plan(&cover, &messages, 4).unwrap();
assert_eq!(plan.coeff_sign_bypass.len(), 5);
}
#[test]
fn pass3_apply_overrides_modifies_decision_cache() {
let mut cache = build_synthetic_gop();
let cover = pass1_collect_cover(&cache);
let mut plan = DomainPlan {
coeff_sign_bypass: cover
.cover
.coeff_sign_bypass
.bits
.iter()
.map(|b| b ^ 1)
.collect(),
..Default::default()
};
plan.total_modifications = plan.coeff_sign_bypass.len();
let count = pass3_apply_overrides(&mut cache, &cover.cover, &plan);
assert_eq!(count, 5, "all 5 coeff sign positions must flip");
let new_cover = pass1_collect_cover(&cache);
let new_bits = new_cover.cover.coeff_sign_bypass.bits;
let old_bits = cover.cover.coeff_sign_bypass.bits;
for (n, o) in new_bits.iter().zip(old_bits.iter()) {
assert_eq!(*n, o ^ 1, "every bit should be inverted");
}
}
#[test]
fn pass3_no_op_plan_does_not_modify_cache() {
let mut cache = build_synthetic_gop();
let cover = pass1_collect_cover(&cache);
let plan = DomainPlan {
coeff_sign_bypass: cover.cover.coeff_sign_bypass.bits.clone(),
mvd_sign_bypass: cover.cover.mvd_sign_bypass.bits.clone(),
..Default::default()
};
let count = pass3_apply_overrides(&mut cache, &cover.cover, &plan);
assert_eq!(count, 0);
}
#[test]
fn three_pass_roundtrip_with_per_domain_keys() {
use crate::stego::stc::extract::stc_extract;
use crate::stego::stc::hhat::generate_hhat;
use super::super::keys::CabacStegoMasterKeys;
let mut cache = build_synthetic_gop();
let cover = pass1_collect_cover(&cache);
let original_message = vec![1u8, 0];
let messages = DomainMessages {
coeff_sign_bypass: original_message.clone(),
..Default::default()
};
let h = 4;
let keys = CabacStegoMasterKeys::derive("phase-6d-8-test").unwrap();
let plan = pass2_stc_plan_with_keys(&cover, &messages, h, &keys, 0).unwrap();
pass3_apply_overrides(&mut cache, &cover.cover, &plan);
let stego_cover = pass1_collect_cover(&cache);
assert_eq!(
stego_cover.cover.coeff_sign_bypass.bits,
plan.coeff_sign_bypass,
);
let domain_seed = keys
.per_gop_seeds(super::super::EmbedDomain::CoeffSignBypass, 0)
.hhat_seed;
let n = stego_cover.cover.coeff_sign_bypass.len();
let w = n / original_message.len();
let hhat = generate_hhat(h, w, &domain_seed);
let recovered = stc_extract(&stego_cover.cover.coeff_sign_bypass.bits, &hhat, w);
assert_eq!(
recovered[..original_message.len()],
original_message,
"per-domain-keyed STC roundtrip must recover the message",
);
}
#[test]
fn three_pass_roundtrip_synthetic_gop() {
use crate::stego::stc::extract::stc_extract;
use crate::stego::stc::hhat::generate_hhat;
let mut cache = build_synthetic_gop();
let cover = pass1_collect_cover(&cache);
let original_message = vec![1u8, 0];
let messages = DomainMessages {
coeff_sign_bypass: original_message.clone(),
..Default::default()
};
let h = 4;
let plan = pass2_stc_plan(&cover, &messages, h).unwrap();
pass3_apply_overrides(&mut cache, &cover.cover, &plan);
let stego_cover = pass1_collect_cover(&cache);
assert_eq!(
stego_cover.cover.coeff_sign_bypass.bits,
plan.coeff_sign_bypass,
);
let n = stego_cover.cover.coeff_sign_bypass.len();
let w = n / original_message.len();
let seed = [0u8; 32];
let hhat = generate_hhat(h, w, &seed);
let recovered = stc_extract(&stego_cover.cover.coeff_sign_bypass.bits, &hhat, w);
assert_eq!(
recovered[..original_message.len()],
original_message,
"STC decode must recover the embedded message",
);
}
}