use super::encoder::VarDctEncoder;
use crate::headers::color_encoding::Primaries;
#[rustfmt::skip]
#[allow(clippy::excessive_precision)]
pub(crate) const P3_TO_SRGB: [[f32; 3]; 3] = [
[ 1.2249401763, -0.2249401763, 0.0000000000],
[-0.0420569547, 1.0420569547, 0.0000000000],
[-0.0196375546, -0.0786360456, 1.0982736001],
];
#[rustfmt::skip]
#[allow(clippy::excessive_precision)]
pub(crate) const BT2020_TO_SRGB: [[f32; 3]; 3] = [
[ 1.6604910021, -0.5876411388, -0.0728498633],
[-0.1245504745, 1.1328998971, -0.0083494226],
[-0.0181507634, -0.1005788980, 1.1187296614],
];
pub(crate) fn compute_primaries_to_srgb(
r: (f64, f64),
g: (f64, f64),
b: (f64, f64),
) -> [[f32; 3]; 3] {
let (wx, wy) = (0.3127, 0.3290);
let xy_to_xyz = |x: f64, y: f64| -> [f64; 3] { [x / y, 1.0, (1.0 - x - y) / y] };
let [xr, yr, zr] = xy_to_xyz(r.0, r.1);
let [xg, yg, zg] = xy_to_xyz(g.0, g.1);
let [xb, yb, zb] = xy_to_xyz(b.0, b.1);
let [xw, yw, zw] = xy_to_xyz(wx, wy);
let det = xr * (yg * zb - yb * zg) - xg * (yr * zb - yb * zr) + xb * (yr * zg - yg * zr);
assert!(det.abs() > 1e-10, "singular primaries matrix");
let inv_det = 1.0 / det;
let sr =
((yg * zb - yb * zg) * xw + (xb * zg - xg * zb) * yw + (xg * yb - xb * yg) * zw) * inv_det;
let sg =
((yb * zr - yr * zb) * xw + (xr * zb - xb * zr) * yw + (xb * yr - xr * yb) * zw) * inv_det;
let sb =
((yr * zg - yg * zr) * xw + (xg * zr - xr * zg) * yw + (xr * yg - xg * yr) * zw) * inv_det;
let p2x = [
[xr * sr, xg * sg, xb * sb],
[yr * sr, yg * sg, yb * sb],
[zr * sr, zg * sg, zb * sb],
];
#[allow(clippy::excessive_precision)]
let srgb_to_xyz = [
[0.4123907993, 0.3575843394, 0.1804807884],
[0.2126390059, 0.7151686788, 0.0721923154],
[0.0193308187, 0.1191947798, 0.9505321522],
];
let inv3 = |m: [[f64; 3]; 3]| -> [[f64; 3]; 3] {
let d = m[0][0] * (m[1][1] * m[2][2] - m[1][2] * m[2][1])
- m[0][1] * (m[1][0] * m[2][2] - m[1][2] * m[2][0])
+ m[0][2] * (m[1][0] * m[2][1] - m[1][1] * m[2][0]);
let id = 1.0 / d;
[
[
(m[1][1] * m[2][2] - m[1][2] * m[2][1]) * id,
(m[0][2] * m[2][1] - m[0][1] * m[2][2]) * id,
(m[0][1] * m[1][2] - m[0][2] * m[1][1]) * id,
],
[
(m[1][2] * m[2][0] - m[1][0] * m[2][2]) * id,
(m[0][0] * m[2][2] - m[0][2] * m[2][0]) * id,
(m[0][2] * m[1][0] - m[0][0] * m[1][2]) * id,
],
[
(m[1][0] * m[2][1] - m[1][1] * m[2][0]) * id,
(m[0][1] * m[2][0] - m[0][0] * m[2][1]) * id,
(m[0][0] * m[1][1] - m[0][1] * m[1][0]) * id,
],
]
};
let xyz_to_srgb = inv3(srgb_to_xyz);
let mut result = [[0.0f32; 3]; 3];
for i in 0..3 {
for j in 0..3 {
let mut sum = 0.0f64;
for k in 0..3 {
sum += xyz_to_srgb[i][k] * p2x[k][j];
}
result[i][j] = sum as f32;
}
}
result
}
pub(crate) fn primaries_to_srgb_matrix(
ce: &crate::headers::color_encoding::ColorEncoding,
) -> Option<[[f32; 3]; 3]> {
match ce.primaries {
Primaries::P3 => Some(P3_TO_SRGB),
Primaries::Bt2100 => Some(BT2020_TO_SRGB),
Primaries::Custom => {
let cp = ce
.custom_primaries
.as_ref()
.expect("custom_primaries must be set when primaries is Custom");
Some(compute_primaries_to_srgb(
(cp.red.x, cp.red.y),
(cp.green.x, cp.green.y),
(cp.blue.x, cp.blue.y),
))
}
Primaries::Srgb => None,
}
}
pub(crate) fn apply_matrix_3x3(r: &mut [f32], g: &mut [f32], b: &mut [f32], m: &[[f32; 3]; 3]) {
let m00 = m[0][0];
let m01 = m[0][1];
let m02 = m[0][2];
let m10 = m[1][0];
let m11 = m[1][1];
let m12 = m[1][2];
let m20 = m[2][0];
let m21 = m[2][1];
let m22 = m[2][2];
let len = r.len();
let chunks = len / 8;
let remainder = chunks * 8;
for chunk in 0..chunks {
let base = chunk * 8;
let rs: &mut [f32; 8] = (&mut r[base..base + 8]).try_into().unwrap();
let gs: &mut [f32; 8] = (&mut g[base..base + 8]).try_into().unwrap();
let bs: &mut [f32; 8] = (&mut b[base..base + 8]).try_into().unwrap();
for j in 0..8 {
let ri = rs[j];
let gi = gs[j];
let bi = bs[j];
rs[j] = m00 * ri + m01 * gi + m02 * bi;
gs[j] = m10 * ri + m11 * gi + m12 * bi;
bs[j] = m20 * ri + m21 * gi + m22 * bi;
}
}
for i in remainder..len {
let ri = r[i];
let gi = g[i];
let bi = b[i];
r[i] = m00 * ri + m01 * gi + m02 * bi;
g[i] = m10 * ri + m11 * gi + m12 * bi;
b[i] = m20 * ri + m21 * gi + m22 * bi;
}
}
impl VarDctEncoder {
pub(crate) fn convert_to_xyb_padded(
&self,
width: usize,
height: usize,
padded_width: usize,
padded_height: usize,
linear_rgb: &[f32],
) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
let primaries_matrix = self
.color_encoding
.as_ref()
.and_then(primaries_to_srgb_matrix);
let padded_n = padded_width * padded_height;
let mut xyb_x = jxl_simd::vec_f32_dirty(padded_n);
let mut xyb_y = jxl_simd::vec_f32_dirty(padded_n);
let mut xyb_b = jxl_simd::vec_f32_dirty(padded_n);
convert_rows_to_xyb(
width,
height,
padded_width,
linear_rgb,
primaries_matrix.as_ref(),
&mut xyb_x,
&mut xyb_y,
&mut xyb_b,
);
if padded_height > height {
let last_row_start = (height - 1) * padded_width;
pad_bottom_three_channels(
&mut xyb_x,
&mut xyb_y,
&mut xyb_b,
last_row_start,
padded_width,
height,
padded_height,
);
}
(xyb_x, xyb_y, xyb_b)
}
}
#[allow(clippy::too_many_arguments)]
fn pad_bottom_three_channels(
xyb_x: &mut [f32],
xyb_y: &mut [f32],
xyb_b: &mut [f32],
last_row_start: usize,
padded_width: usize,
height: usize,
padded_height: usize,
) {
fn pad_one(
plane: &mut [f32],
last_row_start: usize,
padded_width: usize,
height: usize,
padded_height: usize,
) {
let mut last = jxl_simd::vec_f32_dirty(padded_width);
last.copy_from_slice(&plane[last_row_start..last_row_start + padded_width]);
for y in height..padded_height {
let dst = y * padded_width;
plane[dst..dst + padded_width].copy_from_slice(&last);
}
}
#[cfg(feature = "parallel")]
{
let (((), ()), ()) = rayon::join(
|| {
rayon::join(
|| pad_one(xyb_x, last_row_start, padded_width, height, padded_height),
|| pad_one(xyb_y, last_row_start, padded_width, height, padded_height),
)
},
|| pad_one(xyb_b, last_row_start, padded_width, height, padded_height),
);
}
#[cfg(not(feature = "parallel"))]
{
pad_one(xyb_x, last_row_start, padded_width, height, padded_height);
pad_one(xyb_y, last_row_start, padded_width, height, padded_height);
pad_one(xyb_b, last_row_start, padded_width, height, padded_height);
}
}
const XYB_STRIP_ROWS: usize = 16;
#[allow(clippy::too_many_arguments)]
fn convert_rows_to_xyb(
width: usize,
height: usize,
padded_width: usize,
linear_rgb: &[f32],
primaries_matrix: Option<&[[f32; 3]; 3]>,
xyb_x: &mut [f32],
xyb_y: &mut [f32],
xyb_b: &mut [f32],
) {
let strip_len = XYB_STRIP_ROWS * padded_width;
#[cfg(feature = "parallel")]
{
use rayon::prelude::*;
xyb_x[..height * padded_width]
.par_chunks_mut(strip_len)
.zip(xyb_y[..height * padded_width].par_chunks_mut(strip_len))
.zip(xyb_b[..height * padded_width].par_chunks_mut(strip_len))
.enumerate()
.for_each(|(strip_idx, ((strip_x, strip_y), strip_b))| {
let y_start = strip_idx * XYB_STRIP_ROWS;
let strip_rows = strip_x.len() / padded_width;
convert_strip(
width,
padded_width,
y_start,
strip_rows,
linear_rgb,
primaries_matrix,
strip_x,
strip_y,
strip_b,
);
});
}
#[cfg(not(feature = "parallel"))]
{
let full_len = height * padded_width;
let mut y_start = 0;
let mut offset = 0;
while offset < full_len {
let this_len = strip_len.min(full_len - offset);
let strip_rows = this_len / padded_width;
convert_strip(
width,
padded_width,
y_start,
strip_rows,
linear_rgb,
primaries_matrix,
&mut xyb_x[offset..offset + this_len],
&mut xyb_y[offset..offset + this_len],
&mut xyb_b[offset..offset + this_len],
);
y_start += strip_rows;
offset += this_len;
}
}
}
#[allow(clippy::too_many_arguments)]
fn convert_strip(
width: usize,
padded_width: usize,
y_start: usize,
strip_rows: usize,
linear_rgb: &[f32],
primaries_matrix: Option<&[[f32; 3]; 3]>,
strip_x: &mut [f32],
strip_y: &mut [f32],
strip_b: &mut [f32],
) {
let mut row_r = jxl_simd::vec_f32_dirty(width);
let mut row_g = jxl_simd::vec_f32_dirty(width);
let mut row_b = jxl_simd::vec_f32_dirty(width);
for local_y in 0..strip_rows {
let y = y_start + local_y;
let src_row = y * width;
for x in 0..width {
let si = (src_row + x) * 3;
row_r[x] = linear_rgb[si];
row_g[x] = linear_rgb[si + 1];
row_b[x] = linear_rgb[si + 2];
}
if let Some(m) = primaries_matrix {
apply_matrix_3x3(&mut row_r, &mut row_g, &mut row_b, m);
}
let dst_row = local_y * padded_width;
jxl_simd::linear_rgb_to_xyb_batch(
&row_r,
&row_g,
&row_b,
&mut strip_x[dst_row..dst_row + width],
&mut strip_y[dst_row..dst_row + width],
&mut strip_b[dst_row..dst_row + width],
);
#[cfg(feature = "debug-dc")]
if y == 0 {
eprintln!(
"XYB[0,0]: linear_rgb=({:.6},{:.6},{:.6}) -> XYB=({:.6},{:.6},{:.6})",
row_r[0], row_g[0], row_b[0], strip_x[0], strip_y[0], strip_b[0]
);
}
if padded_width > width {
let last_x_idx = dst_row + width - 1;
let last_x = strip_x[last_x_idx];
let last_y = strip_y[last_x_idx];
let last_b = strip_b[last_x_idx];
for x in width..padded_width {
let dst_idx = dst_row + x;
strip_x[dst_idx] = last_x;
strip_y[dst_idx] = last_y;
strip_b[dst_idx] = last_b;
}
}
}
}