Skip to main content

jxl_encoder/modular/
rct.rs

1// Copyright (c) Imazen LLC and the JPEG XL Project Authors.
2// Algorithms and constants derived from libjxl (BSD-3-Clause).
3// Licensed under AGPL-3.0-or-later. Commercial licenses at https://www.imazen.io/pricing
4
5//! Reversible Color Transform (RCT) for modular encoding.
6//!
7//! RCT decorrelates color channels to improve compression. The most effective
8//! transform is YCoCg (rct_type=6), which libjxl uses by default.
9//!
10//! RCT types encode: permutation (rct_type / 7) and transform (rct_type % 7).
11//!
12//! Permutations (0-5): RGB, GBR, BRG, RBG, GRB, BGR
13//! Transforms (0-6) — libjxl uses `second = type >> 1`, `third = type & 1`:
14//!   0: No transform (permutation only)
15//!   1: Third -= First
16//!   2: Second -= First
17//!   3: Second -= First, Third -= First
18//!   4: Second -= (First + Third) >> 1
19//!   5: Second -= (First + Third) >> 1, Third -= First
20//!   6: YCoCg (most effective for typical images)
21
22use crate::error::{Error, Result};
23use crate::modular::Channel;
24
25/// RCT transform type (0-41).
26/// Default is 6 (YCoCg with no permutation).
27#[derive(Clone, Copy, Debug, Default)]
28pub struct RctType(pub u8);
29
30impl RctType {
31    /// YCoCg transform (most effective for typical images).
32    pub const YCOCG: RctType = RctType(6);
33
34    /// No transform.
35    pub const NONE: RctType = RctType(0);
36
37    /// Simple G-R, G-B decorrelation.
38    pub const SUBTRACT_GREEN: RctType = RctType(3);
39
40    /// Get the permutation index (0-5).
41    pub fn permutation(&self) -> usize {
42        (self.0 / 7) as usize
43    }
44
45    /// Get the transform type (0-6).
46    pub fn transform(&self) -> usize {
47        (self.0 % 7) as usize
48    }
49
50    /// Check if this is a no-op.
51    pub fn is_noop(&self) -> bool {
52        self.0 == 0
53    }
54}
55
56/// Apply forward RCT to three channels in-place.
57///
58/// # Arguments
59/// * `channels` - Three channels (RGB or similar) to transform
60/// * `rct_type` - RCT type (0-41)
61///
62/// # Returns
63/// Ok(()) if transform was applied, or error if channels don't match.
64pub fn forward_rct(channels: &mut [Channel], begin_c: usize, rct_type: RctType) -> Result<()> {
65    if rct_type.is_noop() {
66        return Ok(());
67    }
68
69    // Validate we have at least 3 channels starting from begin_c
70    if channels.len() < begin_c + 3 {
71        return Err(Error::InvalidInput(
72            "RCT requires at least 3 channels".to_string(),
73        ));
74    }
75
76    // Validate all three channels have same dimensions
77    let w = channels[begin_c].width();
78    let h = channels[begin_c].height();
79    for c in &channels[begin_c..begin_c + 3] {
80        if c.width() != w || c.height() != h {
81            return Err(Error::InvalidInput(
82                "RCT requires channels with same dimensions".to_string(),
83            ));
84        }
85    }
86
87    let permutation = rct_type.permutation();
88    let transform = rct_type.transform();
89
90    // Get permuted input indices
91    let (idx0, idx1, idx2) = permute_indices(permutation);
92
93    // Apply transform row by row
94    // We need to work around borrow checker by copying data
95    for y in 0..h {
96        // Read from PERMUTED input indices (permutation selects which channel is "first", etc.)
97        let row0: Vec<i32> = channels[begin_c + idx0].row(y).to_vec();
98        let row1: Vec<i32> = channels[begin_c + idx1].row(y).to_vec();
99        let row2: Vec<i32> = channels[begin_c + idx2].row(y).to_vec();
100
101        // Apply transform
102        let (out0, out1, out2) = forward_rct_row_copy(&row0, &row1, &row2, transform);
103
104        // Write back SEQUENTIALLY to channels 0, 1, 2.
105        // libjxl encoder writes transformed output to sequential indices.
106        // The decoder reads sequentially, applies inverse transform, then
107        // applies the permutation to outputs to recover the original channel order.
108        channels[begin_c].row_mut(y).copy_from_slice(&out0);
109        channels[begin_c + 1].row_mut(y).copy_from_slice(&out1);
110        channels[begin_c + 2].row_mut(y).copy_from_slice(&out2);
111    }
112
113    Ok(())
114}
115
116/// Get permuted indices for the given permutation type.
117///
118/// Permutations: 0=RGB, 1=GBR, 2=BRG, 3=RBG, 4=GRB, 5=BGR
119fn permute_indices(permutation: usize) -> (usize, usize, usize) {
120    match permutation {
121        0 => (0, 1, 2), // RGB
122        1 => (1, 2, 0), // GBR
123        2 => (2, 0, 1), // BRG
124        3 => (0, 2, 1), // RBG
125        4 => (1, 0, 2), // GRB
126        5 => (2, 1, 0), // BGR
127        _ => (0, 1, 2), // Default to RGB
128    }
129}
130
131/// Apply forward RCT to a single row, returning copies.
132fn forward_rct_row_copy(
133    c0: &[i32],
134    c1: &[i32],
135    c2: &[i32],
136    transform: usize,
137) -> (Vec<i32>, Vec<i32>, Vec<i32>) {
138    let w = c0.len();
139    let mut out0 = c0.to_vec();
140    let mut out1 = c1.to_vec();
141    let mut out2 = c2.to_vec();
142
143    // libjxl decomposition: second = transform >> 1, third = transform & 1
144    // second: 0=noop, 1=subtract First, 2=subtract (First+Third)>>1
145    // third: 0=noop, 1=subtract First from Third
146    match transform {
147        0 => {
148            // No transform (permutation only handled by caller)
149        }
150        1 => {
151            // third=1: Third -= First
152            for x in 0..w {
153                out2[x] = c2[x] - c0[x];
154            }
155        }
156        2 => {
157            // second=1: Second -= First
158            for x in 0..w {
159                out1[x] = c1[x] - c0[x];
160            }
161        }
162        3 => {
163            // second=1, third=1: Second -= First, Third -= First
164            for x in 0..w {
165                out1[x] = c1[x] - c0[x];
166                out2[x] = c2[x] - c0[x];
167            }
168        }
169        4 => {
170            // second=2: Second -= (First + Third) >> 1
171            for x in 0..w {
172                out1[x] = c1[x] - ((c0[x] + c2[x]) >> 1);
173            }
174        }
175        5 => {
176            // second=2, third=1: Second -= (First + Third) >> 1, Third -= First
177            for x in 0..w {
178                out1[x] = c1[x] - ((c0[x] + c2[x]) >> 1);
179                out2[x] = c2[x] - c0[x];
180            }
181        }
182        6 => {
183            // YCoCg transform
184            // o1 = R - B           (Co)
185            // tmp = B + (o1 >> 1)
186            // o2 = G - tmp         (Cg)
187            // o0 = tmp + (o2 >> 1) (Y)
188            for x in 0..w {
189                let r = c0[x];
190                let g = c1[x];
191                let b = c2[x];
192
193                let co = r - b;
194                let tmp = b + (co >> 1);
195                let cg = g - tmp;
196                let y = tmp + (cg >> 1);
197
198                out0[x] = y;
199                out1[x] = co;
200                out2[x] = cg;
201            }
202        }
203        _ => {
204            // Unknown transform, do nothing
205        }
206    }
207
208    (out0, out1, out2)
209}
210
211/// Apply inverse RCT to three channels in-place.
212///
213/// This reverses the forward transform for decoding.
214pub fn inverse_rct(channels: &mut [Channel], begin_c: usize, rct_type: RctType) -> Result<()> {
215    if rct_type.is_noop() {
216        return Ok(());
217    }
218
219    if channels.len() < begin_c + 3 {
220        return Err(Error::InvalidInput(
221            "RCT requires at least 3 channels".to_string(),
222        ));
223    }
224
225    let h = channels[begin_c].height();
226
227    let permutation = rct_type.permutation();
228    let transform = rct_type.transform();
229
230    // Decoder convention: read sequentially, apply inverse transform,
231    // then write to permuted output indices to recover original channel order.
232    let (idx0, idx1, idx2) = permute_indices(permutation);
233
234    for y in 0..h {
235        // Read SEQUENTIALLY from channels 0, 1, 2
236        let row0: Vec<i32> = channels[begin_c].row(y).to_vec();
237        let row1: Vec<i32> = channels[begin_c + 1].row(y).to_vec();
238        let row2: Vec<i32> = channels[begin_c + 2].row(y).to_vec();
239
240        // Apply inverse transform
241        let (out0, out1, out2) = inverse_rct_row_copy(&row0, &row1, &row2, transform);
242
243        // Write back to PERMUTED output indices
244        channels[begin_c + idx0].row_mut(y).copy_from_slice(&out0);
245        channels[begin_c + idx1].row_mut(y).copy_from_slice(&out1);
246        channels[begin_c + idx2].row_mut(y).copy_from_slice(&out2);
247    }
248
249    Ok(())
250}
251
252/// Apply inverse RCT to a single row, returning copies.
253fn inverse_rct_row_copy(
254    c0: &[i32],
255    c1: &[i32],
256    c2: &[i32],
257    transform: usize,
258) -> (Vec<i32>, Vec<i32>, Vec<i32>) {
259    let w = c0.len();
260    let mut out0 = c0.to_vec();
261    let mut out1 = c1.to_vec();
262    let mut out2 = c2.to_vec();
263
264    // Inverse of libjxl transforms: second = type >> 1, third = type & 1
265    // Must reverse the forward operations in reverse order.
266    match transform {
267        0 => {
268            // No transform
269        }
270        1 => {
271            // Inverse of: Third -= First → Third += First
272            for x in 0..w {
273                out2[x] = c2[x] + c0[x];
274            }
275        }
276        2 => {
277            // Inverse of: Second -= First → Second += First
278            for x in 0..w {
279                out1[x] = c1[x] + c0[x];
280            }
281        }
282        3 => {
283            // Inverse of: Second -= First, Third -= First
284            // → Third += First, Second += First (order doesn't matter here)
285            for x in 0..w {
286                out1[x] = c1[x] + c0[x];
287                out2[x] = c2[x] + c0[x];
288            }
289        }
290        4 => {
291            // Inverse of: Second -= (First + Third) >> 1
292            // → Second += (First + Third) >> 1
293            for x in 0..w {
294                out1[x] = c1[x] + ((c0[x] + c2[x]) >> 1);
295            }
296        }
297        5 => {
298            // Inverse of: Second -= (First + Third) >> 1, Third -= First
299            // Reverse order: Third += First FIRST, then Second += (First + Third_new) >> 1
300            for x in 0..w {
301                out2[x] = c2[x] + c0[x];
302                out1[x] = c1[x] + ((c0[x] + out2[x]) >> 1);
303            }
304        }
305        6 => {
306            // Inverse YCoCg
307            // Y = c0, Co = c1, Cg = c2
308            // tmp = Y - (Cg >> 1)
309            // G = Cg + tmp
310            // B = tmp - (Co >> 1)
311            // R = B + Co
312            for x in 0..w {
313                let y = c0[x];
314                let co = c1[x];
315                let cg = c2[x];
316
317                let tmp = y - (cg >> 1);
318                let g = cg + tmp;
319                let b = tmp - (co >> 1);
320                let r = b + co;
321
322                out0[x] = r;
323                out1[x] = g;
324                out2[x] = b;
325            }
326        }
327        _ => {}
328    }
329
330    (out0, out1, out2)
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336
337    fn make_test_channels(w: usize, h: usize, values: &[(i32, i32, i32)]) -> Vec<Channel> {
338        let mut c0 = Channel::new(w, h).unwrap();
339        let mut c1 = Channel::new(w, h).unwrap();
340        let mut c2 = Channel::new(w, h).unwrap();
341
342        for (i, &(r, g, b)) in values.iter().enumerate() {
343            let x = i % w;
344            let y = i / w;
345            c0.set(x, y, r);
346            c1.set(x, y, g);
347            c2.set(x, y, b);
348        }
349
350        vec![c0, c1, c2]
351    }
352
353    #[test]
354    fn test_ycocg_roundtrip() {
355        // Test YCoCg forward and inverse
356        let original = vec![(100, 150, 200), (255, 0, 128), (50, 50, 50), (0, 255, 0)];
357        let mut channels = make_test_channels(2, 2, &original);
358
359        // Forward transform
360        forward_rct(&mut channels, 0, RctType::YCOCG).unwrap();
361
362        // Inverse transform
363        inverse_rct(&mut channels, 0, RctType::YCOCG).unwrap();
364
365        // Check roundtrip
366        for (i, &(r, g, b)) in original.iter().enumerate() {
367            let x = i % 2;
368            let y = i / 2;
369            assert_eq!(channels[0].get(x, y), r, "R mismatch at {}", i);
370            assert_eq!(channels[1].get(x, y), g, "G mismatch at {}", i);
371            assert_eq!(channels[2].get(x, y), b, "B mismatch at {}", i);
372        }
373    }
374
375    #[test]
376    fn test_subtract_green_roundtrip() {
377        let original = vec![(100, 150, 200), (255, 0, 128)];
378        let mut channels = make_test_channels(2, 1, &original);
379
380        forward_rct(&mut channels, 0, RctType::SUBTRACT_GREEN).unwrap();
381        inverse_rct(&mut channels, 0, RctType::SUBTRACT_GREEN).unwrap();
382
383        for (i, &(r, g, b)) in original.iter().enumerate() {
384            assert_eq!(channels[0].get(i, 0), r, "R mismatch at {}", i);
385            assert_eq!(channels[1].get(i, 0), g, "G mismatch at {}", i);
386            assert_eq!(channels[2].get(i, 0), b, "B mismatch at {}", i);
387        }
388    }
389
390    #[test]
391    fn test_all_transforms_roundtrip() {
392        let original = vec![(100, 150, 200), (255, 0, 128), (50, 50, 50), (0, 255, 0)];
393
394        // Test all 42 RCT types
395        for rct_type in 0..42 {
396            let mut channels = make_test_channels(2, 2, &original);
397
398            forward_rct(&mut channels, 0, RctType(rct_type)).unwrap();
399            inverse_rct(&mut channels, 0, RctType(rct_type)).unwrap();
400
401            for (i, &(r, g, b)) in original.iter().enumerate() {
402                let x = i % 2;
403                let y = i / 2;
404                assert_eq!(
405                    channels[0].get(x, y),
406                    r,
407                    "R mismatch at {} for rct_type {}",
408                    i,
409                    rct_type
410                );
411                assert_eq!(
412                    channels[1].get(x, y),
413                    g,
414                    "G mismatch at {} for rct_type {}",
415                    i,
416                    rct_type
417                );
418                assert_eq!(
419                    channels[2].get(x, y),
420                    b,
421                    "B mismatch at {} for rct_type {}",
422                    i,
423                    rct_type
424                );
425            }
426        }
427    }
428
429    #[test]
430    fn test_ycocg_decorrelation() {
431        // For correlated RGB, YCoCg should have smaller residuals
432        // Green gradient: G varies, R and B follow
433        let values: Vec<(i32, i32, i32)> = (0..8).map(|i| (i * 10, i * 10, i * 10)).collect();
434        let mut channels = make_test_channels(8, 1, &values);
435
436        forward_rct(&mut channels, 0, RctType::YCOCG).unwrap();
437
438        // For gray gradient, Co (R-B) and Cg (G-Y') should be 0
439        for i in 0..8 {
440            assert_eq!(
441                channels[1].get(i, 0),
442                0,
443                "Co should be 0 for gray, got {} at {}",
444                channels[1].get(i, 0),
445                i
446            );
447            assert_eq!(
448                channels[2].get(i, 0),
449                0,
450                "Cg should be 0 for gray, got {} at {}",
451                channels[2].get(i, 0),
452                i
453            );
454        }
455    }
456
457    #[test]
458    fn test_noop() {
459        let original = vec![(100, 150, 200)];
460        let mut channels = make_test_channels(1, 1, &original);
461        let original_data = (
462            channels[0].get(0, 0),
463            channels[1].get(0, 0),
464            channels[2].get(0, 0),
465        );
466
467        forward_rct(&mut channels, 0, RctType::NONE).unwrap();
468
469        assert_eq!(
470            (
471                channels[0].get(0, 0),
472                channels[1].get(0, 0),
473                channels[2].get(0, 0)
474            ),
475            original_data
476        );
477    }
478}