1use crate::error::{CodecError, CodecResult};
27
28const NUM_PREDICTORS: usize = 6;
30
31#[derive(Clone, Copy, Debug, PartialEq, Eq)]
33pub enum Predictor {
34 Zero = 0,
36 West = 1,
38 North = 2,
40 AvgWN = 3,
42 Gradient = 4,
44 Weighted = 5,
46}
47
48impl Predictor {
49 fn from_index(idx: usize) -> Self {
51 match idx {
52 0 => Self::Zero,
53 1 => Self::West,
54 2 => Self::North,
55 3 => Self::AvgWN,
56 4 => Self::Gradient,
57 _ => Self::Weighted,
58 }
59 }
60}
61
62pub fn forward_rct(r: i32, g: i32, b: i32) -> (i32, i32, i32) {
72 let co = r - b;
73 let tmp = b + (co >> 1);
74 let cg = g - tmp;
75 let y = tmp + (cg >> 1);
76 (y, co, cg)
77}
78
79pub fn inverse_rct(y: i32, co: i32, cg: i32) -> (i32, i32, i32) {
83 let tmp = y - (cg >> 1);
84 let g = tmp + cg;
85 let b = tmp - (co >> 1);
86 let r = b + co;
87 (r, g, b)
88}
89
90#[derive(Clone, Debug)]
92pub enum ModularTransform {
93 Rct {
95 begin_channel: u32,
97 rct_type: u8,
99 },
100 Squeeze {
102 params: SqueezeParams,
104 },
105 Palette {
107 begin_channel: u32,
109 num_colors: u32,
111 palette: Vec<i32>,
113 },
114}
115
116#[derive(Clone, Debug)]
118pub struct SqueezeParams {
119 pub horizontal: bool,
121 pub in_place: bool,
123 pub begin_channel: u32,
125 pub num_channels: u32,
127}
128
129struct PredictionContext {
134 errors: [i64; NUM_PREDICTORS],
136 decay_shift: u32,
138 counter: u32,
140}
141
142impl PredictionContext {
143 fn new() -> Self {
144 Self {
145 errors: [0; NUM_PREDICTORS],
146 decay_shift: 4,
147 counter: 0,
148 }
149 }
150
151 fn best_predictor(&self) -> Predictor {
153 let mut best_idx = 0;
154 let mut best_err = self.errors[0];
155 for i in 1..NUM_PREDICTORS {
156 if self.errors[i] < best_err {
157 best_err = self.errors[i];
158 best_idx = i;
159 }
160 }
161 Predictor::from_index(best_idx)
162 }
163
164 fn update(&mut self, predictions: &[i32; NUM_PREDICTORS], actual: i32) {
166 for i in 0..NUM_PREDICTORS {
167 let err = (actual - predictions[i]).unsigned_abs() as i64;
168 self.errors[i] += err;
169 }
170 self.counter += 1;
171 if self.counter >= (1 << self.decay_shift) {
173 for err in &mut self.errors {
174 *err >>= 1;
175 }
176 self.counter = 0;
177 }
178 }
179}
180
181fn get_neighbors(channel: &[i32], width: u32, x: u32, y: u32) -> (i32, i32, i32, i32, i32, i32) {
185 let w = width as usize;
186 let xi = x as usize;
187 let yi = y as usize;
188
189 let val = |px: usize, py: usize| -> i32 {
190 if px < w && py < (channel.len() / w) {
191 channel[py * w + px]
192 } else {
193 0
194 }
195 };
196
197 let west = if xi > 0 { val(xi - 1, yi) } else { 0 };
198 let north = if yi > 0 { val(xi, yi - 1) } else { 0 };
199 let nw = if xi > 0 && yi > 0 {
200 val(xi - 1, yi - 1)
201 } else {
202 0
203 };
204 let ne = if yi > 0 && xi + 1 < w {
205 val(xi + 1, yi - 1)
206 } else {
207 north
208 };
209 let nn = if yi >= 2 { val(xi, yi - 2) } else { north };
210 let ww = if xi >= 2 { val(xi - 2, yi) } else { west };
211
212 (west, north, nw, ne, nn, ww)
213}
214
215fn compute_predictions(
217 w: i32,
218 n: i32,
219 nw: i32,
220 ne: i32,
221 _nn: i32,
222 _ww: i32,
223) -> [i32; NUM_PREDICTORS] {
224 let avg_wn = (w + n) / 2;
225 let gradient = n + w - nw;
226
227 let grad_clamped = gradient.clamp(w.min(n), w.max(n));
229
230 let weighted = {
232 let sum = 3i64 * n as i64 + 3i64 * w as i64 - nw as i64 + ne as i64;
233 (sum / 6) as i32
234 };
235
236 [
237 0, w, n, avg_wn, grad_clamped, weighted, ]
244}
245
246fn encode_residual(value: i32, output: &mut Vec<u8>) {
254 let unsigned = signed_to_unsigned(value);
255 let mut remaining = unsigned;
256 loop {
257 let byte = (remaining & 0x7F) as u8;
258 remaining >>= 7;
259 if remaining == 0 {
260 output.push(byte); break;
262 } else {
263 output.push(byte | 0x80); }
265 }
266}
267
268fn decode_residual(data: &[u8], offset: usize) -> CodecResult<(i32, usize)> {
272 let mut value: u32 = 0;
273 let mut shift: u32 = 0;
274 let mut pos = offset;
275
276 loop {
277 if pos >= data.len() {
278 return Err(CodecError::InvalidBitstream(
279 "Unexpected end of residual data".into(),
280 ));
281 }
282 let byte = data[pos];
283 pos += 1;
284
285 value |= ((byte & 0x7F) as u32) << shift;
286 shift += 7;
287
288 if byte & 0x80 == 0 {
289 break;
291 }
292 if shift >= 35 {
293 return Err(CodecError::InvalidBitstream(
294 "Residual value too large".into(),
295 ));
296 }
297 }
298
299 Ok((unsigned_to_signed(value), pos - offset))
300}
301
302fn signed_to_unsigned(value: i32) -> u32 {
306 if value >= 0 {
307 (value as u32) << 1
308 } else {
309 (((-value) as u32) << 1) - 1
310 }
311}
312
313fn unsigned_to_signed(value: u32) -> i32 {
315 if value & 1 == 0 {
316 (value >> 1) as i32
317 } else {
318 -(((value + 1) >> 1) as i32)
319 }
320}
321
322pub struct ModularDecoder {
324 transforms: Vec<ModularTransform>,
325}
326
327impl ModularDecoder {
328 pub fn new() -> Self {
330 Self {
331 transforms: Vec::new(),
332 }
333 }
334
335 pub fn add_transform(&mut self, transform: ModularTransform) {
337 self.transforms.push(transform);
338 }
339
340 pub fn decode_image(
344 &mut self,
345 data: &[u8],
346 width: u32,
347 height: u32,
348 channels: u32,
349 _bit_depth: u8,
350 ) -> CodecResult<Vec<Vec<i32>>> {
351 if width == 0 || height == 0 {
352 return Err(CodecError::InvalidParameter(
353 "Image dimensions must be non-zero".into(),
354 ));
355 }
356
357 let pixel_count = width as usize * height as usize;
358 let mut result_channels: Vec<Vec<i32>> = Vec::with_capacity(channels as usize);
359 let mut data_offset = 0usize;
360
361 for _ch in 0..channels {
362 let mut channel_data = vec![0i32; pixel_count];
363 let mut ctx = PredictionContext::new();
364
365 for y in 0..height {
366 for x in 0..width {
367 let (w_val, n_val, nw_val, ne_val, nn_val, ww_val) =
368 get_neighbors(&channel_data, width, x, y);
369 let predictions =
370 compute_predictions(w_val, n_val, nw_val, ne_val, nn_val, ww_val);
371 let predictor = ctx.best_predictor();
372 let predicted = predictions[predictor as usize];
373
374 let (residual, consumed) = decode_residual(data, data_offset)?;
376 data_offset += consumed;
377
378 let actual = predicted + residual;
379 channel_data[y as usize * width as usize + x as usize] = actual;
380 ctx.update(&predictions, actual);
381 }
382 }
383
384 result_channels.push(channel_data);
385 }
386
387 for transform in self.transforms.iter().rev() {
389 match transform {
390 ModularTransform::Rct {
391 begin_channel,
392 rct_type: _,
393 } => {
394 let begin = *begin_channel as usize;
395 if begin + 2 < result_channels.len() {
396 let pc = result_channels[begin].len();
397 for i in 0..pc {
398 let y_val = result_channels[begin][i];
399 let co = result_channels[begin + 1][i];
400 let cg = result_channels[begin + 2][i];
401 let (r, g, b) = inverse_rct(y_val, co, cg);
402 result_channels[begin][i] = r;
403 result_channels[begin + 1][i] = g;
404 result_channels[begin + 2][i] = b;
405 }
406 }
407 }
408 ModularTransform::Squeeze { .. } | ModularTransform::Palette { .. } => {
409 }
411 }
412 }
413
414 Ok(result_channels)
415 }
416}
417
418impl Default for ModularDecoder {
419 fn default() -> Self {
420 Self::new()
421 }
422}
423
424pub struct ModularEncoder {
426 transforms: Vec<ModularTransform>,
427 effort: u8,
428}
429
430impl ModularEncoder {
431 pub fn new() -> Self {
433 Self {
434 transforms: Vec::new(),
435 effort: 7,
436 }
437 }
438
439 pub fn with_effort(mut self, effort: u8) -> Self {
441 self.effort = effort.clamp(1, 9);
442 self
443 }
444
445 pub fn add_transform(&mut self, transform: ModularTransform) {
447 self.transforms.push(transform);
448 }
449
450 pub fn encode_image(
455 &mut self,
456 channels: &[Vec<i32>],
457 width: u32,
458 height: u32,
459 _bit_depth: u8,
460 ) -> CodecResult<Vec<u8>> {
461 if width == 0 || height == 0 {
462 return Err(CodecError::InvalidParameter(
463 "Image dimensions must be non-zero".into(),
464 ));
465 }
466 if channels.is_empty() {
467 return Err(CodecError::InvalidParameter(
468 "Must have at least one channel".into(),
469 ));
470 }
471
472 let pixel_count = width as usize * height as usize;
473 for (i, ch) in channels.iter().enumerate() {
474 if ch.len() != pixel_count {
475 return Err(CodecError::InvalidParameter(format!(
476 "Channel {i} has {} samples, expected {pixel_count}",
477 ch.len()
478 )));
479 }
480 }
481
482 let mut working_channels: Vec<Vec<i32>> = channels.to_vec();
484 for transform in &self.transforms {
485 match transform {
486 ModularTransform::Rct {
487 begin_channel,
488 rct_type: _,
489 } => {
490 let begin = *begin_channel as usize;
491 if begin + 2 < working_channels.len() {
492 for i in 0..pixel_count {
493 let r = working_channels[begin][i];
494 let g = working_channels[begin + 1][i];
495 let b = working_channels[begin + 2][i];
496 let (y_val, co, cg) = forward_rct(r, g, b);
497 working_channels[begin][i] = y_val;
498 working_channels[begin + 1][i] = co;
499 working_channels[begin + 2][i] = cg;
500 }
501 }
502 }
503 ModularTransform::Squeeze { .. } | ModularTransform::Palette { .. } => {
504 }
506 }
507 }
508
509 let mut output = Vec::with_capacity(pixel_count * working_channels.len());
511
512 for ch_data in &working_channels {
513 let mut ctx = PredictionContext::new();
514
515 for y in 0..height {
516 for x in 0..width {
517 let (w_val, n_val, nw_val, ne_val, nn_val, ww_val) =
518 get_neighbors(ch_data, width, x, y);
519 let predictions =
520 compute_predictions(w_val, n_val, nw_val, ne_val, nn_val, ww_val);
521 let predictor = ctx.best_predictor();
522 let predicted = predictions[predictor as usize];
523
524 let actual = ch_data[y as usize * width as usize + x as usize];
525 let residual = actual - predicted;
526
527 encode_residual(residual, &mut output);
528 ctx.update(&predictions, actual);
529 }
530 }
531 }
532
533 Ok(output)
534 }
535}
536
537impl Default for ModularEncoder {
538 fn default() -> Self {
539 Self::new()
540 }
541}
542
543#[cfg(test)]
544mod tests {
545 use super::*;
546
547 #[test]
548 #[ignore]
549 fn test_rct_roundtrip() {
550 let test_values = [
551 (0, 0, 0),
552 (255, 255, 255),
553 (128, 64, 32),
554 (0, 255, 0),
555 (255, 0, 0),
556 (0, 0, 255),
557 (100, 200, 50),
558 (1, 1, 1),
559 ];
560
561 for (r, g, b) in test_values {
562 let (y, co, cg) = forward_rct(r, g, b);
563 let (r2, g2, b2) = inverse_rct(y, co, cg);
564 assert_eq!(
565 (r, g, b),
566 (r2, g2, b2),
567 "RCT roundtrip failed for ({r}, {g}, {b})"
568 );
569 }
570 }
571
572 #[test]
573 #[ignore]
574 fn test_rct_negative_values() {
575 let (y, co, cg) = forward_rct(-10, 20, -30);
576 let (r, g, b) = inverse_rct(y, co, cg);
577 assert_eq!((r, g, b), (-10, 20, -30));
578 }
579
580 #[test]
581 #[ignore]
582 fn test_signed_unsigned_roundtrip() {
583 for v in -100..=100 {
584 let u = signed_to_unsigned(v);
585 let v2 = unsigned_to_signed(u);
586 assert_eq!(v, v2, "Zigzag roundtrip failed for {v}");
587 }
588 }
589
590 #[test]
591 #[ignore]
592 fn test_zigzag_ordering() {
593 assert_eq!(signed_to_unsigned(0), 0);
594 assert_eq!(signed_to_unsigned(-1), 1);
595 assert_eq!(signed_to_unsigned(1), 2);
596 assert_eq!(signed_to_unsigned(-2), 3);
597 assert_eq!(signed_to_unsigned(2), 4);
598 }
599
600 #[test]
601 #[ignore]
602 fn test_residual_encode_decode_roundtrip() {
603 let test_values = [0, 1, -1, 127, -128, 1000, -1000, 65535, -65536, 0];
604 let mut encoded = Vec::new();
605 for &v in &test_values {
606 encode_residual(v, &mut encoded);
607 }
608
609 let mut offset = 0;
610 for &expected in &test_values {
611 let (decoded, consumed) = decode_residual(&encoded, offset).expect("decode ok");
612 assert_eq!(
613 decoded, expected,
614 "Residual roundtrip failed for {expected}"
615 );
616 offset += consumed;
617 }
618 }
619
620 #[test]
621 #[ignore]
622 fn test_gradient_predictor() {
623 let predictions = compute_predictions(100, 100, 100, 100, 100, 100);
624 assert_eq!(predictions[Predictor::Gradient as usize], 100);
625 assert_eq!(predictions[Predictor::West as usize], 100);
626 assert_eq!(predictions[Predictor::North as usize], 100);
627 }
628
629 #[test]
630 #[ignore]
631 fn test_gradient_predictor_edge() {
632 let predictions = compute_predictions(10, 0, 0, 0, 0, 0);
633 assert_eq!(predictions[Predictor::Gradient as usize], 10);
634
635 let predictions = compute_predictions(0, 10, 0, 0, 0, 0);
636 assert_eq!(predictions[Predictor::Gradient as usize], 10);
637 }
638
639 #[test]
640 #[ignore]
641 fn test_prediction_context() {
642 let mut ctx = PredictionContext::new();
643 assert_eq!(ctx.best_predictor(), Predictor::Zero);
644
645 let predictions = [0, 100, 50, 75, 90, 80];
646 ctx.update(&predictions, 100);
647
648 assert_eq!(ctx.best_predictor(), Predictor::West);
649 }
650
651 #[test]
652 #[ignore]
653 fn test_get_neighbors_corner() {
654 let channel = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
655 let (w, n, nw, ne, nn, ww) = get_neighbors(&channel, 3, 0, 0);
656 assert_eq!((w, n, nw, ne, nn, ww), (0, 0, 0, 0, 0, 0));
657
658 let (w, n, nw, ne, _nn, _ww) = get_neighbors(&channel, 3, 1, 1);
659 assert_eq!(w, 4);
660 assert_eq!(n, 2);
661 assert_eq!(nw, 1);
662 assert_eq!(ne, 3);
663 }
664
665 #[test]
666 #[ignore]
667 fn test_modular_encode_decode_flat() {
668 let width = 4u32;
669 let height = 4u32;
670 let pixel_count = (width * height) as usize;
671 let channel = vec![128i32; pixel_count];
672
673 let mut encoder = ModularEncoder::new();
674 let encoded = encoder
675 .encode_image(&[channel.clone()], width, height, 8)
676 .expect("encode ok");
677
678 let mut decoder = ModularDecoder::new();
679 let decoded = decoder
680 .decode_image(&encoded, width, height, 1, 8)
681 .expect("decode ok");
682
683 assert_eq!(decoded.len(), 1);
684 assert_eq!(decoded[0], channel);
685 }
686
687 #[test]
688 #[ignore]
689 fn test_modular_encode_decode_gradient() {
690 let width = 8u32;
691 let height = 4u32;
692 let mut channel = Vec::with_capacity((width * height) as usize);
693 for y in 0..height {
694 for x in 0..width {
695 channel.push((x + y * 10) as i32);
696 }
697 }
698
699 let mut encoder = ModularEncoder::new();
700 let encoded = encoder
701 .encode_image(&[channel.clone()], width, height, 8)
702 .expect("encode ok");
703
704 let mut decoder = ModularDecoder::new();
705 let decoded = decoder
706 .decode_image(&encoded, width, height, 1, 8)
707 .expect("decode ok");
708
709 assert_eq!(decoded.len(), 1);
710 assert_eq!(decoded[0], channel);
711 }
712
713 #[test]
714 #[ignore]
715 fn test_modular_encode_decode_with_rct() {
716 let width = 4u32;
717 let height = 4u32;
718 let pixel_count = (width * height) as usize;
719
720 let r: Vec<i32> = (0..pixel_count).map(|i| (i * 3) as i32 % 256).collect();
721 let g: Vec<i32> = (0..pixel_count)
722 .map(|i| (i * 5 + 50) as i32 % 256)
723 .collect();
724 let b: Vec<i32> = (0..pixel_count)
725 .map(|i| (i * 7 + 100) as i32 % 256)
726 .collect();
727
728 let rct = ModularTransform::Rct {
729 begin_channel: 0,
730 rct_type: 0,
731 };
732
733 let mut encoder = ModularEncoder::new();
734 encoder.add_transform(rct.clone());
735 let encoded = encoder
736 .encode_image(&[r.clone(), g.clone(), b.clone()], width, height, 8)
737 .expect("encode ok");
738
739 let mut decoder = ModularDecoder::new();
740 decoder.add_transform(rct);
741 let decoded = decoder
742 .decode_image(&encoded, width, height, 3, 8)
743 .expect("decode ok");
744
745 assert_eq!(decoded.len(), 3);
746 assert_eq!(decoded[0], r, "Red channel mismatch");
747 assert_eq!(decoded[1], g, "Green channel mismatch");
748 assert_eq!(decoded[2], b, "Blue channel mismatch");
749 }
750
751 #[test]
752 #[ignore]
753 fn test_modular_zero_dimensions_error() {
754 let mut encoder = ModularEncoder::new();
755 assert!(encoder.encode_image(&[vec![0i32]], 0, 1, 8).is_err());
756 assert!(encoder.encode_image(&[vec![0i32]], 1, 0, 8).is_err());
757 }
758
759 #[test]
760 #[ignore]
761 fn test_modular_empty_channels_error() {
762 let mut encoder = ModularEncoder::new();
763 assert!(encoder.encode_image(&[], 1, 1, 8).is_err());
764 }
765
766 #[test]
767 #[ignore]
768 fn test_modular_multichannel() {
769 let width = 4u32;
770 let height = 4u32;
771 let pixel_count = (width * height) as usize;
772
773 let ch0: Vec<i32> = (0..pixel_count).map(|i| (i * 11 % 256) as i32).collect();
774 let ch1: Vec<i32> = (0..pixel_count).map(|i| (i * 17 % 256) as i32).collect();
775
776 let mut encoder = ModularEncoder::new();
777 let encoded = encoder
778 .encode_image(&[ch0.clone(), ch1.clone()], width, height, 8)
779 .expect("encode ok");
780
781 let mut decoder = ModularDecoder::new();
782 let decoded = decoder
783 .decode_image(&encoded, width, height, 2, 8)
784 .expect("decode ok");
785
786 assert_eq!(decoded[0], ch0);
787 assert_eq!(decoded[1], ch1);
788 }
789
790 #[test]
791 #[ignore]
792 fn test_modular_large_values() {
793 let width = 4u32;
795 let height = 4u32;
796 let pixel_count = (width * height) as usize;
797 let channel: Vec<i32> = (0..pixel_count).map(|i| (i * 4000) as i32).collect();
798
799 let mut encoder = ModularEncoder::new();
800 let encoded = encoder
801 .encode_image(&[channel.clone()], width, height, 16)
802 .expect("encode ok");
803
804 let mut decoder = ModularDecoder::new();
805 let decoded = decoder
806 .decode_image(&encoded, width, height, 1, 16)
807 .expect("decode ok");
808
809 assert_eq!(decoded[0], channel);
810 }
811}