use crate::conversions::katana::md_nx3::interpolate_out_function;
use crate::conversions::katana::{KatanaFinalStage, KatanaInitialStage};
use crate::conversions::md_lut::{MultidimensionalLut, tetra_3i_to_any_vec};
use crate::profile::LutDataType;
use crate::safe_math::{SafeMul, SafePowi};
use crate::trc::lut_interp_linear_float;
use crate::{
CmsError, DataColorSpace, Layout, MalformedSize, PointeeSizeExpressible, TransformOptions,
};
use num_traits::AsPrimitive;
use std::array::from_fn;
use std::marker::PhantomData;
#[derive(Default)]
struct KatanaLutNx3<T> {
linearization: Vec<Vec<f32>>,
clut: Vec<f32>,
grid_size: u8,
input_inks: usize,
output: [Vec<f32>; 3],
_phantom: PhantomData<T>,
bit_depth: usize,
}
struct KatanaLut3xN<T> {
linearization: [Vec<f32>; 3],
clut: Vec<f32>,
grid_size: u8,
output_inks: usize,
output: Vec<Vec<f32>>,
dst_layout: Layout,
target_color_space: DataColorSpace,
_phantom: PhantomData<T>,
bit_depth: usize,
}
impl<T: Copy + PointeeSizeExpressible + AsPrimitive<f32>> KatanaLutNx3<T> {
fn to_pcs_impl(&self, input: &[T]) -> Result<Vec<f32>, CmsError> {
if input.len() % self.input_inks != 0 {
return Err(CmsError::LaneMultipleOfChannels);
}
let norm_value = if T::FINITE {
1.0 / ((1u32 << self.bit_depth) - 1) as f32
} else {
1.0
};
let grid_sizes: [u8; 16] = from_fn(|i| {
if i < self.input_inks {
self.grid_size
} else {
0
}
});
let md_lut = MultidimensionalLut::new(grid_sizes, self.input_inks, 3);
let layout = Layout::from_inks(self.input_inks);
let mut inks = vec![0.; self.input_inks];
let mut dst = vec![0.; (input.len() / layout.channels()) * 3];
let fetcher = interpolate_out_function(layout);
for (dest, src) in dst
.chunks_exact_mut(3)
.zip(input.chunks_exact(layout.channels()))
{
for ((ink, src_ink), curve) in inks.iter_mut().zip(src).zip(self.linearization.iter()) {
*ink = lut_interp_linear_float(src_ink.as_() * norm_value, curve);
}
let clut = fetcher(&md_lut, &self.clut, &inks);
let pcs_x = lut_interp_linear_float(clut.v[0], &self.output[0]);
let pcs_y = lut_interp_linear_float(clut.v[1], &self.output[1]);
let pcs_z = lut_interp_linear_float(clut.v[2], &self.output[2]);
dest[0] = pcs_x;
dest[1] = pcs_y;
dest[2] = pcs_z;
}
Ok(dst)
}
}
impl<T: Copy + PointeeSizeExpressible + AsPrimitive<f32>> KatanaInitialStage<f32, T>
for KatanaLutNx3<T>
{
fn to_pcs(&self, input: &[T]) -> Result<Vec<f32>, CmsError> {
if input.len() % self.input_inks != 0 {
return Err(CmsError::LaneMultipleOfChannels);
}
self.to_pcs_impl(input)
}
}
impl<T: Copy + PointeeSizeExpressible + AsPrimitive<f32>> KatanaFinalStage<f32, T>
for KatanaLut3xN<T>
where
f32: AsPrimitive<T>,
{
fn to_output(&self, src: &mut [f32], dst: &mut [T]) -> Result<(), CmsError> {
if src.len() % 3 != 0 {
return Err(CmsError::LaneMultipleOfChannels);
}
let grid_sizes: [u8; 16] = from_fn(|i| {
if i < self.output_inks {
self.grid_size
} else {
0
}
});
let md_lut = MultidimensionalLut::new(grid_sizes, 3, self.output_inks);
let scale_value = if T::FINITE {
((1u32 << self.bit_depth) - 1) as f32
} else {
1.0
};
let mut working = vec![0.; self.output_inks];
for (dest, src) in dst
.chunks_exact_mut(self.dst_layout.channels())
.zip(src.chunks_exact(3))
{
let x = lut_interp_linear_float(src[0], &self.linearization[0]);
let y = lut_interp_linear_float(src[1], &self.linearization[1]);
let z = lut_interp_linear_float(src[2], &self.linearization[2]);
tetra_3i_to_any_vec(&md_lut, &self.clut, x, y, z, &mut working, self.output_inks);
for (ink, curve) in working.iter_mut().zip(self.output.iter()) {
*ink = lut_interp_linear_float(*ink, curve);
}
if T::FINITE {
for (dst, ink) in dest.iter_mut().zip(working.iter()) {
*dst = (*ink * scale_value).round().max(0.).min(scale_value).as_();
}
} else {
for (dst, ink) in dest.iter_mut().zip(working.iter()) {
*dst = (*ink * scale_value).as_();
}
}
}
if self.dst_layout == Layout::Rgba && self.target_color_space == DataColorSpace::Rgb {
for dst in dst.chunks_exact_mut(self.dst_layout.channels()) {
dst[3] = scale_value.as_();
}
}
Ok(())
}
}
fn katana_make_lut_nx3<T: Copy + PointeeSizeExpressible + AsPrimitive<f32>>(
inks: usize,
lut: &LutDataType,
_: TransformOptions,
_: DataColorSpace,
bit_depth: usize,
) -> Result<KatanaLutNx3<T>, CmsError> {
if inks != lut.num_input_channels as usize {
return Err(CmsError::UnsupportedProfileConnection);
}
if lut.num_output_channels != 3 {
return Err(CmsError::UnsupportedProfileConnection);
}
let clut_length: usize = (lut.num_clut_grid_points as usize)
.safe_powi(lut.num_input_channels as u32)?
.safe_mul(lut.num_output_channels as usize)?;
let clut_table = lut.clut_table.to_clut_f32();
if clut_table.len() != clut_length {
return Err(CmsError::MalformedClut(MalformedSize {
size: clut_table.len(),
expected: clut_length,
}));
}
let linearization_table = lut.input_table.to_clut_f32();
if linearization_table.len() < lut.num_input_table_entries as usize * inks {
return Err(CmsError::MalformedCurveLutTable(MalformedSize {
size: linearization_table.len(),
expected: lut.num_input_table_entries as usize * inks,
}));
}
let linearization = (0..inks)
.map(|x| {
linearization_table[x * lut.num_input_table_entries as usize
..(x + 1) * lut.num_input_table_entries as usize]
.to_vec()
})
.collect::<_>();
let gamma_table = lut.output_table.to_clut_f32();
if gamma_table.len() < lut.num_output_table_entries as usize * 3 {
return Err(CmsError::MalformedCurveLutTable(MalformedSize {
size: gamma_table.len(),
expected: lut.num_output_table_entries as usize * 3,
}));
}
let gamma_curve0 = gamma_table[..lut.num_output_table_entries as usize].to_vec();
let gamma_curve1 = gamma_table
[lut.num_output_table_entries as usize..lut.num_output_table_entries as usize * 2]
.to_vec();
let gamma_curve2 = gamma_table
[lut.num_output_table_entries as usize * 2..lut.num_output_table_entries as usize * 3]
.to_vec();
let transform = KatanaLutNx3::<T> {
linearization,
clut: clut_table,
grid_size: lut.num_clut_grid_points,
output: [gamma_curve0, gamma_curve1, gamma_curve2],
input_inks: inks,
_phantom: PhantomData,
bit_depth,
};
Ok(transform)
}
fn katana_make_lut_3xn<T: Copy + PointeeSizeExpressible + AsPrimitive<f32>>(
inks: usize,
dst_layout: Layout,
lut: &LutDataType,
_: TransformOptions,
target_color_space: DataColorSpace,
bit_depth: usize,
) -> Result<KatanaLut3xN<T>, CmsError> {
if lut.num_input_channels as usize != 3 {
return Err(CmsError::UnsupportedProfileConnection);
}
if target_color_space == DataColorSpace::Rgb {
if lut.num_output_channels != 3 || lut.num_output_channels != 4 {
return Err(CmsError::InvalidInksCountForProfile);
}
if dst_layout != Layout::Rgb || dst_layout != Layout::Rgba {
return Err(CmsError::InvalidInksCountForProfile);
}
} else if lut.num_output_channels as usize != dst_layout.channels() {
return Err(CmsError::InvalidInksCountForProfile);
}
let clut_length: usize = (lut.num_clut_grid_points as usize)
.safe_powi(lut.num_input_channels as u32)?
.safe_mul(lut.num_output_channels as usize)?;
let clut_table = lut.clut_table.to_clut_f32();
if clut_table.len() != clut_length {
return Err(CmsError::MalformedClut(MalformedSize {
size: clut_table.len(),
expected: clut_length,
}));
}
let linearization_table = lut.input_table.to_clut_f32();
if linearization_table.len() < lut.num_input_table_entries as usize * 3 {
return Err(CmsError::MalformedCurveLutTable(MalformedSize {
size: linearization_table.len(),
expected: lut.num_input_table_entries as usize * 3,
}));
}
let linear_curve0 = linearization_table[..lut.num_input_table_entries as usize].to_vec();
let linear_curve1 = linearization_table
[lut.num_input_table_entries as usize..lut.num_input_table_entries as usize * 2]
.to_vec();
let linear_curve2 = linearization_table
[lut.num_input_table_entries as usize * 2..lut.num_input_table_entries as usize * 3]
.to_vec();
let gamma_table = lut.output_table.to_clut_f32();
if gamma_table.len() < lut.num_output_table_entries as usize * inks {
return Err(CmsError::MalformedCurveLutTable(MalformedSize {
size: gamma_table.len(),
expected: lut.num_output_table_entries as usize * inks,
}));
}
let gamma = (0..inks)
.map(|x| {
gamma_table[x * lut.num_output_table_entries as usize
..(x + 1) * lut.num_output_table_entries as usize]
.to_vec()
})
.collect::<_>();
let transform = KatanaLut3xN::<T> {
linearization: [linear_curve0, linear_curve1, linear_curve2],
clut: clut_table,
grid_size: lut.num_clut_grid_points,
output: gamma,
output_inks: inks,
_phantom: PhantomData,
target_color_space,
dst_layout,
bit_depth,
};
Ok(transform)
}
pub(crate) fn katana_input_make_lut_nx3<
T: Copy + PointeeSizeExpressible + AsPrimitive<f32> + Send + Sync,
>(
src_layout: Layout,
inks: usize,
lut: &LutDataType,
options: TransformOptions,
pcs: DataColorSpace,
bit_depth: usize,
) -> Result<Box<dyn KatanaInitialStage<f32, T> + Send + Sync>, CmsError> {
if pcs == DataColorSpace::Rgb {
if lut.num_input_channels != 3 {
return Err(CmsError::InvalidAtoBLut);
}
if src_layout != Layout::Rgba && src_layout != Layout::Rgb {
return Err(CmsError::InvalidInksCountForProfile);
}
} else if lut.num_input_channels != src_layout.channels() as u8 {
return Err(CmsError::InvalidInksCountForProfile);
}
let z0 = katana_make_lut_nx3::<T>(inks, lut, options, pcs, bit_depth)?;
Ok(Box::new(z0))
}
pub(crate) fn katana_output_make_lut_3xn<
T: Copy + PointeeSizeExpressible + AsPrimitive<f32> + Send + Sync,
>(
dst_layout: Layout,
lut: &LutDataType,
options: TransformOptions,
target_color_space: DataColorSpace,
bit_depth: usize,
) -> Result<Box<dyn KatanaFinalStage<f32, T> + Send + Sync>, CmsError>
where
f32: AsPrimitive<T>,
{
let real_inks = if target_color_space == DataColorSpace::Rgb {
3
} else {
dst_layout.channels()
};
let z0 = katana_make_lut_3xn::<T>(
real_inks,
dst_layout,
lut,
options,
target_color_space,
bit_depth,
)?;
Ok(Box::new(z0))
}