1use super::embed::{stc_embed, EmbedResult};
42use super::extract::stc_extract;
43use super::hhat;
44
45pub fn stc_embed_streaming(
56 cover_bits: &[u8],
57 costs: &[f32],
58 message: &[u8],
59 hhat_matrix: &[Vec<u32>],
60 h: usize,
61 w: usize,
62 window_k: usize,
63) -> Option<EmbedResult> {
64 if w == 0 || h > 7 {
65 return None;
66 }
67
68 let n = cover_bits.len();
69 let m = message.len();
70 let num_states = 1usize << h;
71 let inf = f64::INFINITY;
72
73 if m == 0 {
74 return Some(EmbedResult {
75 stego_bits: cover_bits.to_vec(),
76 total_cost: 0.0,
77 num_modifications: 0,
78 });
79 }
80
81 let columns: Vec<usize> = (0..w)
83 .map(|c| hhat::column_packed(hhat_matrix, c) as usize)
84 .collect();
85
86 let mut prev_cost = vec![inf; num_states];
88 prev_cost[0] = 0.0;
89 let mut curr_cost = vec![0.0f64; num_states];
90 let mut shifted_cost = vec![inf; num_states];
91
92 let mut back_ptr: Vec<u128> = Vec::with_capacity(n);
95 let mut boundary: Vec<Option<u8>> = Vec::with_capacity(n);
99
100 let mut msg_idx = 0;
101 let mut stego_bits = vec![0u8; n];
102
103 for j in 0..n {
104 let col_idx = j % w;
105 let col = columns[col_idx];
106 let flip_cost = costs[j] as f64;
107 let cover_bit = cover_bits[j] & 1;
108
109 let (cost_s0, cost_s1) = if cover_bit == 0 {
110 (0.0, flip_cost)
111 } else {
112 (flip_cost, 0.0)
113 };
114
115 let mut packed_bp = 0u128;
116 for s in 0..num_states {
117 let cost_0 = prev_cost[s] + cost_s0;
118 let cost_1 = prev_cost[s ^ col] + cost_s1;
119 if cost_1 < cost_0 {
120 curr_cost[s] = cost_1;
121 packed_bp |= 1u128 << s;
122 } else {
123 curr_cost[s] = cost_0;
124 }
125 }
126
127 back_ptr.push(packed_bp);
128
129 let is_msg_boundary = col_idx == w - 1 && msg_idx < m;
130 if is_msg_boundary {
131 let required_bit = message[msg_idx] as usize;
132 shifted_cost.fill(inf);
133 for s in 0..num_states {
134 if curr_cost[s] == inf { continue; }
135 if (s & 1) != required_bit { continue; }
136 let s_shifted = s >> 1;
137 if curr_cost[s] < shifted_cost[s_shifted] {
138 shifted_cost[s_shifted] = curr_cost[s];
139 }
140 }
141 std::mem::swap(&mut prev_cost, &mut shifted_cost);
142 boundary.push(Some(required_bit as u8));
143 msg_idx += 1;
144 } else {
145 std::mem::swap(&mut prev_cost, &mut curr_cost);
146 boundary.push(None);
147 }
148
149 if window_k < n && j + 1 > window_k {
154 let commit_j = j - window_k;
155 commit_position(
156 commit_j, j, &back_ptr, &boundary, &columns, w, m,
157 num_states, &prev_cost, message, &mut stego_bits,
158 );
159 }
160 }
161
162 let (best_state, best_cost) = find_best_state(&prev_cost);
165 if best_cost == inf {
166 return None;
167 }
168
169 let drain_start = n.saturating_sub(window_k);
172
173 let mut s = best_state;
176 for j in (drain_start..n).rev() {
177 let col_idx = j % w;
178 if let Some(msg_bit) = boundary[j] {
179 s = ((s << 1) | msg_bit as usize) & (num_states - 1);
180 }
181 let bit = ((back_ptr[j] >> s) & 1) as u8;
182 stego_bits[j] = bit;
183 if bit == 1 {
184 s ^= columns[col_idx];
185 }
186 }
187
188 let num_modifications = stego_bits
189 .iter()
190 .zip(cover_bits.iter())
191 .filter(|(s, c)| s != c)
192 .count();
193
194 let total_cost: f64 = stego_bits
196 .iter()
197 .zip(cover_bits.iter())
198 .zip(costs.iter())
199 .filter_map(|((s, c), cost)| if s != c { Some(*cost as f64) } else { None })
200 .sum();
201
202 Some(EmbedResult { stego_bits, total_cost, num_modifications })
203}
204
205#[allow(clippy::too_many_arguments)]
209fn commit_position(
210 commit_j: usize,
211 look_ahead_j: usize,
212 back_ptr: &[u128],
213 boundary: &[Option<u8>],
214 columns: &[usize],
215 _w: usize,
216 _m: usize,
217 num_states: usize,
218 prev_cost: &[f64],
219 _message: &[u8],
220 stego_bits: &mut [u8],
221) {
222 let (mut s, _) = find_best_state(prev_cost);
224 let _ = num_states;
225
226 for j in (commit_j..=look_ahead_j).rev() {
228 let col_idx = j % columns.len();
229 if let Some(msg_bit) = boundary[j] {
230 s = ((s << 1) | msg_bit as usize) & (num_states - 1);
231 }
232 let bit = ((back_ptr[j] >> s) & 1) as u8;
233 if j == commit_j {
234 stego_bits[commit_j] = bit;
235 return;
236 }
237 if bit == 1 {
238 s ^= columns[col_idx];
239 }
240 }
241}
242
243fn find_best_state(costs: &[f64]) -> (usize, f64) {
244 let mut best = 0;
245 let mut best_cost = f64::INFINITY;
246 for (s, &c) in costs.iter().enumerate() {
247 if c < best_cost {
248 best_cost = c;
249 best = s;
250 }
251 }
252 (best, best_cost)
253}
254
255pub fn run_k_sweep(
260 cover_bits: &[u8],
261 costs: &[f32],
262 message: &[u8],
263 hhat_matrix: &[Vec<u32>],
264 h: usize,
265 w: usize,
266 k_values: &[usize],
267) -> KSweepReport {
268 let reference = stc_embed(cover_bits, costs, message, hhat_matrix, h, w)
269 .expect("reference Viterbi failed");
270 let n = cover_bits.len();
271
272 let ref_flips: Vec<bool> = reference
274 .stego_bits
275 .iter()
276 .zip(cover_bits.iter())
277 .map(|(s, c)| s != c)
278 .collect();
279 let ref_total_flips = ref_flips.iter().filter(|f| **f).count();
280
281 let mut entries = Vec::new();
282 for &k in k_values {
283 let start = std::time::Instant::now();
284 let sw = stc_embed_streaming(
285 cover_bits, costs, message, hhat_matrix, h, w, k,
286 ).expect("streaming Viterbi failed");
287 let wall_clock_ms = start.elapsed().as_millis();
288
289 let hamming: usize = sw
291 .stego_bits
292 .iter()
293 .zip(reference.stego_bits.iter())
294 .filter(|(a, b)| a != b)
295 .count();
296 let hamming_pct = (hamming as f64) / (n as f64) * 100.0;
297
298 let bin_size = (n / 100).max(1);
302 let mut max_bin_divergence: f64 = 0.0;
303 let mut sum_bin_divergence: f64 = 0.0;
304 let mut bin_count = 0;
305 let mut start_idx = 0;
306 while start_idx < n {
307 let end = (start_idx + bin_size).min(n);
308 let ref_bin: usize = ref_flips[start_idx..end].iter().filter(|f| **f).count();
309 let sw_bin: usize = sw.stego_bits[start_idx..end]
310 .iter()
311 .zip(cover_bits[start_idx..end].iter())
312 .filter(|(s, c)| s != c)
313 .count();
314 let len = end - start_idx;
315 let div = ((ref_bin as f64) - (sw_bin as f64)).abs() / (len as f64);
316 if div > max_bin_divergence { max_bin_divergence = div; }
317 sum_bin_divergence += div;
318 bin_count += 1;
319 start_idx = end;
320 }
321 let avg_bin_divergence = sum_bin_divergence / (bin_count as f64);
322
323 let extracted = stc_extract(&sw.stego_bits, hhat_matrix, w);
328 let syndrome_match_bits: usize = extracted[..message.len()]
329 .iter()
330 .zip(message.iter())
331 .filter(|(a, b)| a == b)
332 .count();
333 let syndrome_valid = syndrome_match_bits == message.len();
334 let syndrome_match_pct = (syndrome_match_bits as f64) / (message.len() as f64) * 100.0;
335
336 entries.push(KSweepEntry {
337 k,
338 n,
339 hamming,
340 hamming_pct,
341 sw_total_flips: sw.num_modifications,
342 ref_total_flips,
343 sw_total_cost: sw.total_cost,
344 ref_total_cost: reference.total_cost,
345 max_bin_flip_rate_divergence: max_bin_divergence,
346 avg_bin_flip_rate_divergence: avg_bin_divergence,
347 wall_clock_ms,
348 syndrome_valid,
349 syndrome_match_pct,
350 });
351 }
352
353 KSweepReport {
354 n,
355 m: message.len(),
356 h,
357 w,
358 ref_total_flips,
359 ref_total_cost: reference.total_cost,
360 entries,
361 }
362}
363
364#[derive(Debug, Clone)]
365pub struct KSweepEntry {
366 pub k: usize,
367 pub n: usize,
368 pub hamming: usize,
369 pub hamming_pct: f64,
370 pub sw_total_flips: usize,
371 pub ref_total_flips: usize,
372 pub sw_total_cost: f64,
373 pub ref_total_cost: f64,
374 pub max_bin_flip_rate_divergence: f64,
375 pub avg_bin_flip_rate_divergence: f64,
376 pub wall_clock_ms: u128,
377 pub syndrome_valid: bool,
381 pub syndrome_match_pct: f64,
384}
385
386#[derive(Debug, Clone)]
387pub struct KSweepReport {
388 pub n: usize,
389 pub m: usize,
390 pub h: usize,
391 pub w: usize,
392 pub ref_total_flips: usize,
393 pub ref_total_cost: f64,
394 pub entries: Vec<KSweepEntry>,
395}
396
397impl KSweepReport {
398 pub fn to_markdown(&self) -> String {
401 let mut out = String::new();
402 out.push_str(&format!(
403 "## K sweep results\n\n\
404 Cover length: n = {}, message: m = {} bits, h = {}, w = {}, \
405 reference flips: {} ({:.4}% of n), reference cost: {:.2}\n\n",
406 self.n, self.m, self.h, self.w,
407 self.ref_total_flips,
408 (self.ref_total_flips as f64) / (self.n as f64) * 100.0,
409 self.ref_total_cost,
410 ));
411 out.push_str("| K | syndrome | Hamming | Hamming% | sw flips | ref flips | cost-Δ | max bin Δ | avg bin Δ | wall ms |\n");
412 out.push_str("|---|---|---|---|---|---|---|---|---|---|\n");
413 for e in &self.entries {
414 let cost_delta = e.sw_total_cost - e.ref_total_cost;
415 let syn = if e.syndrome_valid {
416 "✓ valid".to_string()
417 } else {
418 format!("✗ {:.2}%", e.syndrome_match_pct)
419 };
420 out.push_str(&format!(
421 "| {} | {} | {} | {:.4}% | {} | {} | {:+.2} | {:.6} | {:.6} | {} |\n",
422 e.k, syn, e.hamming, e.hamming_pct, e.sw_total_flips,
423 e.ref_total_flips, cost_delta,
424 e.max_bin_flip_rate_divergence,
425 e.avg_bin_flip_rate_divergence,
426 e.wall_clock_ms,
427 ));
428 }
429 out
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436 use super::super::hhat::generate_hhat;
437
438 #[test]
440 fn streaming_with_full_window_matches_inline() {
441 let h = 7;
442 let m = 100;
443 let w = 10;
444 let n = m * w;
445 let seed = [42u8; 32];
446 let hhat = generate_hhat(h, w, &seed);
447
448 let cover_bits: Vec<u8> = (0..n).map(|i| ((i * 31 + 17) % 2) as u8).collect();
449 let costs: Vec<f32> = (0..n).map(|i| 1.0 + (i as f32) * 0.01).collect();
450 let message: Vec<u8> = (0..m).map(|i| ((i * 13 + 7) % 2) as u8).collect();
451
452 let inline = stc_embed(&cover_bits, &costs, &message, &hhat, h, w).unwrap();
453 let stream = stc_embed_streaming(&cover_bits, &costs, &message, &hhat, h, w, n).unwrap();
454
455 assert_eq!(inline.stego_bits, stream.stego_bits,
456 "K=n (full window) must equal inline Viterbi");
457 }
458
459 #[test]
463 fn streaming_with_tiny_window_runs_without_error() {
464 let h = 4;
465 let m = 50;
466 let w = 5;
467 let n = m * w;
468 let seed = [13u8; 32];
469 let hhat = generate_hhat(h, w, &seed);
470
471 let cover_bits: Vec<u8> = (0..n).map(|i| ((i * 7) % 2) as u8).collect();
472 let costs: Vec<f32> = vec![1.0; n];
473 let message: Vec<u8> = (0..m).map(|i| (i % 2) as u8).collect();
474
475 let stream = stc_embed_streaming(&cover_bits, &costs, &message, &hhat, h, w, 5).unwrap();
476 assert_eq!(stream.stego_bits.len(), n);
477 }
478
479 #[test]
486 #[ignore]
487 fn k_sweep_report() {
488 let h = 7;
489 let m = 1000;
490 let w = 100;
491 let n = m * w; let seed = [77u8; 32];
493 let hhat = generate_hhat(h, w, &seed);
494
495 let cover_bits: Vec<u8> = (0..n)
499 .map(|i| {
500 let mut x = i as u64;
502 x = x.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
503 ((x >> 33) & 1) as u8
504 })
505 .collect();
506 let costs: Vec<f32> = (0..n)
507 .map(|i| {
508 let mut x = (i as u64).wrapping_add(0xcafef00d);
509 x = x.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
510 let frac = ((x >> 32) & 0xffff) as f32 / 65536.0; 0.5 + frac * 0.5 })
513 .collect();
514 let message: Vec<u8> = (0..m).map(|i| ((i * 19 + 11) % 2) as u8).collect();
515
516 let k_values = vec![100, 500, 1000, 5000, 10000, 50000];
517 let report = run_k_sweep(&cover_bits, &costs, &message, &hhat, h, w, &k_values);
518
519 println!("\n{}\n", report.to_markdown());
520 }
521}