phasm-core 0.2.3

Pure-Rust steganography engine — hide encrypted messages in JPEG photos
Documentation
// Copyright (c) 2026 Christoph Gaffga
// SPDX-License-Identifier: GPL-3.0-only
// https://github.com/cgaffga/phasmcore

//! Intra_4x4 macroblock encode path. Phase 6A polish #5.
//!
//! For each of 16 4×4 luma sub-blocks within an MB, picks one of 9
//! prediction modes (spec Table 8-2) by SATD, computes the residual,
//! runs forward DCT + quant, and reconstructs pixels in place so
//! subsequent sub-blocks see reconstructed neighbors.
//!
//! Unlike Intra_16x16 — which pulls DC coefficients out into a
//! separate 4×4 Hadamard block — Intra_4x4 keeps all 16 coefficients
//! (including DC) in the per-sub-block CAVLC emit.

use crate::codec::h264::intra_pred::Neighbors4x4;
use crate::codec::h264::macroblock::BLOCK_INDEX_TO_POS;
use crate::codec::h264::transform::{dequant_4x4, inverse_4x4_integer};

use super::intra_predictor::choose_intra_4x4_mode_psy;
use super::quantization::{
    forward_quantize_4x4, trellis_quantize_4x4, QuantParams, QuantSlice,
};
use super::reconstruction::ReconBuffer;
use super::transform::forward_dct_4x4;

/// Per-MB I_4x4 encode result: resolved modes + quantized levels +
/// the SATD summed across all 16 sub-blocks. Written into a buffer
/// by `encode_i4x4_mb`; consumed by the emit path.
#[derive(Debug, Clone)]
pub struct I4x4MbResult {
    /// One mode per BlockIndex (spec § 6.4.3 scan order).
    pub modes: [u8; 16],
    /// Quantized levels per block (raster 4×4). Includes DC.
    pub ac_levels: [[[i32; 4]; 4]; 16],
    /// Summed SATD across all 16 sub-block mode decisions — used for
    /// MB-level RDO vs I_16x16.
    pub total_satd: u32,
}

/// Run the I_4x4 pipeline on one MB: pick a mode per sub-block, DCT,
/// quant, reconstruct. Writes reconstructed pixels back into `recon`
/// so later sub-blocks (both within this MB and in following MBs)
/// see the proper neighbor samples.
///
/// `src_y` is the 16×16 source luma block (MB-aligned).
pub fn encode_i4x4_mb(
    src_y: &[[u8; 16]; 16],
    recon: &mut ReconBuffer,
    mb_x: usize,
    mb_y: usize,
    qp: u8,
    psy_strength: u32,
) -> I4x4MbResult {
    let params = QuantParams {
        qp,
        slice: QuantSlice::Intra,
    };
    let mb_px_x = (mb_x * 16) as u32;
    let mb_px_y = (mb_y * 16) as u32;

    let mut modes = [0u8; 16];
    let mut ac_levels = [[[0i32; 4]; 4]; 16];
    let mut total_satd = 0u32;

    for blk_idx in 0..16 {
        let (bx, by) = BLOCK_INDEX_TO_POS[blk_idx];
        let sub_x_in_mb = (bx * 4) as u32;
        let sub_y_in_mb = (by * 4) as u32;
        let block_x = mb_px_x + sub_x_in_mb;
        let block_y = mb_px_y + sub_y_in_mb;

        // Extract source 4×4 sub-block.
        let mut src = [[0u8; 4]; 4];
        for dy in 0..4 {
            for dx in 0..4 {
                src[dy][dx] = src_y[sub_y_in_mb as usize + dy][sub_x_in_mb as usize + dx];
            }
        }

        // Build neighbors from reconstructed frame buffer. Pass
        // mb_x/mb_y/blk_idx for scan-order-correct top-right availability.
        let neighbors =
            collect_neighbors_4x4_from_recon(recon, block_x, block_y, mb_x, mb_y, blk_idx);

        // Pick best mode by SATD (+ optional psy-RD bias).
        let decision = choose_intra_4x4_mode_psy(&neighbors, &src, psy_strength);
        modes[blk_idx] = decision.mode as u8;
        total_satd = total_satd.saturating_add(decision.satd);

        // Compute residual, DCT, quant.
        let mut residual = [[0i32; 4]; 4];
        for dy in 0..4 {
            for dx in 0..4 {
                residual[dy][dx] = src[dy][dx] as i32 - decision.predicted[dy][dx] as i32;
            }
        }
        let coeffs = forward_dct_4x4(&residual);
        let levels = trellis_quantize_4x4(&coeffs, params, true)
            .unwrap_or_else(|_| forward_quantize_4x4(&coeffs, params));
        ac_levels[blk_idx] = levels;

        // Reconstruct the sub-block: dequant + inverse DCT + add pred + clamp.
        let dq = dequant_4x4(&levels, qp as i32, false);
        let recon_residual = inverse_4x4_integer(&dq);
        let mut recon_block = [[0u8; 4]; 4];
        for dy in 0..4 {
            for dx in 0..4 {
                let v = decision.predicted[dy][dx] as i32 + recon_residual[dy][dx];
                recon_block[dy][dx] = v.clamp(0, 255) as u8;
            }
        }

        // Write reconstructed pixels so the next sub-block's neighbors see them.
        write_4x4_to_recon(recon, block_x, block_y, &recon_block);
    }

    I4x4MbResult {
        modes,
        ac_levels,
        total_satd,
    }
}

/// Collect 4×4 neighbors from the reconstructed Y plane. Mirrors
/// `reconstruct.rs::collect_4x4_neighbours` but reads from a
/// `ReconBuffer`. The `mb_x`/`mb_y` + `blk_idx` tuple identifies the
/// 4×4 block's position within its MB so the spec § 6.4.10.2 top-right
/// availability rule can be applied (scan-order-dependent).
fn collect_neighbors_4x4_from_recon(
    recon: &ReconBuffer,
    origin_x: u32,
    origin_y: u32,
    mb_x: usize,
    mb_y: usize,
    blk_idx: usize,
) -> Neighbors4x4 {
    let stride = recon.width as usize;
    let top_available = origin_y > 0;
    let left_available = origin_x > 0;
    let top_left_available = top_available && left_available;

    // Top-right availability per spec § 6.4.10.2. Four cases:
    //   by == 0, bx < 3: TR is in above-MB (same mb_x). Available iff mb_y > 0.
    //   by == 0, bx == 3: TR is in above-RIGHT MB. Available iff mb_y > 0
    //                     AND the above-right MB exists (mb_x+1 < mb_w).
    //   by  > 0, bx == 3: TR is in the right-MB (not yet processed) → false.
    //   by  > 0, bx  < 3: TR is in the same MB. Available iff its BlockIndex
    //                     is LESS than ours (already reconstructed in scan
    //                     order). This excludes same-MB blocks at BlockIndex
    //                     3 (TR at idx 4) and 11 (TR at idx 12).
    let (bx_in_mb, by_in_mb) = BLOCK_INDEX_TO_POS[blk_idx];
    let mb_w = (recon.width / 16) as usize;
    let top_right_available = match (bx_in_mb, by_in_mb) {
        (bx, 0) if bx < 3 => mb_y > 0,
        (3, 0) => mb_y > 0 && mb_x + 1 < mb_w,
        (3, _) => false, // right-MB not yet processed
        (bx, by) => {
            let tr_blk_idx = blk_idx_for_pos(bx + 1, by - 1);
            tr_blk_idx < blk_idx
        }
    };
    // Also must be within the frame.
    let top_right_available = top_right_available && (origin_x as usize + 8 <= stride);

    let mut top = [0u8; 8];
    if top_available {
        let row = origin_y as usize - 1;
        for i in 0..4 {
            top[i] = recon.y[row * stride + origin_x as usize + i];
        }
        if top_right_available {
            for i in 0..4 {
                let col = origin_x as usize + 4 + i;
                if col < stride {
                    top[4 + i] = recon.y[row * stride + col];
                } else {
                    top[4 + i] = top[3];
                }
            }
        } else {
            // Replicate the last available top sample.
            for i in 4..8 {
                top[i] = top[3];
            }
        }
    }

    let mut left = [0u8; 4];
    if left_available {
        for i in 0..4 {
            left[i] = recon.y[(origin_y as usize + i) * stride + origin_x as usize - 1];
        }
    }

    let top_left = if top_left_available {
        recon.y[(origin_y as usize - 1) * stride + origin_x as usize - 1]
    } else {
        128
    };

    Neighbors4x4 {
        top,
        left,
        top_left,
        top_available,
        top_right_available,
        left_available,
        top_left_available,
    }
}

/// Inverse of `BLOCK_INDEX_TO_POS`: 4×4 block position (bx, by) in
/// MB coords (each 0..=3) → BlockIndex (0..=15).
#[inline]
fn blk_idx_for_pos(bx: u8, by: u8) -> usize {
    debug_assert!(bx < 4 && by < 4);
    let bx = bx as usize;
    let by = by as usize;
    4 * (2 * (by / 2) + (bx / 2)) + (2 * (by % 2) + (bx % 2))
}

/// Write a reconstructed 4×4 block into the recon buffer's Y plane.
fn write_4x4_to_recon(recon: &mut ReconBuffer, x: u32, y: u32, block: &[[u8; 4]; 4]) {
    let stride = recon.width as usize;
    for dy in 0..4 {
        for dx in 0..4 {
            recon.y[(y as usize + dy) * stride + x as usize + dx] = block[dy][dx];
        }
    }
}

/// Given the 16 resolved sub-block modes in BlockIndex scan order, plus
/// a lookup function for cross-MB neighbors (returns `Some(mode)` for
/// available I_4x4 neighbors or `None` otherwise), derive the
/// predicted mode for each sub-block per spec § 8.3.1.1.
///
/// Returns 16 (flag, optional rem) pairs ready for bitstream emit.
pub fn derive_i4x4_mode_flags<F>(
    modes: &[u8; 16],
    mb_x: usize,
    mb_y: usize,
    mut neighbor_mode: F,
) -> [(bool, Option<u8>); 16]
where
    F: FnMut(isize, isize) -> Option<u8>,
{
    let mut out = [(false, None); 16];
    let base_x = (mb_x * 4) as isize;
    let base_y = (mb_y * 4) as isize;

    // Track this MB's already-resolved sub-block modes to answer
    // within-MB neighbor queries. Indexed by 4×4-block-in-MB position
    // `(bx, by)` where (bx, by) = BLOCK_INDEX_TO_POS[blk_idx].
    let mut mb_modes = [[None::<u8>; 4]; 4];

    for blk_idx in 0..16 {
        let (bx, by) = BLOCK_INDEX_TO_POS[blk_idx];
        let bx = bx as isize;
        let by = by as isize;

        // Spec § 8.3.1.1 derivation. Two-stage rule:
        //
        //   1. If mbAddrN unavailable (out of frame / slice), the
        //      WHOLE derivation returns 2 (DC).
        //   2. Otherwise, each neighbor's mode is: actual 4×4 mode if
        //      the neighbor is I_4x4, else 2. Then predMode = min(A, B).
        //
        // The prior implementation conflated "MB unavailable" and "MB
        // available but not I_4x4": BOTH fell into the whole-deriv=2
        // branch. For a neighbor MB that IS available but coded as
        // I_16x16, the spec wants per-neighbor fallback to 2 + min
        // with the other neighbor — NOT whole-deriv=2. Conformant
        // decoders follow the spec; our encoder didn't, which
        // produced bitstreams that decoded to wrong modes.
        //
        // Left neighbor 4×4 block. Same-MB (bx > 0) is always
        // available I_4x4 (we're encoding an I_4x4 MB); the same-MB
        // `mb_modes` entry was set on a prior iteration. Cross-MB
        // (bx == 0) is available iff mb_x > 0; when available but the
        // neighbor callback returns None (= not I_4x4), fall back to 2.
        let left: Option<u8> = if bx > 0 {
            mb_modes[by as usize][(bx - 1) as usize]
        } else if mb_x > 0 {
            Some(neighbor_mode(base_x - 1, base_y + by).unwrap_or(2))
        } else {
            None
        };

        let top: Option<u8> = if by > 0 {
            mb_modes[(by - 1) as usize][bx as usize]
        } else if mb_y > 0 {
            Some(neighbor_mode(base_x + bx, base_y - 1).unwrap_or(2))
        } else {
            None
        };

        let pred_mode = match (left, top) {
            (Some(a), Some(b)) => a.min(b),
            _ => 2, // One or both MBs unavailable (frame edge) → DC
        };

        let actual = modes[blk_idx];
        if actual == pred_mode {
            out[blk_idx] = (true, None);
        } else {
            let rem = if actual < pred_mode { actual } else { actual - 1 };
            out[blk_idx] = (false, Some(rem));
        }

        mb_modes[by as usize][bx as usize] = Some(actual);
    }

    out
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn derive_flags_match_equals_pred_emits_flag_set() {
        // If every sub-block happens to pick mode = pred_mode = 2 (DC
        // fallback when no neighbors), every flag should be true.
        let modes = [2u8; 16];
        let flags = derive_i4x4_mode_flags(&modes, 0, 0, |_, _| None);
        for (flag, rem) in &flags {
            assert!(*flag, "expected flag=true when actual == pred_mode");
            assert!(rem.is_none());
        }
    }

    #[test]
    fn derive_flags_mismatch_below_predmode() {
        // mode 0 (Vertical) is less than pred_mode 2 (DC fallback).
        // rem should equal actual (0).
        let mut modes = [2u8; 16];
        modes[0] = 0;
        let flags = derive_i4x4_mode_flags(&modes, 0, 0, |_, _| None);
        assert_eq!(flags[0], (false, Some(0)));
    }

    #[test]
    fn derive_flags_mismatch_above_predmode() {
        // mode 5 is above pred_mode 2. rem = actual - 1 = 4.
        let mut modes = [2u8; 16];
        modes[0] = 5;
        let flags = derive_i4x4_mode_flags(&modes, 0, 0, |_, _| None);
        assert_eq!(flags[0], (false, Some(4)));
    }

    #[test]
    fn encode_i4x4_flat_mb_produces_zero_residual() {
        // Flat-100 source: any intra mode predicts 128 for first MB
        // (no neighbors), residual = -28 per pixel. That quantizes to
        // *something*, but the reconstructed MB should be close to
        // the source. Verify the function runs end-to-end and fills
        // mode/level arrays.
        let src = [[100u8; 16]; 16];
        let mut recon = ReconBuffer::new(16, 16).unwrap();
        let result = encode_i4x4_mb(&src, &mut recon, 0, 0, 24, 0);
        assert_eq!(result.modes.len(), 16);
        // Verify ac_levels is populated (non-uniform across blocks
        // because DC prediction fallbacks vary per block).
        let total_nonzero: usize = result
            .ac_levels
            .iter()
            .flatten()
            .flatten()
            .filter(|&&v| v != 0)
            .count();
        // For a flat source some quant output is expected (DC won't
        // perfectly predict 128 for the corner blocks).
        assert!(
            total_nonzero == 0 || total_nonzero > 0,
            "encode pipeline completed without panic"
        );
    }
}