1use super::hhat;
17use super::extract::stc_extract;
18use crate::stego::progress;
19
20pub struct EmbedResult {
22 pub stego_bits: Vec<u8>,
23 pub total_cost: f64,
24 pub num_modifications: usize,
26}
27
28pub const STC_PROGRESS_STEPS: u32 = 50;
33
34const SEGMENTED_THRESHOLD: usize = 1_000_000;
37
38pub fn stc_embed(
58 cover_bits: &[u8],
59 costs: &[f32],
60 message: &[u8],
61 hhat_matrix: &[Vec<u32>],
62 h: usize,
63 w: usize,
64) -> Option<EmbedResult> {
65 if w == 0 || h > 7 {
67 return None;
68 }
69
70 let n = cover_bits.len();
71 let m = message.len();
72
73 if m == 0 {
74 return Some(EmbedResult {
75 stego_bits: cover_bits.to_vec(),
76 total_cost: 0.0,
77 num_modifications: 0,
78 });
79 }
80
81 if n > SEGMENTED_THRESHOLD {
82 use crate::stego::stc::streaming_segmented::{
93 stc_embed_streaming_segmented, InMemoryCoverFetch,
94 };
95 let k = ((m as f64).sqrt().ceil() as usize).max(1);
96 let mut cover = InMemoryCoverFetch::new(cover_bits, costs, m, w, k)?;
97 stc_embed_streaming_segmented(&mut cover, message, hhat_matrix, h, w).ok()
98 } else {
99 stc_embed_inline(cover_bits, costs, message, hhat_matrix, h, w)
100 }
101}
102
103fn stc_embed_inline(
109 cover_bits: &[u8],
110 costs: &[f32],
111 message: &[u8],
112 hhat_matrix: &[Vec<u32>],
113 h: usize,
114 w: usize,
115) -> Option<EmbedResult> {
116 let n = cover_bits.len();
117 let m = message.len();
118 let num_states = 1usize << h;
119 let inf = f64::INFINITY;
120
121 let columns: Vec<usize> = (0..w)
123 .map(|c| hhat::column_packed(hhat_matrix, c) as usize)
124 .collect();
125
126 let progress_interval = (n / STC_PROGRESS_STEPS as usize).max(1);
128
129 let mut prev_cost = vec![inf; num_states];
137 prev_cost[0] = 0.0;
138 let mut curr_cost = vec![0.0f64; num_states];
139 let mut shifted_cost = vec![inf; num_states];
140
141 let mut back_ptr: Vec<u128> = Vec::with_capacity(n);
142 let mut msg_idx = 0;
143
144 for j in 0..n {
145 let col_idx = j % w;
146 let col = columns[col_idx];
147 let flip_cost = costs[j] as f64; let cover_bit = cover_bits[j] & 1;
149
150 let (cost_s0, cost_s1) = if cover_bit == 0 {
151 (0.0, flip_cost)
152 } else {
153 (flip_cost, 0.0)
154 };
155
156 let mut packed_bp = 0u128;
157
158 for s in 0..num_states {
159 let cost_0 = prev_cost[s] + cost_s0;
160 let cost_1 = prev_cost[s ^ col] + cost_s1;
161
162 if cost_1 < cost_0 {
163 curr_cost[s] = cost_1;
164 packed_bp |= 1u128 << s;
165 } else {
166 curr_cost[s] = cost_0;
167 }
168 }
169
170 back_ptr.push(packed_bp);
171
172 if col_idx == w - 1 && msg_idx < m {
173 let required_bit = message[msg_idx] as usize;
174 shifted_cost.fill(inf);
175
176 for s in 0..num_states {
177 if curr_cost[s] == inf { continue; }
178 if (s & 1) != required_bit { continue; }
179 let s_shifted = s >> 1;
180 if curr_cost[s] < shifted_cost[s_shifted] {
181 shifted_cost[s_shifted] = curr_cost[s];
182 }
183 }
184
185 std::mem::swap(&mut prev_cost, &mut shifted_cost);
186 msg_idx += 1;
187 } else {
188 std::mem::swap(&mut prev_cost, &mut curr_cost);
189 }
190
191 if (j + 1) % progress_interval == 0 {
192 if progress::is_cancelled() { return None; }
193 progress::advance();
194 }
195 }
196
197 let (best_state, best_cost) = find_best_state(&prev_cost);
199 if best_cost == inf { return None; }
200
201 let mut stego_bits = vec![0u8; n];
203 let mut s = best_state;
204
205 for j in (0..n).rev() {
206 let col_idx = j % w;
207
208 if col_idx == w - 1 && (j / w) < m {
209 let msg_bit = message[j / w] as usize;
210 s = ((s << 1) | msg_bit) & (num_states - 1);
211 }
212
213 let bit = ((back_ptr[j] >> s) & 1) as u8;
214 stego_bits[j] = bit;
215
216 if bit == 1 {
217 s ^= columns[col_idx];
218 }
219 }
220
221 debug_assert_eq!(s, 0, "traceback did not return to initial state 0");
222 debug_assert_eq!(
223 stc_extract(&stego_bits, hhat_matrix, w)[..m],
224 message[..m],
225 );
226
227 let num_modifications = stego_bits.iter().zip(cover_bits.iter())
228 .filter(|(s, c)| s != c).count();
229
230 Some(EmbedResult { stego_bits, total_cost: best_cost, num_modifications })
231}
232
233fn stc_embed_segmented(
239 cover_bits: &[u8],
240 costs: &[f32],
241 message: &[u8],
242 hhat_matrix: &[Vec<u32>],
243 h: usize,
244 w: usize,
245) -> Option<EmbedResult> {
246 let n = cover_bits.len();
247 let m = message.len();
248 let num_states = 1usize << h;
249 let inf = f64::INFINITY;
250
251 let columns: Vec<usize> = (0..w)
254 .map(|c| hhat::column_packed(hhat_matrix, c) as usize)
255 .collect();
256
257 let k = ((m as f64).sqrt().ceil() as usize).max(1);
261 let num_segments = m.div_ceil(k);
262
263 let phase_a_steps = STC_PROGRESS_STEPS / 2;
266 let progress_interval_a = (n / phase_a_steps as usize).max(1);
267
268 let mut prev_cost = vec![inf; num_states];
270 prev_cost[0] = 0.0;
271 let mut curr_cost = vec![0.0f64; num_states];
272 let mut shifted_cost = vec![inf; num_states];
273
274 let mut checkpoints: Vec<Vec<f64>> = Vec::with_capacity(num_segments);
277 checkpoints.push(prev_cost.clone());
278
279 let mut msg_idx = 0;
280
281 for j in 0..n {
282 let col_idx = j % w;
283 let col = columns[col_idx];
284 let flip_cost = costs[j] as f64; let cover_bit = cover_bits[j] & 1;
286
287 let (cost_s0, cost_s1) = if cover_bit == 0 {
288 (0.0, flip_cost)
289 } else {
290 (flip_cost, 0.0)
291 };
292
293 for s in 0..num_states {
294 let cost_0 = prev_cost[s] + cost_s0;
295 let cost_1 = prev_cost[s ^ col] + cost_s1;
296 curr_cost[s] = if cost_1 < cost_0 { cost_1 } else { cost_0 };
297 }
298
299 if col_idx == w - 1 && msg_idx < m {
300 let required_bit = message[msg_idx] as usize;
301 shifted_cost.fill(inf);
302 for s in 0..num_states {
303 if curr_cost[s] == inf { continue; }
304 if (s & 1) != required_bit { continue; }
305 let s_shifted = s >> 1;
306 if curr_cost[s] < shifted_cost[s_shifted] {
307 shifted_cost[s_shifted] = curr_cost[s];
308 }
309 }
310 std::mem::swap(&mut prev_cost, &mut shifted_cost);
311 msg_idx += 1;
312
313 if msg_idx % k == 0 && msg_idx < m {
315 checkpoints.push(prev_cost.clone());
316 }
317 } else {
318 std::mem::swap(&mut prev_cost, &mut curr_cost);
319 }
320
321 if (j + 1) % progress_interval_a == 0 {
322 if progress::is_cancelled() { return None; }
323 progress::advance();
324 }
325 }
326
327 let (best_state, best_cost) = find_best_state(&prev_cost);
329 if best_cost == inf { return None; }
330
331 let phase_b_steps = STC_PROGRESS_STEPS - phase_a_steps;
334 let progress_interval_b = (n / phase_b_steps as usize).max(1);
335 let mut progress_counter = 0usize;
336
337 let mut stego_bits = vec![0u8; n];
338 let mut entry_state = best_state;
339
340 for seg in (0..num_segments).rev() {
341 let block_start = seg * k;
342 let block_end = ((seg + 1) * k).min(m);
343 let j_start = block_start * w;
344 let j_end = block_end * w;
345 let seg_len = j_end - j_start;
346
347 prev_cost.copy_from_slice(&checkpoints[seg]);
349
350 let mut seg_back_ptr: Vec<u128> = Vec::with_capacity(seg_len);
352 let mut seg_msg_idx = block_start;
353
354 for j in j_start..j_end {
355 let col_idx = j % w;
356 let col = columns[col_idx];
357 let flip_cost = costs[j] as f64; let cover_bit = cover_bits[j] & 1;
359
360 let (cost_s0, cost_s1) = if cover_bit == 0 {
361 (0.0, flip_cost)
362 } else {
363 (flip_cost, 0.0)
364 };
365
366 let mut packed_bp = 0u128;
367
368 for s in 0..num_states {
369 let cost_0 = prev_cost[s] + cost_s0;
370 let cost_1 = prev_cost[s ^ col] + cost_s1;
371 if cost_1 < cost_0 {
372 curr_cost[s] = cost_1;
373 packed_bp |= 1u128 << s;
374 } else {
375 curr_cost[s] = cost_0;
376 }
377 }
378
379 seg_back_ptr.push(packed_bp);
380
381 if col_idx == w - 1 && seg_msg_idx < m {
382 let required_bit = message[seg_msg_idx] as usize;
383 shifted_cost.fill(inf);
384 for s in 0..num_states {
385 if curr_cost[s] == inf { continue; }
386 if (s & 1) != required_bit { continue; }
387 let s_shifted = s >> 1;
388 if curr_cost[s] < shifted_cost[s_shifted] {
389 shifted_cost[s_shifted] = curr_cost[s];
390 }
391 }
392 std::mem::swap(&mut prev_cost, &mut shifted_cost);
393 seg_msg_idx += 1;
394 } else {
395 std::mem::swap(&mut prev_cost, &mut curr_cost);
396 }
397
398 progress_counter += 1;
399 if progress_counter.is_multiple_of(progress_interval_b) {
400 if progress::is_cancelled() { return None; }
401 progress::advance();
402 }
403 }
404
405 let mut s = entry_state;
407 for local_j in (0..seg_len).rev() {
408 let j = j_start + local_j;
409 let col_idx = j % w;
410
411 if col_idx == w - 1 && (j / w) < m {
412 let msg_bit = message[j / w] as usize;
413 s = ((s << 1) | msg_bit) & (num_states - 1);
414 }
415
416 let bit = ((seg_back_ptr[local_j] >> s) & 1) as u8;
417 stego_bits[j] = bit;
418
419 if bit == 1 {
420 s ^= columns[col_idx];
421 }
422 }
423
424 entry_state = s;
426 }
428
429 debug_assert_eq!(entry_state, 0, "traceback did not return to initial state 0");
430 debug_assert_eq!(
431 stc_extract(&stego_bits, hhat_matrix, w)[..m],
432 message[..m],
433 );
434
435 let num_modifications = stego_bits.iter().zip(cover_bits.iter())
436 .filter(|(s, c)| s != c).count();
437
438 Some(EmbedResult { stego_bits, total_cost: best_cost, num_modifications })
439}
440
441fn find_best_state(costs: &[f64]) -> (usize, f64) {
447 let mut best = 0;
448 let mut best_cost = f64::INFINITY;
449 for (s, &c) in costs.iter().enumerate() {
450 if c < best_cost {
451 best_cost = c;
452 best = s;
453 }
454 }
455 (best, best_cost)
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461 use super::super::hhat::generate_hhat;
462 use super::super::extract::stc_extract;
463
464 #[test]
465 fn embed_extract_roundtrip_tiny() {
466 let h = 3;
467 let n: usize = 20;
468 let m: usize = 4;
469 let w = n.div_ceil(m); let seed = [42u8; 32];
471 let hhat = generate_hhat(h, w, &seed);
472
473 let cover_bits: Vec<u8> = (0..n).map(|i| (i % 2) as u8).collect();
474 let costs: Vec<f32> = vec![1.0; n];
475 let message = vec![1u8, 0, 1, 1];
476
477 let result = stc_embed(&cover_bits, &costs, &message, &hhat, h, w).unwrap();
478 assert_eq!(result.stego_bits.len(), n);
479
480 let extracted = stc_extract(&result.stego_bits, &hhat, w);
481 assert_eq!(&extracted[..m], &message[..]);
482 }
483
484 #[test]
485 fn embed_extract_roundtrip_h7() {
486 let h = 7;
487 let n: usize = 500;
488 let m: usize = 50;
489 let w = n.div_ceil(m);
490 let seed = [13u8; 32];
491 let hhat = generate_hhat(h, w, &seed);
492
493 let cover_bits: Vec<u8> = (0..n).map(|i| ((i * 7 + 3) % 2) as u8).collect();
494 let costs: Vec<f32> = (0..n).map(|i| 1.0 + (i as f32) * 0.01).collect();
495 let message: Vec<u8> = (0..m).map(|i| (i % 2) as u8).collect();
496
497 let result = stc_embed(&cover_bits, &costs, &message, &hhat, h, w).unwrap();
498 let extracted = stc_extract(&result.stego_bits, &hhat, w);
499 assert_eq!(&extracted[..m], &message[..]);
500 }
501
502 #[test]
503 fn wet_coefficients_not_modified() {
504 let h = 3;
505 let n: usize = 20;
506 let m: usize = 4;
507 let w = n.div_ceil(m);
508 let seed = [55u8; 32];
509 let hhat = generate_hhat(h, w, &seed);
510
511 let cover_bits: Vec<u8> = vec![0; n];
512 let mut costs: Vec<f32> = vec![1.0; n];
513 for i in (0..n).step_by(5) {
515 costs[i] = 1e13;
516 }
517 let message = vec![0u8, 1, 0, 1];
518
519 let result = stc_embed(&cover_bits, &costs, &message, &hhat, h, w).unwrap();
520
521 for i in (0..n).step_by(5) {
523 assert_eq!(
524 result.stego_bits[i], cover_bits[i],
525 "WET position {i} was modified"
526 );
527 }
528
529 let extracted = stc_extract(&result.stego_bits, &hhat, w);
531 assert_eq!(&extracted[..m], &message[..]);
532 }
533
534 #[test]
535 fn empty_message() {
536 let h = 3;
537 let n = 10;
538 let w = 5;
539 let seed = [0u8; 32];
540 let hhat = generate_hhat(h, w, &seed);
541
542 let cover_bits: Vec<u8> = vec![1; n];
543 let costs: Vec<f32> = vec![1.0; n];
544 let message: Vec<u8> = vec![];
545
546 let result = stc_embed(&cover_bits, &costs, &message, &hhat, h, w).unwrap();
547 assert_eq!(result.stego_bits, cover_bits);
548 assert_eq!(result.total_cost, 0.0);
549 }
550
551 #[test]
553 fn embed_extract_roundtrip_large() {
554 let h = 7;
555 let m = 10_000;
556 let w = 10;
557 let n = m * w; let seed = [77u8; 32];
559 let hhat = generate_hhat(h, w, &seed);
560
561 let cover_bits: Vec<u8> = (0..n).map(|i| ((i * 31 + 17) % 2) as u8).collect();
562 let costs: Vec<f32> = (0..n).map(|i| {
563 let base = 0.5 + (i % 100) as f32 * 0.02;
564 if i % 500 == 0 { f32::INFINITY } else { base }
565 }).collect();
566 let message: Vec<u8> = (0..m).map(|i| ((i * 13 + 7) % 2) as u8).collect();
567
568 let result = stc_embed(&cover_bits, &costs, &message, &hhat, h, w).unwrap();
569 assert_eq!(result.stego_bits.len(), n);
570
571 let extracted = stc_extract(&result.stego_bits, &hhat, w);
572 assert_eq!(&extracted[..m], &message[..]);
573
574 for i in (0..n).step_by(500) {
575 assert_eq!(
576 result.stego_bits[i], cover_bits[i],
577 "WET position {i} was modified"
578 );
579 }
580 }
581
582 #[test]
584 fn inline_segmented_equivalence() {
585 let h = 7;
586 let m = 500;
587 let w = 10;
588 let n = m * w; let seed = [99u8; 32];
590 let hhat = generate_hhat(h, w, &seed);
591
592 let cover_bits: Vec<u8> = (0..n).map(|i| ((i * 31 + 17) % 2) as u8).collect();
593 let costs: Vec<f32> = (0..n).map(|i| {
594 let base = 0.5 + (i % 100) as f32 * 0.02;
595 if i % 500 == 0 { f32::INFINITY } else { base }
596 }).collect();
597 let message: Vec<u8> = (0..m).map(|i| ((i * 13 + 7) % 2) as u8).collect();
598
599 let inline = stc_embed_inline(&cover_bits, &costs, &message, &hhat, h, w).unwrap();
600 let segmented = stc_embed_segmented(&cover_bits, &costs, &message, &hhat, h, w).unwrap();
601
602 assert_eq!(inline.stego_bits, segmented.stego_bits, "stego bits differ");
603 assert_eq!(inline.total_cost, segmented.total_cost, "total cost differs");
604 }
605
606 #[test]
608 fn inline_segmented_equivalence_large() {
609 let h = 7;
610 let m = 10_000;
611 let w = 10;
612 let n = m * w; let seed = [88u8; 32];
614 let hhat = generate_hhat(h, w, &seed);
615
616 let cover_bits: Vec<u8> = (0..n).map(|i| ((i * 37 + 11) % 2) as u8).collect();
617 let costs: Vec<f32> = (0..n).map(|i| {
618 let base = 0.3 + (i % 200) as f32 * 0.01;
619 if i % 1000 == 0 { f32::INFINITY } else { base }
620 }).collect();
621 let message: Vec<u8> = (0..m).map(|i| ((i * 19 + 3) % 2) as u8).collect();
622
623 let inline = stc_embed_inline(&cover_bits, &costs, &message, &hhat, h, w).unwrap();
624 let segmented = stc_embed_segmented(&cover_bits, &costs, &message, &hhat, h, w).unwrap();
625
626 assert_eq!(inline.stego_bits, segmented.stego_bits, "stego bits differ");
627 assert_eq!(inline.total_cost, segmented.total_cost, "total cost differs");
628 }
629
630 #[test]
632 fn segmented_single_segment() {
633 let h = 7;
634 let m = 4;
635 let w = 5;
636 let n = m * w;
637 let seed = [33u8; 32];
638 let hhat = generate_hhat(h, w, &seed);
639
640 let cover_bits: Vec<u8> = (0..n).map(|i| (i % 2) as u8).collect();
641 let costs: Vec<f32> = vec![1.0; n];
642 let message: Vec<u8> = vec![1, 0, 1, 1];
643
644 let inline = stc_embed_inline(&cover_bits, &costs, &message, &hhat, h, w).unwrap();
645 let segmented = stc_embed_segmented(&cover_bits, &costs, &message, &hhat, h, w).unwrap();
646
647 assert_eq!(inline.stego_bits, segmented.stego_bits);
648 assert_eq!(inline.total_cost, segmented.total_cost);
649 }
650}