use crate::error::{CodecError, CodecResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AlphaCompression {
NoCompression = 0,
WebPLossless = 1,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AlphaFilter {
None = 0,
Horizontal = 1,
Vertical = 2,
Gradient = 3,
}
#[derive(Debug, Clone)]
pub struct AlphaHeader {
pub compression: AlphaCompression,
pub filter: AlphaFilter,
pub pre_processing: u8,
}
impl AlphaCompression {
fn from_bits(bits: u8) -> CodecResult<Self> {
match bits & 0x03 {
0 => Ok(Self::NoCompression),
1 => Ok(Self::WebPLossless),
v => Err(CodecError::InvalidBitstream(format!(
"unknown alpha compression method: {v}"
))),
}
}
}
impl AlphaFilter {
fn from_bits(bits: u8) -> CodecResult<Self> {
match bits & 0x03 {
0 => Ok(Self::None),
1 => Ok(Self::Horizontal),
2 => Ok(Self::Vertical),
3 => Ok(Self::Gradient),
_ => Err(CodecError::InvalidBitstream(
"unknown alpha filter method".to_string(),
)),
}
}
}
impl AlphaHeader {
pub fn parse(byte: u8) -> CodecResult<Self> {
let reserved = (byte >> 6) & 0x03;
if reserved != 0 {
return Err(CodecError::InvalidBitstream(format!(
"ALPH header reserved bits are non-zero: {reserved}"
)));
}
let compression = AlphaCompression::from_bits(byte & 0x03)?;
let filter = AlphaFilter::from_bits((byte >> 2) & 0x03)?;
let pre_processing = (byte >> 4) & 0x03;
Ok(Self {
compression,
filter,
pre_processing,
})
}
pub fn to_byte(&self) -> u8 {
let comp = self.compression as u8;
let filt = (self.filter as u8) << 2;
let prep = (self.pre_processing & 0x03) << 4;
comp | filt | prep
}
}
#[inline]
fn gradient_predict(left: u8, top: u8, top_left: u8) -> u8 {
let val = left as i16 + top as i16 - top_left as i16;
val.clamp(0, 255) as u8
}
fn apply_filter(data: &mut [u8], width: u32, height: u32, filter: AlphaFilter) {
let w = width as usize;
let h = height as usize;
let total = w * h;
if total == 0 || data.len() < total {
return;
}
match filter {
AlphaFilter::None => { }
AlphaFilter::Horizontal => {
for y in 0..h {
let row_start = y * w;
for x in 1..w {
let idx = row_start + x;
let left = data[idx - 1];
data[idx] = data[idx].wrapping_add(left);
}
}
}
AlphaFilter::Vertical => {
for y in 1..h {
for x in 0..w {
let idx = y * w + x;
let top = data[idx - w];
data[idx] = data[idx].wrapping_add(top);
}
}
}
AlphaFilter::Gradient => {
for x in 1..w {
data[x] = data[x].wrapping_add(data[x - 1]);
}
for y in 1..h {
let row_start = y * w;
data[row_start] = data[row_start].wrapping_add(data[row_start - w]);
for x in 1..w {
let idx = row_start + x;
let left = data[idx - 1];
let top = data[idx - w];
let top_left = data[idx - w - 1];
let pred = gradient_predict(left, top, top_left);
data[idx] = data[idx].wrapping_add(pred);
}
}
}
}
}
fn apply_inverse_filter(data: &[u8], width: u32, height: u32, filter: AlphaFilter) -> Vec<u8> {
let w = width as usize;
let h = height as usize;
let total = w * h;
if total == 0 {
return Vec::new();
}
match filter {
AlphaFilter::None => data[..total].to_vec(),
AlphaFilter::Horizontal => {
let mut out = vec![0u8; total];
for y in 0..h {
let row_start = y * w;
out[row_start] = data[row_start]; for x in 1..w {
let idx = row_start + x;
let left = data[idx - 1];
out[idx] = data[idx].wrapping_sub(left);
}
}
out
}
AlphaFilter::Vertical => {
let mut out = vec![0u8; total];
out[..w].copy_from_slice(&data[..w]);
for y in 1..h {
for x in 0..w {
let idx = y * w + x;
let top = data[idx - w];
out[idx] = data[idx].wrapping_sub(top);
}
}
out
}
AlphaFilter::Gradient => {
let mut out = vec![0u8; total];
out[0] = data[0];
for x in 1..w {
out[x] = data[x].wrapping_sub(data[x - 1]);
}
for y in 1..h {
let row_start = y * w;
out[row_start] = data[row_start].wrapping_sub(data[row_start - w]);
for x in 1..w {
let idx = row_start + x;
let left = data[idx - 1];
let top = data[idx - w];
let top_left = data[idx - w - 1];
let pred = gradient_predict(left, top, top_left);
out[idx] = data[idx].wrapping_sub(pred);
}
}
out
}
}
}
fn score_filter(data: &[u8], width: u32, height: u32, filter: AlphaFilter) -> u64 {
let residuals = apply_inverse_filter(data, width, height, filter);
residuals
.iter()
.map(|&b| {
let v = b as i16;
let d = if v > 128 { 256 - v } else { v };
d as u64
})
.sum()
}
fn select_best_filter(data: &[u8], width: u32, height: u32) -> AlphaFilter {
let filters = [
AlphaFilter::None,
AlphaFilter::Horizontal,
AlphaFilter::Vertical,
AlphaFilter::Gradient,
];
let mut best_filter = AlphaFilter::None;
let mut best_score = u64::MAX;
for &f in &filters {
let s = score_filter(data, width, height, f);
if s < best_score {
best_score = s;
best_filter = f;
}
}
best_filter
}
pub fn decode_alpha(data: &[u8], width: u32, height: u32) -> CodecResult<Vec<u8>> {
if data.is_empty() {
return Err(CodecError::InvalidBitstream(
"ALPH chunk is empty".to_string(),
));
}
let total = (width as usize)
.checked_mul(height as usize)
.ok_or_else(|| {
CodecError::InvalidParameter(format!(
"alpha plane dimensions overflow: {width} x {height}"
))
})?;
if total == 0 {
return Ok(Vec::new());
}
let header = AlphaHeader::parse(data[0])?;
let payload = &data[1..];
match header.compression {
AlphaCompression::NoCompression => {
if payload.len() < total {
return Err(CodecError::BufferTooSmall {
needed: total,
have: payload.len(),
});
}
let mut alpha = payload[..total].to_vec();
apply_filter(&mut alpha, width, height, header.filter);
Ok(alpha)
}
AlphaCompression::WebPLossless => Err(CodecError::UnsupportedFeature(
"VP8L-compressed alpha channel is not yet supported".to_string(),
)),
}
}
pub fn encode_alpha(alpha: &[u8], width: u32, height: u32) -> CodecResult<Vec<u8>> {
let total = (width as usize)
.checked_mul(height as usize)
.ok_or_else(|| {
CodecError::InvalidParameter(format!(
"alpha plane dimensions overflow: {width} x {height}"
))
})?;
if alpha.len() < total {
return Err(CodecError::BufferTooSmall {
needed: total,
have: alpha.len(),
});
}
if total == 0 {
let hdr = AlphaHeader {
compression: AlphaCompression::NoCompression,
filter: AlphaFilter::None,
pre_processing: 0,
};
return Ok(vec![hdr.to_byte()]);
}
let input = &alpha[..total];
let best_filter = select_best_filter(input, width, height);
let header = AlphaHeader {
compression: AlphaCompression::NoCompression,
filter: best_filter,
pre_processing: 0,
};
let residuals = apply_inverse_filter(input, width, height, best_filter);
let mut out = Vec::with_capacity(1 + residuals.len());
out.push(header.to_byte());
out.extend_from_slice(&residuals);
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn header_roundtrip_no_compression_no_filter() {
let hdr = AlphaHeader {
compression: AlphaCompression::NoCompression,
filter: AlphaFilter::None,
pre_processing: 0,
};
let byte = hdr.to_byte();
assert_eq!(byte, 0x00);
let parsed = AlphaHeader::parse(byte).expect("parse should succeed");
assert_eq!(parsed.compression, AlphaCompression::NoCompression);
assert_eq!(parsed.filter, AlphaFilter::None);
assert_eq!(parsed.pre_processing, 0);
}
#[test]
fn header_roundtrip_all_fields() {
let hdr = AlphaHeader {
compression: AlphaCompression::NoCompression,
filter: AlphaFilter::Gradient,
pre_processing: 1,
};
let byte = hdr.to_byte();
assert_eq!(byte, 0x1C);
let parsed = AlphaHeader::parse(byte).expect("parse should succeed");
assert_eq!(parsed.compression, AlphaCompression::NoCompression);
assert_eq!(parsed.filter, AlphaFilter::Gradient);
assert_eq!(parsed.pre_processing, 1);
}
#[test]
fn header_roundtrip_webp_lossless_horizontal() {
let hdr = AlphaHeader {
compression: AlphaCompression::WebPLossless,
filter: AlphaFilter::Horizontal,
pre_processing: 0,
};
let byte = hdr.to_byte();
assert_eq!(byte, 0x05);
let parsed = AlphaHeader::parse(byte).expect("parse should succeed");
assert_eq!(parsed.compression, AlphaCompression::WebPLossless);
assert_eq!(parsed.filter, AlphaFilter::Horizontal);
}
#[test]
fn header_reserved_bits_rejected() {
let result = AlphaHeader::parse(0x40);
assert!(result.is_err());
}
#[test]
fn filter_none_roundtrip() {
let original: Vec<u8> = (0..12).collect();
let w = 4u32;
let h = 3u32;
let residuals = apply_inverse_filter(&original, w, h, AlphaFilter::None);
assert_eq!(residuals, original);
let mut reconstructed = residuals;
apply_filter(&mut reconstructed, w, h, AlphaFilter::None);
assert_eq!(reconstructed, original);
}
#[test]
fn filter_horizontal_roundtrip() {
let original: Vec<u8> = vec![10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120];
let w = 4u32;
let h = 3u32;
let residuals = apply_inverse_filter(&original, w, h, AlphaFilter::Horizontal);
assert_eq!(residuals[0], 10);
assert_eq!(residuals[1], 10);
assert_eq!(residuals[2], 10);
assert_eq!(residuals[3], 10);
let mut reconstructed = residuals;
apply_filter(&mut reconstructed, w, h, AlphaFilter::Horizontal);
assert_eq!(reconstructed, original);
}
#[test]
fn filter_vertical_roundtrip() {
let original: Vec<u8> = vec![10, 20, 30, 40, 15, 25, 35, 45, 20, 30, 40, 50];
let w = 4u32;
let h = 3u32;
let residuals = apply_inverse_filter(&original, w, h, AlphaFilter::Vertical);
assert_eq!(&residuals[0..4], &[10, 20, 30, 40]);
assert_eq!(&residuals[4..8], &[5, 5, 5, 5]);
let mut reconstructed = residuals;
apply_filter(&mut reconstructed, w, h, AlphaFilter::Vertical);
assert_eq!(reconstructed, original);
}
#[test]
fn filter_gradient_roundtrip() {
let original: Vec<u8> = vec![100, 110, 120, 130, 105, 115, 125, 135, 110, 120, 130, 140];
let w = 4u32;
let h = 3u32;
let residuals = apply_inverse_filter(&original, w, h, AlphaFilter::Gradient);
let mut reconstructed = residuals;
apply_filter(&mut reconstructed, w, h, AlphaFilter::Gradient);
assert_eq!(reconstructed, original);
}
#[test]
fn filter_gradient_known_vector() {
let original: Vec<u8> = vec![100, 150, 120, 170];
let w = 2u32;
let h = 2u32;
let residuals = apply_inverse_filter(&original, w, h, AlphaFilter::Gradient);
assert_eq!(residuals[0], 100);
assert_eq!(residuals[1], 50);
assert_eq!(residuals[2], 20);
assert_eq!(residuals[3], 0);
let mut reconstructed = residuals;
apply_filter(&mut reconstructed, w, h, AlphaFilter::Gradient);
assert_eq!(reconstructed, original);
}
#[test]
fn gradient_predict_clamp_high() {
assert_eq!(gradient_predict(200, 200, 0), 255);
}
#[test]
fn gradient_predict_clamp_low() {
assert_eq!(gradient_predict(0, 0, 200), 0);
}
#[test]
fn gradient_predict_normal() {
assert_eq!(gradient_predict(100, 80, 60), 120);
}
#[test]
fn encode_decode_roundtrip_uniform() {
let w = 8u32;
let h = 6u32;
let alpha = vec![128u8; (w * h) as usize];
let encoded = encode_alpha(&alpha, w, h).expect("encode should succeed");
let decoded = decode_alpha(&encoded, w, h).expect("decode should succeed");
assert_eq!(decoded, alpha);
}
#[test]
fn encode_decode_roundtrip_gradient_data() {
let w = 16u32;
let h = 8u32;
let mut alpha = vec![0u8; (w * h) as usize];
for y in 0..h as usize {
for x in 0..w as usize {
alpha[y * w as usize + x] = ((x * 16 + y * 8) & 0xFF) as u8;
}
}
let encoded = encode_alpha(&alpha, w, h).expect("encode should succeed");
let decoded = decode_alpha(&encoded, w, h).expect("decode should succeed");
assert_eq!(decoded, alpha);
}
#[test]
fn encode_decode_roundtrip_random_like() {
let w = 10u32;
let h = 10u32;
let mut alpha = vec![0u8; (w * h) as usize];
let mut state: u32 = 0xDEAD_BEEF;
for byte in alpha.iter_mut() {
state = state.wrapping_mul(1664525).wrapping_add(1013904223);
*byte = (state >> 16) as u8;
}
let encoded = encode_alpha(&alpha, w, h).expect("encode should succeed");
let decoded = decode_alpha(&encoded, w, h).expect("decode should succeed");
assert_eq!(decoded, alpha);
}
#[test]
fn encode_decode_roundtrip_single_pixel() {
let alpha = vec![42u8];
let encoded = encode_alpha(&alpha, 1, 1).expect("encode should succeed");
let decoded = decode_alpha(&encoded, 1, 1).expect("decode should succeed");
assert_eq!(decoded, alpha);
}
#[test]
fn encode_decode_roundtrip_single_row() {
let alpha: Vec<u8> = (0..=255).collect();
let encoded = encode_alpha(&alpha, 256, 1).expect("encode should succeed");
let decoded = decode_alpha(&encoded, 256, 1).expect("decode should succeed");
assert_eq!(decoded, alpha);
}
#[test]
fn encode_decode_roundtrip_single_column() {
let alpha: Vec<u8> = (0..128).collect();
let encoded = encode_alpha(&alpha, 1, 128).expect("encode should succeed");
let decoded = decode_alpha(&encoded, 1, 128).expect("decode should succeed");
assert_eq!(decoded, alpha);
}
#[test]
fn decode_empty_chunk_is_error() {
let result = decode_alpha(&[], 4, 4);
assert!(result.is_err());
}
#[test]
fn decode_truncated_payload_is_error() {
let data = vec![0x00, 1, 2, 3];
let result = decode_alpha(&data, 4, 4);
assert!(result.is_err());
}
#[test]
fn decode_vp8l_alpha_is_unsupported() {
let data = vec![0x01, 0, 0, 0, 0];
let result = decode_alpha(&data, 2, 2);
assert!(result.is_err());
let err_msg = format!("{}", result.expect_err("should be error"));
assert!(err_msg.contains("not yet supported"));
}
#[test]
fn encode_too_short_input_is_error() {
let alpha = vec![0u8; 3]; let result = encode_alpha(&alpha, 2, 2);
assert!(result.is_err());
}
#[test]
fn encode_decode_zero_dimensions() {
let alpha: Vec<u8> = Vec::new();
let encoded = encode_alpha(&alpha, 0, 0).expect("encode 0x0 should succeed");
let decoded = decode_alpha(&encoded, 0, 0).expect("decode 0x0 should succeed");
assert!(decoded.is_empty());
}
#[test]
fn overflow_dimensions_rejected() {
let result = encode_alpha(&[0], u32::MAX, u32::MAX);
assert!(result.is_err());
}
#[test]
fn known_vector_no_filter_no_compression() {
let alpha_raw = vec![255, 128, 64, 0, 200, 100, 50, 25];
let w = 4u32;
let h = 2u32;
let mut chunk = vec![0x00u8]; chunk.extend_from_slice(&alpha_raw);
let decoded = decode_alpha(&chunk, w, h).expect("decode should succeed");
assert_eq!(decoded, alpha_raw);
}
#[test]
fn known_vector_horizontal_filter() {
let expected = vec![10u8, 20, 30, 40, 50, 60, 70, 80];
let residuals = vec![10u8, 10, 10, 10, 50, 10, 10, 10];
let mut chunk = vec![0x04u8];
chunk.extend_from_slice(&residuals);
let decoded = decode_alpha(&chunk, 4, 2).expect("decode should succeed");
assert_eq!(decoded, expected);
}
#[test]
fn known_vector_vertical_filter() {
let expected = vec![10u8, 20, 30, 15, 25, 35, 20, 30, 40];
let residuals = vec![10u8, 20, 30, 5, 5, 5, 5, 5, 5];
let mut chunk = vec![0x08u8];
chunk.extend_from_slice(&residuals);
let decoded = decode_alpha(&chunk, 3, 3).expect("decode should succeed");
assert_eq!(decoded, expected);
}
#[test]
fn known_vector_gradient_filter() {
let expected = vec![100u8, 150, 120, 170];
let residuals = vec![100u8, 50, 20, 0];
let mut chunk = vec![0x0Cu8];
chunk.extend_from_slice(&residuals);
let decoded = decode_alpha(&chunk, 2, 2).expect("decode should succeed");
assert_eq!(decoded, expected);
}
#[test]
fn select_best_filter_for_uniform_data() {
let data = vec![128u8; 64];
let best = select_best_filter(&data, 8, 8);
let best_score = score_filter(&data, 8, 8, best);
let none_score = score_filter(&data, 8, 8, AlphaFilter::None);
assert!(best_score <= none_score);
}
#[test]
fn select_best_filter_for_horizontal_ramp() {
let mut data = vec![0u8; 64];
for y in 0..8usize {
for x in 0..8usize {
data[y * 8 + x] = (x * 30) as u8;
}
}
let best = select_best_filter(&data, 8, 8);
let best_score = score_filter(&data, 8, 8, best);
let horiz_score = score_filter(&data, 8, 8, AlphaFilter::Horizontal);
assert!(best_score <= horiz_score);
}
#[test]
fn select_best_filter_for_vertical_ramp() {
let mut data = vec![0u8; 64];
for y in 0..8usize {
for x in 0..8usize {
data[y * 8 + x] = (y * 30) as u8;
}
}
let best = select_best_filter(&data, 8, 8);
let best_score = score_filter(&data, 8, 8, best);
let vert_score = score_filter(&data, 8, 8, AlphaFilter::Vertical);
assert!(best_score <= vert_score);
}
#[test]
fn filter_horizontal_wrapping() {
let original = vec![250u8, 10];
let residuals = apply_inverse_filter(&original, 2, 1, AlphaFilter::Horizontal);
assert_eq!(residuals[0], 250);
assert_eq!(residuals[1], 10u8.wrapping_sub(250));
let mut reconstructed = residuals;
apply_filter(&mut reconstructed, 2, 1, AlphaFilter::Horizontal);
assert_eq!(reconstructed, original);
}
#[test]
fn filter_vertical_wrapping() {
let original = vec![5u8, 250]; let residuals = apply_inverse_filter(&original, 1, 2, AlphaFilter::Vertical);
assert_eq!(residuals[0], 5);
assert_eq!(residuals[1], 250u8.wrapping_sub(5));
let mut reconstructed = residuals;
apply_filter(&mut reconstructed, 1, 2, AlphaFilter::Vertical);
assert_eq!(reconstructed, original);
}
#[test]
fn encode_decode_large_plane() {
let w = 320u32;
let h = 240u32;
let total = (w * h) as usize;
let mut alpha = vec![0u8; total];
let mut state: u64 = 42;
for byte in alpha.iter_mut() {
state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
*byte = (state >> 33) as u8;
}
let encoded = encode_alpha(&alpha, w, h).expect("encode should succeed");
assert_eq!(encoded.len(), 1 + total);
let decoded = decode_alpha(&encoded, w, h).expect("decode should succeed");
assert_eq!(decoded, alpha);
}
#[test]
fn all_valid_header_bytes_parse() {
for comp in 0..=1u8 {
for filt in 0..=3u8 {
for prep in 0..=3u8 {
let byte = comp | (filt << 2) | (prep << 4);
let hdr = AlphaHeader::parse(byte)
.unwrap_or_else(|e| panic!("valid byte {byte:#04x} failed: {e}"));
assert_eq!(hdr.to_byte(), byte);
}
}
}
}
#[test]
fn all_reserved_header_bytes_rejected() {
for reserved in 1..=3u8 {
let byte = reserved << 6;
assert!(
AlphaHeader::parse(byte).is_err(),
"reserved={reserved} should be rejected"
);
}
}
}