use gamut_color::clip_pixel8;
use gamut_core::{Dimensions, Error, Result};
use crate::vp8l::bit_io::BitReader;
use crate::vp8l::decoder::decode_image;
use crate::vp8l::encoder::encode_image;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AlphaFilter {
None,
Horizontal,
Vertical,
Gradient,
}
impl AlphaFilter {
fn code(self) -> u8 {
match self {
Self::None => 0,
Self::Horizontal => 1,
Self::Vertical => 2,
Self::Gradient => 3,
}
}
fn from_code(code: u8) -> Self {
match code & 0x3 {
1 => Self::Horizontal,
2 => Self::Vertical,
3 => Self::Gradient,
_ => Self::None,
}
}
fn predict(self, plane: &[u8], width: usize, x: usize, y: usize) -> u8 {
let at = |x: usize, y: usize| plane[y * width + x];
if x == 0 && y == 0 {
return 0;
}
match self {
Self::None => 0,
Self::Horizontal => {
if x > 0 {
at(x - 1, y)
} else {
at(0, y - 1)
}
}
Self::Vertical => {
if y > 0 {
at(x, y - 1)
} else {
at(x - 1, 0)
}
}
Self::Gradient => {
if x == 0 {
at(0, y - 1)
} else if y == 0 {
at(x - 1, 0)
} else {
let (a, b, c) = (
i32::from(at(x - 1, y)),
i32::from(at(x, y - 1)),
i32::from(at(x - 1, y - 1)),
);
clip_pixel8(a + b - c)
}
}
}
}
}
#[must_use]
pub fn filter(plane: &[u8], width: usize, height: usize, method: AlphaFilter) -> Vec<u8> {
let mut out = vec![0u8; plane.len()];
for y in 0..height {
for x in 0..width {
let i = y * width + x;
out[i] = plane[i].wrapping_sub(method.predict(plane, width, x, y));
}
}
out
}
#[must_use]
pub fn unfilter(residuals: &[u8], width: usize, height: usize, method: AlphaFilter) -> Vec<u8> {
let mut plane = vec![0u8; residuals.len()];
for y in 0..height {
for x in 0..width {
let i = y * width + x;
let pred = method.predict(&plane, width, x, y);
plane[i] = pred.wrapping_add(residuals[i]);
}
}
plane
}
#[must_use]
pub fn choose_filter(plane: &[u8], width: usize, height: usize) -> AlphaFilter {
[
AlphaFilter::None,
AlphaFilter::Horizontal,
AlphaFilter::Vertical,
AlphaFilter::Gradient,
]
.into_iter()
.min_by_key(|&m| {
filter(plane, width, height, m)
.iter()
.map(|&r| u32::from(r.min(r.wrapping_neg())))
.sum::<u32>()
})
.unwrap_or(AlphaFilter::None)
}
#[must_use]
pub fn write_raw_alph(plane: &[u8], width: usize, height: usize) -> Vec<u8> {
let method = choose_filter(plane, width, height);
let mut out = Vec::with_capacity(1 + plane.len());
out.push(method.code() << 2); out.extend_from_slice(&filter(plane, width, height, method));
out
}
fn write_compressed_alph(plane: &[u8], width: usize, height: usize) -> Result<Vec<u8>> {
let argb: Vec<u32> = plane
.iter()
.map(|&a| 0xff00_0000 | (u32::from(a) << 8))
.collect();
let dims = Dimensions {
width: width as u32,
height: height as u32,
};
let stream = encode_image(&argb, dims)?;
let mut out = Vec::with_capacity(1 + stream.len());
out.push(0x01); out.extend_from_slice(&stream);
Ok(out)
}
pub fn write_alph(plane: &[u8], width: usize, height: usize) -> Result<Vec<u8>> {
let raw = write_raw_alph(plane, width, height);
let compressed = write_compressed_alph(plane, width, height)?;
Ok(if compressed.len() < raw.len() {
compressed
} else {
raw
})
}
pub fn read_alph(payload: &[u8], width: usize, height: usize) -> Result<Vec<u8>> {
let &header = payload
.first()
.ok_or(Error::InvalidInput("ALPH: empty chunk"))?;
let method = AlphaFilter::from_code(header >> 2);
let data = &payload[1..];
let residuals = match header & 0x3 {
0 => {
if data.len() != width * height {
return Err(Error::InvalidInput("ALPH: raw alpha length mismatch"));
}
data.to_vec()
}
1 => {
let mut r = BitReader::new(data);
let argb = decode_image(&mut r, width as u32, height as u32)?;
argb.iter().map(|&p| (p >> 8) as u8).collect() }
_ => return Err(Error::InvalidInput("ALPH: reserved compression method")),
};
Ok(unfilter(&residuals, width, height, method))
}
#[cfg(test)]
mod tests {
use super::*;
fn pattern(width: usize, height: usize) -> Vec<u8> {
(0..width * height)
.map(|i| {
let (x, y) = (i % width, i / width);
((x * 9 + y * 5 + (x ^ y) * 3) & 0xff) as u8
})
.collect()
}
#[test]
fn each_filter_inverts_exactly() {
let (w, h) = (23, 17);
let plane = pattern(w, h);
for m in [
AlphaFilter::None,
AlphaFilter::Horizontal,
AlphaFilter::Vertical,
AlphaFilter::Gradient,
] {
let residuals = filter(&plane, w, h, m);
assert_eq!(
unfilter(&residuals, w, h, m),
plane,
"filter {m:?} round-trip"
);
}
}
#[test]
fn none_filter_stores_alpha_verbatim() {
let plane = pattern(8, 8);
assert_eq!(filter(&plane, 8, 8, AlphaFilter::None), plane);
}
#[test]
fn raw_alph_chunk_round_trips() {
let (w, h) = (19, 11);
let plane = pattern(w, h);
let chunk = write_raw_alph(&plane, w, h);
assert_eq!(chunk.len(), 1 + w * h);
assert_eq!(chunk[0] & 0x3, 0, "compression method is raw");
assert_eq!(read_alph(&chunk, w, h).unwrap(), plane);
}
#[test]
fn read_alph_rejects_bad_input() {
assert!(read_alph(&[], 4, 4).is_err());
assert!(read_alph(&[0, 1, 2], 4, 4).is_err(), "wrong raw length");
assert!(
read_alph(&[0x01, 0x00], 8, 8).is_err(),
"truncated compressed stream"
);
}
#[test]
fn compressed_alph_round_trips() {
let (w, h) = (20, 12);
let plane = pattern(w, h);
let chunk = write_compressed_alph(&plane, w, h).unwrap();
assert_eq!(chunk[0] & 0x3, 1, "compression method is lossless");
assert_eq!(read_alph(&chunk, w, h).unwrap(), plane);
}
#[test]
fn write_alph_picks_the_smaller_and_round_trips() {
let (w, h) = (64, 64);
let plane: Vec<u8> = (0..w * h).map(|i| ((i / w) * 4) as u8).collect();
let chunk = write_alph(&plane, w, h).unwrap();
assert!(
chunk.len() < 1 + w * h,
"compressible alpha should beat raw"
);
assert_eq!(read_alph(&chunk, w, h).unwrap(), plane);
}
}