1use crate::error::{CodecError, CodecResult};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum AlphaCompression {
24 NoCompression = 0,
26 WebPLossless = 1,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum AlphaFilter {
33 None = 0,
35 Horizontal = 1,
37 Vertical = 2,
39 Gradient = 3,
41}
42
43#[derive(Debug, Clone)]
45pub struct AlphaHeader {
46 pub compression: AlphaCompression,
48 pub filter: AlphaFilter,
50 pub pre_processing: u8,
52}
53
54impl AlphaCompression {
59 fn from_bits(bits: u8) -> CodecResult<Self> {
60 match bits & 0x03 {
61 0 => Ok(Self::NoCompression),
62 1 => Ok(Self::WebPLossless),
63 v => Err(CodecError::InvalidBitstream(format!(
64 "unknown alpha compression method: {v}"
65 ))),
66 }
67 }
68}
69
70impl AlphaFilter {
71 fn from_bits(bits: u8) -> CodecResult<Self> {
72 match bits & 0x03 {
73 0 => Ok(Self::None),
74 1 => Ok(Self::Horizontal),
75 2 => Ok(Self::Vertical),
76 3 => Ok(Self::Gradient),
77 _ => Err(CodecError::InvalidBitstream(
79 "unknown alpha filter method".to_string(),
80 )),
81 }
82 }
83}
84
85impl AlphaHeader {
86 pub fn parse(byte: u8) -> CodecResult<Self> {
88 let reserved = (byte >> 6) & 0x03;
89 if reserved != 0 {
90 return Err(CodecError::InvalidBitstream(format!(
91 "ALPH header reserved bits are non-zero: {reserved}"
92 )));
93 }
94 let compression = AlphaCompression::from_bits(byte & 0x03)?;
95 let filter = AlphaFilter::from_bits((byte >> 2) & 0x03)?;
96 let pre_processing = (byte >> 4) & 0x03;
97
98 Ok(Self {
99 compression,
100 filter,
101 pre_processing,
102 })
103 }
104
105 pub fn to_byte(&self) -> u8 {
107 let comp = self.compression as u8;
108 let filt = (self.filter as u8) << 2;
109 let prep = (self.pre_processing & 0x03) << 4;
110 comp | filt | prep
111 }
112}
113
114#[inline]
120fn gradient_predict(left: u8, top: u8, top_left: u8) -> u8 {
121 let val = left as i16 + top as i16 - top_left as i16;
122 val.clamp(0, 255) as u8
123}
124
125fn apply_filter(data: &mut [u8], width: u32, height: u32, filter: AlphaFilter) {
134 let w = width as usize;
135 let h = height as usize;
136 let total = w * h;
137 if total == 0 || data.len() < total {
138 return;
139 }
140
141 match filter {
142 AlphaFilter::None => { }
143
144 AlphaFilter::Horizontal => {
145 for y in 0..h {
146 let row_start = y * w;
147 for x in 1..w {
150 let idx = row_start + x;
151 let left = data[idx - 1];
152 data[idx] = data[idx].wrapping_add(left);
153 }
154 }
155 }
156
157 AlphaFilter::Vertical => {
158 for y in 1..h {
160 for x in 0..w {
161 let idx = y * w + x;
162 let top = data[idx - w];
163 data[idx] = data[idx].wrapping_add(top);
164 }
165 }
166 }
167
168 AlphaFilter::Gradient => {
169 for x in 1..w {
171 data[x] = data[x].wrapping_add(data[x - 1]);
172 }
173 for y in 1..h {
174 let row_start = y * w;
175 data[row_start] = data[row_start].wrapping_add(data[row_start - w]);
177
178 for x in 1..w {
179 let idx = row_start + x;
180 let left = data[idx - 1];
181 let top = data[idx - w];
182 let top_left = data[idx - w - 1];
183 let pred = gradient_predict(left, top, top_left);
184 data[idx] = data[idx].wrapping_add(pred);
185 }
186 }
187 }
188 }
189}
190
191fn apply_inverse_filter(data: &[u8], width: u32, height: u32, filter: AlphaFilter) -> Vec<u8> {
200 let w = width as usize;
201 let h = height as usize;
202 let total = w * h;
203 if total == 0 {
204 return Vec::new();
205 }
206
207 match filter {
208 AlphaFilter::None => data[..total].to_vec(),
209
210 AlphaFilter::Horizontal => {
211 let mut out = vec![0u8; total];
212 for y in 0..h {
213 let row_start = y * w;
214 out[row_start] = data[row_start]; for x in 1..w {
216 let idx = row_start + x;
217 let left = data[idx - 1];
218 out[idx] = data[idx].wrapping_sub(left);
219 }
220 }
221 out
222 }
223
224 AlphaFilter::Vertical => {
225 let mut out = vec![0u8; total];
226 out[..w].copy_from_slice(&data[..w]);
228 for y in 1..h {
229 for x in 0..w {
230 let idx = y * w + x;
231 let top = data[idx - w];
232 out[idx] = data[idx].wrapping_sub(top);
233 }
234 }
235 out
236 }
237
238 AlphaFilter::Gradient => {
239 let mut out = vec![0u8; total];
240 out[0] = data[0]; for x in 1..w {
244 out[x] = data[x].wrapping_sub(data[x - 1]);
245 }
246 for y in 1..h {
247 let row_start = y * w;
248 out[row_start] = data[row_start].wrapping_sub(data[row_start - w]);
250
251 for x in 1..w {
252 let idx = row_start + x;
253 let left = data[idx - 1];
254 let top = data[idx - w];
255 let top_left = data[idx - w - 1];
256 let pred = gradient_predict(left, top, top_left);
257 out[idx] = data[idx].wrapping_sub(pred);
258 }
259 }
260 out
261 }
262 }
263}
264
265fn score_filter(data: &[u8], width: u32, height: u32, filter: AlphaFilter) -> u64 {
272 let residuals = apply_inverse_filter(data, width, height, filter);
273 residuals
274 .iter()
275 .map(|&b| {
276 let v = b as i16;
278 let d = if v > 128 { 256 - v } else { v };
279 d as u64
280 })
281 .sum()
282}
283
284fn select_best_filter(data: &[u8], width: u32, height: u32) -> AlphaFilter {
286 let filters = [
287 AlphaFilter::None,
288 AlphaFilter::Horizontal,
289 AlphaFilter::Vertical,
290 AlphaFilter::Gradient,
291 ];
292
293 let mut best_filter = AlphaFilter::None;
294 let mut best_score = u64::MAX;
295
296 for &f in &filters {
297 let s = score_filter(data, width, height, f);
298 if s < best_score {
299 best_score = s;
300 best_filter = f;
301 }
302 }
303
304 best_filter
305}
306
307pub fn decode_alpha(data: &[u8], width: u32, height: u32) -> CodecResult<Vec<u8>> {
317 if data.is_empty() {
318 return Err(CodecError::InvalidBitstream(
319 "ALPH chunk is empty".to_string(),
320 ));
321 }
322
323 let total = (width as usize)
324 .checked_mul(height as usize)
325 .ok_or_else(|| {
326 CodecError::InvalidParameter(format!(
327 "alpha plane dimensions overflow: {width} x {height}"
328 ))
329 })?;
330
331 if total == 0 {
332 return Ok(Vec::new());
333 }
334
335 let header = AlphaHeader::parse(data[0])?;
336 let payload = &data[1..];
337
338 match header.compression {
339 AlphaCompression::NoCompression => {
340 if payload.len() < total {
341 return Err(CodecError::BufferTooSmall {
342 needed: total,
343 have: payload.len(),
344 });
345 }
346 let mut alpha = payload[..total].to_vec();
347 apply_filter(&mut alpha, width, height, header.filter);
348 Ok(alpha)
349 }
350 AlphaCompression::WebPLossless => Err(CodecError::UnsupportedFeature(
351 "VP8L-compressed alpha channel is not yet supported".to_string(),
352 )),
353 }
354}
355
356pub fn encode_alpha(alpha: &[u8], width: u32, height: u32) -> CodecResult<Vec<u8>> {
362 let total = (width as usize)
363 .checked_mul(height as usize)
364 .ok_or_else(|| {
365 CodecError::InvalidParameter(format!(
366 "alpha plane dimensions overflow: {width} x {height}"
367 ))
368 })?;
369
370 if alpha.len() < total {
371 return Err(CodecError::BufferTooSmall {
372 needed: total,
373 have: alpha.len(),
374 });
375 }
376
377 if total == 0 {
378 let hdr = AlphaHeader {
380 compression: AlphaCompression::NoCompression,
381 filter: AlphaFilter::None,
382 pre_processing: 0,
383 };
384 return Ok(vec![hdr.to_byte()]);
385 }
386
387 let input = &alpha[..total];
388
389 let best_filter = select_best_filter(input, width, height);
391
392 let header = AlphaHeader {
393 compression: AlphaCompression::NoCompression,
394 filter: best_filter,
395 pre_processing: 0,
396 };
397
398 let residuals = apply_inverse_filter(input, width, height, best_filter);
399
400 let mut out = Vec::with_capacity(1 + residuals.len());
401 out.push(header.to_byte());
402 out.extend_from_slice(&residuals);
403 Ok(out)
404}
405
406#[cfg(test)]
411mod tests {
412 use super::*;
413
414 #[test]
417 fn header_roundtrip_no_compression_no_filter() {
418 let hdr = AlphaHeader {
419 compression: AlphaCompression::NoCompression,
420 filter: AlphaFilter::None,
421 pre_processing: 0,
422 };
423 let byte = hdr.to_byte();
424 assert_eq!(byte, 0x00);
425 let parsed = AlphaHeader::parse(byte).expect("parse should succeed");
426 assert_eq!(parsed.compression, AlphaCompression::NoCompression);
427 assert_eq!(parsed.filter, AlphaFilter::None);
428 assert_eq!(parsed.pre_processing, 0);
429 }
430
431 #[test]
432 fn header_roundtrip_all_fields() {
433 let hdr = AlphaHeader {
436 compression: AlphaCompression::NoCompression,
437 filter: AlphaFilter::Gradient,
438 pre_processing: 1,
439 };
440 let byte = hdr.to_byte();
441 assert_eq!(byte, 0x1C);
442 let parsed = AlphaHeader::parse(byte).expect("parse should succeed");
443 assert_eq!(parsed.compression, AlphaCompression::NoCompression);
444 assert_eq!(parsed.filter, AlphaFilter::Gradient);
445 assert_eq!(parsed.pre_processing, 1);
446 }
447
448 #[test]
449 fn header_roundtrip_webp_lossless_horizontal() {
450 let hdr = AlphaHeader {
453 compression: AlphaCompression::WebPLossless,
454 filter: AlphaFilter::Horizontal,
455 pre_processing: 0,
456 };
457 let byte = hdr.to_byte();
458 assert_eq!(byte, 0x05);
459 let parsed = AlphaHeader::parse(byte).expect("parse should succeed");
460 assert_eq!(parsed.compression, AlphaCompression::WebPLossless);
461 assert_eq!(parsed.filter, AlphaFilter::Horizontal);
462 }
463
464 #[test]
465 fn header_reserved_bits_rejected() {
466 let result = AlphaHeader::parse(0x40);
468 assert!(result.is_err());
469 }
470
471 #[test]
474 fn filter_none_roundtrip() {
475 let original: Vec<u8> = (0..12).collect();
476 let w = 4u32;
477 let h = 3u32;
478
479 let residuals = apply_inverse_filter(&original, w, h, AlphaFilter::None);
480 assert_eq!(residuals, original);
481
482 let mut reconstructed = residuals;
483 apply_filter(&mut reconstructed, w, h, AlphaFilter::None);
484 assert_eq!(reconstructed, original);
485 }
486
487 #[test]
488 fn filter_horizontal_roundtrip() {
489 let original: Vec<u8> = vec![10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120];
490 let w = 4u32;
491 let h = 3u32;
492
493 let residuals = apply_inverse_filter(&original, w, h, AlphaFilter::Horizontal);
494
495 assert_eq!(residuals[0], 10);
501 assert_eq!(residuals[1], 10);
502 assert_eq!(residuals[2], 10);
503 assert_eq!(residuals[3], 10);
504
505 let mut reconstructed = residuals;
506 apply_filter(&mut reconstructed, w, h, AlphaFilter::Horizontal);
507 assert_eq!(reconstructed, original);
508 }
509
510 #[test]
511 fn filter_vertical_roundtrip() {
512 let original: Vec<u8> = vec![10, 20, 30, 40, 15, 25, 35, 45, 20, 30, 40, 50];
513 let w = 4u32;
514 let h = 3u32;
515
516 let residuals = apply_inverse_filter(&original, w, h, AlphaFilter::Vertical);
517
518 assert_eq!(&residuals[0..4], &[10, 20, 30, 40]);
520 assert_eq!(&residuals[4..8], &[5, 5, 5, 5]);
523
524 let mut reconstructed = residuals;
525 apply_filter(&mut reconstructed, w, h, AlphaFilter::Vertical);
526 assert_eq!(reconstructed, original);
527 }
528
529 #[test]
530 fn filter_gradient_roundtrip() {
531 let original: Vec<u8> = vec![100, 110, 120, 130, 105, 115, 125, 135, 110, 120, 130, 140];
532 let w = 4u32;
533 let h = 3u32;
534
535 let residuals = apply_inverse_filter(&original, w, h, AlphaFilter::Gradient);
536 let mut reconstructed = residuals;
537 apply_filter(&mut reconstructed, w, h, AlphaFilter::Gradient);
538 assert_eq!(reconstructed, original);
539 }
540
541 #[test]
542 fn filter_gradient_known_vector() {
543 let original: Vec<u8> = vec![100, 150, 120, 170];
547 let w = 2u32;
548 let h = 2u32;
549
550 let residuals = apply_inverse_filter(&original, w, h, AlphaFilter::Gradient);
551
552 assert_eq!(residuals[0], 100);
554 assert_eq!(residuals[1], 50);
556 assert_eq!(residuals[2], 20);
558 assert_eq!(residuals[3], 0);
560
561 let mut reconstructed = residuals;
562 apply_filter(&mut reconstructed, w, h, AlphaFilter::Gradient);
563 assert_eq!(reconstructed, original);
564 }
565
566 #[test]
567 fn gradient_predict_clamp_high() {
568 assert_eq!(gradient_predict(200, 200, 0), 255);
570 }
571
572 #[test]
573 fn gradient_predict_clamp_low() {
574 assert_eq!(gradient_predict(0, 0, 200), 0);
576 }
577
578 #[test]
579 fn gradient_predict_normal() {
580 assert_eq!(gradient_predict(100, 80, 60), 120);
581 }
582
583 #[test]
586 fn encode_decode_roundtrip_uniform() {
587 let w = 8u32;
588 let h = 6u32;
589 let alpha = vec![128u8; (w * h) as usize];
590
591 let encoded = encode_alpha(&alpha, w, h).expect("encode should succeed");
592 let decoded = decode_alpha(&encoded, w, h).expect("decode should succeed");
593 assert_eq!(decoded, alpha);
594 }
595
596 #[test]
597 fn encode_decode_roundtrip_gradient_data() {
598 let w = 16u32;
599 let h = 8u32;
600 let mut alpha = vec![0u8; (w * h) as usize];
601 for y in 0..h as usize {
602 for x in 0..w as usize {
603 alpha[y * w as usize + x] = ((x * 16 + y * 8) & 0xFF) as u8;
604 }
605 }
606
607 let encoded = encode_alpha(&alpha, w, h).expect("encode should succeed");
608 let decoded = decode_alpha(&encoded, w, h).expect("decode should succeed");
609 assert_eq!(decoded, alpha);
610 }
611
612 #[test]
613 fn encode_decode_roundtrip_random_like() {
614 let w = 10u32;
616 let h = 10u32;
617 let mut alpha = vec![0u8; (w * h) as usize];
618 let mut state: u32 = 0xDEAD_BEEF;
619 for byte in alpha.iter_mut() {
620 state = state.wrapping_mul(1664525).wrapping_add(1013904223);
621 *byte = (state >> 16) as u8;
622 }
623
624 let encoded = encode_alpha(&alpha, w, h).expect("encode should succeed");
625 let decoded = decode_alpha(&encoded, w, h).expect("decode should succeed");
626 assert_eq!(decoded, alpha);
627 }
628
629 #[test]
630 fn encode_decode_roundtrip_single_pixel() {
631 let alpha = vec![42u8];
632 let encoded = encode_alpha(&alpha, 1, 1).expect("encode should succeed");
633 let decoded = decode_alpha(&encoded, 1, 1).expect("decode should succeed");
634 assert_eq!(decoded, alpha);
635 }
636
637 #[test]
638 fn encode_decode_roundtrip_single_row() {
639 let alpha: Vec<u8> = (0..=255).collect();
640 let encoded = encode_alpha(&alpha, 256, 1).expect("encode should succeed");
641 let decoded = decode_alpha(&encoded, 256, 1).expect("decode should succeed");
642 assert_eq!(decoded, alpha);
643 }
644
645 #[test]
646 fn encode_decode_roundtrip_single_column() {
647 let alpha: Vec<u8> = (0..128).collect();
648 let encoded = encode_alpha(&alpha, 1, 128).expect("encode should succeed");
649 let decoded = decode_alpha(&encoded, 1, 128).expect("decode should succeed");
650 assert_eq!(decoded, alpha);
651 }
652
653 #[test]
656 fn decode_empty_chunk_is_error() {
657 let result = decode_alpha(&[], 4, 4);
658 assert!(result.is_err());
659 }
660
661 #[test]
662 fn decode_truncated_payload_is_error() {
663 let data = vec![0x00, 1, 2, 3];
665 let result = decode_alpha(&data, 4, 4);
666 assert!(result.is_err());
667 }
668
669 #[test]
670 fn decode_vp8l_alpha_is_unsupported() {
671 let data = vec![0x01, 0, 0, 0, 0];
673 let result = decode_alpha(&data, 2, 2);
674 assert!(result.is_err());
675 let err_msg = format!("{}", result.expect_err("should be error"));
676 assert!(err_msg.contains("not yet supported"));
677 }
678
679 #[test]
680 fn encode_too_short_input_is_error() {
681 let alpha = vec![0u8; 3]; let result = encode_alpha(&alpha, 2, 2);
683 assert!(result.is_err());
684 }
685
686 #[test]
687 fn encode_decode_zero_dimensions() {
688 let alpha: Vec<u8> = Vec::new();
689 let encoded = encode_alpha(&alpha, 0, 0).expect("encode 0x0 should succeed");
690 let decoded = decode_alpha(&encoded, 0, 0).expect("decode 0x0 should succeed");
691 assert!(decoded.is_empty());
692 }
693
694 #[test]
695 fn overflow_dimensions_rejected() {
696 let result = encode_alpha(&[0], u32::MAX, u32::MAX);
697 assert!(result.is_err());
698 }
699
700 #[test]
703 fn known_vector_no_filter_no_compression() {
704 let alpha_raw = vec![255, 128, 64, 0, 200, 100, 50, 25];
706 let w = 4u32;
707 let h = 2u32;
708
709 let mut chunk = vec![0x00u8]; chunk.extend_from_slice(&alpha_raw);
711
712 let decoded = decode_alpha(&chunk, w, h).expect("decode should succeed");
713 assert_eq!(decoded, alpha_raw);
714 }
715
716 #[test]
717 fn known_vector_horizontal_filter() {
718 let expected = vec![10u8, 20, 30, 40, 50, 60, 70, 80];
723 let residuals = vec![10u8, 10, 10, 10, 50, 10, 10, 10];
724
725 let mut chunk = vec![0x04u8];
727 chunk.extend_from_slice(&residuals);
728
729 let decoded = decode_alpha(&chunk, 4, 2).expect("decode should succeed");
730 assert_eq!(decoded, expected);
731 }
732
733 #[test]
734 fn known_vector_vertical_filter() {
735 let expected = vec![10u8, 20, 30, 15, 25, 35, 20, 30, 40];
744 let residuals = vec![10u8, 20, 30, 5, 5, 5, 5, 5, 5];
745
746 let mut chunk = vec![0x08u8];
748 chunk.extend_from_slice(&residuals);
749
750 let decoded = decode_alpha(&chunk, 3, 3).expect("decode should succeed");
751 assert_eq!(decoded, expected);
752 }
753
754 #[test]
755 fn known_vector_gradient_filter() {
756 let expected = vec![100u8, 150, 120, 170];
760 let residuals = vec![100u8, 50, 20, 0];
761
762 let mut chunk = vec![0x0Cu8];
764 chunk.extend_from_slice(&residuals);
765
766 let decoded = decode_alpha(&chunk, 2, 2).expect("decode should succeed");
767 assert_eq!(decoded, expected);
768 }
769
770 #[test]
773 fn select_best_filter_for_uniform_data() {
774 let data = vec![128u8; 64];
779 let best = select_best_filter(&data, 8, 8);
780 let best_score = score_filter(&data, 8, 8, best);
781 let none_score = score_filter(&data, 8, 8, AlphaFilter::None);
782 assert!(best_score <= none_score);
783 }
784
785 #[test]
786 fn select_best_filter_for_horizontal_ramp() {
787 let mut data = vec![0u8; 64];
791 for y in 0..8usize {
792 for x in 0..8usize {
793 data[y * 8 + x] = (x * 30) as u8;
794 }
795 }
796 let best = select_best_filter(&data, 8, 8);
797 let best_score = score_filter(&data, 8, 8, best);
798 let horiz_score = score_filter(&data, 8, 8, AlphaFilter::Horizontal);
799 assert!(best_score <= horiz_score);
801 }
802
803 #[test]
804 fn select_best_filter_for_vertical_ramp() {
805 let mut data = vec![0u8; 64];
808 for y in 0..8usize {
809 for x in 0..8usize {
810 data[y * 8 + x] = (y * 30) as u8;
811 }
812 }
813 let best = select_best_filter(&data, 8, 8);
814 let best_score = score_filter(&data, 8, 8, best);
815 let vert_score = score_filter(&data, 8, 8, AlphaFilter::Vertical);
816 assert!(best_score <= vert_score);
818 }
819
820 #[test]
823 fn filter_horizontal_wrapping() {
824 let original = vec![250u8, 10];
827 let residuals = apply_inverse_filter(&original, 2, 1, AlphaFilter::Horizontal);
828 assert_eq!(residuals[0], 250);
829 assert_eq!(residuals[1], 10u8.wrapping_sub(250));
830
831 let mut reconstructed = residuals;
832 apply_filter(&mut reconstructed, 2, 1, AlphaFilter::Horizontal);
833 assert_eq!(reconstructed, original);
834 }
835
836 #[test]
837 fn filter_vertical_wrapping() {
838 let original = vec![5u8, 250]; let residuals = apply_inverse_filter(&original, 1, 2, AlphaFilter::Vertical);
840 assert_eq!(residuals[0], 5);
841 assert_eq!(residuals[1], 250u8.wrapping_sub(5));
842
843 let mut reconstructed = residuals;
844 apply_filter(&mut reconstructed, 1, 2, AlphaFilter::Vertical);
845 assert_eq!(reconstructed, original);
846 }
847
848 #[test]
851 fn encode_decode_large_plane() {
852 let w = 320u32;
853 let h = 240u32;
854 let total = (w * h) as usize;
855 let mut alpha = vec![0u8; total];
856 let mut state: u64 = 42;
857 for byte in alpha.iter_mut() {
858 state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
859 *byte = (state >> 33) as u8;
860 }
861
862 let encoded = encode_alpha(&alpha, w, h).expect("encode should succeed");
863 assert_eq!(encoded.len(), 1 + total);
865
866 let decoded = decode_alpha(&encoded, w, h).expect("decode should succeed");
867 assert_eq!(decoded, alpha);
868 }
869
870 #[test]
873 fn all_valid_header_bytes_parse() {
874 for comp in 0..=1u8 {
877 for filt in 0..=3u8 {
878 for prep in 0..=3u8 {
879 let byte = comp | (filt << 2) | (prep << 4);
880 let hdr = AlphaHeader::parse(byte)
881 .unwrap_or_else(|e| panic!("valid byte {byte:#04x} failed: {e}"));
882 assert_eq!(hdr.to_byte(), byte);
883 }
884 }
885 }
886 }
887
888 #[test]
889 fn all_reserved_header_bytes_rejected() {
890 for reserved in 1..=3u8 {
891 let byte = reserved << 6;
892 assert!(
893 AlphaHeader::parse(byte).is_err(),
894 "reserved={reserved} should be rejected"
895 );
896 }
897 }
898}