use crate::stego::stc::streaming_segmented::CoverFetch;
pub struct H264GopReplayCover<'a> {
yuv: &'a [u8],
width: u32,
height: u32,
n_frames: usize,
gop_size: usize,
b_count: usize,
quality: Option<u8>,
domain: super::hook::EmbedDomain,
cum_positions: Vec<usize>,
segment_size_in_blocks: usize,
w: usize,
m: usize,
}
impl<'a> H264GopReplayCover<'a> {
#[allow(clippy::too_many_arguments)]
pub fn new(
yuv: &'a [u8],
width: u32,
height: u32,
n_frames: usize,
gop_size: usize,
b_count: usize,
quality: Option<u8>,
domain: super::hook::EmbedDomain,
m: usize,
w: usize,
segment_size_in_blocks: usize,
) -> Result<Self, crate::stego::error::StegoError> {
if w == 0 || segment_size_in_blocks == 0 {
return Err(crate::stego::error::StegoError::InvalidVideo(
"w and segment_size_in_blocks must be > 0".into(),
));
}
let per_gop_counts = super::encode_pixels::pass1_count_per_gop_4domain(
yuv, width, height, n_frames, gop_size, b_count, quality,
)?;
let domain_idx = domain as usize;
let mut cum_positions = Vec::with_capacity(per_gop_counts.len() + 1);
cum_positions.push(0);
for row in per_gop_counts.iter() {
let last = *cum_positions.last().unwrap();
cum_positions.push(last + row[domain_idx]);
}
let total = *cum_positions.last().unwrap();
if m * w > total {
return Err(crate::stego::error::StegoError::InvalidVideo(format!(
"m * w = {} exceeds domain {:?} cover total {}",
m * w,
domain,
total,
)));
}
Ok(Self {
yuv,
width,
height,
n_frames,
gop_size,
b_count,
quality,
domain,
cum_positions,
segment_size_in_blocks,
w,
m,
})
}
#[allow(clippy::too_many_arguments)]
pub fn from_counts(
yuv: &'a [u8],
width: u32,
height: u32,
n_frames: usize,
gop_size: usize,
b_count: usize,
quality: Option<u8>,
domain: super::hook::EmbedDomain,
per_gop_counts: &[[usize; 4]],
m: usize,
w: usize,
segment_size_in_blocks: usize,
) -> Result<Self, crate::stego::error::StegoError> {
if w == 0 || segment_size_in_blocks == 0 {
return Err(crate::stego::error::StegoError::InvalidVideo(
"w and segment_size_in_blocks must be > 0".into(),
));
}
let domain_idx = domain as usize;
let mut cum_positions = Vec::with_capacity(per_gop_counts.len() + 1);
cum_positions.push(0);
for row in per_gop_counts.iter() {
let last = *cum_positions.last().unwrap();
cum_positions.push(last + row[domain_idx]);
}
let total = *cum_positions.last().unwrap();
if m * w > total {
return Err(crate::stego::error::StegoError::InvalidVideo(format!(
"m * w = {} exceeds domain {:?} cover total {}",
m * w,
domain,
total,
)));
}
Ok(Self {
yuv,
width,
height,
n_frames,
gop_size,
b_count,
quality,
domain,
cum_positions,
segment_size_in_blocks,
w,
m,
})
}
fn map_range(&self, j_start: usize, j_end: usize) -> (usize, usize, usize, usize) {
let gop_start = self
.cum_positions
.partition_point(|&c| c <= j_start)
.saturating_sub(1);
let gop_end = self.cum_positions.partition_point(|&c| c < j_end);
let off_start = j_start - self.cum_positions[gop_start];
let off_end = j_end - self.cum_positions[gop_start];
(gop_start, gop_end, off_start, off_end)
}
fn slice_domain(
&self,
cov: &super::orchestrate::GopCover,
off_start: usize,
off_end: usize,
) -> (Vec<u8>, Vec<f32>) {
use super::hook::EmbedDomain;
match self.domain {
EmbedDomain::CoeffSignBypass => (
cov.cover.coeff_sign_bypass.bits[off_start..off_end].to_vec(),
cov.costs.coeff_sign_bypass[off_start..off_end].to_vec(),
),
EmbedDomain::CoeffSuffixLsb => (
cov.cover.coeff_suffix_lsb.bits[off_start..off_end].to_vec(),
cov.costs.coeff_suffix_lsb[off_start..off_end].to_vec(),
),
EmbedDomain::MvdSignBypass => (
cov.cover.mvd_sign_bypass.bits[off_start..off_end].to_vec(),
cov.costs.mvd_sign_bypass[off_start..off_end].to_vec(),
),
EmbedDomain::MvdSuffixLsb => (
cov.cover.mvd_suffix_lsb.bits[off_start..off_end].to_vec(),
cov.costs.mvd_suffix_lsb[off_start..off_end].to_vec(),
),
}
}
}
impl<'a> CoverFetch for H264GopReplayCover<'a> {
fn total_positions(&self) -> usize {
self.m * self.w
}
fn num_segments(&self) -> usize {
self.m.div_ceil(self.segment_size_in_blocks)
}
fn segment_size_in_blocks(&self) -> usize {
self.segment_size_in_blocks
}
fn fetch_segment(&mut self, seg_idx: usize) -> (Vec<u8>, Vec<f32>) {
let block_start = seg_idx * self.segment_size_in_blocks;
let block_end =
((seg_idx + 1) * self.segment_size_in_blocks).min(self.m);
let j_start = block_start * self.w;
let j_end = block_end * self.w;
if j_end <= j_start {
return (Vec::new(), Vec::new());
}
let (gop_start, gop_end, off_start, off_end) =
self.map_range(j_start, j_end);
let cover = super::encode_pixels::pass1_capture_4domain_for_gop_range(
self.yuv,
self.width,
self.height,
self.n_frames,
self.gop_size,
self.b_count,
self.quality,
gop_start,
gop_end,
)
.expect("pass1_capture_4domain_for_gop_range");
self.slice_domain(&cover, off_start, off_end)
}
}