use crate::error::{Error, Result};
use crate::modular::Channel;
#[derive(Clone, Copy, Debug, Default)]
pub struct RctType(pub u8);
impl RctType {
pub const YCOCG: RctType = RctType(6);
pub const NONE: RctType = RctType(0);
pub const SUBTRACT_GREEN: RctType = RctType(3);
pub fn permutation(&self) -> usize {
(self.0 / 7) as usize
}
pub fn transform(&self) -> usize {
(self.0 % 7) as usize
}
pub fn is_noop(&self) -> bool {
self.0 == 0
}
}
pub fn forward_rct(channels: &mut [Channel], begin_c: usize, rct_type: RctType) -> Result<()> {
if rct_type.is_noop() {
return Ok(());
}
if channels.len() < begin_c + 3 {
return Err(Error::InvalidInput(
"RCT requires at least 3 channels".to_string(),
));
}
let w = channels[begin_c].width();
let h = channels[begin_c].height();
for c in &channels[begin_c..begin_c + 3] {
if c.width() != w || c.height() != h {
return Err(Error::InvalidInput(
"RCT requires channels with same dimensions".to_string(),
));
}
}
let permutation = rct_type.permutation();
let transform = rct_type.transform();
let (idx0, idx1, idx2) = permute_indices(permutation);
for y in 0..h {
let row0: Vec<i32> = channels[begin_c + idx0].row(y).to_vec();
let row1: Vec<i32> = channels[begin_c + idx1].row(y).to_vec();
let row2: Vec<i32> = channels[begin_c + idx2].row(y).to_vec();
let (out0, out1, out2) = forward_rct_row_copy(&row0, &row1, &row2, transform);
channels[begin_c].row_mut(y).copy_from_slice(&out0);
channels[begin_c + 1].row_mut(y).copy_from_slice(&out1);
channels[begin_c + 2].row_mut(y).copy_from_slice(&out2);
}
Ok(())
}
fn permute_indices(permutation: usize) -> (usize, usize, usize) {
match permutation {
0 => (0, 1, 2), 1 => (1, 2, 0), 2 => (2, 0, 1), 3 => (0, 2, 1), 4 => (1, 0, 2), 5 => (2, 1, 0), _ => (0, 1, 2), }
}
fn forward_rct_row_copy(
c0: &[i32],
c1: &[i32],
c2: &[i32],
transform: usize,
) -> (Vec<i32>, Vec<i32>, Vec<i32>) {
let w = c0.len();
let mut out0 = c0.to_vec();
let mut out1 = c1.to_vec();
let mut out2 = c2.to_vec();
match transform {
0 => {
}
1 => {
for x in 0..w {
out2[x] = c2[x] - c0[x];
}
}
2 => {
for x in 0..w {
out1[x] = c1[x] - c0[x];
}
}
3 => {
for x in 0..w {
out1[x] = c1[x] - c0[x];
out2[x] = c2[x] - c0[x];
}
}
4 => {
for x in 0..w {
out1[x] = c1[x] - ((c0[x] + c2[x]) >> 1);
}
}
5 => {
for x in 0..w {
out1[x] = c1[x] - ((c0[x] + c2[x]) >> 1);
out2[x] = c2[x] - c0[x];
}
}
6 => {
for x in 0..w {
let r = c0[x];
let g = c1[x];
let b = c2[x];
let co = r - b;
let tmp = b + (co >> 1);
let cg = g - tmp;
let y = tmp + (cg >> 1);
out0[x] = y;
out1[x] = co;
out2[x] = cg;
}
}
_ => {
}
}
(out0, out1, out2)
}
pub fn inverse_rct(channels: &mut [Channel], begin_c: usize, rct_type: RctType) -> Result<()> {
if rct_type.is_noop() {
return Ok(());
}
if channels.len() < begin_c + 3 {
return Err(Error::InvalidInput(
"RCT requires at least 3 channels".to_string(),
));
}
let h = channels[begin_c].height();
let permutation = rct_type.permutation();
let transform = rct_type.transform();
let (idx0, idx1, idx2) = permute_indices(permutation);
for y in 0..h {
let row0: Vec<i32> = channels[begin_c].row(y).to_vec();
let row1: Vec<i32> = channels[begin_c + 1].row(y).to_vec();
let row2: Vec<i32> = channels[begin_c + 2].row(y).to_vec();
let (out0, out1, out2) = inverse_rct_row_copy(&row0, &row1, &row2, transform);
channels[begin_c + idx0].row_mut(y).copy_from_slice(&out0);
channels[begin_c + idx1].row_mut(y).copy_from_slice(&out1);
channels[begin_c + idx2].row_mut(y).copy_from_slice(&out2);
}
Ok(())
}
fn inverse_rct_row_copy(
c0: &[i32],
c1: &[i32],
c2: &[i32],
transform: usize,
) -> (Vec<i32>, Vec<i32>, Vec<i32>) {
let w = c0.len();
let mut out0 = c0.to_vec();
let mut out1 = c1.to_vec();
let mut out2 = c2.to_vec();
match transform {
0 => {
}
1 => {
for x in 0..w {
out2[x] = c2[x] + c0[x];
}
}
2 => {
for x in 0..w {
out1[x] = c1[x] + c0[x];
}
}
3 => {
for x in 0..w {
out1[x] = c1[x] + c0[x];
out2[x] = c2[x] + c0[x];
}
}
4 => {
for x in 0..w {
out1[x] = c1[x] + ((c0[x] + c2[x]) >> 1);
}
}
5 => {
for x in 0..w {
out2[x] = c2[x] + c0[x];
out1[x] = c1[x] + ((c0[x] + out2[x]) >> 1);
}
}
6 => {
for x in 0..w {
let y = c0[x];
let co = c1[x];
let cg = c2[x];
let tmp = y - (cg >> 1);
let g = cg + tmp;
let b = tmp - (co >> 1);
let r = b + co;
out0[x] = r;
out1[x] = g;
out2[x] = b;
}
}
_ => {}
}
(out0, out1, out2)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_test_channels(w: usize, h: usize, values: &[(i32, i32, i32)]) -> Vec<Channel> {
let mut c0 = Channel::new(w, h).unwrap();
let mut c1 = Channel::new(w, h).unwrap();
let mut c2 = Channel::new(w, h).unwrap();
for (i, &(r, g, b)) in values.iter().enumerate() {
let x = i % w;
let y = i / w;
c0.set(x, y, r);
c1.set(x, y, g);
c2.set(x, y, b);
}
vec![c0, c1, c2]
}
#[test]
fn test_ycocg_roundtrip() {
let original = vec![(100, 150, 200), (255, 0, 128), (50, 50, 50), (0, 255, 0)];
let mut channels = make_test_channels(2, 2, &original);
forward_rct(&mut channels, 0, RctType::YCOCG).unwrap();
inverse_rct(&mut channels, 0, RctType::YCOCG).unwrap();
for (i, &(r, g, b)) in original.iter().enumerate() {
let x = i % 2;
let y = i / 2;
assert_eq!(channels[0].get(x, y), r, "R mismatch at {}", i);
assert_eq!(channels[1].get(x, y), g, "G mismatch at {}", i);
assert_eq!(channels[2].get(x, y), b, "B mismatch at {}", i);
}
}
#[test]
fn test_subtract_green_roundtrip() {
let original = vec![(100, 150, 200), (255, 0, 128)];
let mut channels = make_test_channels(2, 1, &original);
forward_rct(&mut channels, 0, RctType::SUBTRACT_GREEN).unwrap();
inverse_rct(&mut channels, 0, RctType::SUBTRACT_GREEN).unwrap();
for (i, &(r, g, b)) in original.iter().enumerate() {
assert_eq!(channels[0].get(i, 0), r, "R mismatch at {}", i);
assert_eq!(channels[1].get(i, 0), g, "G mismatch at {}", i);
assert_eq!(channels[2].get(i, 0), b, "B mismatch at {}", i);
}
}
#[test]
fn test_all_transforms_roundtrip() {
let original = vec![(100, 150, 200), (255, 0, 128), (50, 50, 50), (0, 255, 0)];
for rct_type in 0..42 {
let mut channels = make_test_channels(2, 2, &original);
forward_rct(&mut channels, 0, RctType(rct_type)).unwrap();
inverse_rct(&mut channels, 0, RctType(rct_type)).unwrap();
for (i, &(r, g, b)) in original.iter().enumerate() {
let x = i % 2;
let y = i / 2;
assert_eq!(
channels[0].get(x, y),
r,
"R mismatch at {} for rct_type {}",
i,
rct_type
);
assert_eq!(
channels[1].get(x, y),
g,
"G mismatch at {} for rct_type {}",
i,
rct_type
);
assert_eq!(
channels[2].get(x, y),
b,
"B mismatch at {} for rct_type {}",
i,
rct_type
);
}
}
}
#[test]
fn test_ycocg_decorrelation() {
let values: Vec<(i32, i32, i32)> = (0..8).map(|i| (i * 10, i * 10, i * 10)).collect();
let mut channels = make_test_channels(8, 1, &values);
forward_rct(&mut channels, 0, RctType::YCOCG).unwrap();
for i in 0..8 {
assert_eq!(
channels[1].get(i, 0),
0,
"Co should be 0 for gray, got {} at {}",
channels[1].get(i, 0),
i
);
assert_eq!(
channels[2].get(i, 0),
0,
"Cg should be 0 for gray, got {} at {}",
channels[2].get(i, 0),
i
);
}
}
#[test]
fn test_noop() {
let original = vec![(100, 150, 200)];
let mut channels = make_test_channels(1, 1, &original);
let original_data = (
channels[0].get(0, 0),
channels[1].get(0, 0),
channels[2].get(0, 0),
);
forward_rct(&mut channels, 0, RctType::NONE).unwrap();
assert_eq!(
(
channels[0].get(0, 0),
channels[1].get(0, 0),
channels[2].get(0, 0)
),
original_data
);
}
}