use std::any::Any;
use std::sync::Mutex;
use crate::api::JxlCmsTransformer;
use crate::error::Result;
use crate::render::RenderPipelineInPlaceStage;
use crate::render::simd_utils::{
deinterleave_2_dispatch, deinterleave_3_dispatch, deinterleave_4_dispatch,
interleave_2_dispatch, interleave_3_dispatch, interleave_4_dispatch,
};
struct CmsLocalState {
transformer: Box<dyn JxlCmsTransformer + Send + Sync>,
input_buffer: Vec<f32>,
output_buffer: Vec<f32>,
}
pub struct CmsStage {
transformer_pool: Mutex<Vec<Box<dyn JxlCmsTransformer + Send + Sync>>>,
in_channels: usize,
out_channels: usize,
black_channel: Option<usize>,
input_buffer_size: usize,
output_buffer_size: usize,
}
impl CmsStage {
pub fn new(
transformers: Vec<Box<dyn JxlCmsTransformer + Send + Sync>>,
in_channels: usize,
out_channels: usize,
black_channel: Option<usize>,
max_pixels: usize,
) -> Self {
assert!(
(1..=4).contains(&in_channels),
"CMS stage only supports 1-4 input channels, got {in_channels}"
);
assert!(
(1..=4).contains(&out_channels),
"CMS stage only supports 1-4 output channels, got {out_channels}"
);
assert!(
out_channels <= in_channels,
"out_channels ({out_channels}) must be <= in_channels ({in_channels})"
);
assert!(
black_channel.is_some() == (in_channels == 4),
"black_channel must be Some iff in_channels == 4"
);
let padded_pixels = max_pixels.next_multiple_of(16);
Self {
transformer_pool: Mutex::new(transformers),
in_channels,
out_channels,
black_channel,
input_buffer_size: padded_pixels
.checked_mul(in_channels)
.expect("CMS input buffer size overflow"),
output_buffer_size: padded_pixels
.checked_mul(out_channels)
.expect("CMS output buffer size overflow"),
}
}
}
impl std::fmt::Display for CmsStage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(k) = self.black_channel {
write!(
f,
"CMS transform: {} channels (K at {}) -> {} channels",
self.in_channels, k, self.out_channels
)
} else {
write!(
f,
"CMS transform: {} channels -> {} channels",
self.in_channels, self.out_channels
)
}
}
}
impl RenderPipelineInPlaceStage for CmsStage {
type Type = f32;
fn uses_channel(&self, c: usize) -> bool {
c < self.in_channels.min(3) || self.black_channel == Some(c)
}
fn init_local_state(&self, _thread_index: usize) -> Result<Option<Box<dyn Any + Send>>> {
let transformer = self.transformer_pool.lock().unwrap().pop();
let Some(transformer) = transformer else {
return Ok(None);
};
Ok(Some(Box::new(CmsLocalState {
transformer,
input_buffer: vec![0.0f32; self.input_buffer_size],
output_buffer: vec![0.0f32; self.output_buffer_size],
})))
}
fn process_row_chunk(
&self,
_position: (usize, usize),
xsize: usize,
row: &mut [&mut [f32]],
state: Option<&mut (dyn Any + Send)>,
) {
let Some(state) = state else {
return;
};
let state: &mut CmsLocalState = state.downcast_mut().unwrap();
debug_assert!(
xsize * self.in_channels <= state.input_buffer.len(),
"xsize {} exceeds buffer capacity",
xsize
);
if self.in_channels == 1 && self.out_channels == 1 {
state.input_buffer[..xsize].copy_from_slice(&row[0][..xsize]);
state
.transformer
.do_transform(
&state.input_buffer[..xsize],
&mut state.output_buffer[..xsize],
)
.expect("CMS transform failed");
row[0][..xsize].copy_from_slice(&state.output_buffer[..xsize]);
return;
}
let xsize_padded = xsize.next_multiple_of(16);
match self.in_channels {
2 => {
interleave_2_dispatch(
&row[0][..xsize_padded],
&row[1][..xsize_padded],
&mut state.input_buffer[..xsize_padded * 2],
);
}
3 => {
interleave_3_dispatch(
&row[0][..xsize_padded],
&row[1][..xsize_padded],
&row[2][..xsize_padded],
&mut state.input_buffer[..xsize_padded * 3],
);
}
4 => {
interleave_4_dispatch(
&row[0][..xsize_padded],
&row[1][..xsize_padded],
&row[2][..xsize_padded],
&row[3][..xsize_padded],
&mut state.input_buffer[..xsize_padded * 4],
);
}
_ => unreachable!("CMS stage only supports 2-4 input channels here"),
}
state
.transformer
.do_transform(
&state.input_buffer[..xsize * self.in_channels],
&mut state.output_buffer[..xsize * self.out_channels],
)
.expect("CMS transform failed");
let output_buf = &state.output_buffer;
match self.out_channels {
1 => {
row[0][..xsize].copy_from_slice(&output_buf[..xsize]);
}
2 => {
let (r0, r1) = row.split_at_mut(1);
deinterleave_2_dispatch(
&output_buf[..xsize_padded * 2],
&mut r0[0][..xsize_padded],
&mut r1[0][..xsize_padded],
);
}
3 => {
let (r0, rest) = row.split_at_mut(1);
let (r1, r2) = rest.split_at_mut(1);
deinterleave_3_dispatch(
&output_buf[..xsize_padded * 3],
&mut r0[0][..xsize_padded],
&mut r1[0][..xsize_padded],
&mut r2[0][..xsize_padded],
);
}
4 => {
let (r0, rest) = row.split_at_mut(1);
let (r1, rest) = rest.split_at_mut(1);
let (r2, r3) = rest.split_at_mut(1);
deinterleave_4_dispatch(
&output_buf[..xsize_padded * 4],
&mut r0[0][..xsize_padded],
&mut r1[0][..xsize_padded],
&mut r2[0][..xsize_padded],
&mut r3[0][..xsize_padded],
);
}
_ => unreachable!("CMS stage only supports 1-4 output channels"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
struct IdentityTransformer;
impl JxlCmsTransformer for IdentityTransformer {
fn do_transform(&mut self, input: &[f32], output: &mut [f32]) -> Result<()> {
output.copy_from_slice(input);
Ok(())
}
fn do_transform_inplace(&mut self, _inout: &mut [f32]) -> Result<()> {
Ok(())
}
}
struct ScaleTransformer;
impl JxlCmsTransformer for ScaleTransformer {
fn do_transform(&mut self, input: &[f32], output: &mut [f32]) -> Result<()> {
for (o, i) in output.iter_mut().zip(input.iter()) {
*o = *i * 2.0;
}
Ok(())
}
fn do_transform_inplace(&mut self, inout: &mut [f32]) -> Result<()> {
for v in inout.iter_mut() {
*v *= 2.0;
}
Ok(())
}
}
struct FourToThreeTransformer;
impl JxlCmsTransformer for FourToThreeTransformer {
fn do_transform(&mut self, input: &[f32], output: &mut [f32]) -> Result<()> {
let num_pixels = input.len() / 4;
for i in 0..num_pixels {
output[i * 3] = 1.0 - input[i * 4]; output[i * 3 + 1] = 1.0 - input[i * 4 + 1]; output[i * 3 + 2] = 1.0 - input[i * 4 + 2]; }
Ok(())
}
fn do_transform_inplace(&mut self, _inout: &mut [f32]) -> Result<()> {
panic!("FourToThreeTransformer does not support in-place transform");
}
}
#[test]
fn test_cms_stage_rgb_inplace() {
let transformers: Vec<Box<dyn JxlCmsTransformer + Send + Sync>> =
vec![Box::new(ScaleTransformer)];
let stage = CmsStage::new(transformers, 3, 3, None, 16);
let state = stage.init_local_state(0).unwrap().unwrap();
let mut state_ref: Box<dyn Any + Send> = state;
let mut ch0 = vec![1.0, 2.0, 3.0, 4.0];
let mut ch1 = vec![0.5, 0.5, 0.5, 0.5];
let mut ch2 = vec![0.1, 0.2, 0.3, 0.4];
ch0.resize(16, 0.0);
ch1.resize(16, 0.0);
ch2.resize(16, 0.0);
let mut rows: Vec<&mut [f32]> = vec![&mut ch0, &mut ch1, &mut ch2];
stage.process_row_chunk((0, 0), 4, &mut rows, Some(state_ref.as_mut()));
assert_eq!(ch0[0], 2.0);
assert_eq!(ch0[1], 4.0);
assert_eq!(ch1[0], 1.0);
assert_eq!(ch2[0], 0.2);
}
#[test]
fn test_cms_stage_cmyk_to_rgb() {
let transformers: Vec<Box<dyn JxlCmsTransformer + Send + Sync>> =
vec![Box::new(FourToThreeTransformer)];
let stage = CmsStage::new(transformers, 4, 3, Some(5), 16);
let state = stage.init_local_state(0).unwrap().unwrap();
let mut state_ref: Box<dyn Any + Send> = state;
let mut ch0 = vec![0.2, 0.5]; let mut ch1 = vec![0.3, 0.5]; let mut ch2 = vec![0.4, 0.5]; let mut ch3 = vec![0.1, 0.5];
ch0.resize(16, 0.0);
ch1.resize(16, 0.0);
ch2.resize(16, 0.0);
ch3.resize(16, 0.0);
let mut rows: Vec<&mut [f32]> = vec![&mut ch0, &mut ch1, &mut ch2, &mut ch3];
stage.process_row_chunk((0, 0), 2, &mut rows, Some(state_ref.as_mut()));
assert!((ch0[0] - 0.8).abs() < 0.001);
assert!((ch1[0] - 0.7).abs() < 0.001);
assert!((ch2[0] - 0.6).abs() < 0.001);
}
#[test]
fn test_cms_stage_single_channel() {
let transformers: Vec<Box<dyn JxlCmsTransformer + Send + Sync>> =
vec![Box::new(ScaleTransformer)];
let stage = CmsStage::new(transformers, 1, 1, None, 16);
let state = stage.init_local_state(0).unwrap().unwrap();
let mut state_ref: Box<dyn Any + Send> = state;
let mut ch0 = vec![1.0, 2.0, 3.0, 4.0];
ch0.resize(16, 0.0);
let mut rows: Vec<&mut [f32]> = vec![&mut ch0];
stage.process_row_chunk((0, 0), 4, &mut rows, Some(state_ref.as_mut()));
assert_eq!(ch0[0], 2.0);
assert_eq!(ch0[1], 4.0);
assert_eq!(ch0[2], 6.0);
assert_eq!(ch0[3], 8.0);
}
#[test]
fn test_cms_stage_no_transformers() {
let transformers: Vec<Box<dyn JxlCmsTransformer + Send + Sync>> = vec![];
let stage = CmsStage::new(transformers, 3, 3, None, 16);
let state = stage.init_local_state(0).unwrap();
assert!(state.is_none());
let mut ch0 = vec![1.0, 2.0, 3.0, 4.0];
ch0.resize(16, 0.0);
let original = ch0.clone();
let mut rows: Vec<&mut [f32]> = vec![&mut ch0];
stage.process_row_chunk((0, 0), 4, &mut rows, None);
assert_eq!(ch0, original);
}
#[test]
fn test_cms_stage_display() {
let transformers: Vec<Box<dyn JxlCmsTransformer + Send + Sync>> =
vec![Box::new(IdentityTransformer)];
let stage_rgb = CmsStage::new(transformers, 3, 3, None, 16);
let display = format!("{}", stage_rgb);
assert!(display.contains("3 channels -> 3 channels"));
let transformers: Vec<Box<dyn JxlCmsTransformer + Send + Sync>> =
vec![Box::new(IdentityTransformer)];
let stage_cmyk = CmsStage::new(transformers, 4, 3, Some(5), 16);
let display = format!("{}", stage_cmyk);
assert!(display.contains("4 channels"));
assert!(display.contains("K at 5"));
assert!(display.contains("-> 3 channels"));
}
#[test]
fn test_cms_stage_uses_channel() {
let transformers: Vec<Box<dyn JxlCmsTransformer + Send + Sync>> =
vec![Box::new(IdentityTransformer)];
let stage_rgb = CmsStage::new(transformers, 3, 3, None, 16);
assert!(stage_rgb.uses_channel(0));
assert!(stage_rgb.uses_channel(1));
assert!(stage_rgb.uses_channel(2));
assert!(!stage_rgb.uses_channel(3));
assert!(!stage_rgb.uses_channel(5));
let transformers: Vec<Box<dyn JxlCmsTransformer + Send + Sync>> =
vec![Box::new(IdentityTransformer)];
let stage_cmyk = CmsStage::new(transformers, 4, 3, Some(5), 16);
assert!(stage_cmyk.uses_channel(0));
assert!(stage_cmyk.uses_channel(1));
assert!(stage_cmyk.uses_channel(2));
assert!(!stage_cmyk.uses_channel(3)); assert!(!stage_cmyk.uses_channel(4)); assert!(stage_cmyk.uses_channel(5)); assert!(!stage_cmyk.uses_channel(6));
}
#[test]
fn test_stage_consistency_cms() -> crate::error::Result<()> {
crate::render::test::test_stage_consistency(
|| CmsStage::new(vec![Box::new(IdentityTransformer)], 3, 3, None, 512),
(500, 500),
3,
)
}
}