1use crate::error::{Error, Result};
23use crate::modular::Channel;
24
25#[derive(Clone, Copy, Debug, Default)]
28pub struct RctType(pub u8);
29
30impl RctType {
31 pub const YCOCG: RctType = RctType(6);
33
34 pub const NONE: RctType = RctType(0);
36
37 pub const SUBTRACT_GREEN: RctType = RctType(3);
39
40 pub fn permutation(&self) -> usize {
42 (self.0 / 7) as usize
43 }
44
45 pub fn transform(&self) -> usize {
47 (self.0 % 7) as usize
48 }
49
50 pub fn is_noop(&self) -> bool {
52 self.0 == 0
53 }
54}
55
56pub fn forward_rct(channels: &mut [Channel], begin_c: usize, rct_type: RctType) -> Result<()> {
65 if rct_type.is_noop() {
66 return Ok(());
67 }
68
69 if channels.len() < begin_c + 3 {
71 return Err(Error::InvalidInput(
72 "RCT requires at least 3 channels".to_string(),
73 ));
74 }
75
76 let w = channels[begin_c].width();
78 let h = channels[begin_c].height();
79 for c in &channels[begin_c..begin_c + 3] {
80 if c.width() != w || c.height() != h {
81 return Err(Error::InvalidInput(
82 "RCT requires channels with same dimensions".to_string(),
83 ));
84 }
85 }
86
87 let permutation = rct_type.permutation();
88 let transform = rct_type.transform();
89
90 let (idx0, idx1, idx2) = permute_indices(permutation);
92
93 for y in 0..h {
96 let row0: Vec<i32> = channels[begin_c + idx0].row(y).to_vec();
98 let row1: Vec<i32> = channels[begin_c + idx1].row(y).to_vec();
99 let row2: Vec<i32> = channels[begin_c + idx2].row(y).to_vec();
100
101 let (out0, out1, out2) = forward_rct_row_copy(&row0, &row1, &row2, transform);
103
104 channels[begin_c].row_mut(y).copy_from_slice(&out0);
109 channels[begin_c + 1].row_mut(y).copy_from_slice(&out1);
110 channels[begin_c + 2].row_mut(y).copy_from_slice(&out2);
111 }
112
113 Ok(())
114}
115
116fn permute_indices(permutation: usize) -> (usize, usize, usize) {
120 match permutation {
121 0 => (0, 1, 2), 1 => (1, 2, 0), 2 => (2, 0, 1), 3 => (0, 2, 1), 4 => (1, 0, 2), 5 => (2, 1, 0), _ => (0, 1, 2), }
129}
130
131fn forward_rct_row_copy(
133 c0: &[i32],
134 c1: &[i32],
135 c2: &[i32],
136 transform: usize,
137) -> (Vec<i32>, Vec<i32>, Vec<i32>) {
138 let w = c0.len();
139 let mut out0 = c0.to_vec();
140 let mut out1 = c1.to_vec();
141 let mut out2 = c2.to_vec();
142
143 match transform {
147 0 => {
148 }
150 1 => {
151 for x in 0..w {
153 out2[x] = c2[x] - c0[x];
154 }
155 }
156 2 => {
157 for x in 0..w {
159 out1[x] = c1[x] - c0[x];
160 }
161 }
162 3 => {
163 for x in 0..w {
165 out1[x] = c1[x] - c0[x];
166 out2[x] = c2[x] - c0[x];
167 }
168 }
169 4 => {
170 for x in 0..w {
172 out1[x] = c1[x] - ((c0[x] + c2[x]) >> 1);
173 }
174 }
175 5 => {
176 for x in 0..w {
178 out1[x] = c1[x] - ((c0[x] + c2[x]) >> 1);
179 out2[x] = c2[x] - c0[x];
180 }
181 }
182 6 => {
183 for x in 0..w {
189 let r = c0[x];
190 let g = c1[x];
191 let b = c2[x];
192
193 let co = r - b;
194 let tmp = b + (co >> 1);
195 let cg = g - tmp;
196 let y = tmp + (cg >> 1);
197
198 out0[x] = y;
199 out1[x] = co;
200 out2[x] = cg;
201 }
202 }
203 _ => {
204 }
206 }
207
208 (out0, out1, out2)
209}
210
211pub fn inverse_rct(channels: &mut [Channel], begin_c: usize, rct_type: RctType) -> Result<()> {
215 if rct_type.is_noop() {
216 return Ok(());
217 }
218
219 if channels.len() < begin_c + 3 {
220 return Err(Error::InvalidInput(
221 "RCT requires at least 3 channels".to_string(),
222 ));
223 }
224
225 let h = channels[begin_c].height();
226
227 let permutation = rct_type.permutation();
228 let transform = rct_type.transform();
229
230 let (idx0, idx1, idx2) = permute_indices(permutation);
233
234 for y in 0..h {
235 let row0: Vec<i32> = channels[begin_c].row(y).to_vec();
237 let row1: Vec<i32> = channels[begin_c + 1].row(y).to_vec();
238 let row2: Vec<i32> = channels[begin_c + 2].row(y).to_vec();
239
240 let (out0, out1, out2) = inverse_rct_row_copy(&row0, &row1, &row2, transform);
242
243 channels[begin_c + idx0].row_mut(y).copy_from_slice(&out0);
245 channels[begin_c + idx1].row_mut(y).copy_from_slice(&out1);
246 channels[begin_c + idx2].row_mut(y).copy_from_slice(&out2);
247 }
248
249 Ok(())
250}
251
252fn inverse_rct_row_copy(
254 c0: &[i32],
255 c1: &[i32],
256 c2: &[i32],
257 transform: usize,
258) -> (Vec<i32>, Vec<i32>, Vec<i32>) {
259 let w = c0.len();
260 let mut out0 = c0.to_vec();
261 let mut out1 = c1.to_vec();
262 let mut out2 = c2.to_vec();
263
264 match transform {
267 0 => {
268 }
270 1 => {
271 for x in 0..w {
273 out2[x] = c2[x] + c0[x];
274 }
275 }
276 2 => {
277 for x in 0..w {
279 out1[x] = c1[x] + c0[x];
280 }
281 }
282 3 => {
283 for x in 0..w {
286 out1[x] = c1[x] + c0[x];
287 out2[x] = c2[x] + c0[x];
288 }
289 }
290 4 => {
291 for x in 0..w {
294 out1[x] = c1[x] + ((c0[x] + c2[x]) >> 1);
295 }
296 }
297 5 => {
298 for x in 0..w {
301 out2[x] = c2[x] + c0[x];
302 out1[x] = c1[x] + ((c0[x] + out2[x]) >> 1);
303 }
304 }
305 6 => {
306 for x in 0..w {
313 let y = c0[x];
314 let co = c1[x];
315 let cg = c2[x];
316
317 let tmp = y - (cg >> 1);
318 let g = cg + tmp;
319 let b = tmp - (co >> 1);
320 let r = b + co;
321
322 out0[x] = r;
323 out1[x] = g;
324 out2[x] = b;
325 }
326 }
327 _ => {}
328 }
329
330 (out0, out1, out2)
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336
337 fn make_test_channels(w: usize, h: usize, values: &[(i32, i32, i32)]) -> Vec<Channel> {
338 let mut c0 = Channel::new(w, h).unwrap();
339 let mut c1 = Channel::new(w, h).unwrap();
340 let mut c2 = Channel::new(w, h).unwrap();
341
342 for (i, &(r, g, b)) in values.iter().enumerate() {
343 let x = i % w;
344 let y = i / w;
345 c0.set(x, y, r);
346 c1.set(x, y, g);
347 c2.set(x, y, b);
348 }
349
350 vec![c0, c1, c2]
351 }
352
353 #[test]
354 fn test_ycocg_roundtrip() {
355 let original = vec![(100, 150, 200), (255, 0, 128), (50, 50, 50), (0, 255, 0)];
357 let mut channels = make_test_channels(2, 2, &original);
358
359 forward_rct(&mut channels, 0, RctType::YCOCG).unwrap();
361
362 inverse_rct(&mut channels, 0, RctType::YCOCG).unwrap();
364
365 for (i, &(r, g, b)) in original.iter().enumerate() {
367 let x = i % 2;
368 let y = i / 2;
369 assert_eq!(channels[0].get(x, y), r, "R mismatch at {}", i);
370 assert_eq!(channels[1].get(x, y), g, "G mismatch at {}", i);
371 assert_eq!(channels[2].get(x, y), b, "B mismatch at {}", i);
372 }
373 }
374
375 #[test]
376 fn test_subtract_green_roundtrip() {
377 let original = vec![(100, 150, 200), (255, 0, 128)];
378 let mut channels = make_test_channels(2, 1, &original);
379
380 forward_rct(&mut channels, 0, RctType::SUBTRACT_GREEN).unwrap();
381 inverse_rct(&mut channels, 0, RctType::SUBTRACT_GREEN).unwrap();
382
383 for (i, &(r, g, b)) in original.iter().enumerate() {
384 assert_eq!(channels[0].get(i, 0), r, "R mismatch at {}", i);
385 assert_eq!(channels[1].get(i, 0), g, "G mismatch at {}", i);
386 assert_eq!(channels[2].get(i, 0), b, "B mismatch at {}", i);
387 }
388 }
389
390 #[test]
391 fn test_all_transforms_roundtrip() {
392 let original = vec![(100, 150, 200), (255, 0, 128), (50, 50, 50), (0, 255, 0)];
393
394 for rct_type in 0..42 {
396 let mut channels = make_test_channels(2, 2, &original);
397
398 forward_rct(&mut channels, 0, RctType(rct_type)).unwrap();
399 inverse_rct(&mut channels, 0, RctType(rct_type)).unwrap();
400
401 for (i, &(r, g, b)) in original.iter().enumerate() {
402 let x = i % 2;
403 let y = i / 2;
404 assert_eq!(
405 channels[0].get(x, y),
406 r,
407 "R mismatch at {} for rct_type {}",
408 i,
409 rct_type
410 );
411 assert_eq!(
412 channels[1].get(x, y),
413 g,
414 "G mismatch at {} for rct_type {}",
415 i,
416 rct_type
417 );
418 assert_eq!(
419 channels[2].get(x, y),
420 b,
421 "B mismatch at {} for rct_type {}",
422 i,
423 rct_type
424 );
425 }
426 }
427 }
428
429 #[test]
430 fn test_ycocg_decorrelation() {
431 let values: Vec<(i32, i32, i32)> = (0..8).map(|i| (i * 10, i * 10, i * 10)).collect();
434 let mut channels = make_test_channels(8, 1, &values);
435
436 forward_rct(&mut channels, 0, RctType::YCOCG).unwrap();
437
438 for i in 0..8 {
440 assert_eq!(
441 channels[1].get(i, 0),
442 0,
443 "Co should be 0 for gray, got {} at {}",
444 channels[1].get(i, 0),
445 i
446 );
447 assert_eq!(
448 channels[2].get(i, 0),
449 0,
450 "Cg should be 0 for gray, got {} at {}",
451 channels[2].get(i, 0),
452 i
453 );
454 }
455 }
456
457 #[test]
458 fn test_noop() {
459 let original = vec![(100, 150, 200)];
460 let mut channels = make_test_channels(1, 1, &original);
461 let original_data = (
462 channels[0].get(0, 0),
463 channels[1].get(0, 0),
464 channels[2].get(0, 0),
465 );
466
467 forward_rct(&mut channels, 0, RctType::NONE).unwrap();
468
469 assert_eq!(
470 (
471 channels[0].get(0, 0),
472 channels[1].get(0, 0),
473 channels[2].get(0, 0)
474 ),
475 original_data
476 );
477 }
478}