const QUANT_MF: [[i32; 3]; 6] = [
[13107, 5243, 8066],
[11916, 4660, 7490],
[10082, 4194, 6554],
[9362, 3647, 5825],
[8192, 3355, 5243],
[7282, 2893, 4559],
];
const POSITION_CLASS: [usize; 16] = [0, 2, 0, 2, 2, 1, 2, 1, 0, 2, 0, 2, 2, 1, 2, 1];
const ZIGZAG_SCAN_4X4: [usize; 16] = {
let zig: [(usize, usize); 16] = [
(0, 0),
(0, 1),
(1, 0),
(2, 0),
(1, 1),
(0, 2),
(0, 3),
(1, 2),
(2, 1),
(3, 0),
(3, 1),
(2, 2),
(1, 3),
(2, 3),
(3, 2),
(3, 3),
];
let mut result = [0usize; 16];
let mut i = 0;
while i < 16 {
let (r, c) = zig[i];
result[r * 4 + c] = i;
i += 1;
}
result
};
struct BitWriter {
buf: Vec<u8>,
current: u8,
bits_in_current: u8,
}
impl BitWriter {
fn new() -> Self {
Self {
buf: Vec::with_capacity(4096),
current: 0,
bits_in_current: 0,
}
}
fn write_bits(&mut self, value: u32, n: u8) {
for i in (0..n).rev() {
let bit = (value >> i) & 1;
self.current = (self.current << 1) | bit as u8;
self.bits_in_current += 1;
if self.bits_in_current == 8 {
self.buf.push(self.current);
self.current = 0;
self.bits_in_current = 0;
}
}
}
fn write_bit(&mut self, bit: u8) {
self.write_bits(bit as u32, 1);
}
fn write_ue(&mut self, value: u32) {
if value == 0 {
self.write_bit(1);
return;
}
let n = 32 - (value + 1).leading_zeros(); let leading_zeros = n - 1;
for _ in 0..leading_zeros {
self.write_bit(0);
}
self.write_bits(value + 1, n as u8);
}
fn write_se(&mut self, value: i32) {
let code = if value > 0 {
(value as u32) * 2 - 1
} else if value < 0 {
((-value) as u32) * 2
} else {
0
};
self.write_ue(code);
}
fn rbsp_trailing_bits(&mut self) {
self.write_bit(1);
while self.bits_in_current != 0 {
self.write_bit(0);
}
}
fn into_bytes(mut self) -> Vec<u8> {
if self.bits_in_current > 0 {
self.current <<= 8 - self.bits_in_current;
self.buf.push(self.current);
}
self.buf
}
}
pub fn forward_dct_4x4(block: &mut [i32; 16]) {
#[cfg(target_arch = "aarch64")]
{
unsafe {
forward_dct_4x4_neon(block);
}
return;
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("sse2") {
unsafe {
forward_dct_4x4_sse2(block);
}
return;
}
}
#[allow(unreachable_code)]
forward_dct_4x4_scalar(block);
}
fn forward_dct_4x4_scalar(block: &mut [i32; 16]) {
for i in 0..4 {
let base = i * 4;
let s0 = block[base];
let s1 = block[base + 1];
let s2 = block[base + 2];
let s3 = block[base + 3];
let p0 = s0 + s3;
let p1 = s1 + s2;
let p2 = s1 - s2;
let p3 = s0 - s3;
block[base] = p0 + p1;
block[base + 1] = (p3 << 1) + p2;
block[base + 2] = p0 - p1;
block[base + 3] = p3 - (p2 << 1);
}
for j in 0..4 {
let s0 = block[j];
let s1 = block[4 + j];
let s2 = block[8 + j];
let s3 = block[12 + j];
let p0 = s0 + s3;
let p1 = s1 + s2;
let p2 = s1 - s2;
let p3 = s0 - s3;
block[j] = p0 + p1;
block[4 + j] = (p3 << 1) + p2;
block[8 + j] = p0 - p1;
block[12 + j] = p3 - (p2 << 1);
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn forward_dct_4x4_neon(block: &mut [i32; 16]) {
use std::arch::aarch64::*;
let ptr = block.as_mut_ptr();
for i in 0..4 {
let row = vld1q_s32(ptr.add(i * 4));
let s0 = vgetq_lane_s32(row, 0);
let s1 = vgetq_lane_s32(row, 1);
let s2 = vgetq_lane_s32(row, 2);
let s3 = vgetq_lane_s32(row, 3);
let p0 = s0 + s3;
let p1 = s1 + s2;
let p2 = s1 - s2;
let p3 = s0 - s3;
let out = [p0 + p1, (p3 << 1) + p2, p0 - p1, p3 - (p2 << 1)];
vst1q_s32(ptr.add(i * 4), vld1q_s32(out.as_ptr()));
}
for j in 0..4 {
let s0 = *ptr.add(j);
let s1 = *ptr.add(4 + j);
let s2 = *ptr.add(8 + j);
let s3 = *ptr.add(12 + j);
let p0 = s0 + s3;
let p1 = s1 + s2;
let p2 = s1 - s2;
let p3 = s0 - s3;
*ptr.add(j) = p0 + p1;
*ptr.add(4 + j) = (p3 << 1) + p2;
*ptr.add(8 + j) = p0 - p1;
*ptr.add(12 + j) = p3 - (p2 << 1);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse2")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn forward_dct_4x4_sse2(block: &mut [i32; 16]) {
use std::arch::x86_64::*;
let ptr = block.as_mut_ptr();
for i in 0..4 {
let row = _mm_loadu_si128(ptr.add(i * 4) as *const __m128i);
let s0 = _mm_extract_epi32::<0>(row);
let s1 = _mm_extract_epi32::<1>(row);
let s2 = _mm_extract_epi32::<2>(row);
let s3 = _mm_extract_epi32::<3>(row);
let p0 = s0 + s3;
let p1 = s1 + s2;
let p2 = s1 - s2;
let p3 = s0 - s3;
let out = _mm_set_epi32(p3 - (p2 << 1), p0 - p1, (p3 << 1) + p2, p0 + p1);
_mm_storeu_si128(ptr.add(i * 4) as *mut __m128i, out);
}
for j in 0..4 {
let s0 = *ptr.add(j);
let s1 = *ptr.add(4 + j);
let s2 = *ptr.add(8 + j);
let s3 = *ptr.add(12 + j);
let p0 = s0 + s3;
let p1 = s1 + s2;
let p2 = s1 - s2;
let p3 = s0 - s3;
*ptr.add(j) = p0 + p1;
*ptr.add(4 + j) = (p3 << 1) + p2;
*ptr.add(8 + j) = p0 - p1;
*ptr.add(12 + j) = p3 - (p2 << 1);
}
}
fn quantize_4x4(coeffs: &mut [i32; 16], qp: u8) {
let qp = qp.min(51);
let qp_div6 = (qp / 6) as u32;
let qp_mod6 = (qp % 6) as usize;
let qbits = 15 + qp_div6;
let f = (1i32 << qbits) / 3;
let mf = &QUANT_MF[qp_mod6];
for i in 0..16 {
let class = POSITION_CLASS[i];
let sign = if coeffs[i] < 0 { -1i32 } else { 1i32 };
let abs_coeff = coeffs[i].abs();
coeffs[i] = sign * ((abs_coeff * mf[class] + f) >> qbits);
}
}
#[cfg(test)]
fn dequantize_4x4(coeffs: &mut [i32; 16], qp: u8) {
let qp = qp.min(51) as i32;
super::h264_transform::dequant_4x4(coeffs, qp);
}
fn cavlc_encode_block(writer: &mut BitWriter, coeffs: &[i32; 16], nc: i32) -> usize {
let mut levels = [0i32; 16];
let mut total_coeffs: usize = 0;
let mut scan_coeffs = [0i32; 16];
for i in 0..16 {
scan_coeffs[ZIGZAG_SCAN_4X4[i]] = coeffs[i];
}
let mut last_nonzero: Option<usize> = None;
for i in (0..16).rev() {
if scan_coeffs[i] != 0 {
last_nonzero = Some(i);
break;
}
}
let last_nonzero = match last_nonzero {
Some(v) => v,
None => {
write_coeff_token(writer, 0, 0, nc);
return 0;
}
};
let mut runs = [0usize; 16];
let mut run_count = 0usize;
let mut total_zeros = 0usize;
let mut current_run = 0usize;
for i in (0..=last_nonzero).rev() {
if scan_coeffs[i] != 0 {
if total_coeffs > 0 {
runs[run_count] = current_run;
run_count += 1;
}
levels[total_coeffs] = scan_coeffs[i];
total_coeffs += 1;
current_run = 0;
} else {
current_run += 1;
total_zeros += 1;
}
}
if total_coeffs > 0 {
runs[run_count] = current_run;
}
let mut trailing_ones = 0usize;
let mut t1_signs = [0u8; 3]; for i in 0..total_coeffs.min(3) {
if levels[i] == 1 || levels[i] == -1 {
t1_signs[trailing_ones] = if levels[i] < 0 { 1 } else { 0 };
trailing_ones += 1;
} else {
break;
}
}
write_coeff_token(writer, total_coeffs as u8, trailing_ones as u8, nc);
for i in 0..trailing_ones {
writer.write_bit(t1_signs[i]);
}
let mut suffix_length: u8 = if total_coeffs > 10 && trailing_ones < 3 {
1
} else {
0
};
for i in trailing_ones..total_coeffs {
let level = levels[i];
let mut level_code = if level > 0 {
(level - 1) * 2
} else {
-level * 2 - 1
};
if i == trailing_ones && trailing_ones < 3 {
level_code -= 2;
}
let (prefix, suffix_val, suffix_len) = encode_level_code(level_code as u32, suffix_length);
for _ in 0..prefix {
writer.write_bit(0);
}
writer.write_bit(1);
if suffix_len > 0 {
writer.write_bits(suffix_val, suffix_len);
}
if suffix_length == 0 {
suffix_length = 1;
}
if level.unsigned_abs() > (3 << (suffix_length - 1)) {
suffix_length += 1;
}
}
if total_coeffs < 16 {
write_total_zeros(writer, total_zeros as u8, total_coeffs as u8);
}
let mut zeros_left = total_zeros;
for i in 0..total_coeffs.saturating_sub(1) {
if zeros_left == 0 {
break;
}
let run = runs[i];
write_run_before(writer, run as u8, zeros_left as u8);
zeros_left = zeros_left.saturating_sub(run);
}
total_coeffs
}
fn encode_level_code(level_code: u32, suffix_length: u8) -> (u32, u32, u8) {
if suffix_length == 0 {
if level_code < 14 {
(level_code, 0, 0)
} else if level_code < 30 {
(14, level_code - 14, 4)
} else {
let suffix_bits = 12u8;
(15, level_code - 30, suffix_bits)
}
} else {
let shift = suffix_length as u32;
let prefix = level_code >> shift;
let suffix = level_code & ((1 << shift) - 1);
if prefix < 15 {
(prefix, suffix, suffix_length)
} else {
let suffix_bits = 12u8;
(15, level_code - (15 << shift), suffix_bits)
}
}
}
fn write_coeff_token(writer: &mut BitWriter, total_coeffs: u8, trailing_ones: u8, nc: i32) {
if nc >= 8 {
let code = if total_coeffs == 0 {
3u32
} else {
((total_coeffs as u32 - 1) << 2) | trailing_ones as u32
};
writer.write_bits(code, 6);
return;
}
let table: &[(u32, u8, u8, u8)] = match nc {
0..=1 => COEFF_TOKEN_ENC_0_1,
2..=3 => COEFF_TOKEN_ENC_2_3,
4..=7 => COEFF_TOKEN_ENC_4_7,
_ => COEFF_TOKEN_ENC_0_1, };
for &(pattern, length, tc, t1) in table {
if tc == total_coeffs && t1 == trailing_ones {
writer.write_bits(pattern, length);
return;
}
}
let code = if total_coeffs == 0 {
3u32
} else {
((total_coeffs as u32 - 1) << 2) | trailing_ones as u32
};
writer.write_bits(code, 6);
}
const COEFF_TOKEN_ENC_0_1: &[(u32, u8, u8, u8)] = &[
(0b1, 1, 0, 0),
(0b000101, 6, 1, 0),
(0b01, 2, 1, 1),
(0b00000111, 8, 2, 0),
(0b000100, 6, 2, 1),
(0b001, 3, 2, 2),
(0b000000111, 9, 3, 0),
(0b00000110, 8, 3, 1),
(0b0000101, 7, 3, 2),
(0b00011, 5, 3, 3),
(0b0000000111, 10, 4, 0),
(0b000000110, 9, 4, 1),
(0b00000101, 8, 4, 2),
(0b000011, 6, 4, 3),
];
const COEFF_TOKEN_ENC_2_3: &[(u32, u8, u8, u8)] = &[
(0b11, 2, 0, 0),
(0b001011, 6, 1, 0),
(0b10, 2, 1, 1),
(0b000111, 6, 2, 0),
(0b00111, 5, 2, 1),
(0b011, 3, 2, 2),
(0b0000111, 7, 3, 0),
(0b001010, 6, 3, 1),
(0b001001, 6, 3, 2),
(0b00101, 5, 3, 3),
(0b00000111, 8, 4, 0),
(0b0000110, 7, 4, 1),
(0b000110, 6, 4, 2),
(0b00100, 5, 4, 3),
];
const COEFF_TOKEN_ENC_4_7: &[(u32, u8, u8, u8)] = &[
(0b1111, 4, 0, 0),
(0b001111, 6, 1, 0),
(0b1110, 4, 1, 1),
(0b001011, 6, 2, 0),
(0b01111, 5, 2, 1),
(0b1101, 4, 2, 2),
(0b001000, 6, 3, 0),
(0b01110, 5, 3, 1),
(0b01101, 5, 3, 2),
(0b1100, 4, 3, 3),
(0b0000111, 7, 4, 0),
(0b001110, 6, 4, 1),
(0b001010, 6, 4, 2),
(0b1011, 4, 4, 3),
];
fn write_total_zeros(writer: &mut BitWriter, total_zeros: u8, total_coeffs: u8) {
let table: &[(u32, u8, u8)] = match total_coeffs {
1 => TOTAL_ZEROS_ENC_TC1,
2 => TOTAL_ZEROS_ENC_TC2,
3 => TOTAL_ZEROS_ENC_TC3,
4 => TOTAL_ZEROS_ENC_TC4,
_ => {
write_ue_total_zeros(writer, total_zeros);
return;
}
};
for &(pattern, length, value) in table {
if value == total_zeros {
writer.write_bits(pattern, length);
return;
}
}
write_ue_total_zeros(writer, total_zeros);
}
fn write_ue_total_zeros(writer: &mut BitWriter, total_zeros: u8) {
writer.write_ue(total_zeros as u32);
}
const TOTAL_ZEROS_ENC_TC1: &[(u32, u8, u8)] = &[
(0b1, 1, 0),
(0b011, 3, 1),
(0b010, 3, 2),
(0b0011, 4, 3),
(0b0010, 4, 4),
(0b00011, 5, 5),
(0b00010, 5, 6),
(0b00001, 5, 7),
(0b000001, 6, 8),
(0b0000001, 7, 9),
(0b00000001, 8, 10),
(0b000000001, 9, 11),
(0b0000000001, 10, 12),
(0b00000000011, 11, 13),
(0b00000000010, 11, 14),
(0b00000000001, 11, 15),
];
const TOTAL_ZEROS_ENC_TC2: &[(u32, u8, u8)] = &[
(0b111, 3, 0),
(0b110, 3, 1),
(0b101, 3, 2),
(0b100, 3, 3),
(0b011, 3, 4),
(0b0101, 4, 5),
(0b0100, 4, 6),
(0b0011, 4, 7),
(0b0010, 4, 8),
(0b00011, 5, 9),
(0b00010, 5, 10),
(0b000011, 6, 11),
(0b000010, 6, 12),
(0b000001, 6, 13),
(0b000000, 6, 14),
];
const TOTAL_ZEROS_ENC_TC3: &[(u32, u8, u8)] = &[
(0b0101, 4, 0),
(0b111, 3, 1),
(0b110, 3, 2),
(0b101, 3, 3),
(0b0100, 4, 4),
(0b0011, 4, 5),
(0b100, 3, 6),
(0b011, 3, 7),
(0b0010, 4, 8),
(0b00011, 5, 9),
(0b00010, 5, 10),
(0b000001, 6, 11),
(0b00001, 5, 12),
(0b000000, 6, 13),
];
const TOTAL_ZEROS_ENC_TC4: &[(u32, u8, u8)] = &[
(0b00011, 5, 0),
(0b111, 3, 1),
(0b0101, 4, 2),
(0b0100, 4, 3),
(0b110, 3, 4),
(0b101, 3, 5),
(0b100, 3, 6),
(0b0011, 4, 7),
(0b011, 3, 8),
(0b00010, 5, 9),
(0b00001, 5, 10),
(0b00000, 5, 11),
(0b0010, 4, 12),
];
fn write_run_before(writer: &mut BitWriter, run: u8, zeros_left: u8) {
match zeros_left {
0 => {}
1 => {
writer.write_bit(if run == 0 { 1 } else { 0 });
}
2 => match run {
0 => writer.write_bit(1),
1 => {
writer.write_bit(0);
writer.write_bit(1);
}
_ => {
writer.write_bit(0);
writer.write_bit(0);
}
},
3 => match run {
0 => writer.write_bits(0b11, 2),
1 => writer.write_bits(0b10, 2),
2 => writer.write_bits(0b01, 2),
_ => writer.write_bits(0b00, 2),
},
4 => match run {
0 => writer.write_bits(0b11, 2),
1 => writer.write_bits(0b10, 2),
2 => writer.write_bits(0b01, 2),
3 => {
writer.write_bits(0b00, 2);
writer.write_bit(1);
}
_ => {
writer.write_bits(0b00, 2);
writer.write_bit(0);
}
},
5 => match run {
0 => writer.write_bits(0b11, 2),
1 => writer.write_bits(0b10, 2),
2 => writer.write_bits(0b01, 2),
3 => {
writer.write_bits(0b00, 2);
writer.write_bit(1);
}
4 => {
writer.write_bits(0b00, 2);
writer.write_bit(0);
writer.write_bit(1);
}
_ => {
writer.write_bits(0b00, 2);
writer.write_bit(0);
writer.write_bit(0);
}
},
6 => match run {
0 => writer.write_bits(0b11, 2),
1 => writer.write_bits(0b10, 2),
2 => writer.write_bits(0b01, 2),
3 => {
writer.write_bits(0b00, 2);
writer.write_bit(1);
}
4 => {
writer.write_bits(0b00, 2);
writer.write_bit(0);
writer.write_bit(1);
}
5 => {
writer.write_bits(0b00, 2);
writer.write_bit(0);
writer.write_bit(0);
writer.write_bit(1);
}
_ => {
writer.write_bits(0b00, 2);
writer.write_bit(0);
writer.write_bit(0);
writer.write_bit(0);
}
},
_ => {
for _ in 0..run {
writer.write_bit(0);
}
writer.write_bit(1);
}
}
}
fn add_emulation_prevention(rbsp: &[u8]) -> Vec<u8> {
let mut out = Vec::with_capacity(rbsp.len() + rbsp.len() / 128);
let mut zero_count = 0u32;
for &byte in rbsp {
if zero_count >= 2 && byte <= 0x03 {
out.push(0x03);
zero_count = 0;
}
out.push(byte);
if byte == 0x00 {
zero_count += 1;
} else {
zero_count = 0;
}
}
out
}
fn write_nal_unit(output: &mut Vec<u8>, nal_ref_idc: u8, nal_type: u8, rbsp: &[u8]) {
output.extend_from_slice(&[0x00, 0x00, 0x00, 0x01]);
let header = (nal_ref_idc << 5) | (nal_type & 0x1F);
output.push(header);
output.extend_from_slice(&add_emulation_prevention(rbsp));
}
pub fn rgb8_to_yuv420(rgb: &[u8], width: usize, height: usize) -> Vec<u8> {
let y_size = width * height;
let uv_w = width / 2;
let uv_h = height / 2;
let uv_size = uv_w * uv_h;
let mut yuv = vec![0u8; y_size + uv_size * 2];
#[cfg(target_arch = "aarch64")]
{
unsafe {
rgb8_to_yuv420_neon(rgb, width, height, &mut yuv);
}
return yuv;
}
#[allow(unreachable_code)]
{
rgb8_to_yuv420_scalar(rgb, width, height, &mut yuv);
yuv
}
}
fn rgb8_to_yuv420_scalar(rgb: &[u8], width: usize, height: usize, yuv: &mut [u8]) {
let y_size = width * height;
let uv_w = width / 2;
for y in 0..height {
for x in 0..width {
let idx = (y * width + x) * 3;
let r = rgb[idx] as i32;
let g = rgb[idx + 1] as i32;
let b = rgb[idx + 2] as i32;
let luma = ((66 * r + 129 * g + 25 * b + 128) >> 8) + 16;
yuv[y * width + x] = luma.clamp(0, 255) as u8;
}
}
for cy in 0..height / 2 {
for cx in 0..uv_w {
let mut r_sum = 0i32;
let mut g_sum = 0i32;
let mut b_sum = 0i32;
for dy in 0..2 {
for dx in 0..2 {
let idx = ((cy * 2 + dy) * width + cx * 2 + dx) * 3;
r_sum += rgb[idx] as i32;
g_sum += rgb[idx + 1] as i32;
b_sum += rgb[idx + 2] as i32;
}
}
let r = r_sum >> 2;
let g = g_sum >> 2;
let b = b_sum >> 2;
let u = ((-38 * r - 74 * g + 112 * b + 128) >> 8) + 128;
let v = ((112 * r - 94 * g - 18 * b + 128) >> 8) + 128;
yuv[y_size + cy * uv_w + cx] = u.clamp(0, 255) as u8;
yuv[y_size + uv_w * (height / 2) + cy * uv_w + cx] = v.clamp(0, 255) as u8;
}
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn rgb8_to_yuv420_neon(rgb: &[u8], width: usize, height: usize, yuv: &mut [u8]) {
use std::arch::aarch64::*;
let y_size = width * height;
let uv_w = width / 2;
let uv_h = height / 2;
let coeff_r = vdupq_n_s16(66);
let coeff_g = vdupq_n_s16(129);
let coeff_b = vdupq_n_s16(25);
let offset_y = vdupq_n_s16(16);
let round = vdupq_n_s16(128);
for y in 0..height {
let row_start = y * width * 3;
let mut x = 0usize;
while x + 8 <= width {
let rgb_ptr = rgb.as_ptr().add(row_start + x * 3);
let rgb_data = vld3_u8(rgb_ptr);
let r = vreinterpretq_s16_u16(vmovl_u8(rgb_data.0));
let g = vreinterpretq_s16_u16(vmovl_u8(rgb_data.1));
let b = vreinterpretq_s16_u16(vmovl_u8(rgb_data.2));
let yr = vmulq_s16(r, coeff_r);
let yg = vmulq_s16(g, coeff_g);
let yb = vmulq_s16(b, coeff_b);
let ysum = vaddq_s16(vaddq_s16(yr, yg), vaddq_s16(yb, round));
let yval = vaddq_s16(vshrq_n_s16(ysum, 8), offset_y);
let yclamped = vqmovun_s16(yval);
vst1_u8(yuv.as_mut_ptr().add(y * width + x), yclamped);
x += 8;
}
while x < width {
let idx = row_start + x * 3;
let r = rgb[idx] as i32;
let g = rgb[idx + 1] as i32;
let b = rgb[idx + 2] as i32;
let luma = ((66 * r + 129 * g + 25 * b + 128) >> 8) + 16;
yuv[y * width + x] = luma.clamp(0, 255) as u8;
x += 1;
}
}
for cy in 0..uv_h {
for cx in 0..uv_w {
let mut r_sum = 0i32;
let mut g_sum = 0i32;
let mut b_sum = 0i32;
for dy in 0..2 {
for dx in 0..2 {
let idx = ((cy * 2 + dy) * width + cx * 2 + dx) * 3;
r_sum += rgb[idx] as i32;
g_sum += rgb[idx + 1] as i32;
b_sum += rgb[idx + 2] as i32;
}
}
let r = r_sum >> 2;
let g = g_sum >> 2;
let b = b_sum >> 2;
let u = ((-38 * r - 74 * g + 112 * b + 128) >> 8) + 128;
let v = ((112 * r - 94 * g - 18 * b + 128) >> 8) + 128;
yuv[y_size + cy * uv_w + cx] = u.clamp(0, 255) as u8;
yuv[y_size + uv_w * uv_h + cy * uv_w + cx] = v.clamp(0, 255) as u8;
}
}
}
pub struct H264Encoder {
width: u32,
height: u32,
qp: u8,
frame_num: u32,
}
impl H264Encoder {
pub fn new(width: u32, height: u32, qp: u8) -> Self {
Self {
width,
height,
qp: qp.min(51),
frame_num: 0,
}
}
pub fn encode_frame(&mut self, yuv420: &[u8]) -> Vec<u8> {
let mut output = Vec::with_capacity(yuv420.len() / 4);
output.extend_from_slice(&self.build_sps_nal());
output.extend_from_slice(&self.build_pps_nal());
output.extend_from_slice(&self.encode_idr_slice(yuv420));
self.frame_num = self.frame_num.wrapping_add(1);
output
}
fn mb_width(&self) -> u32 {
self.width.div_ceil(16)
}
fn mb_height(&self) -> u32 {
self.height.div_ceil(16)
}
fn build_sps_nal(&self) -> Vec<u8> {
let mut w = BitWriter::new();
w.write_bits(66, 8);
w.write_bits(0b11000000, 8);
w.write_bits(30, 8);
w.write_ue(0);
w.write_ue(0);
w.write_ue(0);
w.write_ue(0);
w.write_ue(0);
w.write_bit(0);
w.write_ue(self.mb_width() - 1);
w.write_ue(self.mb_height() - 1);
w.write_bit(1);
w.write_bit(0);
w.write_bit(0);
w.write_bit(0);
w.rbsp_trailing_bits();
let rbsp = w.into_bytes();
let mut nal = Vec::new();
write_nal_unit(&mut nal, 3, 7, &rbsp);
nal
}
fn build_pps_nal(&self) -> Vec<u8> {
let mut w = BitWriter::new();
w.write_ue(0);
w.write_ue(0);
w.write_bit(0);
w.write_bit(0);
w.write_ue(0);
w.write_ue(0);
w.write_ue(0);
w.write_bit(0);
w.write_bits(0, 2);
w.write_se(self.qp as i32 - 26);
w.write_se(0);
w.write_se(0);
w.write_bit(1);
w.write_bit(0);
w.write_bit(0);
w.rbsp_trailing_bits();
let rbsp = w.into_bytes();
let mut nal = Vec::new();
write_nal_unit(&mut nal, 3, 8, &rbsp);
nal
}
fn encode_idr_slice(&self, yuv420: &[u8]) -> Vec<u8> {
let mut w = BitWriter::new();
w.write_ue(0);
w.write_ue(2);
w.write_ue(0);
w.write_bits(self.frame_num & 0xF, 4);
w.write_ue(self.frame_num);
w.write_bits((self.frame_num * 2) & 0xF, 4);
w.write_bit(0); w.write_bit(0); w.write_se(0);
w.write_ue(1);
let mb_w = self.mb_width() as usize;
let mb_h = self.mb_height() as usize;
let width = self.width as usize;
let height = self.height as usize;
let y_plane = &yuv420[..width * height];
let uv_offset = width * height;
let uv_w = width / 2;
let uv_h = height / 2;
let u_plane = &yuv420[uv_offset..uv_offset + uv_w * uv_h];
let v_plane = &yuv420[uv_offset + uv_w * uv_h..];
let mut nc_above = vec![0i32; mb_w * 4]; let mut nc_left = [0i32; 4];
for mb_y in 0..mb_h {
for i in 0..4 {
nc_left[i] = 0;
}
for mb_x in 0..mb_w {
self.encode_macroblock(
&mut w,
y_plane,
u_plane,
v_plane,
width,
height,
uv_w,
mb_x,
mb_y,
&mut nc_above,
&mut nc_left,
);
}
}
w.rbsp_trailing_bits();
let rbsp = w.into_bytes();
let mut nal = Vec::new();
write_nal_unit(&mut nal, 3, 5, &rbsp);
nal
}
#[allow(clippy::too_many_arguments)]
fn encode_macroblock(
&self,
w: &mut BitWriter,
y_plane: &[u8],
u_plane: &[u8],
v_plane: &[u8],
width: usize,
height: usize,
uv_w: usize,
mb_x: usize,
mb_y: usize,
nc_above: &mut [i32],
nc_left: &mut [i32],
) {
w.write_ue(0);
for _ in 0..16 {
w.write_bit(0); w.write_bits(2, 3); }
w.write_ue(0);
let mut any_luma_nonzero = false;
let mut luma_nz = [0usize; 16];
let mut luma_blocks = [[0i32; 16]; 16];
for blk_idx in 0..16 {
let blk_y = blk_idx / 4;
let blk_x = blk_idx % 4;
let px = mb_x * 16 + blk_x * 4;
let py = mb_y * 16 + blk_y * 4;
let mut block = [0i32; 16];
for row in 0..4 {
for col in 0..4 {
let sy = py + row;
let sx = px + col;
block[row * 4 + col] = if sy < height && sx < width {
y_plane[sy * width + sx] as i32
} else {
128
};
}
}
let dc_pred = self.compute_dc_pred_4x4(y_plane, width, height, px, py);
for i in 0..16 {
block[i] -= dc_pred;
}
forward_dct_4x4(&mut block);
quantize_4x4(&mut block, self.qp);
if block.iter().any(|&c| c != 0) {
any_luma_nonzero = true;
luma_nz[blk_idx] = block.iter().filter(|&&c| c != 0).count();
}
luma_blocks[blk_idx] = block;
}
let mut chroma_blocks_u = [[0i32; 16]; 4];
let mut chroma_blocks_v = [[0i32; 16]; 4];
let mut any_chroma_nonzero = false;
let uv_h = height / 2;
for blk_idx in 0..4 {
let blk_y = blk_idx / 2;
let blk_x = blk_idx % 2;
for (plane, blocks) in [
(u_plane, &mut chroma_blocks_u),
(v_plane, &mut chroma_blocks_v),
] {
let px = mb_x * 8 + blk_x * 4;
let py = mb_y * 8 + blk_y * 4;
let mut block = [0i32; 16];
for row in 0..4 {
for col in 0..4 {
let sy = py + row;
let sx = px + col;
block[row * 4 + col] = if sy < uv_h && sx < uv_w {
plane[sy * uv_w + sx] as i32
} else {
128
};
}
}
let dc_pred = 128i32; for i in 0..16 {
block[i] -= dc_pred;
}
forward_dct_4x4(&mut block);
quantize_4x4(&mut block, self.qp);
if block.iter().any(|&c| c != 0) {
any_chroma_nonzero = true;
}
blocks[blk_idx] = block;
}
}
let mut cbp_luma = 0u32;
for group in 0..4 {
let base_y = (group / 2) * 2;
let base_x = (group % 2) * 2;
let mut group_nonzero = false;
for dy in 0..2 {
for dx in 0..2 {
let idx = (base_y + dy) * 4 + base_x + dx;
if luma_nz[idx] > 0 {
group_nonzero = true;
}
}
}
if group_nonzero {
cbp_luma |= 1 << group;
}
}
let cbp_chroma = if any_chroma_nonzero { 2u32 } else { 0 };
let cbp = cbp_luma | (cbp_chroma << 4);
w.write_ue(me_code_intra_4x4(cbp));
if any_luma_nonzero || any_chroma_nonzero {
w.write_se(0);
for blk_idx in 0..16 {
let blk_y = blk_idx / 4;
let blk_x = blk_idx % 4;
let group = (blk_y / 2) * 2 + blk_x / 2;
if cbp_luma & (1 << group) == 0 {
nc_above[mb_x * 4 + blk_x] = 0;
if blk_x == 3 {
nc_left[blk_y] = 0;
}
continue;
}
let nc_a = if blk_x > 0 {
nc_left[blk_y]
} else if mb_x > 0 {
nc_above[(mb_x - 1) * 4 + 3]
} else {
0
};
let nc_b = if blk_y > 0 {
nc_above[mb_x * 4 + blk_x]
} else {
0
};
let nc = (nc_a + nc_b + 1) / 2;
let nz = cavlc_encode_block(w, &luma_blocks[blk_idx], nc);
nc_above[mb_x * 4 + blk_x] = nz as i32;
nc_left[blk_y] = nz as i32;
}
if cbp_chroma > 0 {
for chroma_blocks in [&chroma_blocks_u, &chroma_blocks_v] {
let mut dc_block = [0i32; 16];
for i in 0..4 {
dc_block[i] = chroma_blocks[i][0]; }
cavlc_encode_block(w, &dc_block, -1);
}
if cbp_chroma >= 2 {
for chroma_blocks in [&chroma_blocks_u, &chroma_blocks_v] {
for blk_idx in 0..4 {
let mut ac_block = chroma_blocks[blk_idx];
ac_block[0] = 0; cavlc_encode_block(w, &ac_block, 0);
}
}
}
}
}
}
fn compute_dc_pred_4x4(
&self,
y_plane: &[u8],
width: usize,
height: usize,
px: usize,
py: usize,
) -> i32 {
let mut sum = 0i32;
let mut count = 0i32;
if py > 0 {
for col in 0..4 {
let sx = px + col;
if sx < width {
sum += y_plane[(py - 1) * width + sx] as i32;
count += 1;
}
}
}
if px > 0 {
for row in 0..4 {
let sy = py + row;
if sy < height {
sum += y_plane[sy * width + px - 1] as i32;
count += 1;
}
}
}
if count > 0 {
(sum + count / 2) / count
} else {
128 }
}
}
fn me_code_intra_4x4(cbp: u32) -> u32 {
const TABLE: [u32; 48] = [
3, 29, 30, 17, 31, 18, 37, 8, 32, 38, 19, 9, 20, 10, 11, 2, 16, 33, 34, 21, 35, 22, 39, 4,
36, 40, 23, 5, 24, 6, 7, 1, 41, 42, 43, 25, 44, 26, 46, 12, 45, 47, 27, 13, 28, 14, 15, 0,
];
let idx = cbp as usize;
if idx < TABLE.len() { TABLE[idx] } else { 0 }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn forward_inverse_dct_roundtrip() {
let original = [
10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 5, 15, 25, 35,
];
let mut block = original;
forward_dct_4x4(&mut block);
assert_ne!(block, original);
let sum: i32 = original.iter().sum();
assert_eq!(
block[0], sum,
"DC coefficient should equal sum of input block"
);
let mut scaled = block;
let qp = 10u8;
quantize_4x4(&mut scaled, qp);
dequantize_4x4(&mut scaled, qp);
super::super::h264_transform::inverse_dct_4x4(&mut scaled);
for i in 0..16 {
assert!(
(scaled[i] - original[i]).abs() <= 5,
"position {i}: expected ~{}, got {} (QP={qp})",
original[i],
scaled[i]
);
}
}
#[test]
fn quantize_dequantize_roundtrip() {
let original = [
100, -50, 30, -20, 10, -5, 3, -1, 200, -100, 80, -40, 0, 0, 0, 0,
];
let qp = 20u8;
let mut block = original;
quantize_4x4(&mut block, qp);
let nonzero: usize = block.iter().filter(|&&c| c != 0).count();
assert!(
nonzero < 16,
"quantization should zero out some coefficients"
);
let mut reconstructed = block;
dequantize_4x4(&mut reconstructed, qp);
for i in 0..16 {
if block[i] != 0 {
assert_eq!(
original[i].signum(),
reconstructed[i].signum(),
"sign mismatch at position {i}"
);
}
}
}
#[test]
fn sps_nal_generation() {
let enc = H264Encoder::new(320, 240, 26);
let sps = enc.build_sps_nal();
assert_eq!(&sps[..4], &[0x00, 0x00, 0x00, 0x01]);
assert_eq!(sps[4] & 0x80, 0, "forbidden_zero_bit must be 0");
assert_eq!((sps[4] >> 5) & 0x03, 3, "nal_ref_idc should be 3");
assert_eq!(sps[4] & 0x1F, 7, "nal_unit_type should be 7 (SPS)");
assert_eq!(sps[5], 66);
}
#[test]
fn pps_nal_generation() {
let enc = H264Encoder::new(320, 240, 26);
let pps = enc.build_pps_nal();
assert_eq!(&pps[..4], &[0x00, 0x00, 0x00, 0x01]);
assert_eq!(pps[4] & 0x80, 0, "forbidden_zero_bit must be 0");
assert_eq!((pps[4] >> 5) & 0x03, 3, "nal_ref_idc should be 3");
assert_eq!(pps[4] & 0x1F, 8, "nal_unit_type should be 8 (PPS)");
}
#[test]
fn rgb8_to_yuv420_neutral_gray() {
let w = 4;
let h = 4;
let rgb = vec![128u8; w * h * 3];
let yuv = rgb8_to_yuv420(&rgb, w, h);
let y_size = w * h;
let uv_size = (w / 2) * (h / 2);
assert_eq!(yuv.len(), y_size + uv_size * 2);
for &y in &yuv[..y_size] {
assert!(
(y as i32 - 128).abs() <= 2,
"Y should be ~128 for gray, got {y}"
);
}
for &u in &yuv[y_size..y_size + uv_size] {
assert!(
(u as i32 - 128).abs() <= 2,
"U should be ~128 for gray, got {u}"
);
}
for &v in &yuv[y_size + uv_size..] {
assert!(
(v as i32 - 128).abs() <= 2,
"V should be ~128 for gray, got {v}"
);
}
}
#[test]
fn rgb8_to_yuv420_distinct_colors() {
let w = 2;
let h = 2;
let rgb = vec![255, 0, 0, 0, 0, 255, 0, 0, 255, 0, 0, 255];
let yuv = rgb8_to_yuv420(&rgb, w, h);
assert_ne!(
yuv[0], yuv[1],
"different colors should produce different Y"
);
}
#[test]
fn emulation_prevention_insertion() {
let input = [0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x02];
let output = add_emulation_prevention(&input);
assert!(output.len() > input.len());
for w in output.windows(3) {
if w[0] == 0x00 && w[1] == 0x00 && w[2] <= 0x03 {
assert_eq!(
w[2], 0x03,
"should only have emulation prevention byte (0x03)"
);
}
}
}
#[test]
fn encode_frame_produces_valid_annex_b() {
let w = 16;
let h = 16;
let rgb = vec![128u8; w * h * 3];
let yuv = rgb8_to_yuv420(&rgb, w, h);
let mut enc = H264Encoder::new(w as u32, h as u32, 26);
let data = enc.encode_frame(&yuv);
let mut start_codes = 0;
for i in 0..data.len().saturating_sub(3) {
if data[i] == 0 && data[i + 1] == 0 && data[i + 2] == 0 && data[i + 3] == 1 {
start_codes += 1;
}
}
assert!(
start_codes >= 3,
"should have at least 3 NAL units (SPS+PPS+IDR), found {start_codes}"
);
}
#[test]
fn bitwriter_exp_golomb() {
let mut w = BitWriter::new();
w.write_ue(0); w.write_ue(1); w.write_ue(2); w.write_ue(3); w.rbsp_trailing_bits();
let bytes = w.into_bytes();
let mut r = super::super::cavlc::BitReader::new(&bytes);
assert_eq!(r.read_ue(), Some(0));
assert_eq!(r.read_ue(), Some(1));
assert_eq!(r.read_ue(), Some(2));
assert_eq!(r.read_ue(), Some(3));
}
#[test]
fn zigzag_scan_is_inverse_of_unscan() {
use super::super::h264_transform::ZIGZAG_4X4;
for scan_idx in 0..16 {
let (r, c) = ZIGZAG_4X4[scan_idx];
let raster_pos = r * 4 + c;
assert_eq!(
ZIGZAG_SCAN_4X4[raster_pos], scan_idx,
"zigzag scan inconsistency at scan_idx={scan_idx}, raster_pos={raster_pos}"
);
}
}
}