1use super::hhat;
17use super::extract::stc_extract;
18use crate::stego::progress;
19
20pub struct EmbedResult {
22 pub stego_bits: Vec<u8>,
23 pub total_cost: f64,
24 pub num_modifications: usize,
26}
27
28pub const STC_PROGRESS_STEPS: u32 = 50;
33
34const SEGMENTED_THRESHOLD: usize = 1_000_000;
37
38pub fn stc_embed(
58 cover_bits: &[u8],
59 costs: &[f32],
60 message: &[u8],
61 hhat_matrix: &[Vec<u32>],
62 h: usize,
63 w: usize,
64) -> Option<EmbedResult> {
65 if w == 0 || h > 7 {
67 return None;
68 }
69
70 let n = cover_bits.len();
71 let m = message.len();
72
73 if m == 0 {
74 return Some(EmbedResult {
75 stego_bits: cover_bits.to_vec(),
76 total_cost: 0.0,
77 num_modifications: 0,
78 });
79 }
80
81 if n > SEGMENTED_THRESHOLD {
82 use crate::stego::stc::streaming_segmented::{
93 stc_embed_streaming_segmented, InMemoryCoverFetch,
94 };
95 let k = ((m as f64).sqrt().ceil() as usize).max(1);
96 let mut cover = InMemoryCoverFetch::new(cover_bits, costs, m, w, k)?;
97 stc_embed_streaming_segmented(&mut cover, message, hhat_matrix, h, w).ok()
98 } else {
99 stc_embed_inline(cover_bits, costs, message, hhat_matrix, h, w)
100 }
101}
102
103fn stc_embed_inline(
109 cover_bits: &[u8],
110 costs: &[f32],
111 message: &[u8],
112 hhat_matrix: &[Vec<u32>],
113 h: usize,
114 w: usize,
115) -> Option<EmbedResult> {
116 let n = cover_bits.len();
117 let m = message.len();
118 let num_states = 1usize << h;
119 let inf = f64::INFINITY;
120
121 let columns: Vec<usize> = (0..w)
123 .map(|c| hhat::column_packed(hhat_matrix, c) as usize)
124 .collect();
125
126 let progress_interval = (n / STC_PROGRESS_STEPS as usize).max(1);
128
129 let mut prev_cost = vec![inf; num_states];
137 prev_cost[0] = 0.0;
138 let mut curr_cost = vec![0.0f64; num_states];
139 let mut shifted_cost = vec![inf; num_states];
140
141 let mut back_ptr: Vec<u128> = Vec::with_capacity(n);
142 let mut msg_idx = 0;
143
144 for j in 0..n {
145 let col_idx = j % w;
146 let col = columns[col_idx];
147 let flip_cost = costs[j] as f64; let cover_bit = cover_bits[j] & 1;
149
150 let (cost_s0, cost_s1) = if cover_bit == 0 {
151 (0.0, flip_cost)
152 } else {
153 (flip_cost, 0.0)
154 };
155
156 let mut packed_bp = 0u128;
157
158 for s in 0..num_states {
159 let cost_0 = prev_cost[s] + cost_s0;
160 let cost_1 = prev_cost[s ^ col] + cost_s1;
161
162 if cost_1 < cost_0 {
163 curr_cost[s] = cost_1;
164 packed_bp |= 1u128 << s;
165 } else {
166 curr_cost[s] = cost_0;
167 }
168 }
169
170 back_ptr.push(packed_bp);
171
172 if col_idx == w - 1 && msg_idx < m {
173 let required_bit = message[msg_idx] as usize;
174 shifted_cost.fill(inf);
175
176 for s in 0..num_states {
177 if curr_cost[s] == inf { continue; }
178 if (s & 1) != required_bit { continue; }
179 let s_shifted = s >> 1;
180 if curr_cost[s] < shifted_cost[s_shifted] {
181 shifted_cost[s_shifted] = curr_cost[s];
182 }
183 }
184
185 std::mem::swap(&mut prev_cost, &mut shifted_cost);
186 msg_idx += 1;
187 } else {
188 std::mem::swap(&mut prev_cost, &mut curr_cost);
189 }
190
191 if (j + 1) % progress_interval == 0 {
192 if progress::is_cancelled() { return None; }
193 progress::advance();
194 }
195 }
196
197 let (best_state, best_cost) = find_best_state(&prev_cost);
199 if best_cost == inf { return None; }
200
201 let mut stego_bits = vec![0u8; n];
203 let mut s = best_state;
204
205 for j in (0..n).rev() {
206 let col_idx = j % w;
207
208 if col_idx == w - 1 && (j / w) < m {
209 let msg_bit = message[j / w] as usize;
210 s = ((s << 1) | msg_bit) & (num_states - 1);
211 }
212
213 let bit = ((back_ptr[j] >> s) & 1) as u8;
214 stego_bits[j] = bit;
215
216 if bit == 1 {
217 s ^= columns[col_idx];
218 }
219 }
220
221 debug_assert_eq!(s, 0, "traceback did not return to initial state 0");
222 debug_assert_eq!(
223 stc_extract(&stego_bits, hhat_matrix, w)[..m],
224 message[..m],
225 );
226
227 let num_modifications = stego_bits.iter().zip(cover_bits.iter())
228 .filter(|(s, c)| s != c).count();
229
230 Some(EmbedResult { stego_bits, total_cost: best_cost, num_modifications })
231}
232
233#[allow(dead_code)] fn stc_embed_segmented(
240 cover_bits: &[u8],
241 costs: &[f32],
242 message: &[u8],
243 hhat_matrix: &[Vec<u32>],
244 h: usize,
245 w: usize,
246) -> Option<EmbedResult> {
247 let n = cover_bits.len();
248 let m = message.len();
249 let num_states = 1usize << h;
250 let inf = f64::INFINITY;
251
252 let columns: Vec<usize> = (0..w)
255 .map(|c| hhat::column_packed(hhat_matrix, c) as usize)
256 .collect();
257
258 let k = ((m as f64).sqrt().ceil() as usize).max(1);
262 let num_segments = m.div_ceil(k);
263
264 let phase_a_steps = STC_PROGRESS_STEPS / 2;
267 let progress_interval_a = (n / phase_a_steps as usize).max(1);
268
269 let mut prev_cost = vec![inf; num_states];
271 prev_cost[0] = 0.0;
272 let mut curr_cost = vec![0.0f64; num_states];
273 let mut shifted_cost = vec![inf; num_states];
274
275 let mut checkpoints: Vec<Vec<f64>> = Vec::with_capacity(num_segments);
278 checkpoints.push(prev_cost.clone());
279
280 let mut msg_idx = 0;
281
282 for j in 0..n {
283 let col_idx = j % w;
284 let col = columns[col_idx];
285 let flip_cost = costs[j] as f64; let cover_bit = cover_bits[j] & 1;
287
288 let (cost_s0, cost_s1) = if cover_bit == 0 {
289 (0.0, flip_cost)
290 } else {
291 (flip_cost, 0.0)
292 };
293
294 for s in 0..num_states {
295 let cost_0 = prev_cost[s] + cost_s0;
296 let cost_1 = prev_cost[s ^ col] + cost_s1;
297 curr_cost[s] = if cost_1 < cost_0 { cost_1 } else { cost_0 };
298 }
299
300 if col_idx == w - 1 && msg_idx < m {
301 let required_bit = message[msg_idx] as usize;
302 shifted_cost.fill(inf);
303 for s in 0..num_states {
304 if curr_cost[s] == inf { continue; }
305 if (s & 1) != required_bit { continue; }
306 let s_shifted = s >> 1;
307 if curr_cost[s] < shifted_cost[s_shifted] {
308 shifted_cost[s_shifted] = curr_cost[s];
309 }
310 }
311 std::mem::swap(&mut prev_cost, &mut shifted_cost);
312 msg_idx += 1;
313
314 if msg_idx % k == 0 && msg_idx < m {
316 checkpoints.push(prev_cost.clone());
317 }
318 } else {
319 std::mem::swap(&mut prev_cost, &mut curr_cost);
320 }
321
322 if (j + 1) % progress_interval_a == 0 {
323 if progress::is_cancelled() { return None; }
324 progress::advance();
325 }
326 }
327
328 let (best_state, best_cost) = find_best_state(&prev_cost);
330 if best_cost == inf { return None; }
331
332 let phase_b_steps = STC_PROGRESS_STEPS - phase_a_steps;
335 let progress_interval_b = (n / phase_b_steps as usize).max(1);
336 let mut progress_counter = 0usize;
337
338 let mut stego_bits = vec![0u8; n];
339 let mut entry_state = best_state;
340
341 for seg in (0..num_segments).rev() {
342 let block_start = seg * k;
343 let block_end = ((seg + 1) * k).min(m);
344 let j_start = block_start * w;
345 let j_end = block_end * w;
346 let seg_len = j_end - j_start;
347
348 prev_cost.copy_from_slice(&checkpoints[seg]);
350
351 let mut seg_back_ptr: Vec<u128> = Vec::with_capacity(seg_len);
353 let mut seg_msg_idx = block_start;
354
355 for j in j_start..j_end {
356 let col_idx = j % w;
357 let col = columns[col_idx];
358 let flip_cost = costs[j] as f64; let cover_bit = cover_bits[j] & 1;
360
361 let (cost_s0, cost_s1) = if cover_bit == 0 {
362 (0.0, flip_cost)
363 } else {
364 (flip_cost, 0.0)
365 };
366
367 let mut packed_bp = 0u128;
368
369 for s in 0..num_states {
370 let cost_0 = prev_cost[s] + cost_s0;
371 let cost_1 = prev_cost[s ^ col] + cost_s1;
372 if cost_1 < cost_0 {
373 curr_cost[s] = cost_1;
374 packed_bp |= 1u128 << s;
375 } else {
376 curr_cost[s] = cost_0;
377 }
378 }
379
380 seg_back_ptr.push(packed_bp);
381
382 if col_idx == w - 1 && seg_msg_idx < m {
383 let required_bit = message[seg_msg_idx] as usize;
384 shifted_cost.fill(inf);
385 for s in 0..num_states {
386 if curr_cost[s] == inf { continue; }
387 if (s & 1) != required_bit { continue; }
388 let s_shifted = s >> 1;
389 if curr_cost[s] < shifted_cost[s_shifted] {
390 shifted_cost[s_shifted] = curr_cost[s];
391 }
392 }
393 std::mem::swap(&mut prev_cost, &mut shifted_cost);
394 seg_msg_idx += 1;
395 } else {
396 std::mem::swap(&mut prev_cost, &mut curr_cost);
397 }
398
399 progress_counter += 1;
400 if progress_counter.is_multiple_of(progress_interval_b) {
401 if progress::is_cancelled() { return None; }
402 progress::advance();
403 }
404 }
405
406 let mut s = entry_state;
408 for local_j in (0..seg_len).rev() {
409 let j = j_start + local_j;
410 let col_idx = j % w;
411
412 if col_idx == w - 1 && (j / w) < m {
413 let msg_bit = message[j / w] as usize;
414 s = ((s << 1) | msg_bit) & (num_states - 1);
415 }
416
417 let bit = ((seg_back_ptr[local_j] >> s) & 1) as u8;
418 stego_bits[j] = bit;
419
420 if bit == 1 {
421 s ^= columns[col_idx];
422 }
423 }
424
425 entry_state = s;
427 }
429
430 debug_assert_eq!(entry_state, 0, "traceback did not return to initial state 0");
431 debug_assert_eq!(
432 stc_extract(&stego_bits, hhat_matrix, w)[..m],
433 message[..m],
434 );
435
436 let num_modifications = stego_bits.iter().zip(cover_bits.iter())
437 .filter(|(s, c)| s != c).count();
438
439 Some(EmbedResult { stego_bits, total_cost: best_cost, num_modifications })
440}
441
442fn find_best_state(costs: &[f64]) -> (usize, f64) {
448 let mut best = 0;
449 let mut best_cost = f64::INFINITY;
450 for (s, &c) in costs.iter().enumerate() {
451 if c < best_cost {
452 best_cost = c;
453 best = s;
454 }
455 }
456 (best, best_cost)
457}
458
459#[cfg(test)]
460mod tests {
461 use super::*;
462 use super::super::hhat::generate_hhat;
463 use super::super::extract::stc_extract;
464
465 #[test]
466 fn embed_extract_roundtrip_tiny() {
467 let h = 3;
468 let n: usize = 20;
469 let m: usize = 4;
470 let w = n.div_ceil(m); let seed = [42u8; 32];
472 let hhat = generate_hhat(h, w, &seed);
473
474 let cover_bits: Vec<u8> = (0..n).map(|i| (i % 2) as u8).collect();
475 let costs: Vec<f32> = vec![1.0; n];
476 let message = vec![1u8, 0, 1, 1];
477
478 let result = stc_embed(&cover_bits, &costs, &message, &hhat, h, w).unwrap();
479 assert_eq!(result.stego_bits.len(), n);
480
481 let extracted = stc_extract(&result.stego_bits, &hhat, w);
482 assert_eq!(&extracted[..m], &message[..]);
483 }
484
485 #[test]
486 fn embed_extract_roundtrip_h7() {
487 let h = 7;
488 let n: usize = 500;
489 let m: usize = 50;
490 let w = n.div_ceil(m);
491 let seed = [13u8; 32];
492 let hhat = generate_hhat(h, w, &seed);
493
494 let cover_bits: Vec<u8> = (0..n).map(|i| ((i * 7 + 3) % 2) as u8).collect();
495 let costs: Vec<f32> = (0..n).map(|i| 1.0 + (i as f32) * 0.01).collect();
496 let message: Vec<u8> = (0..m).map(|i| (i % 2) as u8).collect();
497
498 let result = stc_embed(&cover_bits, &costs, &message, &hhat, h, w).unwrap();
499 let extracted = stc_extract(&result.stego_bits, &hhat, w);
500 assert_eq!(&extracted[..m], &message[..]);
501 }
502
503 #[test]
504 fn wet_coefficients_not_modified() {
505 let h = 3;
506 let n: usize = 20;
507 let m: usize = 4;
508 let w = n.div_ceil(m);
509 let seed = [55u8; 32];
510 let hhat = generate_hhat(h, w, &seed);
511
512 let cover_bits: Vec<u8> = vec![0; n];
513 let mut costs: Vec<f32> = vec![1.0; n];
514 for i in (0..n).step_by(5) {
516 costs[i] = 1e13;
517 }
518 let message = vec![0u8, 1, 0, 1];
519
520 let result = stc_embed(&cover_bits, &costs, &message, &hhat, h, w).unwrap();
521
522 for i in (0..n).step_by(5) {
524 assert_eq!(
525 result.stego_bits[i], cover_bits[i],
526 "WET position {i} was modified"
527 );
528 }
529
530 let extracted = stc_extract(&result.stego_bits, &hhat, w);
532 assert_eq!(&extracted[..m], &message[..]);
533 }
534
535 #[test]
536 fn empty_message() {
537 let h = 3;
538 let n = 10;
539 let w = 5;
540 let seed = [0u8; 32];
541 let hhat = generate_hhat(h, w, &seed);
542
543 let cover_bits: Vec<u8> = vec![1; n];
544 let costs: Vec<f32> = vec![1.0; n];
545 let message: Vec<u8> = vec![];
546
547 let result = stc_embed(&cover_bits, &costs, &message, &hhat, h, w).unwrap();
548 assert_eq!(result.stego_bits, cover_bits);
549 assert_eq!(result.total_cost, 0.0);
550 }
551
552 #[test]
554 fn embed_extract_roundtrip_large() {
555 let h = 7;
556 let m = 10_000;
557 let w = 10;
558 let n = m * w; let seed = [77u8; 32];
560 let hhat = generate_hhat(h, w, &seed);
561
562 let cover_bits: Vec<u8> = (0..n).map(|i| ((i * 31 + 17) % 2) as u8).collect();
563 let costs: Vec<f32> = (0..n).map(|i| {
564 let base = 0.5 + (i % 100) as f32 * 0.02;
565 if i % 500 == 0 { f32::INFINITY } else { base }
566 }).collect();
567 let message: Vec<u8> = (0..m).map(|i| ((i * 13 + 7) % 2) as u8).collect();
568
569 let result = stc_embed(&cover_bits, &costs, &message, &hhat, h, w).unwrap();
570 assert_eq!(result.stego_bits.len(), n);
571
572 let extracted = stc_extract(&result.stego_bits, &hhat, w);
573 assert_eq!(&extracted[..m], &message[..]);
574
575 for i in (0..n).step_by(500) {
576 assert_eq!(
577 result.stego_bits[i], cover_bits[i],
578 "WET position {i} was modified"
579 );
580 }
581 }
582
583 #[test]
585 fn inline_segmented_equivalence() {
586 let h = 7;
587 let m = 500;
588 let w = 10;
589 let n = m * w; let seed = [99u8; 32];
591 let hhat = generate_hhat(h, w, &seed);
592
593 let cover_bits: Vec<u8> = (0..n).map(|i| ((i * 31 + 17) % 2) as u8).collect();
594 let costs: Vec<f32> = (0..n).map(|i| {
595 let base = 0.5 + (i % 100) as f32 * 0.02;
596 if i % 500 == 0 { f32::INFINITY } else { base }
597 }).collect();
598 let message: Vec<u8> = (0..m).map(|i| ((i * 13 + 7) % 2) as u8).collect();
599
600 let inline = stc_embed_inline(&cover_bits, &costs, &message, &hhat, h, w).unwrap();
601 let segmented = stc_embed_segmented(&cover_bits, &costs, &message, &hhat, h, w).unwrap();
602
603 assert_eq!(inline.stego_bits, segmented.stego_bits, "stego bits differ");
604 assert_eq!(inline.total_cost, segmented.total_cost, "total cost differs");
605 }
606
607 #[test]
609 fn inline_segmented_equivalence_large() {
610 let h = 7;
611 let m = 10_000;
612 let w = 10;
613 let n = m * w; let seed = [88u8; 32];
615 let hhat = generate_hhat(h, w, &seed);
616
617 let cover_bits: Vec<u8> = (0..n).map(|i| ((i * 37 + 11) % 2) as u8).collect();
618 let costs: Vec<f32> = (0..n).map(|i| {
619 let base = 0.3 + (i % 200) as f32 * 0.01;
620 if i % 1000 == 0 { f32::INFINITY } else { base }
621 }).collect();
622 let message: Vec<u8> = (0..m).map(|i| ((i * 19 + 3) % 2) as u8).collect();
623
624 let inline = stc_embed_inline(&cover_bits, &costs, &message, &hhat, h, w).unwrap();
625 let segmented = stc_embed_segmented(&cover_bits, &costs, &message, &hhat, h, w).unwrap();
626
627 assert_eq!(inline.stego_bits, segmented.stego_bits, "stego bits differ");
628 assert_eq!(inline.total_cost, segmented.total_cost, "total cost differs");
629 }
630
631 #[test]
633 fn segmented_single_segment() {
634 let h = 7;
635 let m = 4;
636 let w = 5;
637 let n = m * w;
638 let seed = [33u8; 32];
639 let hhat = generate_hhat(h, w, &seed);
640
641 let cover_bits: Vec<u8> = (0..n).map(|i| (i % 2) as u8).collect();
642 let costs: Vec<f32> = vec![1.0; n];
643 let message: Vec<u8> = vec![1, 0, 1, 1];
644
645 let inline = stc_embed_inline(&cover_bits, &costs, &message, &hhat, h, w).unwrap();
646 let segmented = stc_embed_segmented(&cover_bits, &costs, &message, &hhat, h, w).unwrap();
647
648 assert_eq!(inline.stego_bits, segmented.stego_bits);
649 assert_eq!(inline.total_cost, segmented.total_cost);
650 }
651}