Skip to main content

oximedia_codec/webp/
alpha.rs

1//! WebP ALPH chunk handler for alpha channel encoding and decoding.
2//!
3//! In WebP extended format, the alpha channel is stored separately in an ALPH chunk.
4//! The chunk consists of a 1-byte header followed by filtered/compressed alpha data.
5//!
6//! # ALPH Header Byte Layout
7//!
8//! ```text
9//! ┌─────────┬─────────┬──────────────┬────────────────────┐
10//! │ bits 7:6│ bits 5:4│ bits 3:2     │ bits 1:0           │
11//! │reserved │pre-proc │filter method │compression method  │
12//! └─────────┴─────────┴──────────────┴────────────────────┘
13//! ```
14
15use crate::error::{CodecError, CodecResult};
16
17// ---------------------------------------------------------------------------
18// Public types
19// ---------------------------------------------------------------------------
20
21/// Alpha compression method.
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum AlphaCompression {
24    /// Raw (uncompressed) alpha data.
25    NoCompression = 0,
26    /// Alpha data encoded as a VP8L (WebP lossless) bitstream.
27    WebPLossless = 1,
28}
29
30/// Alpha filtering method applied before compression.
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum AlphaFilter {
33    /// No filtering – data stored as-is.
34    None = 0,
35    /// Horizontal prediction: each pixel predicted from its left neighbour.
36    Horizontal = 1,
37    /// Vertical prediction: each pixel predicted from the pixel above.
38    Vertical = 2,
39    /// Gradient prediction: `left + top - top_left`, clamped to `[0, 255]`.
40    Gradient = 3,
41}
42
43/// Parsed ALPH chunk header.
44#[derive(Debug, Clone)]
45pub struct AlphaHeader {
46    /// Compression method used for the alpha data.
47    pub compression: AlphaCompression,
48    /// Spatial filter applied prior to compression.
49    pub filter: AlphaFilter,
50    /// Pre-processing level (0 = none, 1 = level reduction).
51    pub pre_processing: u8,
52}
53
54// ---------------------------------------------------------------------------
55// Header parsing helpers
56// ---------------------------------------------------------------------------
57
58impl AlphaCompression {
59    fn from_bits(bits: u8) -> CodecResult<Self> {
60        match bits & 0x03 {
61            0 => Ok(Self::NoCompression),
62            1 => Ok(Self::WebPLossless),
63            v => Err(CodecError::InvalidBitstream(format!(
64                "unknown alpha compression method: {v}"
65            ))),
66        }
67    }
68}
69
70impl AlphaFilter {
71    fn from_bits(bits: u8) -> CodecResult<Self> {
72        match bits & 0x03 {
73            0 => Ok(Self::None),
74            1 => Ok(Self::Horizontal),
75            2 => Ok(Self::Vertical),
76            3 => Ok(Self::Gradient),
77            // unreachable because of the mask, but the compiler does not know
78            _ => Err(CodecError::InvalidBitstream(
79                "unknown alpha filter method".to_string(),
80            )),
81        }
82    }
83}
84
85impl AlphaHeader {
86    /// Parse the single header byte from the beginning of an ALPH chunk payload.
87    pub fn parse(byte: u8) -> CodecResult<Self> {
88        let reserved = (byte >> 6) & 0x03;
89        if reserved != 0 {
90            return Err(CodecError::InvalidBitstream(format!(
91                "ALPH header reserved bits are non-zero: {reserved}"
92            )));
93        }
94        let compression = AlphaCompression::from_bits(byte & 0x03)?;
95        let filter = AlphaFilter::from_bits((byte >> 2) & 0x03)?;
96        let pre_processing = (byte >> 4) & 0x03;
97
98        Ok(Self {
99            compression,
100            filter,
101            pre_processing,
102        })
103    }
104
105    /// Serialize the header back to a single byte.
106    pub fn to_byte(&self) -> u8 {
107        let comp = self.compression as u8;
108        let filt = (self.filter as u8) << 2;
109        let prep = (self.pre_processing & 0x03) << 4;
110        comp | filt | prep
111    }
112}
113
114// ---------------------------------------------------------------------------
115// Gradient prediction helper
116// ---------------------------------------------------------------------------
117
118/// Compute the gradient predictor: `left + top - top_left`, clamped to [0, 255].
119#[inline]
120fn gradient_predict(left: u8, top: u8, top_left: u8) -> u8 {
121    let val = left as i16 + top as i16 - top_left as i16;
122    val.clamp(0, 255) as u8
123}
124
125// ---------------------------------------------------------------------------
126// Filtering (decode direction – reconstruction)
127// ---------------------------------------------------------------------------
128
129/// Reconstruct original alpha values from filtered residuals.
130///
131/// This is the *decode* direction: we read the stored residual and add the
132/// prediction derived from already-reconstructed neighbours.
133fn apply_filter(data: &mut [u8], width: u32, height: u32, filter: AlphaFilter) {
134    let w = width as usize;
135    let h = height as usize;
136    let total = w * h;
137    if total == 0 || data.len() < total {
138        return;
139    }
140
141    match filter {
142        AlphaFilter::None => { /* nothing to do */ }
143
144        AlphaFilter::Horizontal => {
145            for y in 0..h {
146                let row_start = y * w;
147                // First pixel in row: predicted from 0 (or left=0)
148                // data[row_start] already holds the correct value (residual + 0)
149                for x in 1..w {
150                    let idx = row_start + x;
151                    let left = data[idx - 1];
152                    data[idx] = data[idx].wrapping_add(left);
153                }
154            }
155        }
156
157        AlphaFilter::Vertical => {
158            // First row: no prediction (residual = raw)
159            for y in 1..h {
160                for x in 0..w {
161                    let idx = y * w + x;
162                    let top = data[idx - w];
163                    data[idx] = data[idx].wrapping_add(top);
164                }
165            }
166        }
167
168        AlphaFilter::Gradient => {
169            // First row: horizontal-only prediction
170            for x in 1..w {
171                data[x] = data[x].wrapping_add(data[x - 1]);
172            }
173            for y in 1..h {
174                let row_start = y * w;
175                // First pixel of row: predict from top only
176                data[row_start] = data[row_start].wrapping_add(data[row_start - w]);
177
178                for x in 1..w {
179                    let idx = row_start + x;
180                    let left = data[idx - 1];
181                    let top = data[idx - w];
182                    let top_left = data[idx - w - 1];
183                    let pred = gradient_predict(left, top, top_left);
184                    data[idx] = data[idx].wrapping_add(pred);
185                }
186            }
187        }
188    }
189}
190
191// ---------------------------------------------------------------------------
192// Inverse filtering (encode direction – produce residuals)
193// ---------------------------------------------------------------------------
194
195/// Produce filtered residuals from raw alpha values.
196///
197/// This is the *encode* direction: given original alpha values we compute
198/// `residual = original - prediction` (using wrapping arithmetic).
199fn apply_inverse_filter(data: &[u8], width: u32, height: u32, filter: AlphaFilter) -> Vec<u8> {
200    let w = width as usize;
201    let h = height as usize;
202    let total = w * h;
203    if total == 0 {
204        return Vec::new();
205    }
206
207    match filter {
208        AlphaFilter::None => data[..total].to_vec(),
209
210        AlphaFilter::Horizontal => {
211            let mut out = vec![0u8; total];
212            for y in 0..h {
213                let row_start = y * w;
214                out[row_start] = data[row_start]; // first pixel – no prediction
215                for x in 1..w {
216                    let idx = row_start + x;
217                    let left = data[idx - 1];
218                    out[idx] = data[idx].wrapping_sub(left);
219                }
220            }
221            out
222        }
223
224        AlphaFilter::Vertical => {
225            let mut out = vec![0u8; total];
226            // First row – no prediction
227            out[..w].copy_from_slice(&data[..w]);
228            for y in 1..h {
229                for x in 0..w {
230                    let idx = y * w + x;
231                    let top = data[idx - w];
232                    out[idx] = data[idx].wrapping_sub(top);
233                }
234            }
235            out
236        }
237
238        AlphaFilter::Gradient => {
239            let mut out = vec![0u8; total];
240            out[0] = data[0]; // top-left corner
241
242            // First row: horizontal prediction
243            for x in 1..w {
244                out[x] = data[x].wrapping_sub(data[x - 1]);
245            }
246            for y in 1..h {
247                let row_start = y * w;
248                // First pixel of row: predict from top only
249                out[row_start] = data[row_start].wrapping_sub(data[row_start - w]);
250
251                for x in 1..w {
252                    let idx = row_start + x;
253                    let left = data[idx - 1];
254                    let top = data[idx - w];
255                    let top_left = data[idx - w - 1];
256                    let pred = gradient_predict(left, top, top_left);
257                    out[idx] = data[idx].wrapping_sub(pred);
258                }
259            }
260            out
261        }
262    }
263}
264
265// ---------------------------------------------------------------------------
266// Heuristic: choose the best filter for encoding
267// ---------------------------------------------------------------------------
268
269/// Score a filter by summing the absolute residuals – lower is better
270/// (residuals closer to zero compress better).
271fn score_filter(data: &[u8], width: u32, height: u32, filter: AlphaFilter) -> u64 {
272    let residuals = apply_inverse_filter(data, width, height, filter);
273    residuals
274        .iter()
275        .map(|&b| {
276            // Treat residual as signed offset from zero (wrapping distance)
277            let v = b as i16;
278            let d = if v > 128 { 256 - v } else { v };
279            d as u64
280        })
281        .sum()
282}
283
284/// Select the filter that produces the smallest total residual.
285fn select_best_filter(data: &[u8], width: u32, height: u32) -> AlphaFilter {
286    let filters = [
287        AlphaFilter::None,
288        AlphaFilter::Horizontal,
289        AlphaFilter::Vertical,
290        AlphaFilter::Gradient,
291    ];
292
293    let mut best_filter = AlphaFilter::None;
294    let mut best_score = u64::MAX;
295
296    for &f in &filters {
297        let s = score_filter(data, width, height, f);
298        if s < best_score {
299            best_score = s;
300            best_filter = f;
301        }
302    }
303
304    best_filter
305}
306
307// ---------------------------------------------------------------------------
308// Public API
309// ---------------------------------------------------------------------------
310
311/// Decode alpha channel from ALPH chunk payload.
312///
313/// `data` is the raw ALPH chunk payload (starting with the header byte).
314/// Returns a `Vec<u8>` of length `width * height` with reconstructed alpha
315/// values in row-major order.
316pub fn decode_alpha(data: &[u8], width: u32, height: u32) -> CodecResult<Vec<u8>> {
317    if data.is_empty() {
318        return Err(CodecError::InvalidBitstream(
319            "ALPH chunk is empty".to_string(),
320        ));
321    }
322
323    let total = (width as usize)
324        .checked_mul(height as usize)
325        .ok_or_else(|| {
326            CodecError::InvalidParameter(format!(
327                "alpha plane dimensions overflow: {width} x {height}"
328            ))
329        })?;
330
331    if total == 0 {
332        return Ok(Vec::new());
333    }
334
335    let header = AlphaHeader::parse(data[0])?;
336    let payload = &data[1..];
337
338    match header.compression {
339        AlphaCompression::NoCompression => {
340            if payload.len() < total {
341                return Err(CodecError::BufferTooSmall {
342                    needed: total,
343                    have: payload.len(),
344                });
345            }
346            let mut alpha = payload[..total].to_vec();
347            apply_filter(&mut alpha, width, height, header.filter);
348            Ok(alpha)
349        }
350        AlphaCompression::WebPLossless => Err(CodecError::UnsupportedFeature(
351            "VP8L-compressed alpha channel is not yet supported".to_string(),
352        )),
353    }
354}
355
356/// Encode alpha channel into an ALPH chunk payload.
357///
358/// `alpha` must contain exactly `width * height` bytes in row-major order.
359/// Returns the complete chunk payload: 1-byte header followed by the
360/// (optionally filtered) alpha data, using no compression.
361pub fn encode_alpha(alpha: &[u8], width: u32, height: u32) -> CodecResult<Vec<u8>> {
362    let total = (width as usize)
363        .checked_mul(height as usize)
364        .ok_or_else(|| {
365            CodecError::InvalidParameter(format!(
366                "alpha plane dimensions overflow: {width} x {height}"
367            ))
368        })?;
369
370    if alpha.len() < total {
371        return Err(CodecError::BufferTooSmall {
372            needed: total,
373            have: alpha.len(),
374        });
375    }
376
377    if total == 0 {
378        // Degenerate case: just a header byte
379        let hdr = AlphaHeader {
380            compression: AlphaCompression::NoCompression,
381            filter: AlphaFilter::None,
382            pre_processing: 0,
383        };
384        return Ok(vec![hdr.to_byte()]);
385    }
386
387    let input = &alpha[..total];
388
389    // Choose the best filter heuristically
390    let best_filter = select_best_filter(input, width, height);
391
392    let header = AlphaHeader {
393        compression: AlphaCompression::NoCompression,
394        filter: best_filter,
395        pre_processing: 0,
396    };
397
398    let residuals = apply_inverse_filter(input, width, height, best_filter);
399
400    let mut out = Vec::with_capacity(1 + residuals.len());
401    out.push(header.to_byte());
402    out.extend_from_slice(&residuals);
403    Ok(out)
404}
405
406// ---------------------------------------------------------------------------
407// Tests
408// ---------------------------------------------------------------------------
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413
414    // -- Header round-trip ---------------------------------------------------
415
416    #[test]
417    fn header_roundtrip_no_compression_no_filter() {
418        let hdr = AlphaHeader {
419            compression: AlphaCompression::NoCompression,
420            filter: AlphaFilter::None,
421            pre_processing: 0,
422        };
423        let byte = hdr.to_byte();
424        assert_eq!(byte, 0x00);
425        let parsed = AlphaHeader::parse(byte).expect("parse should succeed");
426        assert_eq!(parsed.compression, AlphaCompression::NoCompression);
427        assert_eq!(parsed.filter, AlphaFilter::None);
428        assert_eq!(parsed.pre_processing, 0);
429    }
430
431    #[test]
432    fn header_roundtrip_all_fields() {
433        // compression=0, filter=gradient(3), pre_processing=1
434        // byte = 0 | (3 << 2) | (1 << 4) = 0 | 12 | 16 = 28
435        let hdr = AlphaHeader {
436            compression: AlphaCompression::NoCompression,
437            filter: AlphaFilter::Gradient,
438            pre_processing: 1,
439        };
440        let byte = hdr.to_byte();
441        assert_eq!(byte, 0x1C);
442        let parsed = AlphaHeader::parse(byte).expect("parse should succeed");
443        assert_eq!(parsed.compression, AlphaCompression::NoCompression);
444        assert_eq!(parsed.filter, AlphaFilter::Gradient);
445        assert_eq!(parsed.pre_processing, 1);
446    }
447
448    #[test]
449    fn header_roundtrip_webp_lossless_horizontal() {
450        // compression=1, filter=horizontal(1), pre_processing=0
451        // byte = 1 | (1 << 2) = 5
452        let hdr = AlphaHeader {
453            compression: AlphaCompression::WebPLossless,
454            filter: AlphaFilter::Horizontal,
455            pre_processing: 0,
456        };
457        let byte = hdr.to_byte();
458        assert_eq!(byte, 0x05);
459        let parsed = AlphaHeader::parse(byte).expect("parse should succeed");
460        assert_eq!(parsed.compression, AlphaCompression::WebPLossless);
461        assert_eq!(parsed.filter, AlphaFilter::Horizontal);
462    }
463
464    #[test]
465    fn header_reserved_bits_rejected() {
466        // Set reserved bits (bit 6)
467        let result = AlphaHeader::parse(0x40);
468        assert!(result.is_err());
469    }
470
471    // -- Filter round-trips --------------------------------------------------
472
473    #[test]
474    fn filter_none_roundtrip() {
475        let original: Vec<u8> = (0..12).collect();
476        let w = 4u32;
477        let h = 3u32;
478
479        let residuals = apply_inverse_filter(&original, w, h, AlphaFilter::None);
480        assert_eq!(residuals, original);
481
482        let mut reconstructed = residuals;
483        apply_filter(&mut reconstructed, w, h, AlphaFilter::None);
484        assert_eq!(reconstructed, original);
485    }
486
487    #[test]
488    fn filter_horizontal_roundtrip() {
489        let original: Vec<u8> = vec![10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120];
490        let w = 4u32;
491        let h = 3u32;
492
493        let residuals = apply_inverse_filter(&original, w, h, AlphaFilter::Horizontal);
494
495        // Verify residuals manually for first row:
496        // res[0] = 10 (first pixel, no prediction)
497        // res[1] = 20 - 10 = 10
498        // res[2] = 30 - 20 = 10
499        // res[3] = 40 - 30 = 10
500        assert_eq!(residuals[0], 10);
501        assert_eq!(residuals[1], 10);
502        assert_eq!(residuals[2], 10);
503        assert_eq!(residuals[3], 10);
504
505        let mut reconstructed = residuals;
506        apply_filter(&mut reconstructed, w, h, AlphaFilter::Horizontal);
507        assert_eq!(reconstructed, original);
508    }
509
510    #[test]
511    fn filter_vertical_roundtrip() {
512        let original: Vec<u8> = vec![10, 20, 30, 40, 15, 25, 35, 45, 20, 30, 40, 50];
513        let w = 4u32;
514        let h = 3u32;
515
516        let residuals = apply_inverse_filter(&original, w, h, AlphaFilter::Vertical);
517
518        // First row: unchanged
519        assert_eq!(&residuals[0..4], &[10, 20, 30, 40]);
520        // Second row: res[i] = orig[i] - orig[i-w]
521        // 15-10=5, 25-20=5, 35-30=5, 45-40=5
522        assert_eq!(&residuals[4..8], &[5, 5, 5, 5]);
523
524        let mut reconstructed = residuals;
525        apply_filter(&mut reconstructed, w, h, AlphaFilter::Vertical);
526        assert_eq!(reconstructed, original);
527    }
528
529    #[test]
530    fn filter_gradient_roundtrip() {
531        let original: Vec<u8> = vec![100, 110, 120, 130, 105, 115, 125, 135, 110, 120, 130, 140];
532        let w = 4u32;
533        let h = 3u32;
534
535        let residuals = apply_inverse_filter(&original, w, h, AlphaFilter::Gradient);
536        let mut reconstructed = residuals;
537        apply_filter(&mut reconstructed, w, h, AlphaFilter::Gradient);
538        assert_eq!(reconstructed, original);
539    }
540
541    #[test]
542    fn filter_gradient_known_vector() {
543        // 2x2 image
544        // [100, 150]
545        // [120, 170]
546        let original: Vec<u8> = vec![100, 150, 120, 170];
547        let w = 2u32;
548        let h = 2u32;
549
550        let residuals = apply_inverse_filter(&original, w, h, AlphaFilter::Gradient);
551
552        // Pixel (0,0): no prediction => 100
553        assert_eq!(residuals[0], 100);
554        // Pixel (1,0): horizontal prediction from left => 150 - 100 = 50
555        assert_eq!(residuals[1], 50);
556        // Pixel (0,1): vertical prediction from top => 120 - 100 = 20
557        assert_eq!(residuals[2], 20);
558        // Pixel (1,1): gradient = left(120) + top(150) - top_left(100) = 170 => 170 - 170 = 0
559        assert_eq!(residuals[3], 0);
560
561        let mut reconstructed = residuals;
562        apply_filter(&mut reconstructed, w, h, AlphaFilter::Gradient);
563        assert_eq!(reconstructed, original);
564    }
565
566    #[test]
567    fn gradient_predict_clamp_high() {
568        // left=200, top=200, top_left=0 => 200+200-0 = 400 => clamped to 255
569        assert_eq!(gradient_predict(200, 200, 0), 255);
570    }
571
572    #[test]
573    fn gradient_predict_clamp_low() {
574        // left=0, top=0, top_left=200 => 0+0-200 = -200 => clamped to 0
575        assert_eq!(gradient_predict(0, 0, 200), 0);
576    }
577
578    #[test]
579    fn gradient_predict_normal() {
580        assert_eq!(gradient_predict(100, 80, 60), 120);
581    }
582
583    // -- Encode / Decode round-trip ------------------------------------------
584
585    #[test]
586    fn encode_decode_roundtrip_uniform() {
587        let w = 8u32;
588        let h = 6u32;
589        let alpha = vec![128u8; (w * h) as usize];
590
591        let encoded = encode_alpha(&alpha, w, h).expect("encode should succeed");
592        let decoded = decode_alpha(&encoded, w, h).expect("decode should succeed");
593        assert_eq!(decoded, alpha);
594    }
595
596    #[test]
597    fn encode_decode_roundtrip_gradient_data() {
598        let w = 16u32;
599        let h = 8u32;
600        let mut alpha = vec![0u8; (w * h) as usize];
601        for y in 0..h as usize {
602            for x in 0..w as usize {
603                alpha[y * w as usize + x] = ((x * 16 + y * 8) & 0xFF) as u8;
604            }
605        }
606
607        let encoded = encode_alpha(&alpha, w, h).expect("encode should succeed");
608        let decoded = decode_alpha(&encoded, w, h).expect("decode should succeed");
609        assert_eq!(decoded, alpha);
610    }
611
612    #[test]
613    fn encode_decode_roundtrip_random_like() {
614        // Pseudo-random data generated from a simple LCG
615        let w = 10u32;
616        let h = 10u32;
617        let mut alpha = vec![0u8; (w * h) as usize];
618        let mut state: u32 = 0xDEAD_BEEF;
619        for byte in alpha.iter_mut() {
620            state = state.wrapping_mul(1664525).wrapping_add(1013904223);
621            *byte = (state >> 16) as u8;
622        }
623
624        let encoded = encode_alpha(&alpha, w, h).expect("encode should succeed");
625        let decoded = decode_alpha(&encoded, w, h).expect("decode should succeed");
626        assert_eq!(decoded, alpha);
627    }
628
629    #[test]
630    fn encode_decode_roundtrip_single_pixel() {
631        let alpha = vec![42u8];
632        let encoded = encode_alpha(&alpha, 1, 1).expect("encode should succeed");
633        let decoded = decode_alpha(&encoded, 1, 1).expect("decode should succeed");
634        assert_eq!(decoded, alpha);
635    }
636
637    #[test]
638    fn encode_decode_roundtrip_single_row() {
639        let alpha: Vec<u8> = (0..=255).collect();
640        let encoded = encode_alpha(&alpha, 256, 1).expect("encode should succeed");
641        let decoded = decode_alpha(&encoded, 256, 1).expect("decode should succeed");
642        assert_eq!(decoded, alpha);
643    }
644
645    #[test]
646    fn encode_decode_roundtrip_single_column() {
647        let alpha: Vec<u8> = (0..128).collect();
648        let encoded = encode_alpha(&alpha, 1, 128).expect("encode should succeed");
649        let decoded = decode_alpha(&encoded, 1, 128).expect("decode should succeed");
650        assert_eq!(decoded, alpha);
651    }
652
653    // -- Edge cases / error paths --------------------------------------------
654
655    #[test]
656    fn decode_empty_chunk_is_error() {
657        let result = decode_alpha(&[], 4, 4);
658        assert!(result.is_err());
659    }
660
661    #[test]
662    fn decode_truncated_payload_is_error() {
663        // Header byte for no-compression, no-filter + only 3 bytes but we need 16
664        let data = vec![0x00, 1, 2, 3];
665        let result = decode_alpha(&data, 4, 4);
666        assert!(result.is_err());
667    }
668
669    #[test]
670    fn decode_vp8l_alpha_is_unsupported() {
671        // Header byte with compression = 1 (WebP lossless)
672        let data = vec![0x01, 0, 0, 0, 0];
673        let result = decode_alpha(&data, 2, 2);
674        assert!(result.is_err());
675        let err_msg = format!("{}", result.expect_err("should be error"));
676        assert!(err_msg.contains("not yet supported"));
677    }
678
679    #[test]
680    fn encode_too_short_input_is_error() {
681        let alpha = vec![0u8; 3]; // need 4 for 2x2
682        let result = encode_alpha(&alpha, 2, 2);
683        assert!(result.is_err());
684    }
685
686    #[test]
687    fn encode_decode_zero_dimensions() {
688        let alpha: Vec<u8> = Vec::new();
689        let encoded = encode_alpha(&alpha, 0, 0).expect("encode 0x0 should succeed");
690        let decoded = decode_alpha(&encoded, 0, 0).expect("decode 0x0 should succeed");
691        assert!(decoded.is_empty());
692    }
693
694    #[test]
695    fn overflow_dimensions_rejected() {
696        let result = encode_alpha(&[0], u32::MAX, u32::MAX);
697        assert!(result.is_err());
698    }
699
700    // -- Known ALPH chunk test vectors ---------------------------------------
701
702    #[test]
703    fn known_vector_no_filter_no_compression() {
704        // Manually constructed ALPH chunk: header=0x00, followed by raw alpha
705        let alpha_raw = vec![255, 128, 64, 0, 200, 100, 50, 25];
706        let w = 4u32;
707        let h = 2u32;
708
709        let mut chunk = vec![0x00u8]; // no compression, no filter, no pre-processing
710        chunk.extend_from_slice(&alpha_raw);
711
712        let decoded = decode_alpha(&chunk, w, h).expect("decode should succeed");
713        assert_eq!(decoded, alpha_raw);
714    }
715
716    #[test]
717    fn known_vector_horizontal_filter() {
718        // 4x2 image alpha: [10, 20, 30, 40, 50, 60, 70, 80]
719        // Horizontal residuals:
720        //   row0: [10, 10, 10, 10]  (20-10=10, 30-20=10, 40-30=10)
721        //   row1: [50, 10, 10, 10]  (60-50=10, 70-60=10, 80-70=10)
722        let expected = vec![10u8, 20, 30, 40, 50, 60, 70, 80];
723        let residuals = vec![10u8, 10, 10, 10, 50, 10, 10, 10];
724
725        // header byte: compression=0, filter=horizontal(1) => (1 << 2) = 0x04
726        let mut chunk = vec![0x04u8];
727        chunk.extend_from_slice(&residuals);
728
729        let decoded = decode_alpha(&chunk, 4, 2).expect("decode should succeed");
730        assert_eq!(decoded, expected);
731    }
732
733    #[test]
734    fn known_vector_vertical_filter() {
735        // 3x3 image alpha:
736        // [10, 20, 30]
737        // [15, 25, 35]
738        // [20, 30, 40]
739        // Vertical residuals:
740        //   row0: [10, 20, 30] (first row unchanged)
741        //   row1: [5, 5, 5]   (15-10, 25-20, 35-30)
742        //   row2: [5, 5, 5]   (20-15, 30-25, 40-35)
743        let expected = vec![10u8, 20, 30, 15, 25, 35, 20, 30, 40];
744        let residuals = vec![10u8, 20, 30, 5, 5, 5, 5, 5, 5];
745
746        // header byte: compression=0, filter=vertical(2) => (2 << 2) = 0x08
747        let mut chunk = vec![0x08u8];
748        chunk.extend_from_slice(&residuals);
749
750        let decoded = decode_alpha(&chunk, 3, 3).expect("decode should succeed");
751        assert_eq!(decoded, expected);
752    }
753
754    #[test]
755    fn known_vector_gradient_filter() {
756        // 2x2 image: [100, 150, 120, 170]
757        // Gradient residuals (computed in filter_gradient_known_vector test):
758        //   [100, 50, 20, 0]
759        let expected = vec![100u8, 150, 120, 170];
760        let residuals = vec![100u8, 50, 20, 0];
761
762        // header byte: compression=0, filter=gradient(3) => (3 << 2) = 0x0C
763        let mut chunk = vec![0x0Cu8];
764        chunk.extend_from_slice(&residuals);
765
766        let decoded = decode_alpha(&chunk, 2, 2).expect("decode should succeed");
767        assert_eq!(decoded, expected);
768    }
769
770    // -- Filter selection heuristic ------------------------------------------
771
772    #[test]
773    fn select_best_filter_for_uniform_data() {
774        // All same value => None, Horizontal, and Vertical filters produce
775        // near-zero residuals. Gradient has slightly higher residuals for
776        // the boundary pixels. The best filter should produce a score no
777        // worse than the None filter.
778        let data = vec![128u8; 64];
779        let best = select_best_filter(&data, 8, 8);
780        let best_score = score_filter(&data, 8, 8, best);
781        let none_score = score_filter(&data, 8, 8, AlphaFilter::None);
782        assert!(best_score <= none_score);
783    }
784
785    #[test]
786    fn select_best_filter_for_horizontal_ramp() {
787        // Each row is a horizontal ramp with constant step.
788        // Gradient filter perfectly predicts this pattern (left + top - top_left
789        // cancels out for linear data), so it ties with or beats horizontal.
790        let mut data = vec![0u8; 64];
791        for y in 0..8usize {
792            for x in 0..8usize {
793                data[y * 8 + x] = (x * 30) as u8;
794            }
795        }
796        let best = select_best_filter(&data, 8, 8);
797        let best_score = score_filter(&data, 8, 8, best);
798        let horiz_score = score_filter(&data, 8, 8, AlphaFilter::Horizontal);
799        // The chosen filter must be at least as good as horizontal
800        assert!(best_score <= horiz_score);
801    }
802
803    #[test]
804    fn select_best_filter_for_vertical_ramp() {
805        // Each column is a vertical ramp => gradient filter also handles this well
806        // since it subsumes both horizontal and vertical prediction.
807        let mut data = vec![0u8; 64];
808        for y in 0..8usize {
809            for x in 0..8usize {
810                data[y * 8 + x] = (y * 30) as u8;
811            }
812        }
813        let best = select_best_filter(&data, 8, 8);
814        let best_score = score_filter(&data, 8, 8, best);
815        let vert_score = score_filter(&data, 8, 8, AlphaFilter::Vertical);
816        // The chosen filter must be at least as good as vertical
817        assert!(best_score <= vert_score);
818    }
819
820    // -- Wrapping arithmetic correctness -------------------------------------
821
822    #[test]
823    fn filter_horizontal_wrapping() {
824        // Ensure wrapping works: 250, 10 => residual = 10 - 250 = -240 => 16 (wrapping)
825        // Reconstruction: 16 + 250 = 266 => 10 (wrapping)
826        let original = vec![250u8, 10];
827        let residuals = apply_inverse_filter(&original, 2, 1, AlphaFilter::Horizontal);
828        assert_eq!(residuals[0], 250);
829        assert_eq!(residuals[1], 10u8.wrapping_sub(250));
830
831        let mut reconstructed = residuals;
832        apply_filter(&mut reconstructed, 2, 1, AlphaFilter::Horizontal);
833        assert_eq!(reconstructed, original);
834    }
835
836    #[test]
837    fn filter_vertical_wrapping() {
838        let original = vec![5u8, 250]; // w=1, h=2
839        let residuals = apply_inverse_filter(&original, 1, 2, AlphaFilter::Vertical);
840        assert_eq!(residuals[0], 5);
841        assert_eq!(residuals[1], 250u8.wrapping_sub(5));
842
843        let mut reconstructed = residuals;
844        apply_filter(&mut reconstructed, 1, 2, AlphaFilter::Vertical);
845        assert_eq!(reconstructed, original);
846    }
847
848    // -- Larger stress test --------------------------------------------------
849
850    #[test]
851    fn encode_decode_large_plane() {
852        let w = 320u32;
853        let h = 240u32;
854        let total = (w * h) as usize;
855        let mut alpha = vec![0u8; total];
856        let mut state: u64 = 42;
857        for byte in alpha.iter_mut() {
858            state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
859            *byte = (state >> 33) as u8;
860        }
861
862        let encoded = encode_alpha(&alpha, w, h).expect("encode should succeed");
863        // Encoded data = 1 header byte + total alpha bytes
864        assert_eq!(encoded.len(), 1 + total);
865
866        let decoded = decode_alpha(&encoded, w, h).expect("decode should succeed");
867        assert_eq!(decoded, alpha);
868    }
869
870    // -- Header byte exhaustive coverage -------------------------------------
871
872    #[test]
873    fn all_valid_header_bytes_parse() {
874        // Valid combinations: reserved=0, compression in {0,1},
875        // filter in {0,1,2,3}, pre_processing in {0,1,2,3}
876        for comp in 0..=1u8 {
877            for filt in 0..=3u8 {
878                for prep in 0..=3u8 {
879                    let byte = comp | (filt << 2) | (prep << 4);
880                    let hdr = AlphaHeader::parse(byte)
881                        .unwrap_or_else(|e| panic!("valid byte {byte:#04x} failed: {e}"));
882                    assert_eq!(hdr.to_byte(), byte);
883                }
884            }
885        }
886    }
887
888    #[test]
889    fn all_reserved_header_bytes_rejected() {
890        for reserved in 1..=3u8 {
891            let byte = reserved << 6;
892            assert!(
893                AlphaHeader::parse(byte).is_err(),
894                "reserved={reserved} should be rejected"
895            );
896        }
897    }
898}