Skip to main content

oximedia_codec/jpegxl/
modular.rs

1//! Modular sub-codec for JPEG-XL lossless encoding.
2//!
3//! The Modular mode is the backbone of JPEG-XL lossless compression. It operates
4//! by applying reversible transforms (RCT, Squeeze) to decorrelate channels,
5//! predicting each sample using adaptive weighted predictors, and entropy-coding
6//! the residuals.
7//!
8//! ## Pipeline
9//!
10//! 1. **Reversible Color Transform (RCT)**: Converts RGB to YCoCg-R, which
11//!    decorrelates color channels for better compression.
12//! 2. **Prediction**: Each pixel is predicted from its causal neighbors
13//!    (N, W, NW, NE, NN, WW) using an adaptive weighted predictor.
14//! 3. **Residual coding**: The prediction errors are variable-length coded.
15//!
16//! ## Predictors
17//!
18//! JPEG-XL defines several predictors:
19//! - Zero: predict 0 (for first pixels)
20//! - West (left neighbor)
21//! - North (top neighbor)
22//! - Average of W and N
23//! - Gradient: N + W - NW (MED/median edge detector)
24//! - Weighted: adaptive weighted combination of multiple neighbors
25
26use crate::error::{CodecError, CodecResult};
27
28/// Maximum number of predictor types.
29const NUM_PREDICTORS: usize = 6;
30
31/// Predictor types used in the modular sub-codec.
32#[derive(Clone, Copy, Debug, PartialEq, Eq)]
33pub enum Predictor {
34    /// Predict zero (used for edges).
35    Zero = 0,
36    /// Use the west (left) neighbor.
37    West = 1,
38    /// Use the north (top) neighbor.
39    North = 2,
40    /// Average of west and north.
41    AvgWN = 3,
42    /// Gradient predictor: N + W - NW.
43    Gradient = 4,
44    /// Adaptive weighted combination of neighbors.
45    Weighted = 5,
46}
47
48impl Predictor {
49    /// Convert from integer index to predictor.
50    fn from_index(idx: usize) -> Self {
51        match idx {
52            0 => Self::Zero,
53            1 => Self::West,
54            2 => Self::North,
55            3 => Self::AvgWN,
56            4 => Self::Gradient,
57            _ => Self::Weighted,
58        }
59    }
60}
61
62/// Reversible Color Transform: forward (RGB -> YCoCg-R).
63///
64/// This is a lossless integer approximation of the YCoCg color space.
65/// It decorrelates color channels for better compression.
66///
67/// - Co = R - B
68/// - tmp = B + (Co >> 1)
69/// - Cg = G - tmp
70/// - Y = tmp + (Cg >> 1)
71pub fn forward_rct(r: i32, g: i32, b: i32) -> (i32, i32, i32) {
72    let co = r - b;
73    let tmp = b + (co >> 1);
74    let cg = g - tmp;
75    let y = tmp + (cg >> 1);
76    (y, co, cg)
77}
78
79/// Reversible Color Transform: inverse (YCoCg-R -> RGB).
80///
81/// Exactly inverts `forward_rct` for all integer inputs.
82pub fn inverse_rct(y: i32, co: i32, cg: i32) -> (i32, i32, i32) {
83    let tmp = y - (cg >> 1);
84    let g = tmp + cg;
85    let b = tmp - (co >> 1);
86    let r = b + co;
87    (r, g, b)
88}
89
90/// Modular transform types that can be applied to channels.
91#[derive(Clone, Debug)]
92pub enum ModularTransform {
93    /// Reversible Color Transform on a group of 3 channels.
94    Rct {
95        /// First channel index of the group.
96        begin_channel: u32,
97        /// RCT variant (0 = YCoCg-R).
98        rct_type: u8,
99    },
100    /// Squeeze (wavelet-like) transform for progressive decoding.
101    Squeeze {
102        /// Squeeze parameters.
103        params: SqueezeParams,
104    },
105    /// Palette transform for indexed-color images.
106    Palette {
107        /// First channel to apply palette to.
108        begin_channel: u32,
109        /// Number of palette entries.
110        num_colors: u32,
111        /// Palette data (interleaved channel values).
112        palette: Vec<i32>,
113    },
114}
115
116/// Parameters for the Squeeze transform.
117#[derive(Clone, Debug)]
118pub struct SqueezeParams {
119    /// Apply horizontal squeeze.
120    pub horizontal: bool,
121    /// Perform in-place (otherwise creates new channels).
122    pub in_place: bool,
123    /// First channel to squeeze.
124    pub begin_channel: u32,
125    /// Number of channels to squeeze.
126    pub num_channels: u32,
127}
128
129/// Context for adaptive prediction weight selection.
130///
131/// Tracks prediction errors to adaptively choose the best predictor
132/// for each pixel context.
133struct PredictionContext {
134    /// Accumulated absolute errors for each predictor.
135    errors: [i64; NUM_PREDICTORS],
136    /// Decay factor for error accumulation (shift right by this amount).
137    decay_shift: u32,
138    /// Counter for periodic error decay.
139    counter: u32,
140}
141
142impl PredictionContext {
143    fn new() -> Self {
144        Self {
145            errors: [0; NUM_PREDICTORS],
146            decay_shift: 4,
147            counter: 0,
148        }
149    }
150
151    /// Select the predictor with the lowest accumulated error.
152    fn best_predictor(&self) -> Predictor {
153        let mut best_idx = 0;
154        let mut best_err = self.errors[0];
155        for i in 1..NUM_PREDICTORS {
156            if self.errors[i] < best_err {
157                best_err = self.errors[i];
158                best_idx = i;
159            }
160        }
161        Predictor::from_index(best_idx)
162    }
163
164    /// Update error accumulators after observing the actual value.
165    fn update(&mut self, predictions: &[i32; NUM_PREDICTORS], actual: i32) {
166        for i in 0..NUM_PREDICTORS {
167            let err = (actual - predictions[i]).unsigned_abs() as i64;
168            self.errors[i] += err;
169        }
170        self.counter += 1;
171        // Periodic decay to adapt to changing statistics
172        if self.counter >= (1 << self.decay_shift) {
173            for err in &mut self.errors {
174                *err >>= 1;
175            }
176            self.counter = 0;
177        }
178    }
179}
180
181/// Get the neighbor values for prediction at position (x, y) in a channel.
182///
183/// Returns (W, N, NW, NE, NN, WW).
184fn get_neighbors(channel: &[i32], width: u32, x: u32, y: u32) -> (i32, i32, i32, i32, i32, i32) {
185    let w = width as usize;
186    let xi = x as usize;
187    let yi = y as usize;
188
189    let val = |px: usize, py: usize| -> i32 {
190        if px < w && py < (channel.len() / w) {
191            channel[py * w + px]
192        } else {
193            0
194        }
195    };
196
197    let west = if xi > 0 { val(xi - 1, yi) } else { 0 };
198    let north = if yi > 0 { val(xi, yi - 1) } else { 0 };
199    let nw = if xi > 0 && yi > 0 {
200        val(xi - 1, yi - 1)
201    } else {
202        0
203    };
204    let ne = if yi > 0 && xi + 1 < w {
205        val(xi + 1, yi - 1)
206    } else {
207        north
208    };
209    let nn = if yi >= 2 { val(xi, yi - 2) } else { north };
210    let ww = if xi >= 2 { val(xi - 2, yi) } else { west };
211
212    (west, north, nw, ne, nn, ww)
213}
214
215/// Compute all predictor values for a given set of neighbors.
216fn compute_predictions(
217    w: i32,
218    n: i32,
219    nw: i32,
220    ne: i32,
221    _nn: i32,
222    _ww: i32,
223) -> [i32; NUM_PREDICTORS] {
224    let avg_wn = (w + n) / 2;
225    let gradient = n + w - nw;
226
227    // Clamp gradient to the range [min(W,N), max(W,N)] for stability
228    let grad_clamped = gradient.clamp(w.min(n), w.max(n));
229
230    // Weighted predictor: adaptive combination
231    let weighted = {
232        let sum = 3i64 * n as i64 + 3i64 * w as i64 - nw as i64 + ne as i64;
233        (sum / 6) as i32
234    };
235
236    [
237        0,            // Zero
238        w,            // West
239        n,            // North
240        avg_wn,       // Average(W, N)
241        grad_clamped, // Gradient (clamped)
242        weighted,     // Weighted
243    ]
244}
245
246/// Encode a signed residual into a variable-length byte sequence.
247///
248/// Encoding scheme:
249/// - Map signed to unsigned via zigzag: 0->0, -1->1, 1->2, -2->3, ...
250/// - Write unsigned value in 7-bit chunks with high bit as continuation flag:
251///   - If high bit = 1, more bytes follow
252///   - If high bit = 0, this is the last byte
253fn encode_residual(value: i32, output: &mut Vec<u8>) {
254    let unsigned = signed_to_unsigned(value);
255    let mut remaining = unsigned;
256    loop {
257        let byte = (remaining & 0x7F) as u8;
258        remaining >>= 7;
259        if remaining == 0 {
260            output.push(byte); // high bit = 0, last byte
261            break;
262        } else {
263            output.push(byte | 0x80); // high bit = 1, more bytes follow
264        }
265    }
266}
267
268/// Decode a variable-length encoded residual.
269///
270/// Returns (decoded_value, bytes_consumed).
271fn decode_residual(data: &[u8], offset: usize) -> CodecResult<(i32, usize)> {
272    let mut value: u32 = 0;
273    let mut shift: u32 = 0;
274    let mut pos = offset;
275
276    loop {
277        if pos >= data.len() {
278            return Err(CodecError::InvalidBitstream(
279                "Unexpected end of residual data".into(),
280            ));
281        }
282        let byte = data[pos];
283        pos += 1;
284
285        value |= ((byte & 0x7F) as u32) << shift;
286        shift += 7;
287
288        if byte & 0x80 == 0 {
289            // Last byte
290            break;
291        }
292        if shift >= 35 {
293            return Err(CodecError::InvalidBitstream(
294                "Residual value too large".into(),
295            ));
296        }
297    }
298
299    Ok((unsigned_to_signed(value), pos - offset))
300}
301
302/// Map a signed residual to an unsigned value for entropy coding.
303///
304/// Uses the standard zigzag mapping: 0 -> 0, -1 -> 1, 1 -> 2, -2 -> 3, ...
305fn signed_to_unsigned(value: i32) -> u32 {
306    if value >= 0 {
307        (value as u32) << 1
308    } else {
309        (((-value) as u32) << 1) - 1
310    }
311}
312
313/// Map an unsigned value back to a signed residual.
314fn unsigned_to_signed(value: u32) -> i32 {
315    if value & 1 == 0 {
316        (value >> 1) as i32
317    } else {
318        -(((value + 1) >> 1) as i32)
319    }
320}
321
322/// Modular decoder for JPEG-XL lossless images.
323pub struct ModularDecoder {
324    transforms: Vec<ModularTransform>,
325}
326
327impl ModularDecoder {
328    /// Create a new modular decoder.
329    pub fn new() -> Self {
330        Self {
331            transforms: Vec::new(),
332        }
333    }
334
335    /// Add a transform to be applied during decoding (inverse order).
336    pub fn add_transform(&mut self, transform: ModularTransform) {
337        self.transforms.push(transform);
338    }
339
340    /// Decode an image from variable-length coded residual data.
341    ///
342    /// Returns one `Vec<i32>` per channel, each of length `width * height`.
343    pub fn decode_image(
344        &mut self,
345        data: &[u8],
346        width: u32,
347        height: u32,
348        channels: u32,
349        _bit_depth: u8,
350    ) -> CodecResult<Vec<Vec<i32>>> {
351        if width == 0 || height == 0 {
352            return Err(CodecError::InvalidParameter(
353                "Image dimensions must be non-zero".into(),
354            ));
355        }
356
357        let pixel_count = width as usize * height as usize;
358        let mut result_channels: Vec<Vec<i32>> = Vec::with_capacity(channels as usize);
359        let mut data_offset = 0usize;
360
361        for _ch in 0..channels {
362            let mut channel_data = vec![0i32; pixel_count];
363            let mut ctx = PredictionContext::new();
364
365            for y in 0..height {
366                for x in 0..width {
367                    let (w_val, n_val, nw_val, ne_val, nn_val, ww_val) =
368                        get_neighbors(&channel_data, width, x, y);
369                    let predictions =
370                        compute_predictions(w_val, n_val, nw_val, ne_val, nn_val, ww_val);
371                    let predictor = ctx.best_predictor();
372                    let predicted = predictions[predictor as usize];
373
374                    // Decode residual
375                    let (residual, consumed) = decode_residual(data, data_offset)?;
376                    data_offset += consumed;
377
378                    let actual = predicted + residual;
379                    channel_data[y as usize * width as usize + x as usize] = actual;
380                    ctx.update(&predictions, actual);
381                }
382            }
383
384            result_channels.push(channel_data);
385        }
386
387        // Apply inverse transforms in reverse order
388        for transform in self.transforms.iter().rev() {
389            match transform {
390                ModularTransform::Rct {
391                    begin_channel,
392                    rct_type: _,
393                } => {
394                    let begin = *begin_channel as usize;
395                    if begin + 2 < result_channels.len() {
396                        let pc = result_channels[begin].len();
397                        for i in 0..pc {
398                            let y_val = result_channels[begin][i];
399                            let co = result_channels[begin + 1][i];
400                            let cg = result_channels[begin + 2][i];
401                            let (r, g, b) = inverse_rct(y_val, co, cg);
402                            result_channels[begin][i] = r;
403                            result_channels[begin + 1][i] = g;
404                            result_channels[begin + 2][i] = b;
405                        }
406                    }
407                }
408                ModularTransform::Squeeze {
409                    params:
410                        SqueezeParams {
411                            horizontal,
412                            begin_channel,
413                            num_channels,
414                            ..
415                        },
416                } => {
417                    let begin = *begin_channel as usize;
418                    let nc = *num_channels as usize;
419                    let horiz = *horizontal;
420
421                    for ch_idx in begin..begin + nc {
422                        if ch_idx >= result_channels.len() {
423                            break;
424                        }
425
426                        // The channel data is laid out as avg half followed by
427                        // residual half along the squeezed dimension.
428                        // Horizontal squeeze: first W/2 cols = avg, next W/2 = residual.
429                        // Vertical squeeze:   first H/2 rows = avg, next H/2 = residual.
430                        if horiz {
431                            let half_w = (width / 2) as usize;
432                            if half_w == 0 {
433                                continue;
434                            }
435                            let h = height as usize;
436                            let w = width as usize;
437                            // Build a new buffer of the same size.
438                            let old = result_channels[ch_idx].clone();
439                            let buf = &mut result_channels[ch_idx];
440                            for row in 0..h {
441                                for i in 0..half_w {
442                                    let avg = old[row * w + i];
443                                    let diff = old[row * w + half_w + i];
444                                    // a = avg + ((diff + 1) >> 1)
445                                    // b = avg - (diff >> 1)
446                                    let a = avg + ((diff + 1) >> 1);
447                                    let b = avg - (diff >> 1);
448                                    buf[row * w + 2 * i] = a;
449                                    buf[row * w + 2 * i + 1] = b;
450                                }
451                                // If width is odd, the last column (index w-1) has no
452                                // pair in the residual half; it was stored as-is.
453                                if w % 2 != 0 {
454                                    buf[row * w + w - 1] = old[row * w + w - 1];
455                                }
456                            }
457                        } else {
458                            // Vertical squeeze
459                            let half_h = (height / 2) as usize;
460                            if half_h == 0 {
461                                continue;
462                            }
463                            let h = height as usize;
464                            let w = width as usize;
465                            let old = result_channels[ch_idx].clone();
466                            let buf = &mut result_channels[ch_idx];
467                            for col in 0..w {
468                                for i in 0..half_h {
469                                    let avg = old[i * w + col];
470                                    let diff = old[(half_h + i) * w + col];
471                                    let a = avg + ((diff + 1) >> 1);
472                                    let b = avg - (diff >> 1);
473                                    buf[(2 * i) * w + col] = a;
474                                    buf[(2 * i + 1) * w + col] = b;
475                                }
476                                // If height is odd, the last row has no pair.
477                                if h % 2 != 0 {
478                                    buf[(h - 1) * w + col] = old[(h - 1) * w + col];
479                                }
480                            }
481                        }
482                    }
483                }
484                ModularTransform::Palette {
485                    begin_channel,
486                    num_colors,
487                    palette,
488                } => {
489                    if *num_colors == 0 {
490                        return Err(CodecError::InvalidBitstream(
491                            "Palette: num_colors must be non-zero".into(),
492                        ));
493                    }
494                    let nc = *num_colors as usize;
495                    // Derive number of components from palette size.
496                    if palette.len() % nc != 0 {
497                        return Err(CodecError::InvalidBitstream(format!(
498                            "Palette length {} is not divisible by num_colors {}",
499                            palette.len(),
500                            nc
501                        )));
502                    }
503                    let num_components = palette.len() / nc;
504                    let begin = *begin_channel as usize;
505
506                    if begin >= result_channels.len() {
507                        return Err(CodecError::InvalidBitstream(
508                            "Palette: begin_channel out of bounds".into(),
509                        ));
510                    }
511                    // Indices are stored in the first (begin_channel) channel.
512                    // Read indices into a temporary to avoid borrow conflicts.
513                    let indices = result_channels[begin].clone();
514
515                    // Validate and apply: expand index channel into num_components channels.
516                    for (pixel_pos, &idx_val) in indices.iter().enumerate() {
517                        if idx_val < 0 || idx_val as usize >= nc {
518                            return Err(CodecError::InvalidBitstream(format!(
519                                "Palette index {idx_val} out of range [0, {nc})"
520                            )));
521                        }
522                        let idx = idx_val as usize;
523                        for c in 0..num_components {
524                            let target_ch = begin + c;
525                            if target_ch >= result_channels.len() {
526                                return Err(CodecError::InvalidBitstream(format!(
527                                    "Palette: channel {target_ch} out of bounds"
528                                )));
529                            }
530                            result_channels[target_ch][pixel_pos] =
531                                palette[idx * num_components + c];
532                        }
533                    }
534                }
535            }
536        }
537
538        Ok(result_channels)
539    }
540}
541
542impl Default for ModularDecoder {
543    fn default() -> Self {
544        Self::new()
545    }
546}
547
548/// Modular encoder for JPEG-XL lossless images.
549pub struct ModularEncoder {
550    transforms: Vec<ModularTransform>,
551    effort: u8,
552}
553
554impl ModularEncoder {
555    /// Create a new modular encoder.
556    pub fn new() -> Self {
557        Self {
558            transforms: Vec::new(),
559            effort: 7,
560        }
561    }
562
563    /// Set encoding effort (1-9).
564    pub fn with_effort(mut self, effort: u8) -> Self {
565        self.effort = effort.clamp(1, 9);
566        self
567    }
568
569    /// Add a transform to be applied during encoding.
570    pub fn add_transform(&mut self, transform: ModularTransform) {
571        self.transforms.push(transform);
572    }
573
574    /// Encode channels into a compressed byte stream.
575    ///
576    /// Input: one `Vec<i32>` per channel, each of length `width * height`.
577    /// Returns the variable-length coded residual data.
578    pub fn encode_image(
579        &mut self,
580        channels: &[Vec<i32>],
581        width: u32,
582        height: u32,
583        _bit_depth: u8,
584    ) -> CodecResult<Vec<u8>> {
585        if width == 0 || height == 0 {
586            return Err(CodecError::InvalidParameter(
587                "Image dimensions must be non-zero".into(),
588            ));
589        }
590        if channels.is_empty() {
591            return Err(CodecError::InvalidParameter(
592                "Must have at least one channel".into(),
593            ));
594        }
595
596        let pixel_count = width as usize * height as usize;
597        for (i, ch) in channels.iter().enumerate() {
598            if ch.len() != pixel_count {
599                return Err(CodecError::InvalidParameter(format!(
600                    "Channel {i} has {} samples, expected {pixel_count}",
601                    ch.len()
602                )));
603            }
604        }
605
606        // Apply forward transforms
607        let mut working_channels: Vec<Vec<i32>> = channels.to_vec();
608        for transform in &self.transforms {
609            match transform {
610                ModularTransform::Rct {
611                    begin_channel,
612                    rct_type: _,
613                } => {
614                    let begin = *begin_channel as usize;
615                    if begin + 2 < working_channels.len() {
616                        for i in 0..pixel_count {
617                            let r = working_channels[begin][i];
618                            let g = working_channels[begin + 1][i];
619                            let b = working_channels[begin + 2][i];
620                            let (y_val, co, cg) = forward_rct(r, g, b);
621                            working_channels[begin][i] = y_val;
622                            working_channels[begin + 1][i] = co;
623                            working_channels[begin + 2][i] = cg;
624                        }
625                    }
626                }
627                ModularTransform::Squeeze { .. } | ModularTransform::Palette { .. } => {
628                    // Not yet implemented
629                }
630            }
631        }
632
633        // Encode residuals channel by channel
634        let mut output = Vec::with_capacity(pixel_count * working_channels.len());
635
636        for ch_data in &working_channels {
637            let mut ctx = PredictionContext::new();
638
639            for y in 0..height {
640                for x in 0..width {
641                    let (w_val, n_val, nw_val, ne_val, nn_val, ww_val) =
642                        get_neighbors(ch_data, width, x, y);
643                    let predictions =
644                        compute_predictions(w_val, n_val, nw_val, ne_val, nn_val, ww_val);
645                    let predictor = ctx.best_predictor();
646                    let predicted = predictions[predictor as usize];
647
648                    let actual = ch_data[y as usize * width as usize + x as usize];
649                    let residual = actual - predicted;
650
651                    encode_residual(residual, &mut output);
652                    ctx.update(&predictions, actual);
653                }
654            }
655        }
656
657        Ok(output)
658    }
659}
660
661impl Default for ModularEncoder {
662    fn default() -> Self {
663        Self::new()
664    }
665}
666
667#[cfg(test)]
668mod tests {
669    use super::*;
670
671    #[test]
672    #[ignore]
673    fn test_rct_roundtrip() {
674        let test_values = [
675            (0, 0, 0),
676            (255, 255, 255),
677            (128, 64, 32),
678            (0, 255, 0),
679            (255, 0, 0),
680            (0, 0, 255),
681            (100, 200, 50),
682            (1, 1, 1),
683        ];
684
685        for (r, g, b) in test_values {
686            let (y, co, cg) = forward_rct(r, g, b);
687            let (r2, g2, b2) = inverse_rct(y, co, cg);
688            assert_eq!(
689                (r, g, b),
690                (r2, g2, b2),
691                "RCT roundtrip failed for ({r}, {g}, {b})"
692            );
693        }
694    }
695
696    #[test]
697    #[ignore]
698    fn test_rct_negative_values() {
699        let (y, co, cg) = forward_rct(-10, 20, -30);
700        let (r, g, b) = inverse_rct(y, co, cg);
701        assert_eq!((r, g, b), (-10, 20, -30));
702    }
703
704    #[test]
705    #[ignore]
706    fn test_signed_unsigned_roundtrip() {
707        for v in -100..=100 {
708            let u = signed_to_unsigned(v);
709            let v2 = unsigned_to_signed(u);
710            assert_eq!(v, v2, "Zigzag roundtrip failed for {v}");
711        }
712    }
713
714    #[test]
715    #[ignore]
716    fn test_zigzag_ordering() {
717        assert_eq!(signed_to_unsigned(0), 0);
718        assert_eq!(signed_to_unsigned(-1), 1);
719        assert_eq!(signed_to_unsigned(1), 2);
720        assert_eq!(signed_to_unsigned(-2), 3);
721        assert_eq!(signed_to_unsigned(2), 4);
722    }
723
724    #[test]
725    #[ignore]
726    fn test_residual_encode_decode_roundtrip() {
727        let test_values = [0, 1, -1, 127, -128, 1000, -1000, 65535, -65536, 0];
728        let mut encoded = Vec::new();
729        for &v in &test_values {
730            encode_residual(v, &mut encoded);
731        }
732
733        let mut offset = 0;
734        for &expected in &test_values {
735            let (decoded, consumed) = decode_residual(&encoded, offset).expect("decode ok");
736            assert_eq!(
737                decoded, expected,
738                "Residual roundtrip failed for {expected}"
739            );
740            offset += consumed;
741        }
742    }
743
744    #[test]
745    #[ignore]
746    fn test_gradient_predictor() {
747        let predictions = compute_predictions(100, 100, 100, 100, 100, 100);
748        assert_eq!(predictions[Predictor::Gradient as usize], 100);
749        assert_eq!(predictions[Predictor::West as usize], 100);
750        assert_eq!(predictions[Predictor::North as usize], 100);
751    }
752
753    #[test]
754    #[ignore]
755    fn test_gradient_predictor_edge() {
756        let predictions = compute_predictions(10, 0, 0, 0, 0, 0);
757        assert_eq!(predictions[Predictor::Gradient as usize], 10);
758
759        let predictions = compute_predictions(0, 10, 0, 0, 0, 0);
760        assert_eq!(predictions[Predictor::Gradient as usize], 10);
761    }
762
763    #[test]
764    #[ignore]
765    fn test_prediction_context() {
766        let mut ctx = PredictionContext::new();
767        assert_eq!(ctx.best_predictor(), Predictor::Zero);
768
769        let predictions = [0, 100, 50, 75, 90, 80];
770        ctx.update(&predictions, 100);
771
772        assert_eq!(ctx.best_predictor(), Predictor::West);
773    }
774
775    #[test]
776    #[ignore]
777    fn test_get_neighbors_corner() {
778        let channel = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
779        let (w, n, nw, ne, nn, ww) = get_neighbors(&channel, 3, 0, 0);
780        assert_eq!((w, n, nw, ne, nn, ww), (0, 0, 0, 0, 0, 0));
781
782        let (w, n, nw, ne, _nn, _ww) = get_neighbors(&channel, 3, 1, 1);
783        assert_eq!(w, 4);
784        assert_eq!(n, 2);
785        assert_eq!(nw, 1);
786        assert_eq!(ne, 3);
787    }
788
789    #[test]
790    #[ignore]
791    fn test_modular_encode_decode_flat() {
792        let width = 4u32;
793        let height = 4u32;
794        let pixel_count = (width * height) as usize;
795        let channel = vec![128i32; pixel_count];
796
797        let mut encoder = ModularEncoder::new();
798        let encoded = encoder
799            .encode_image(&[channel.clone()], width, height, 8)
800            .expect("encode ok");
801
802        let mut decoder = ModularDecoder::new();
803        let decoded = decoder
804            .decode_image(&encoded, width, height, 1, 8)
805            .expect("decode ok");
806
807        assert_eq!(decoded.len(), 1);
808        assert_eq!(decoded[0], channel);
809    }
810
811    #[test]
812    #[ignore]
813    fn test_modular_encode_decode_gradient() {
814        let width = 8u32;
815        let height = 4u32;
816        let mut channel = Vec::with_capacity((width * height) as usize);
817        for y in 0..height {
818            for x in 0..width {
819                channel.push((x + y * 10) as i32);
820            }
821        }
822
823        let mut encoder = ModularEncoder::new();
824        let encoded = encoder
825            .encode_image(&[channel.clone()], width, height, 8)
826            .expect("encode ok");
827
828        let mut decoder = ModularDecoder::new();
829        let decoded = decoder
830            .decode_image(&encoded, width, height, 1, 8)
831            .expect("decode ok");
832
833        assert_eq!(decoded.len(), 1);
834        assert_eq!(decoded[0], channel);
835    }
836
837    #[test]
838    #[ignore]
839    fn test_modular_encode_decode_with_rct() {
840        let width = 4u32;
841        let height = 4u32;
842        let pixel_count = (width * height) as usize;
843
844        let r: Vec<i32> = (0..pixel_count).map(|i| (i * 3) as i32 % 256).collect();
845        let g: Vec<i32> = (0..pixel_count)
846            .map(|i| (i * 5 + 50) as i32 % 256)
847            .collect();
848        let b: Vec<i32> = (0..pixel_count)
849            .map(|i| (i * 7 + 100) as i32 % 256)
850            .collect();
851
852        let rct = ModularTransform::Rct {
853            begin_channel: 0,
854            rct_type: 0,
855        };
856
857        let mut encoder = ModularEncoder::new();
858        encoder.add_transform(rct.clone());
859        let encoded = encoder
860            .encode_image(&[r.clone(), g.clone(), b.clone()], width, height, 8)
861            .expect("encode ok");
862
863        let mut decoder = ModularDecoder::new();
864        decoder.add_transform(rct);
865        let decoded = decoder
866            .decode_image(&encoded, width, height, 3, 8)
867            .expect("decode ok");
868
869        assert_eq!(decoded.len(), 3);
870        assert_eq!(decoded[0], r, "Red channel mismatch");
871        assert_eq!(decoded[1], g, "Green channel mismatch");
872        assert_eq!(decoded[2], b, "Blue channel mismatch");
873    }
874
875    #[test]
876    #[ignore]
877    fn test_modular_zero_dimensions_error() {
878        let mut encoder = ModularEncoder::new();
879        assert!(encoder.encode_image(&[vec![0i32]], 0, 1, 8).is_err());
880        assert!(encoder.encode_image(&[vec![0i32]], 1, 0, 8).is_err());
881    }
882
883    #[test]
884    #[ignore]
885    fn test_modular_empty_channels_error() {
886        let mut encoder = ModularEncoder::new();
887        assert!(encoder.encode_image(&[], 1, 1, 8).is_err());
888    }
889
890    #[test]
891    #[ignore]
892    fn test_modular_multichannel() {
893        let width = 4u32;
894        let height = 4u32;
895        let pixel_count = (width * height) as usize;
896
897        let ch0: Vec<i32> = (0..pixel_count).map(|i| (i * 11 % 256) as i32).collect();
898        let ch1: Vec<i32> = (0..pixel_count).map(|i| (i * 17 % 256) as i32).collect();
899
900        let mut encoder = ModularEncoder::new();
901        let encoded = encoder
902            .encode_image(&[ch0.clone(), ch1.clone()], width, height, 8)
903            .expect("encode ok");
904
905        let mut decoder = ModularDecoder::new();
906        let decoded = decoder
907            .decode_image(&encoded, width, height, 2, 8)
908            .expect("decode ok");
909
910        assert_eq!(decoded[0], ch0);
911        assert_eq!(decoded[1], ch1);
912    }
913
914    #[test]
915    #[ignore]
916    fn test_modular_large_values() {
917        // Test with 16-bit range values
918        let width = 4u32;
919        let height = 4u32;
920        let pixel_count = (width * height) as usize;
921        let channel: Vec<i32> = (0..pixel_count).map(|i| (i * 4000) as i32).collect();
922
923        let mut encoder = ModularEncoder::new();
924        let encoded = encoder
925            .encode_image(&[channel.clone()], width, height, 16)
926            .expect("encode ok");
927
928        let mut decoder = ModularDecoder::new();
929        let decoded = decoder
930            .decode_image(&encoded, width, height, 1, 16)
931            .expect("decode ok");
932
933        assert_eq!(decoded[0], channel);
934    }
935}