use super::hhat;
use super::extract::stc_extract;
use crate::stego::progress;
pub struct EmbedResult {
pub stego_bits: Vec<u8>,
pub total_cost: f64,
pub num_modifications: usize,
}
pub const STC_PROGRESS_STEPS: u32 = 50;
const SEGMENTED_THRESHOLD: usize = 1_000_000;
pub fn stc_embed(
cover_bits: &[u8],
costs: &[f32],
message: &[u8],
hhat_matrix: &[Vec<u32>],
h: usize,
w: usize,
) -> Option<EmbedResult> {
if w == 0 || h > 7 {
return None;
}
let n = cover_bits.len();
let m = message.len();
if m == 0 {
return Some(EmbedResult {
stego_bits: cover_bits.to_vec(),
total_cost: 0.0,
num_modifications: 0,
});
}
if n > SEGMENTED_THRESHOLD {
use crate::stego::stc::streaming_segmented::{
stc_embed_streaming_segmented, InMemoryCoverFetch,
};
let k = ((m as f64).sqrt().ceil() as usize).max(1);
let mut cover = InMemoryCoverFetch::new(cover_bits, costs, m, w, k)?;
stc_embed_streaming_segmented(&mut cover, message, hhat_matrix, h, w).ok()
} else {
stc_embed_inline(cover_bits, costs, message, hhat_matrix, h, w)
}
}
fn stc_embed_inline(
cover_bits: &[u8],
costs: &[f32],
message: &[u8],
hhat_matrix: &[Vec<u32>],
h: usize,
w: usize,
) -> Option<EmbedResult> {
let n = cover_bits.len();
let m = message.len();
let num_states = 1usize << h;
let inf = f64::INFINITY;
let columns: Vec<usize> = (0..w)
.map(|c| hhat::column_packed(hhat_matrix, c) as usize)
.collect();
let progress_interval = (n / STC_PROGRESS_STEPS as usize).max(1);
let mut prev_cost = vec![inf; num_states];
prev_cost[0] = 0.0;
let mut curr_cost = vec![0.0f64; num_states];
let mut shifted_cost = vec![inf; num_states];
let mut back_ptr: Vec<u128> = Vec::with_capacity(n);
let mut msg_idx = 0;
for j in 0..n {
let col_idx = j % w;
let col = columns[col_idx];
let flip_cost = costs[j] as f64; let cover_bit = cover_bits[j] & 1;
let (cost_s0, cost_s1) = if cover_bit == 0 {
(0.0, flip_cost)
} else {
(flip_cost, 0.0)
};
let mut packed_bp = 0u128;
for s in 0..num_states {
let cost_0 = prev_cost[s] + cost_s0;
let cost_1 = prev_cost[s ^ col] + cost_s1;
if cost_1 < cost_0 {
curr_cost[s] = cost_1;
packed_bp |= 1u128 << s;
} else {
curr_cost[s] = cost_0;
}
}
back_ptr.push(packed_bp);
if col_idx == w - 1 && msg_idx < m {
let required_bit = message[msg_idx] as usize;
shifted_cost.fill(inf);
for s in 0..num_states {
if curr_cost[s] == inf { continue; }
if (s & 1) != required_bit { continue; }
let s_shifted = s >> 1;
if curr_cost[s] < shifted_cost[s_shifted] {
shifted_cost[s_shifted] = curr_cost[s];
}
}
std::mem::swap(&mut prev_cost, &mut shifted_cost);
msg_idx += 1;
} else {
std::mem::swap(&mut prev_cost, &mut curr_cost);
}
if (j + 1) % progress_interval == 0 {
if progress::is_cancelled() { return None; }
progress::advance();
}
}
let (best_state, best_cost) = find_best_state(&prev_cost);
if best_cost == inf { return None; }
let mut stego_bits = vec![0u8; n];
let mut s = best_state;
for j in (0..n).rev() {
let col_idx = j % w;
if col_idx == w - 1 && (j / w) < m {
let msg_bit = message[j / w] as usize;
s = ((s << 1) | msg_bit) & (num_states - 1);
}
let bit = ((back_ptr[j] >> s) & 1) as u8;
stego_bits[j] = bit;
if bit == 1 {
s ^= columns[col_idx];
}
}
debug_assert_eq!(s, 0, "traceback did not return to initial state 0");
debug_assert_eq!(
stc_extract(&stego_bits, hhat_matrix, w)[..m],
message[..m],
);
let num_modifications = stego_bits.iter().zip(cover_bits.iter())
.filter(|(s, c)| s != c).count();
Some(EmbedResult { stego_bits, total_cost: best_cost, num_modifications })
}
fn stc_embed_segmented(
cover_bits: &[u8],
costs: &[f32],
message: &[u8],
hhat_matrix: &[Vec<u32>],
h: usize,
w: usize,
) -> Option<EmbedResult> {
let n = cover_bits.len();
let m = message.len();
let num_states = 1usize << h;
let inf = f64::INFINITY;
let columns: Vec<usize> = (0..w)
.map(|c| hhat::column_packed(hhat_matrix, c) as usize)
.collect();
let k = ((m as f64).sqrt().ceil() as usize).max(1);
let num_segments = m.div_ceil(k);
let phase_a_steps = STC_PROGRESS_STEPS / 2;
let progress_interval_a = (n / phase_a_steps as usize).max(1);
let mut prev_cost = vec![inf; num_states];
prev_cost[0] = 0.0;
let mut curr_cost = vec![0.0f64; num_states];
let mut shifted_cost = vec![inf; num_states];
let mut checkpoints: Vec<Vec<f64>> = Vec::with_capacity(num_segments);
checkpoints.push(prev_cost.clone());
let mut msg_idx = 0;
for j in 0..n {
let col_idx = j % w;
let col = columns[col_idx];
let flip_cost = costs[j] as f64; let cover_bit = cover_bits[j] & 1;
let (cost_s0, cost_s1) = if cover_bit == 0 {
(0.0, flip_cost)
} else {
(flip_cost, 0.0)
};
for s in 0..num_states {
let cost_0 = prev_cost[s] + cost_s0;
let cost_1 = prev_cost[s ^ col] + cost_s1;
curr_cost[s] = if cost_1 < cost_0 { cost_1 } else { cost_0 };
}
if col_idx == w - 1 && msg_idx < m {
let required_bit = message[msg_idx] as usize;
shifted_cost.fill(inf);
for s in 0..num_states {
if curr_cost[s] == inf { continue; }
if (s & 1) != required_bit { continue; }
let s_shifted = s >> 1;
if curr_cost[s] < shifted_cost[s_shifted] {
shifted_cost[s_shifted] = curr_cost[s];
}
}
std::mem::swap(&mut prev_cost, &mut shifted_cost);
msg_idx += 1;
if msg_idx % k == 0 && msg_idx < m {
checkpoints.push(prev_cost.clone());
}
} else {
std::mem::swap(&mut prev_cost, &mut curr_cost);
}
if (j + 1) % progress_interval_a == 0 {
if progress::is_cancelled() { return None; }
progress::advance();
}
}
let (best_state, best_cost) = find_best_state(&prev_cost);
if best_cost == inf { return None; }
let phase_b_steps = STC_PROGRESS_STEPS - phase_a_steps;
let progress_interval_b = (n / phase_b_steps as usize).max(1);
let mut progress_counter = 0usize;
let mut stego_bits = vec![0u8; n];
let mut entry_state = best_state;
for seg in (0..num_segments).rev() {
let block_start = seg * k;
let block_end = ((seg + 1) * k).min(m);
let j_start = block_start * w;
let j_end = block_end * w;
let seg_len = j_end - j_start;
prev_cost.copy_from_slice(&checkpoints[seg]);
let mut seg_back_ptr: Vec<u128> = Vec::with_capacity(seg_len);
let mut seg_msg_idx = block_start;
for j in j_start..j_end {
let col_idx = j % w;
let col = columns[col_idx];
let flip_cost = costs[j] as f64; let cover_bit = cover_bits[j] & 1;
let (cost_s0, cost_s1) = if cover_bit == 0 {
(0.0, flip_cost)
} else {
(flip_cost, 0.0)
};
let mut packed_bp = 0u128;
for s in 0..num_states {
let cost_0 = prev_cost[s] + cost_s0;
let cost_1 = prev_cost[s ^ col] + cost_s1;
if cost_1 < cost_0 {
curr_cost[s] = cost_1;
packed_bp |= 1u128 << s;
} else {
curr_cost[s] = cost_0;
}
}
seg_back_ptr.push(packed_bp);
if col_idx == w - 1 && seg_msg_idx < m {
let required_bit = message[seg_msg_idx] as usize;
shifted_cost.fill(inf);
for s in 0..num_states {
if curr_cost[s] == inf { continue; }
if (s & 1) != required_bit { continue; }
let s_shifted = s >> 1;
if curr_cost[s] < shifted_cost[s_shifted] {
shifted_cost[s_shifted] = curr_cost[s];
}
}
std::mem::swap(&mut prev_cost, &mut shifted_cost);
seg_msg_idx += 1;
} else {
std::mem::swap(&mut prev_cost, &mut curr_cost);
}
progress_counter += 1;
if progress_counter.is_multiple_of(progress_interval_b) {
if progress::is_cancelled() { return None; }
progress::advance();
}
}
let mut s = entry_state;
for local_j in (0..seg_len).rev() {
let j = j_start + local_j;
let col_idx = j % w;
if col_idx == w - 1 && (j / w) < m {
let msg_bit = message[j / w] as usize;
s = ((s << 1) | msg_bit) & (num_states - 1);
}
let bit = ((seg_back_ptr[local_j] >> s) & 1) as u8;
stego_bits[j] = bit;
if bit == 1 {
s ^= columns[col_idx];
}
}
entry_state = s;
}
debug_assert_eq!(entry_state, 0, "traceback did not return to initial state 0");
debug_assert_eq!(
stc_extract(&stego_bits, hhat_matrix, w)[..m],
message[..m],
);
let num_modifications = stego_bits.iter().zip(cover_bits.iter())
.filter(|(s, c)| s != c).count();
Some(EmbedResult { stego_bits, total_cost: best_cost, num_modifications })
}
fn find_best_state(costs: &[f64]) -> (usize, f64) {
let mut best = 0;
let mut best_cost = f64::INFINITY;
for (s, &c) in costs.iter().enumerate() {
if c < best_cost {
best_cost = c;
best = s;
}
}
(best, best_cost)
}
#[cfg(test)]
mod tests {
use super::*;
use super::super::hhat::generate_hhat;
use super::super::extract::stc_extract;
#[test]
fn embed_extract_roundtrip_tiny() {
let h = 3;
let n: usize = 20;
let m: usize = 4;
let w = n.div_ceil(m); let seed = [42u8; 32];
let hhat = generate_hhat(h, w, &seed);
let cover_bits: Vec<u8> = (0..n).map(|i| (i % 2) as u8).collect();
let costs: Vec<f32> = vec![1.0; n];
let message = vec![1u8, 0, 1, 1];
let result = stc_embed(&cover_bits, &costs, &message, &hhat, h, w).unwrap();
assert_eq!(result.stego_bits.len(), n);
let extracted = stc_extract(&result.stego_bits, &hhat, w);
assert_eq!(&extracted[..m], &message[..]);
}
#[test]
fn embed_extract_roundtrip_h7() {
let h = 7;
let n: usize = 500;
let m: usize = 50;
let w = n.div_ceil(m);
let seed = [13u8; 32];
let hhat = generate_hhat(h, w, &seed);
let cover_bits: Vec<u8> = (0..n).map(|i| ((i * 7 + 3) % 2) as u8).collect();
let costs: Vec<f32> = (0..n).map(|i| 1.0 + (i as f32) * 0.01).collect();
let message: Vec<u8> = (0..m).map(|i| (i % 2) as u8).collect();
let result = stc_embed(&cover_bits, &costs, &message, &hhat, h, w).unwrap();
let extracted = stc_extract(&result.stego_bits, &hhat, w);
assert_eq!(&extracted[..m], &message[..]);
}
#[test]
fn wet_coefficients_not_modified() {
let h = 3;
let n: usize = 20;
let m: usize = 4;
let w = n.div_ceil(m);
let seed = [55u8; 32];
let hhat = generate_hhat(h, w, &seed);
let cover_bits: Vec<u8> = vec![0; n];
let mut costs: Vec<f32> = vec![1.0; n];
for i in (0..n).step_by(5) {
costs[i] = 1e13;
}
let message = vec![0u8, 1, 0, 1];
let result = stc_embed(&cover_bits, &costs, &message, &hhat, h, w).unwrap();
for i in (0..n).step_by(5) {
assert_eq!(
result.stego_bits[i], cover_bits[i],
"WET position {i} was modified"
);
}
let extracted = stc_extract(&result.stego_bits, &hhat, w);
assert_eq!(&extracted[..m], &message[..]);
}
#[test]
fn empty_message() {
let h = 3;
let n = 10;
let w = 5;
let seed = [0u8; 32];
let hhat = generate_hhat(h, w, &seed);
let cover_bits: Vec<u8> = vec![1; n];
let costs: Vec<f32> = vec![1.0; n];
let message: Vec<u8> = vec![];
let result = stc_embed(&cover_bits, &costs, &message, &hhat, h, w).unwrap();
assert_eq!(result.stego_bits, cover_bits);
assert_eq!(result.total_cost, 0.0);
}
#[test]
fn embed_extract_roundtrip_large() {
let h = 7;
let m = 10_000;
let w = 10;
let n = m * w; let seed = [77u8; 32];
let hhat = generate_hhat(h, w, &seed);
let cover_bits: Vec<u8> = (0..n).map(|i| ((i * 31 + 17) % 2) as u8).collect();
let costs: Vec<f32> = (0..n).map(|i| {
let base = 0.5 + (i % 100) as f32 * 0.02;
if i % 500 == 0 { f32::INFINITY } else { base }
}).collect();
let message: Vec<u8> = (0..m).map(|i| ((i * 13 + 7) % 2) as u8).collect();
let result = stc_embed(&cover_bits, &costs, &message, &hhat, h, w).unwrap();
assert_eq!(result.stego_bits.len(), n);
let extracted = stc_extract(&result.stego_bits, &hhat, w);
assert_eq!(&extracted[..m], &message[..]);
for i in (0..n).step_by(500) {
assert_eq!(
result.stego_bits[i], cover_bits[i],
"WET position {i} was modified"
);
}
}
#[test]
fn inline_segmented_equivalence() {
let h = 7;
let m = 500;
let w = 10;
let n = m * w; let seed = [99u8; 32];
let hhat = generate_hhat(h, w, &seed);
let cover_bits: Vec<u8> = (0..n).map(|i| ((i * 31 + 17) % 2) as u8).collect();
let costs: Vec<f32> = (0..n).map(|i| {
let base = 0.5 + (i % 100) as f32 * 0.02;
if i % 500 == 0 { f32::INFINITY } else { base }
}).collect();
let message: Vec<u8> = (0..m).map(|i| ((i * 13 + 7) % 2) as u8).collect();
let inline = stc_embed_inline(&cover_bits, &costs, &message, &hhat, h, w).unwrap();
let segmented = stc_embed_segmented(&cover_bits, &costs, &message, &hhat, h, w).unwrap();
assert_eq!(inline.stego_bits, segmented.stego_bits, "stego bits differ");
assert_eq!(inline.total_cost, segmented.total_cost, "total cost differs");
}
#[test]
fn inline_segmented_equivalence_large() {
let h = 7;
let m = 10_000;
let w = 10;
let n = m * w; let seed = [88u8; 32];
let hhat = generate_hhat(h, w, &seed);
let cover_bits: Vec<u8> = (0..n).map(|i| ((i * 37 + 11) % 2) as u8).collect();
let costs: Vec<f32> = (0..n).map(|i| {
let base = 0.3 + (i % 200) as f32 * 0.01;
if i % 1000 == 0 { f32::INFINITY } else { base }
}).collect();
let message: Vec<u8> = (0..m).map(|i| ((i * 19 + 3) % 2) as u8).collect();
let inline = stc_embed_inline(&cover_bits, &costs, &message, &hhat, h, w).unwrap();
let segmented = stc_embed_segmented(&cover_bits, &costs, &message, &hhat, h, w).unwrap();
assert_eq!(inline.stego_bits, segmented.stego_bits, "stego bits differ");
assert_eq!(inline.total_cost, segmented.total_cost, "total cost differs");
}
#[test]
fn segmented_single_segment() {
let h = 7;
let m = 4;
let w = 5;
let n = m * w;
let seed = [33u8; 32];
let hhat = generate_hhat(h, w, &seed);
let cover_bits: Vec<u8> = (0..n).map(|i| (i % 2) as u8).collect();
let costs: Vec<f32> = vec![1.0; n];
let message: Vec<u8> = vec![1, 0, 1, 1];
let inline = stc_embed_inline(&cover_bits, &costs, &message, &hhat, h, w).unwrap();
let segmented = stc_embed_segmented(&cover_bits, &costs, &message, &hhat, h, w).unwrap();
assert_eq!(inline.stego_bits, segmented.stego_bits);
assert_eq!(inline.total_cost, segmented.total_cost);
}
}