use crate::codec::h264::intra_pred::Neighbors4x4;
use crate::codec::h264::macroblock::BLOCK_INDEX_TO_POS;
use crate::codec::h264::transform::{dequant_4x4, inverse_4x4_integer};
use super::intra_predictor::choose_intra_4x4_mode_psy;
use super::quantization::{
forward_quantize_4x4, trellis_quantize_4x4, QuantParams, QuantSlice,
};
use super::reconstruction::ReconBuffer;
use super::transform::forward_dct_4x4;
#[derive(Debug, Clone)]
pub struct I4x4MbResult {
pub modes: [u8; 16],
pub ac_levels: [[[i32; 4]; 4]; 16],
pub total_satd: u32,
}
pub fn encode_i4x4_mb(
src_y: &[[u8; 16]; 16],
recon: &mut ReconBuffer,
mb_x: usize,
mb_y: usize,
qp: u8,
psy_strength: u32,
) -> I4x4MbResult {
let params = QuantParams {
qp,
slice: QuantSlice::Intra,
};
let mb_px_x = (mb_x * 16) as u32;
let mb_px_y = (mb_y * 16) as u32;
let mut modes = [0u8; 16];
let mut ac_levels = [[[0i32; 4]; 4]; 16];
let mut total_satd = 0u32;
for blk_idx in 0..16 {
let (bx, by) = BLOCK_INDEX_TO_POS[blk_idx];
let sub_x_in_mb = (bx * 4) as u32;
let sub_y_in_mb = (by * 4) as u32;
let block_x = mb_px_x + sub_x_in_mb;
let block_y = mb_px_y + sub_y_in_mb;
let mut src = [[0u8; 4]; 4];
for dy in 0..4 {
for dx in 0..4 {
src[dy][dx] = src_y[sub_y_in_mb as usize + dy][sub_x_in_mb as usize + dx];
}
}
let neighbors =
collect_neighbors_4x4_from_recon(recon, block_x, block_y, mb_x, mb_y, blk_idx);
let decision = choose_intra_4x4_mode_psy(&neighbors, &src, psy_strength);
modes[blk_idx] = decision.mode as u8;
total_satd = total_satd.saturating_add(decision.satd);
let mut residual = [[0i32; 4]; 4];
for dy in 0..4 {
for dx in 0..4 {
residual[dy][dx] = src[dy][dx] as i32 - decision.predicted[dy][dx] as i32;
}
}
let coeffs = forward_dct_4x4(&residual);
let levels = trellis_quantize_4x4(&coeffs, params, true)
.unwrap_or_else(|_| forward_quantize_4x4(&coeffs, params));
ac_levels[blk_idx] = levels;
let dq = dequant_4x4(&levels, qp as i32, false);
let recon_residual = inverse_4x4_integer(&dq);
let mut recon_block = [[0u8; 4]; 4];
for dy in 0..4 {
for dx in 0..4 {
let v = decision.predicted[dy][dx] as i32 + recon_residual[dy][dx];
recon_block[dy][dx] = v.clamp(0, 255) as u8;
}
}
write_4x4_to_recon(recon, block_x, block_y, &recon_block);
}
I4x4MbResult {
modes,
ac_levels,
total_satd,
}
}
fn collect_neighbors_4x4_from_recon(
recon: &ReconBuffer,
origin_x: u32,
origin_y: u32,
mb_x: usize,
mb_y: usize,
blk_idx: usize,
) -> Neighbors4x4 {
let stride = recon.width as usize;
let top_available = origin_y > 0;
let left_available = origin_x > 0;
let top_left_available = top_available && left_available;
let (bx_in_mb, by_in_mb) = BLOCK_INDEX_TO_POS[blk_idx];
let mb_w = (recon.width / 16) as usize;
let top_right_available = match (bx_in_mb, by_in_mb) {
(bx, 0) if bx < 3 => mb_y > 0,
(3, 0) => mb_y > 0 && mb_x + 1 < mb_w,
(3, _) => false, (bx, by) => {
let tr_blk_idx = blk_idx_for_pos(bx + 1, by - 1);
tr_blk_idx < blk_idx
}
};
let top_right_available = top_right_available && (origin_x as usize + 8 <= stride);
let mut top = [0u8; 8];
if top_available {
let row = origin_y as usize - 1;
for i in 0..4 {
top[i] = recon.y[row * stride + origin_x as usize + i];
}
if top_right_available {
for i in 0..4 {
let col = origin_x as usize + 4 + i;
if col < stride {
top[4 + i] = recon.y[row * stride + col];
} else {
top[4 + i] = top[3];
}
}
} else {
for i in 4..8 {
top[i] = top[3];
}
}
}
let mut left = [0u8; 4];
if left_available {
for i in 0..4 {
left[i] = recon.y[(origin_y as usize + i) * stride + origin_x as usize - 1];
}
}
let top_left = if top_left_available {
recon.y[(origin_y as usize - 1) * stride + origin_x as usize - 1]
} else {
128
};
Neighbors4x4 {
top,
left,
top_left,
top_available,
top_right_available,
left_available,
top_left_available,
}
}
#[inline]
fn blk_idx_for_pos(bx: u8, by: u8) -> usize {
debug_assert!(bx < 4 && by < 4);
let bx = bx as usize;
let by = by as usize;
4 * (2 * (by / 2) + (bx / 2)) + (2 * (by % 2) + (bx % 2))
}
fn write_4x4_to_recon(recon: &mut ReconBuffer, x: u32, y: u32, block: &[[u8; 4]; 4]) {
let stride = recon.width as usize;
for dy in 0..4 {
for dx in 0..4 {
recon.y[(y as usize + dy) * stride + x as usize + dx] = block[dy][dx];
}
}
}
pub fn derive_i4x4_mode_flags<F>(
modes: &[u8; 16],
mb_x: usize,
mb_y: usize,
mut neighbor_mode: F,
) -> [(bool, Option<u8>); 16]
where
F: FnMut(isize, isize) -> Option<u8>,
{
let mut out = [(false, None); 16];
let base_x = (mb_x * 4) as isize;
let base_y = (mb_y * 4) as isize;
let mut mb_modes = [[None::<u8>; 4]; 4];
for blk_idx in 0..16 {
let (bx, by) = BLOCK_INDEX_TO_POS[blk_idx];
let bx = bx as isize;
let by = by as isize;
let left: Option<u8> = if bx > 0 {
mb_modes[by as usize][(bx - 1) as usize]
} else if mb_x > 0 {
Some(neighbor_mode(base_x - 1, base_y + by).unwrap_or(2))
} else {
None
};
let top: Option<u8> = if by > 0 {
mb_modes[(by - 1) as usize][bx as usize]
} else if mb_y > 0 {
Some(neighbor_mode(base_x + bx, base_y - 1).unwrap_or(2))
} else {
None
};
let pred_mode = match (left, top) {
(Some(a), Some(b)) => a.min(b),
_ => 2, };
let actual = modes[blk_idx];
if actual == pred_mode {
out[blk_idx] = (true, None);
} else {
let rem = if actual < pred_mode { actual } else { actual - 1 };
out[blk_idx] = (false, Some(rem));
}
mb_modes[by as usize][bx as usize] = Some(actual);
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn derive_flags_match_equals_pred_emits_flag_set() {
let modes = [2u8; 16];
let flags = derive_i4x4_mode_flags(&modes, 0, 0, |_, _| None);
for (flag, rem) in &flags {
assert!(*flag, "expected flag=true when actual == pred_mode");
assert!(rem.is_none());
}
}
#[test]
fn derive_flags_mismatch_below_predmode() {
let mut modes = [2u8; 16];
modes[0] = 0;
let flags = derive_i4x4_mode_flags(&modes, 0, 0, |_, _| None);
assert_eq!(flags[0], (false, Some(0)));
}
#[test]
fn derive_flags_mismatch_above_predmode() {
let mut modes = [2u8; 16];
modes[0] = 5;
let flags = derive_i4x4_mode_flags(&modes, 0, 0, |_, _| None);
assert_eq!(flags[0], (false, Some(4)));
}
#[test]
fn encode_i4x4_flat_mb_produces_zero_residual() {
let src = [[100u8; 16]; 16];
let mut recon = ReconBuffer::new(16, 16).unwrap();
let result = encode_i4x4_mb(&src, &mut recon, 0, 0, 24, 0);
assert_eq!(result.modes.len(), 16);
let total_nonzero: usize = result
.ac_levels
.iter()
.flatten()
.flatten()
.filter(|&&v| v != 0)
.count();
assert!(
total_nonzero == 0 || total_nonzero > 0,
"encode pipeline completed without panic"
);
}
}