1use gamut_color::clip_pixel8;
12use gamut_core::{Dimensions, Error, Result};
13
14use crate::vp8l::bit_io::BitReader;
15use crate::vp8l::decoder::decode_image;
16use crate::vp8l::encoder::encode_image;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum AlphaFilter {
21 None,
23 Horizontal,
25 Vertical,
27 Gradient,
29}
30
31impl AlphaFilter {
32 fn code(self) -> u8 {
34 match self {
35 Self::None => 0,
36 Self::Horizontal => 1,
37 Self::Vertical => 2,
38 Self::Gradient => 3,
39 }
40 }
41
42 fn from_code(code: u8) -> Self {
44 match code & 0x3 {
45 1 => Self::Horizontal,
46 2 => Self::Vertical,
47 3 => Self::Gradient,
48 _ => Self::None,
49 }
50 }
51
52 fn predict(self, plane: &[u8], width: usize, x: usize, y: usize) -> u8 {
57 let at = |x: usize, y: usize| plane[y * width + x];
58 if x == 0 && y == 0 {
59 return 0;
60 }
61 match self {
62 Self::None => 0,
63 Self::Horizontal => {
64 if x > 0 {
65 at(x - 1, y)
66 } else {
67 at(0, y - 1)
68 }
69 }
70 Self::Vertical => {
71 if y > 0 {
72 at(x, y - 1)
73 } else {
74 at(x - 1, 0)
75 }
76 }
77 Self::Gradient => {
78 if x == 0 {
79 at(0, y - 1)
80 } else if y == 0 {
81 at(x - 1, 0)
82 } else {
83 let (a, b, c) = (
84 i32::from(at(x - 1, y)),
85 i32::from(at(x, y - 1)),
86 i32::from(at(x - 1, y - 1)),
87 );
88 clip_pixel8(a + b - c)
89 }
90 }
91 }
92 }
93}
94
95#[must_use]
99pub fn filter(plane: &[u8], width: usize, height: usize, method: AlphaFilter) -> Vec<u8> {
100 let mut out = vec![0u8; plane.len()];
101 for y in 0..height {
102 for x in 0..width {
103 let i = y * width + x;
104 out[i] = plane[i].wrapping_sub(method.predict(plane, width, x, y));
105 }
106 }
107 out
108}
109
110#[must_use]
113pub fn unfilter(residuals: &[u8], width: usize, height: usize, method: AlphaFilter) -> Vec<u8> {
114 let mut plane = vec![0u8; residuals.len()];
115 for y in 0..height {
116 for x in 0..width {
117 let i = y * width + x;
118 let pred = method.predict(&plane, width, x, y);
119 plane[i] = pred.wrapping_add(residuals[i]);
120 }
121 }
122 plane
123}
124
125#[must_use]
129pub fn choose_filter(plane: &[u8], width: usize, height: usize) -> AlphaFilter {
130 [
131 AlphaFilter::None,
132 AlphaFilter::Horizontal,
133 AlphaFilter::Vertical,
134 AlphaFilter::Gradient,
135 ]
136 .into_iter()
137 .min_by_key(|&m| {
138 filter(plane, width, height, m)
139 .iter()
140 .map(|&r| u32::from(r.min(r.wrapping_neg())))
141 .sum::<u32>()
142 })
143 .unwrap_or(AlphaFilter::None)
144}
145
146#[must_use]
149pub fn write_raw_alph(plane: &[u8], width: usize, height: usize) -> Vec<u8> {
150 let method = choose_filter(plane, width, height);
151 let mut out = Vec::with_capacity(1 + plane.len());
152 out.push(method.code() << 2); out.extend_from_slice(&filter(plane, width, height, method));
154 out
155}
156
157fn write_compressed_alph(plane: &[u8], width: usize, height: usize) -> Result<Vec<u8>> {
166 let argb: Vec<u32> = plane
167 .iter()
168 .map(|&a| 0xff00_0000 | (u32::from(a) << 8))
169 .collect();
170 let dims = Dimensions {
171 width: width as u32,
172 height: height as u32,
173 };
174 let stream = encode_image(&argb, dims)?;
175 let mut out = Vec::with_capacity(1 + stream.len());
176 out.push(0x01); out.extend_from_slice(&stream);
178 Ok(out)
179}
180
181pub fn write_alph(plane: &[u8], width: usize, height: usize) -> Result<Vec<u8>> {
187 let raw = write_raw_alph(plane, width, height);
188 let compressed = write_compressed_alph(plane, width, height)?;
189 Ok(if compressed.len() < raw.len() {
190 compressed
191 } else {
192 raw
193 })
194}
195
196pub fn read_alph(payload: &[u8], width: usize, height: usize) -> Result<Vec<u8>> {
206 let &header = payload
207 .first()
208 .ok_or(Error::InvalidInput("ALPH: empty chunk"))?;
209 let method = AlphaFilter::from_code(header >> 2);
210 let data = &payload[1..];
211 let residuals = match header & 0x3 {
212 0 => {
213 if data.len() != width * height {
214 return Err(Error::InvalidInput("ALPH: raw alpha length mismatch"));
215 }
216 data.to_vec()
217 }
218 1 => {
219 let mut r = BitReader::new(data);
220 let argb = decode_image(&mut r, width as u32, height as u32)?;
221 argb.iter().map(|&p| (p >> 8) as u8).collect() }
223 _ => return Err(Error::InvalidInput("ALPH: reserved compression method")),
224 };
225 Ok(unfilter(&residuals, width, height, method))
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231
232 fn pattern(width: usize, height: usize) -> Vec<u8> {
233 (0..width * height)
234 .map(|i| {
235 let (x, y) = (i % width, i / width);
236 ((x * 9 + y * 5 + (x ^ y) * 3) & 0xff) as u8
237 })
238 .collect()
239 }
240
241 #[test]
242 fn each_filter_inverts_exactly() {
243 let (w, h) = (23, 17);
244 let plane = pattern(w, h);
245 for m in [
246 AlphaFilter::None,
247 AlphaFilter::Horizontal,
248 AlphaFilter::Vertical,
249 AlphaFilter::Gradient,
250 ] {
251 let residuals = filter(&plane, w, h, m);
252 assert_eq!(
253 unfilter(&residuals, w, h, m),
254 plane,
255 "filter {m:?} round-trip"
256 );
257 }
258 }
259
260 #[test]
261 fn none_filter_stores_alpha_verbatim() {
262 let plane = pattern(8, 8);
263 assert_eq!(filter(&plane, 8, 8, AlphaFilter::None), plane);
264 }
265
266 #[test]
267 fn raw_alph_chunk_round_trips() {
268 let (w, h) = (19, 11);
269 let plane = pattern(w, h);
270 let chunk = write_raw_alph(&plane, w, h);
271 assert_eq!(chunk.len(), 1 + w * h);
272 assert_eq!(chunk[0] & 0x3, 0, "compression method is raw");
273 assert_eq!(read_alph(&chunk, w, h).unwrap(), plane);
274 }
275
276 #[test]
277 fn read_alph_rejects_bad_input() {
278 assert!(read_alph(&[], 4, 4).is_err());
279 assert!(read_alph(&[0, 1, 2], 4, 4).is_err(), "wrong raw length");
280 assert!(
281 read_alph(&[0x01, 0x00], 8, 8).is_err(),
282 "truncated compressed stream"
283 );
284 }
285
286 #[test]
287 fn compressed_alph_round_trips() {
288 let (w, h) = (20, 12);
289 let plane = pattern(w, h);
290 let chunk = write_compressed_alph(&plane, w, h).unwrap();
291 assert_eq!(chunk[0] & 0x3, 1, "compression method is lossless");
292 assert_eq!(read_alph(&chunk, w, h).unwrap(), plane);
293 }
294
295 #[test]
296 fn write_alph_picks_the_smaller_and_round_trips() {
297 let (w, h) = (64, 64);
299 let plane: Vec<u8> = (0..w * h).map(|i| ((i / w) * 4) as u8).collect();
300 let chunk = write_alph(&plane, w, h).unwrap();
301 assert!(
302 chunk.len() < 1 + w * h,
303 "compressible alpha should beat raw"
304 );
305 assert_eq!(read_alph(&chunk, w, h).unwrap(), plane);
306 }
307}