1use crate::error::{CodecError, CodecResult};
27
28const NUM_PREDICTORS: usize = 6;
30
31#[derive(Clone, Copy, Debug, PartialEq, Eq)]
33pub enum Predictor {
34 Zero = 0,
36 West = 1,
38 North = 2,
40 AvgWN = 3,
42 Gradient = 4,
44 Weighted = 5,
46}
47
48impl Predictor {
49 fn from_index(idx: usize) -> Self {
51 match idx {
52 0 => Self::Zero,
53 1 => Self::West,
54 2 => Self::North,
55 3 => Self::AvgWN,
56 4 => Self::Gradient,
57 _ => Self::Weighted,
58 }
59 }
60}
61
62pub fn forward_rct(r: i32, g: i32, b: i32) -> (i32, i32, i32) {
72 let co = r - b;
73 let tmp = b + (co >> 1);
74 let cg = g - tmp;
75 let y = tmp + (cg >> 1);
76 (y, co, cg)
77}
78
79pub fn inverse_rct(y: i32, co: i32, cg: i32) -> (i32, i32, i32) {
83 let tmp = y - (cg >> 1);
84 let g = tmp + cg;
85 let b = tmp - (co >> 1);
86 let r = b + co;
87 (r, g, b)
88}
89
90#[derive(Clone, Debug)]
92pub enum ModularTransform {
93 Rct {
95 begin_channel: u32,
97 rct_type: u8,
99 },
100 Squeeze {
102 params: SqueezeParams,
104 },
105 Palette {
107 begin_channel: u32,
109 num_colors: u32,
111 palette: Vec<i32>,
113 },
114}
115
116#[derive(Clone, Debug)]
118pub struct SqueezeParams {
119 pub horizontal: bool,
121 pub in_place: bool,
123 pub begin_channel: u32,
125 pub num_channels: u32,
127}
128
129struct PredictionContext {
134 errors: [i64; NUM_PREDICTORS],
136 decay_shift: u32,
138 counter: u32,
140}
141
142impl PredictionContext {
143 fn new() -> Self {
144 Self {
145 errors: [0; NUM_PREDICTORS],
146 decay_shift: 4,
147 counter: 0,
148 }
149 }
150
151 fn best_predictor(&self) -> Predictor {
153 let mut best_idx = 0;
154 let mut best_err = self.errors[0];
155 for i in 1..NUM_PREDICTORS {
156 if self.errors[i] < best_err {
157 best_err = self.errors[i];
158 best_idx = i;
159 }
160 }
161 Predictor::from_index(best_idx)
162 }
163
164 fn update(&mut self, predictions: &[i32; NUM_PREDICTORS], actual: i32) {
166 for i in 0..NUM_PREDICTORS {
167 let err = (actual - predictions[i]).unsigned_abs() as i64;
168 self.errors[i] += err;
169 }
170 self.counter += 1;
171 if self.counter >= (1 << self.decay_shift) {
173 for err in &mut self.errors {
174 *err >>= 1;
175 }
176 self.counter = 0;
177 }
178 }
179}
180
181fn get_neighbors(channel: &[i32], width: u32, x: u32, y: u32) -> (i32, i32, i32, i32, i32, i32) {
185 let w = width as usize;
186 let xi = x as usize;
187 let yi = y as usize;
188
189 let val = |px: usize, py: usize| -> i32 {
190 if px < w && py < (channel.len() / w) {
191 channel[py * w + px]
192 } else {
193 0
194 }
195 };
196
197 let west = if xi > 0 { val(xi - 1, yi) } else { 0 };
198 let north = if yi > 0 { val(xi, yi - 1) } else { 0 };
199 let nw = if xi > 0 && yi > 0 {
200 val(xi - 1, yi - 1)
201 } else {
202 0
203 };
204 let ne = if yi > 0 && xi + 1 < w {
205 val(xi + 1, yi - 1)
206 } else {
207 north
208 };
209 let nn = if yi >= 2 { val(xi, yi - 2) } else { north };
210 let ww = if xi >= 2 { val(xi - 2, yi) } else { west };
211
212 (west, north, nw, ne, nn, ww)
213}
214
215fn compute_predictions(
217 w: i32,
218 n: i32,
219 nw: i32,
220 ne: i32,
221 _nn: i32,
222 _ww: i32,
223) -> [i32; NUM_PREDICTORS] {
224 let avg_wn = (w + n) / 2;
225 let gradient = n + w - nw;
226
227 let grad_clamped = gradient.clamp(w.min(n), w.max(n));
229
230 let weighted = {
232 let sum = 3i64 * n as i64 + 3i64 * w as i64 - nw as i64 + ne as i64;
233 (sum / 6) as i32
234 };
235
236 [
237 0, w, n, avg_wn, grad_clamped, weighted, ]
244}
245
246fn encode_residual(value: i32, output: &mut Vec<u8>) {
254 let unsigned = signed_to_unsigned(value);
255 let mut remaining = unsigned;
256 loop {
257 let byte = (remaining & 0x7F) as u8;
258 remaining >>= 7;
259 if remaining == 0 {
260 output.push(byte); break;
262 } else {
263 output.push(byte | 0x80); }
265 }
266}
267
268fn decode_residual(data: &[u8], offset: usize) -> CodecResult<(i32, usize)> {
272 let mut value: u32 = 0;
273 let mut shift: u32 = 0;
274 let mut pos = offset;
275
276 loop {
277 if pos >= data.len() {
278 return Err(CodecError::InvalidBitstream(
279 "Unexpected end of residual data".into(),
280 ));
281 }
282 let byte = data[pos];
283 pos += 1;
284
285 value |= ((byte & 0x7F) as u32) << shift;
286 shift += 7;
287
288 if byte & 0x80 == 0 {
289 break;
291 }
292 if shift >= 35 {
293 return Err(CodecError::InvalidBitstream(
294 "Residual value too large".into(),
295 ));
296 }
297 }
298
299 Ok((unsigned_to_signed(value), pos - offset))
300}
301
302fn signed_to_unsigned(value: i32) -> u32 {
306 if value >= 0 {
307 (value as u32) << 1
308 } else {
309 (((-value) as u32) << 1) - 1
310 }
311}
312
313fn unsigned_to_signed(value: u32) -> i32 {
315 if value & 1 == 0 {
316 (value >> 1) as i32
317 } else {
318 -(((value + 1) >> 1) as i32)
319 }
320}
321
322pub struct ModularDecoder {
324 transforms: Vec<ModularTransform>,
325}
326
327impl ModularDecoder {
328 pub fn new() -> Self {
330 Self {
331 transforms: Vec::new(),
332 }
333 }
334
335 pub fn add_transform(&mut self, transform: ModularTransform) {
337 self.transforms.push(transform);
338 }
339
340 pub fn decode_image(
344 &mut self,
345 data: &[u8],
346 width: u32,
347 height: u32,
348 channels: u32,
349 _bit_depth: u8,
350 ) -> CodecResult<Vec<Vec<i32>>> {
351 if width == 0 || height == 0 {
352 return Err(CodecError::InvalidParameter(
353 "Image dimensions must be non-zero".into(),
354 ));
355 }
356
357 let pixel_count = width as usize * height as usize;
358 let mut result_channels: Vec<Vec<i32>> = Vec::with_capacity(channels as usize);
359 let mut data_offset = 0usize;
360
361 for _ch in 0..channels {
362 let mut channel_data = vec![0i32; pixel_count];
363 let mut ctx = PredictionContext::new();
364
365 for y in 0..height {
366 for x in 0..width {
367 let (w_val, n_val, nw_val, ne_val, nn_val, ww_val) =
368 get_neighbors(&channel_data, width, x, y);
369 let predictions =
370 compute_predictions(w_val, n_val, nw_val, ne_val, nn_val, ww_val);
371 let predictor = ctx.best_predictor();
372 let predicted = predictions[predictor as usize];
373
374 let (residual, consumed) = decode_residual(data, data_offset)?;
376 data_offset += consumed;
377
378 let actual = predicted + residual;
379 channel_data[y as usize * width as usize + x as usize] = actual;
380 ctx.update(&predictions, actual);
381 }
382 }
383
384 result_channels.push(channel_data);
385 }
386
387 for transform in self.transforms.iter().rev() {
389 match transform {
390 ModularTransform::Rct {
391 begin_channel,
392 rct_type: _,
393 } => {
394 let begin = *begin_channel as usize;
395 if begin + 2 < result_channels.len() {
396 let pc = result_channels[begin].len();
397 for i in 0..pc {
398 let y_val = result_channels[begin][i];
399 let co = result_channels[begin + 1][i];
400 let cg = result_channels[begin + 2][i];
401 let (r, g, b) = inverse_rct(y_val, co, cg);
402 result_channels[begin][i] = r;
403 result_channels[begin + 1][i] = g;
404 result_channels[begin + 2][i] = b;
405 }
406 }
407 }
408 ModularTransform::Squeeze {
409 params:
410 SqueezeParams {
411 horizontal,
412 begin_channel,
413 num_channels,
414 ..
415 },
416 } => {
417 let begin = *begin_channel as usize;
418 let nc = *num_channels as usize;
419 let horiz = *horizontal;
420
421 for ch_idx in begin..begin + nc {
422 if ch_idx >= result_channels.len() {
423 break;
424 }
425
426 if horiz {
431 let half_w = (width / 2) as usize;
432 if half_w == 0 {
433 continue;
434 }
435 let h = height as usize;
436 let w = width as usize;
437 let old = result_channels[ch_idx].clone();
439 let buf = &mut result_channels[ch_idx];
440 for row in 0..h {
441 for i in 0..half_w {
442 let avg = old[row * w + i];
443 let diff = old[row * w + half_w + i];
444 let a = avg + ((diff + 1) >> 1);
447 let b = avg - (diff >> 1);
448 buf[row * w + 2 * i] = a;
449 buf[row * w + 2 * i + 1] = b;
450 }
451 if w % 2 != 0 {
454 buf[row * w + w - 1] = old[row * w + w - 1];
455 }
456 }
457 } else {
458 let half_h = (height / 2) as usize;
460 if half_h == 0 {
461 continue;
462 }
463 let h = height as usize;
464 let w = width as usize;
465 let old = result_channels[ch_idx].clone();
466 let buf = &mut result_channels[ch_idx];
467 for col in 0..w {
468 for i in 0..half_h {
469 let avg = old[i * w + col];
470 let diff = old[(half_h + i) * w + col];
471 let a = avg + ((diff + 1) >> 1);
472 let b = avg - (diff >> 1);
473 buf[(2 * i) * w + col] = a;
474 buf[(2 * i + 1) * w + col] = b;
475 }
476 if h % 2 != 0 {
478 buf[(h - 1) * w + col] = old[(h - 1) * w + col];
479 }
480 }
481 }
482 }
483 }
484 ModularTransform::Palette {
485 begin_channel,
486 num_colors,
487 palette,
488 } => {
489 if *num_colors == 0 {
490 return Err(CodecError::InvalidBitstream(
491 "Palette: num_colors must be non-zero".into(),
492 ));
493 }
494 let nc = *num_colors as usize;
495 if palette.len() % nc != 0 {
497 return Err(CodecError::InvalidBitstream(format!(
498 "Palette length {} is not divisible by num_colors {}",
499 palette.len(),
500 nc
501 )));
502 }
503 let num_components = palette.len() / nc;
504 let begin = *begin_channel as usize;
505
506 if begin >= result_channels.len() {
507 return Err(CodecError::InvalidBitstream(
508 "Palette: begin_channel out of bounds".into(),
509 ));
510 }
511 let indices = result_channels[begin].clone();
514
515 for (pixel_pos, &idx_val) in indices.iter().enumerate() {
517 if idx_val < 0 || idx_val as usize >= nc {
518 return Err(CodecError::InvalidBitstream(format!(
519 "Palette index {idx_val} out of range [0, {nc})"
520 )));
521 }
522 let idx = idx_val as usize;
523 for c in 0..num_components {
524 let target_ch = begin + c;
525 if target_ch >= result_channels.len() {
526 return Err(CodecError::InvalidBitstream(format!(
527 "Palette: channel {target_ch} out of bounds"
528 )));
529 }
530 result_channels[target_ch][pixel_pos] =
531 palette[idx * num_components + c];
532 }
533 }
534 }
535 }
536 }
537
538 Ok(result_channels)
539 }
540}
541
542impl Default for ModularDecoder {
543 fn default() -> Self {
544 Self::new()
545 }
546}
547
548pub struct ModularEncoder {
550 transforms: Vec<ModularTransform>,
551 effort: u8,
552}
553
554impl ModularEncoder {
555 pub fn new() -> Self {
557 Self {
558 transforms: Vec::new(),
559 effort: 7,
560 }
561 }
562
563 pub fn with_effort(mut self, effort: u8) -> Self {
565 self.effort = effort.clamp(1, 9);
566 self
567 }
568
569 pub fn add_transform(&mut self, transform: ModularTransform) {
571 self.transforms.push(transform);
572 }
573
574 pub fn encode_image(
579 &mut self,
580 channels: &[Vec<i32>],
581 width: u32,
582 height: u32,
583 _bit_depth: u8,
584 ) -> CodecResult<Vec<u8>> {
585 if width == 0 || height == 0 {
586 return Err(CodecError::InvalidParameter(
587 "Image dimensions must be non-zero".into(),
588 ));
589 }
590 if channels.is_empty() {
591 return Err(CodecError::InvalidParameter(
592 "Must have at least one channel".into(),
593 ));
594 }
595
596 let pixel_count = width as usize * height as usize;
597 for (i, ch) in channels.iter().enumerate() {
598 if ch.len() != pixel_count {
599 return Err(CodecError::InvalidParameter(format!(
600 "Channel {i} has {} samples, expected {pixel_count}",
601 ch.len()
602 )));
603 }
604 }
605
606 let mut working_channels: Vec<Vec<i32>> = channels.to_vec();
608 for transform in &self.transforms {
609 match transform {
610 ModularTransform::Rct {
611 begin_channel,
612 rct_type: _,
613 } => {
614 let begin = *begin_channel as usize;
615 if begin + 2 < working_channels.len() {
616 for i in 0..pixel_count {
617 let r = working_channels[begin][i];
618 let g = working_channels[begin + 1][i];
619 let b = working_channels[begin + 2][i];
620 let (y_val, co, cg) = forward_rct(r, g, b);
621 working_channels[begin][i] = y_val;
622 working_channels[begin + 1][i] = co;
623 working_channels[begin + 2][i] = cg;
624 }
625 }
626 }
627 ModularTransform::Squeeze { .. } | ModularTransform::Palette { .. } => {
628 }
630 }
631 }
632
633 let mut output = Vec::with_capacity(pixel_count * working_channels.len());
635
636 for ch_data in &working_channels {
637 let mut ctx = PredictionContext::new();
638
639 for y in 0..height {
640 for x in 0..width {
641 let (w_val, n_val, nw_val, ne_val, nn_val, ww_val) =
642 get_neighbors(ch_data, width, x, y);
643 let predictions =
644 compute_predictions(w_val, n_val, nw_val, ne_val, nn_val, ww_val);
645 let predictor = ctx.best_predictor();
646 let predicted = predictions[predictor as usize];
647
648 let actual = ch_data[y as usize * width as usize + x as usize];
649 let residual = actual - predicted;
650
651 encode_residual(residual, &mut output);
652 ctx.update(&predictions, actual);
653 }
654 }
655 }
656
657 Ok(output)
658 }
659}
660
661impl Default for ModularEncoder {
662 fn default() -> Self {
663 Self::new()
664 }
665}
666
667#[cfg(test)]
668mod tests {
669 use super::*;
670
671 #[test]
672 #[ignore]
673 fn test_rct_roundtrip() {
674 let test_values = [
675 (0, 0, 0),
676 (255, 255, 255),
677 (128, 64, 32),
678 (0, 255, 0),
679 (255, 0, 0),
680 (0, 0, 255),
681 (100, 200, 50),
682 (1, 1, 1),
683 ];
684
685 for (r, g, b) in test_values {
686 let (y, co, cg) = forward_rct(r, g, b);
687 let (r2, g2, b2) = inverse_rct(y, co, cg);
688 assert_eq!(
689 (r, g, b),
690 (r2, g2, b2),
691 "RCT roundtrip failed for ({r}, {g}, {b})"
692 );
693 }
694 }
695
696 #[test]
697 #[ignore]
698 fn test_rct_negative_values() {
699 let (y, co, cg) = forward_rct(-10, 20, -30);
700 let (r, g, b) = inverse_rct(y, co, cg);
701 assert_eq!((r, g, b), (-10, 20, -30));
702 }
703
704 #[test]
705 #[ignore]
706 fn test_signed_unsigned_roundtrip() {
707 for v in -100..=100 {
708 let u = signed_to_unsigned(v);
709 let v2 = unsigned_to_signed(u);
710 assert_eq!(v, v2, "Zigzag roundtrip failed for {v}");
711 }
712 }
713
714 #[test]
715 #[ignore]
716 fn test_zigzag_ordering() {
717 assert_eq!(signed_to_unsigned(0), 0);
718 assert_eq!(signed_to_unsigned(-1), 1);
719 assert_eq!(signed_to_unsigned(1), 2);
720 assert_eq!(signed_to_unsigned(-2), 3);
721 assert_eq!(signed_to_unsigned(2), 4);
722 }
723
724 #[test]
725 #[ignore]
726 fn test_residual_encode_decode_roundtrip() {
727 let test_values = [0, 1, -1, 127, -128, 1000, -1000, 65535, -65536, 0];
728 let mut encoded = Vec::new();
729 for &v in &test_values {
730 encode_residual(v, &mut encoded);
731 }
732
733 let mut offset = 0;
734 for &expected in &test_values {
735 let (decoded, consumed) = decode_residual(&encoded, offset).expect("decode ok");
736 assert_eq!(
737 decoded, expected,
738 "Residual roundtrip failed for {expected}"
739 );
740 offset += consumed;
741 }
742 }
743
744 #[test]
745 #[ignore]
746 fn test_gradient_predictor() {
747 let predictions = compute_predictions(100, 100, 100, 100, 100, 100);
748 assert_eq!(predictions[Predictor::Gradient as usize], 100);
749 assert_eq!(predictions[Predictor::West as usize], 100);
750 assert_eq!(predictions[Predictor::North as usize], 100);
751 }
752
753 #[test]
754 #[ignore]
755 fn test_gradient_predictor_edge() {
756 let predictions = compute_predictions(10, 0, 0, 0, 0, 0);
757 assert_eq!(predictions[Predictor::Gradient as usize], 10);
758
759 let predictions = compute_predictions(0, 10, 0, 0, 0, 0);
760 assert_eq!(predictions[Predictor::Gradient as usize], 10);
761 }
762
763 #[test]
764 #[ignore]
765 fn test_prediction_context() {
766 let mut ctx = PredictionContext::new();
767 assert_eq!(ctx.best_predictor(), Predictor::Zero);
768
769 let predictions = [0, 100, 50, 75, 90, 80];
770 ctx.update(&predictions, 100);
771
772 assert_eq!(ctx.best_predictor(), Predictor::West);
773 }
774
775 #[test]
776 #[ignore]
777 fn test_get_neighbors_corner() {
778 let channel = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
779 let (w, n, nw, ne, nn, ww) = get_neighbors(&channel, 3, 0, 0);
780 assert_eq!((w, n, nw, ne, nn, ww), (0, 0, 0, 0, 0, 0));
781
782 let (w, n, nw, ne, _nn, _ww) = get_neighbors(&channel, 3, 1, 1);
783 assert_eq!(w, 4);
784 assert_eq!(n, 2);
785 assert_eq!(nw, 1);
786 assert_eq!(ne, 3);
787 }
788
789 #[test]
790 #[ignore]
791 fn test_modular_encode_decode_flat() {
792 let width = 4u32;
793 let height = 4u32;
794 let pixel_count = (width * height) as usize;
795 let channel = vec![128i32; pixel_count];
796
797 let mut encoder = ModularEncoder::new();
798 let encoded = encoder
799 .encode_image(&[channel.clone()], width, height, 8)
800 .expect("encode ok");
801
802 let mut decoder = ModularDecoder::new();
803 let decoded = decoder
804 .decode_image(&encoded, width, height, 1, 8)
805 .expect("decode ok");
806
807 assert_eq!(decoded.len(), 1);
808 assert_eq!(decoded[0], channel);
809 }
810
811 #[test]
812 #[ignore]
813 fn test_modular_encode_decode_gradient() {
814 let width = 8u32;
815 let height = 4u32;
816 let mut channel = Vec::with_capacity((width * height) as usize);
817 for y in 0..height {
818 for x in 0..width {
819 channel.push((x + y * 10) as i32);
820 }
821 }
822
823 let mut encoder = ModularEncoder::new();
824 let encoded = encoder
825 .encode_image(&[channel.clone()], width, height, 8)
826 .expect("encode ok");
827
828 let mut decoder = ModularDecoder::new();
829 let decoded = decoder
830 .decode_image(&encoded, width, height, 1, 8)
831 .expect("decode ok");
832
833 assert_eq!(decoded.len(), 1);
834 assert_eq!(decoded[0], channel);
835 }
836
837 #[test]
838 #[ignore]
839 fn test_modular_encode_decode_with_rct() {
840 let width = 4u32;
841 let height = 4u32;
842 let pixel_count = (width * height) as usize;
843
844 let r: Vec<i32> = (0..pixel_count).map(|i| (i * 3) as i32 % 256).collect();
845 let g: Vec<i32> = (0..pixel_count)
846 .map(|i| (i * 5 + 50) as i32 % 256)
847 .collect();
848 let b: Vec<i32> = (0..pixel_count)
849 .map(|i| (i * 7 + 100) as i32 % 256)
850 .collect();
851
852 let rct = ModularTransform::Rct {
853 begin_channel: 0,
854 rct_type: 0,
855 };
856
857 let mut encoder = ModularEncoder::new();
858 encoder.add_transform(rct.clone());
859 let encoded = encoder
860 .encode_image(&[r.clone(), g.clone(), b.clone()], width, height, 8)
861 .expect("encode ok");
862
863 let mut decoder = ModularDecoder::new();
864 decoder.add_transform(rct);
865 let decoded = decoder
866 .decode_image(&encoded, width, height, 3, 8)
867 .expect("decode ok");
868
869 assert_eq!(decoded.len(), 3);
870 assert_eq!(decoded[0], r, "Red channel mismatch");
871 assert_eq!(decoded[1], g, "Green channel mismatch");
872 assert_eq!(decoded[2], b, "Blue channel mismatch");
873 }
874
875 #[test]
876 #[ignore]
877 fn test_modular_zero_dimensions_error() {
878 let mut encoder = ModularEncoder::new();
879 assert!(encoder.encode_image(&[vec![0i32]], 0, 1, 8).is_err());
880 assert!(encoder.encode_image(&[vec![0i32]], 1, 0, 8).is_err());
881 }
882
883 #[test]
884 #[ignore]
885 fn test_modular_empty_channels_error() {
886 let mut encoder = ModularEncoder::new();
887 assert!(encoder.encode_image(&[], 1, 1, 8).is_err());
888 }
889
890 #[test]
891 #[ignore]
892 fn test_modular_multichannel() {
893 let width = 4u32;
894 let height = 4u32;
895 let pixel_count = (width * height) as usize;
896
897 let ch0: Vec<i32> = (0..pixel_count).map(|i| (i * 11 % 256) as i32).collect();
898 let ch1: Vec<i32> = (0..pixel_count).map(|i| (i * 17 % 256) as i32).collect();
899
900 let mut encoder = ModularEncoder::new();
901 let encoded = encoder
902 .encode_image(&[ch0.clone(), ch1.clone()], width, height, 8)
903 .expect("encode ok");
904
905 let mut decoder = ModularDecoder::new();
906 let decoded = decoder
907 .decode_image(&encoded, width, height, 2, 8)
908 .expect("decode ok");
909
910 assert_eq!(decoded[0], ch0);
911 assert_eq!(decoded[1], ch1);
912 }
913
914 #[test]
915 #[ignore]
916 fn test_modular_large_values() {
917 let width = 4u32;
919 let height = 4u32;
920 let pixel_count = (width * height) as usize;
921 let channel: Vec<i32> = (0..pixel_count).map(|i| (i * 4000) as i32).collect();
922
923 let mut encoder = ModularEncoder::new();
924 let encoded = encoder
925 .encode_image(&[channel.clone()], width, height, 16)
926 .expect("encode ok");
927
928 let mut decoder = ModularDecoder::new();
929 let decoded = decoder
930 .decode_image(&encoded, width, height, 1, 16)
931 .expect("decode ok");
932
933 assert_eq!(decoded[0], channel);
934 }
935}