use alloc::vec;
use alloc::vec::Vec;
use super::build::{Decomposition, SubBand};
use super::codestream::WaveletTransform;
use super::decode::{DecompositionStorage, TileDecodeContext};
use super::rect::IntRect;
use crate::j2c::Header;
use crate::math::{self, Level, SIMD_WIDTH, Simd, dispatch, f32x8};
pub(crate) struct IDWTOutput {
pub(crate) coefficients: Vec<f32>,
pub(crate) rect: IntRect,
}
impl IDWTOutput {
pub(crate) fn dummy() -> Self {
Self {
coefficients: vec![],
rect: IntRect::from_ltrb(0, 0, u32::MAX, u32::MAX),
}
}
}
struct IDWTTempOutput {
rect: IntRect,
}
pub(crate) fn apply(
storage: &DecompositionStorage<'_>,
tile_ctx: &mut TileDecodeContext<'_>,
component_idx: usize,
header: &Header<'_>,
transform: WaveletTransform,
) {
let tile_decompositions = &storage.tile_decompositions[component_idx];
let mut decompositions = &storage.decompositions[tile_decompositions.decompositions.clone()];
decompositions = &decompositions[..decompositions
.len()
.saturating_sub(header.skipped_resolution_levels as usize)];
let ll_sub_band = &storage.sub_bands[tile_decompositions.first_ll_sub_band];
let (scratch_buf, output) = (&mut tile_ctx.idwt_scratch_buffer, &mut tile_ctx.idwt_output);
let estimate_buffer_size = |decomposition: &Decomposition| {
let total_width = decomposition.rect.width() as usize;
let total_height = decomposition.rect.height() as usize;
let min = total_width * total_height;
let max = (total_width + 1) * (total_height + 1);
(min, max)
};
if decompositions.is_empty() {
output.coefficients.clear();
output
.coefficients
.extend_from_slice(&storage.coefficients[ll_sub_band.coefficients.clone()]);
output.rect = ll_sub_band.rect;
return;
}
let (s_min, s_max) = estimate_buffer_size(decompositions.last().unwrap());
if output.coefficients.len() < s_min {
output
.coefficients
.reserve_exact(s_max - output.coefficients.len());
}
if decompositions.len() > 1 {
let (s_min, s_max) = estimate_buffer_size(&decompositions[decompositions.len() - 2]);
if scratch_buf.len() < s_min {
scratch_buf.reserve_exact(s_max - scratch_buf.len());
}
}
let mut use_scratch = decompositions.len().is_multiple_of(2);
let mut temp_output = filter_2d(
IDWTInput::from_sub_band(ll_sub_band, storage),
if use_scratch {
scratch_buf
} else {
&mut output.coefficients
},
&decompositions[0],
transform,
storage,
);
for decomposition in decompositions.iter().skip(1) {
use_scratch = !use_scratch;
temp_output = if use_scratch {
filter_2d(
IDWTInput::from_output(&output.coefficients),
scratch_buf,
decomposition,
transform,
storage,
)
} else {
filter_2d(
IDWTInput::from_output(scratch_buf),
&mut output.coefficients,
decomposition,
transform,
storage,
)
};
}
output.rect = temp_output.rect;
}
struct IDWTInput<'a> {
coefficients: &'a [f32],
}
impl<'a> IDWTInput<'a> {
fn from_sub_band(sub_band: &'a SubBand, storage: &'a DecompositionStorage<'_>) -> Self {
IDWTInput {
coefficients: &storage.coefficients[sub_band.coefficients.clone()],
}
}
fn from_output(coefficients: &'a [f32]) -> Self {
IDWTInput { coefficients }
}
}
fn filter_2d(
input: IDWTInput<'_>,
coefficients: &mut Vec<f32>,
decomposition: &Decomposition,
transform: WaveletTransform,
storage: &DecompositionStorage<'_>,
) -> IDWTTempOutput {
interleave_samples(input, decomposition, coefficients, storage);
if decomposition.rect.width() > 0 && decomposition.rect.height() > 0 {
filter_horizontal(coefficients, decomposition.rect, transform);
filter_vertical(coefficients, decomposition.rect, transform);
}
IDWTTempOutput {
rect: decomposition.rect,
}
}
fn interleave_samples(
input: IDWTInput<'_>,
decomposition: &Decomposition,
coefficients: &mut Vec<f32>,
storage: &DecompositionStorage<'_>,
) {
let level = Level::new();
dispatch!(level, simd => {
interleave_samples_inner::<_>(simd, input, decomposition, coefficients, storage);
});
}
#[inline(always)]
fn interleave_samples_inner<S: Simd>(
simd: S,
input: IDWTInput<'_>,
decomposition: &Decomposition,
coefficients: &mut Vec<f32>,
storage: &DecompositionStorage<'_>,
) {
let width = decomposition.rect.width() as usize;
let height = decomposition.rect.height() as usize;
assert!(coefficients.capacity() >= width * height);
coefficients.resize(width * height, 0.0);
let IntRect {
x0: u0,
x1: u1,
y0: v0,
y1: v1,
} = decomposition.rect;
let ll = input.coefficients;
let hl = &storage.coefficients[storage.sub_bands[decomposition.sub_bands[0]]
.coefficients
.clone()];
let lh = &storage.coefficients[storage.sub_bands[decomposition.sub_bands[1]]
.coefficients
.clone()];
let hh = &storage.coefficients[storage.sub_bands[decomposition.sub_bands[2]]
.coefficients
.clone()];
let num_u_low = (u1.div_ceil(2) - u0.div_ceil(2)) as usize;
let num_u_high = (u1 / 2 - u0 / 2) as usize;
let num_v_low = (v1.div_ceil(2) - v0.div_ceil(2)) as usize;
let num_v_high = (v1 / 2 - v0 / 2) as usize;
let (first_w, second_w) = if u0 % 2 == 0 {
(num_u_low, num_u_high)
} else {
(num_u_high, num_u_low)
};
let even_row_start = if v0 % 2 == 0 { 0 } else { 1 };
let odd_row_start = if v0 % 2 == 0 { 1 } else { 0 };
let (first_even, second_even) = if u0 % 2 == 0 { (ll, hl) } else { (hl, ll) };
interleave_rows(
simd,
first_even,
second_even,
first_w,
second_w,
coefficients,
width,
height,
even_row_start,
num_v_low,
);
let (first_odd, second_odd) = if u0 % 2 == 0 { (lh, hh) } else { (hh, lh) };
interleave_rows(
simd,
first_odd,
second_odd,
first_w,
second_w,
coefficients,
width,
height,
odd_row_start,
num_v_high,
);
}
#[inline(always)]
fn interleave_rows<S: Simd>(
simd: S,
first_band: &[f32],
second_band: &[f32],
first_w: usize,
second_w: usize,
output: &mut [f32],
width: usize,
height: usize,
start_row: usize,
num_rows: usize,
) {
for v in 0..num_rows {
let out_row = start_row + v * 2;
if out_row >= height {
break;
}
let first_row = &first_band[v * first_w..][..first_w];
let second_row = &second_band[v * second_w..][..second_w];
let out_slice = &mut output[out_row * width..][..width];
interleave_row(simd, first_row, second_row, out_slice);
}
}
#[inline(always)]
fn interleave_row<S: Simd>(simd: S, first: &[f32], second: &[f32], output: &mut [f32]) {
let num_pairs = first.len().min(second.len());
let simd_chunks = num_pairs / SIMD_WIDTH;
for i in 0..simd_chunks {
let base = i * SIMD_WIDTH;
let f = f32x8::from_slice(simd, &first[base..base + SIMD_WIDTH]);
let s = f32x8::from_slice(simd, &second[base..base + SIMD_WIDTH]);
f.zip_low(s).store(&mut output[base * 2..]);
f.zip_high(s).store(&mut output[base * 2 + SIMD_WIDTH..]);
}
for i in (simd_chunks * SIMD_WIDTH)..num_pairs {
output[i * 2] = first[i];
output[i * 2 + 1] = second[i];
}
if first.len() > num_pairs {
output[num_pairs * 2] = first[num_pairs];
}
}
fn filter_horizontal(coefficients: &mut [f32], rect: IntRect, transform: WaveletTransform) {
let width = rect.width() as usize;
for scanline in coefficients
.chunks_exact_mut(width)
.take(rect.height() as usize)
{
filter_row(scanline, width, rect.x0 as usize, transform);
}
}
fn filter_row(scanline: &mut [f32], width: usize, x0: usize, transform: WaveletTransform) {
if width == 1 {
if !x0.is_multiple_of(2) {
scanline[0] *= 0.5;
}
return;
}
match transform {
WaveletTransform::Reversible53 => reversible_filter_53r(scanline, width, x0),
WaveletTransform::Irreversible97 => irreversible_filter_97i(scanline, width, x0),
}
}
fn reversible_filter_53r(scanline: &mut [f32], width: usize, x0: usize) {
let first_even = x0 % 2;
let first_odd = 1 - first_even;
filter_step_horizontal(
scanline,
width,
first_even,
#[inline(always)]
|s, left, right| s - math::floor_f32(math::mul_add(left + right, 0.25, 0.5)),
);
filter_step_horizontal(
scanline,
width,
first_odd,
#[inline(always)]
|s, left, right| s + math::floor_f32((left + right) * 0.5),
);
}
fn irreversible_filter_97i(scanline: &mut [f32], width: usize, x0: usize) {
const NEG_ALPHA: f32 = 1.586_134_3;
const NEG_BETA: f32 = 0.052_980_117;
const NEG_GAMMA: f32 = -0.882_911_1;
const NEG_DELTA: f32 = -0.443_506_87;
const KAPPA: f32 = 1.230_174_1;
const INV_KAPPA: f32 = 1.0 / KAPPA;
let first_even = x0 % 2;
let first_odd = 1 - first_even;
let (k0, k1) = if first_even == 0 {
(KAPPA, INV_KAPPA)
} else {
(INV_KAPPA, KAPPA)
};
for i in (0..width.saturating_sub(1)).step_by(2) {
scanline[i] *= k0;
scanline[i + 1] *= k1;
}
if width % 2 == 1 {
scanline[width - 1] *= k0;
}
filter_step_horizontal(
scanline,
width,
first_even,
#[inline(always)]
|s, left, right| math::mul_add(left + right, NEG_DELTA, s),
);
filter_step_horizontal(
scanline,
width,
first_odd,
#[inline(always)]
|s, left, right| math::mul_add(left + right, NEG_GAMMA, s),
);
filter_step_horizontal(
scanline,
width,
first_even,
#[inline(always)]
|s, left, right| math::mul_add(left + right, NEG_BETA, s),
);
filter_step_horizontal(
scanline,
width,
first_odd,
#[inline(always)]
|s, left, right| math::mul_add(left + right, NEG_ALPHA, s),
);
}
#[inline(always)]
fn filter_step_horizontal(
scanline: &mut [f32],
width: usize,
first: usize,
f: impl Fn(f32, f32, f32) -> f32,
) {
if first == 0 {
let left = periodic_symmetric_extension_left(0, 1);
let right = periodic_symmetric_extension_right(0, 1, width);
scanline[0] = f(scanline[0], scanline[left], scanline[right]);
}
let middle_start = if first == 0 { 2 } else { 1 };
for i in (middle_start..width - 1).step_by(2) {
scanline[i] = f(scanline[i], scanline[i - 1], scanline[i + 1]);
}
if width > 1 && (width - 1) % 2 == first {
let i = width - 1;
let left = periodic_symmetric_extension_left(i, 1);
let right = periodic_symmetric_extension_right(i, 1, width);
scanline[i] = f(scanline[i], scanline[left], scanline[right]);
}
}
#[inline(always)]
fn filter_step_vertical<S: Simd>(
simd: S,
scanline: &mut [f32],
height: usize,
width: usize,
simd_width: usize,
first: usize,
f_simd: impl Fn(f32x8<S>, f32x8<S>, f32x8<S>) -> f32x8<S>,
f_scalar: impl Fn(f32, f32, f32) -> f32,
) {
for row in (first..height).step_by(2) {
let row_above = periodic_symmetric_extension_left(row, 1);
let row_below = periodic_symmetric_extension_right(row, 1, height);
for base_column in (0..simd_width).step_by(SIMD_WIDTH) {
let s1 = f32x8::from_slice(simd, &scanline[row * width + base_column..][..SIMD_WIDTH]);
let s2 = f32x8::from_slice(
simd,
&scanline[row_above * width + base_column..][..SIMD_WIDTH],
);
let s3 = f32x8::from_slice(
simd,
&scanline[row_below * width + base_column..][..SIMD_WIDTH],
);
let result = f_simd(s1, s2, s3);
result.store(&mut scanline[row * width + base_column..][..SIMD_WIDTH]);
}
for col in simd_width..width {
let s1 = scanline[row * width + col];
let s2 = scanline[row_above * width + col];
let s3 = scanline[row_below * width + col];
scanline[row * width + col] = f_scalar(s1, s2, s3);
}
}
}
#[inline(always)]
fn periodic_symmetric_extension_left(idx: usize, offset: usize) -> usize {
offset.abs_diff(idx)
}
#[inline(always)]
fn periodic_symmetric_extension_right(idx: usize, offset: usize, length: usize) -> usize {
let new_idx = idx + offset;
if new_idx >= length {
let overshoot = new_idx - length;
length - 2 - overshoot
} else {
new_idx
}
}
fn filter_vertical(coefficients: &mut [f32], rect: IntRect, transform: WaveletTransform) {
dispatch!(Level::new(), simd => filter_vertical_impl(simd, coefficients, rect, transform));
}
#[inline(always)]
fn filter_vertical_impl<S: Simd>(
simd: S,
scanline: &mut [f32],
rect: IntRect,
transform: WaveletTransform,
) {
let width = rect.width() as usize;
let height = rect.height() as usize;
let y0 = rect.y0 as usize;
if height == 1 {
if !y0.is_multiple_of(2) {
let simd_width = width / SIMD_WIDTH * SIMD_WIDTH;
for base_column in (0..simd_width).step_by(SIMD_WIDTH) {
let mut loaded = f32x8::from_slice(simd, &scanline[base_column..][..SIMD_WIDTH]);
loaded *= 0.5;
loaded.store(&mut scanline[base_column..][..SIMD_WIDTH]);
}
#[allow(clippy::needless_range_loop)]
for col in simd_width..width {
scanline[col] *= 0.5;
}
}
return;
}
match transform {
WaveletTransform::Reversible53 => {
reversible_filter_53r_simd(simd, scanline, height, width, y0);
}
WaveletTransform::Irreversible97 => {
irreversible_filter_97i_simd(simd, scanline, height, width, y0);
}
}
}
#[inline(always)]
fn reversible_filter_53r_simd<S: Simd>(
simd: S,
scanline: &mut [f32],
height: usize,
width: usize,
y0: usize,
) {
let first_even = y0 % 2;
let first_odd = 1 - first_even;
let simd_width = width / SIMD_WIDTH * SIMD_WIDTH;
filter_step_vertical(
simd,
scanline,
height,
width,
simd_width,
first_even,
#[inline(always)]
|s1, s2, s3| s1 - ((s2 + s3 + 2.0) * 0.25).floor(),
#[inline(always)]
|s1, s2, s3| s1 - math::floor_f32(math::mul_add(s2 + s3, 0.25, 0.5)),
);
filter_step_vertical(
simd,
scanline,
height,
width,
simd_width,
first_odd,
#[inline(always)]
|s1, s2, s3| s1 + ((s2 + s3) * 0.5).floor(),
#[inline(always)]
|s1, s2, s3| s1 + math::floor_f32((s2 + s3) * 0.5),
);
}
#[inline(always)]
fn irreversible_filter_97i_simd<S: Simd>(
simd: S,
scanline: &mut [f32],
height: usize,
width: usize,
y0: usize,
) {
const NEG_ALPHA: f32 = 1.586_134_3;
const NEG_BETA: f32 = 0.052_980_117;
const NEG_GAMMA: f32 = -0.882_911_1;
const NEG_DELTA: f32 = -0.443_506_87;
const KAPPA: f32 = 1.230_174_1;
const INV_KAPPA: f32 = 1.0 / KAPPA;
let neg_alpha = f32x8::splat(simd, NEG_ALPHA);
let neg_beta = f32x8::splat(simd, NEG_BETA);
let neg_gamma = f32x8::splat(simd, NEG_GAMMA);
let neg_delta = f32x8::splat(simd, NEG_DELTA);
let kappa = f32x8::splat(simd, KAPPA);
let inv_kappa = f32x8::splat(simd, INV_KAPPA);
let first_even = y0 % 2;
let first_odd = 1 - first_even;
let simd_width = width / SIMD_WIDTH * SIMD_WIDTH;
let (k0, k1, k0_simd, k1_simd) = if first_even == 0 {
(KAPPA, INV_KAPPA, kappa, inv_kappa)
} else {
(INV_KAPPA, KAPPA, inv_kappa, kappa)
};
for row in (0..height.saturating_sub(1)).step_by(2) {
for base_column in (0..simd_width).step_by(SIMD_WIDTH) {
let mut vals0 =
f32x8::from_slice(simd, &scanline[row * width + base_column..][..SIMD_WIDTH]);
let mut vals1 = f32x8::from_slice(
simd,
&scanline[(row + 1) * width + base_column..][..SIMD_WIDTH],
);
vals0 = vals0 * k0_simd;
vals1 = vals1 * k1_simd;
vals0.store(&mut scanline[row * width + base_column..][..SIMD_WIDTH]);
vals1.store(&mut scanline[(row + 1) * width + base_column..][..SIMD_WIDTH]);
}
for col in simd_width..width {
scanline[row * width + col] *= k0;
scanline[(row + 1) * width + col] *= k1;
}
}
if height % 2 == 1 {
let row = height - 1;
for base_column in (0..simd_width).step_by(SIMD_WIDTH) {
let mut vals =
f32x8::from_slice(simd, &scanline[row * width + base_column..][..SIMD_WIDTH]);
vals = vals * k0_simd;
vals.store(&mut scanline[row * width + base_column..][..SIMD_WIDTH]);
}
for col in simd_width..width {
scanline[row * width + col] *= k0;
}
}
filter_step_vertical(
simd,
scanline,
height,
width,
simd_width,
first_even,
#[inline(always)]
|s1, s2, s3| (s2 + s3).mul_add(neg_delta, s1),
#[inline(always)]
|s1, s2, s3| math::mul_add(s2 + s3, NEG_DELTA, s1),
);
filter_step_vertical(
simd,
scanline,
height,
width,
simd_width,
first_odd,
#[inline(always)]
|s1, s2, s3| (s2 + s3).mul_add(neg_gamma, s1),
#[inline(always)]
|s1, s2, s3| math::mul_add(s2 + s3, NEG_GAMMA, s1),
);
filter_step_vertical(
simd,
scanline,
height,
width,
simd_width,
first_even,
#[inline(always)]
|s1, s2, s3| (s2 + s3).mul_add(neg_beta, s1),
#[inline(always)]
|s1, s2, s3| math::mul_add(s2 + s3, NEG_BETA, s1),
);
filter_step_vertical(
simd,
scanline,
height,
width,
simd_width,
first_odd,
#[inline(always)]
|s1, s2, s3| (s2 + s3).mul_add(neg_alpha, s1),
#[inline(always)]
|s1, s2, s3| math::mul_add(s2 + s3, NEG_ALPHA, s1),
);
}