use alloc::vec;
use alloc::vec::Vec;
use crate::math::floor_f32;
const ALPHA: f32 = -1.586_134_3;
const BETA: f32 = -0.052_980_117;
const GAMMA: f32 = 0.882_911_1;
const DELTA: f32 = 0.443_506_87;
const KAPPA: f32 = 1.230_174_1;
const INV_KAPPA: f32 = 1.0 / 1.230_174_1;
#[derive(Debug)]
pub(crate) struct DwtDecomposition {
pub(crate) ll: Vec<f32>,
pub(crate) ll_width: u32,
pub(crate) ll_height: u32,
pub(crate) levels: Vec<DwtLevel>,
}
#[derive(Debug)]
pub(crate) struct DwtLevel {
pub(crate) hl: Vec<f32>,
pub(crate) lh: Vec<f32>,
pub(crate) hh: Vec<f32>,
pub(crate) low_width: u32,
pub(crate) low_height: u32,
pub(crate) high_width: u32,
pub(crate) high_height: u32,
}
pub(crate) fn forward_dwt(
samples: &[f32],
width: u32,
height: u32,
num_levels: u8,
reversible: bool,
) -> DwtDecomposition {
let w = width as usize;
let h = height as usize;
let mut buffer = samples.to_vec();
let mut current_width = w;
let mut current_height = h;
let mut levels = Vec::with_capacity(num_levels as usize);
for _ in 0..num_levels {
if current_width < 2 && current_height < 2 {
break;
}
if current_height >= 2 {
let mut col_buf = vec![0.0f32; current_height];
for x in 0..current_width {
for y in 0..current_height {
col_buf[y] = buffer[y * w + x];
}
if reversible {
forward_lift_53(&mut col_buf[..current_height]);
} else {
forward_lift_97(&mut col_buf[..current_height]);
}
let num_low = current_height.div_ceil(2);
for i in 0..num_low {
buffer[i * w + x] = col_buf[i * 2];
}
for i in 0..(current_height / 2) {
buffer[(num_low + i) * w + x] = col_buf[i * 2 + 1];
}
}
}
if current_width >= 2 {
let mut row_buf = vec![0.0f32; current_width];
for y in 0..current_height {
let row_start = y * w;
row_buf[..current_width]
.copy_from_slice(&buffer[row_start..row_start + current_width]);
if reversible {
forward_lift_53(&mut row_buf[..current_width]);
} else {
forward_lift_97(&mut row_buf[..current_width]);
}
let num_low = current_width.div_ceil(2);
for i in 0..num_low {
buffer[row_start + i] = row_buf[i * 2];
}
for i in 0..(current_width / 2) {
buffer[row_start + num_low + i] = row_buf[i * 2 + 1];
}
}
}
let low_w = current_width.div_ceil(2);
let low_h = current_height.div_ceil(2);
let high_w = current_width / 2;
let high_h = current_height / 2;
let mut hl = vec![0.0f32; high_w * low_h];
let mut lh = vec![0.0f32; low_w * high_h];
let mut hh = vec![0.0f32; high_w * high_h];
for y in 0..low_h {
for x in 0..high_w {
hl[y * high_w + x] = buffer[y * w + low_w + x];
}
}
for y in 0..high_h {
for x in 0..low_w {
lh[y * low_w + x] = buffer[(low_h + y) * w + x];
}
}
for y in 0..high_h {
for x in 0..high_w {
hh[y * high_w + x] = buffer[(low_h + y) * w + low_w + x];
}
}
levels.push(DwtLevel {
hl,
lh,
hh,
low_width: low_w as u32,
low_height: low_h as u32,
high_width: high_w as u32,
high_height: high_h as u32,
});
current_width = low_w;
current_height = low_h;
}
let mut ll = vec![0.0f32; current_width * current_height];
for y in 0..current_height {
for x in 0..current_width {
ll[y * current_width + x] = buffer[y * w + x];
}
}
levels.reverse();
DwtDecomposition {
ll,
ll_width: current_width as u32,
ll_height: current_height as u32,
levels,
}
}
fn forward_lift_53(data: &mut [f32]) {
let n = data.len();
if n < 2 {
return;
}
if n.is_multiple_of(2) {
forward_lift_53_even(data);
return;
}
let last_even = if n.is_multiple_of(2) { n - 2 } else { n - 1 };
for i in (1..n).step_by(2) {
let left = data[i - 1];
let right = if i + 1 < n {
data[i + 1]
} else {
data[last_even]
};
data[i] -= floor_f32((left + right) * 0.5);
}
for i in (0..n).step_by(2) {
let left = if i > 0 { data[i - 1] } else { data[1] };
let right = if i + 1 < n { data[i + 1] } else { left };
data[i] += floor_f32((left + right) * 0.25 + 0.5);
}
}
fn forward_lift_53_even(data: &mut [f32]) {
let n = data.len();
debug_assert!(n >= 2);
debug_assert!(n.is_multiple_of(2));
for i in (1..n - 1).step_by(2) {
data[i] -= floor_f32((data[i - 1] + data[i + 1]) * 0.5);
}
data[n - 1] -= floor_f32(data[n - 2]);
data[0] += floor_f32(data[1] * 0.5 + 0.5);
for i in (2..n).step_by(2) {
data[i] += floor_f32((data[i - 1] + data[i + 1]) * 0.25 + 0.5);
}
}
fn forward_lift_97(data: &mut [f32]) {
let n = data.len();
if n < 2 {
return;
}
let last_even = if n.is_multiple_of(2) { n - 2 } else { n - 1 };
for i in (1..n).step_by(2) {
let left = data[i - 1];
let right = if i + 1 < n {
data[i + 1]
} else {
data[last_even]
};
data[i] += ALPHA * (left + right);
}
for i in (0..n).step_by(2) {
let left = if i > 0 { data[i - 1] } else { data[1] };
let right = if i + 1 < n { data[i + 1] } else { left };
data[i] += BETA * (left + right);
}
for i in (1..n).step_by(2) {
let left = data[i - 1];
let right = if i + 1 < n {
data[i + 1]
} else {
data[last_even]
};
data[i] += GAMMA * (left + right);
}
for i in (0..n).step_by(2) {
let left = if i > 0 { data[i - 1] } else { data[1] };
let right = if i + 1 < n { data[i + 1] } else { left };
data[i] += DELTA * (left + right);
}
for i in (0..n).step_by(2) {
data[i] *= INV_KAPPA;
}
for i in (1..n).step_by(2) {
data[i] *= KAPPA;
}
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq_slice(a: &[f32], b: &[f32], eps: f32) -> bool {
a.len() == b.len() && a.iter().zip(b).all(|(x, y)| (x - y).abs() < eps)
}
#[test]
fn test_forward_53_basic() {
let mut data = vec![10.0, 20.0, 30.0, 40.0];
forward_lift_53(&mut data);
inverse_lift_53(&mut data);
assert!(approx_eq_slice(&data, &[10.0, 20.0, 30.0, 40.0], 0.001));
}
#[test]
fn forward_53_even_fast_path_matches_reference_for_common_tile_widths() {
for len in [2usize, 4, 8, 64, 512] {
let mut expected = (0..len)
.map(|idx| ((idx * 37 + idx / 3) & 0xff) as f32 - 128.0)
.collect::<Vec<_>>();
let mut actual = expected.clone();
forward_lift_53_reference(&mut expected);
forward_lift_53_even(&mut actual);
assert_eq!(actual, expected, "len={len}");
}
}
#[test]
fn test_forward_97_round_trip() {
for len in [2usize, 3, 8, 9, 64, 65] {
let original: Vec<f32> = (0..len)
.map(|idx| ((idx * 37 + idx / 3) & 0xff) as f32 - 128.0)
.collect();
let mut data = original.clone();
forward_lift_97(&mut data);
crate::j2c::idwt::test_irreversible_filter_97i(&mut data, len, 0);
assert!(
approx_eq_slice(&data, &original, 0.01),
"len={len} data={data:?} original={original:?}"
);
}
}
#[test]
fn forward_lift_97_places_constant_signal_in_low_pass() {
for len in [2usize, 3, 8, 9, 64, 65] {
let mut data = vec![50.0; len];
forward_lift_97(&mut data);
for &low in data.iter().step_by(2) {
assert!((low - 50.0).abs() < 0.001, "len={len} data={data:?}");
}
for &high in data.iter().skip(1).step_by(2) {
assert!(high.abs() < 0.001, "len={len} data={data:?}");
}
}
}
#[test]
fn test_forward_dwt_53_single_level() {
let samples: Vec<f32> = (0..16).map(|x| x as f32).collect();
let decomp = forward_dwt(&samples, 4, 4, 1, true);
assert_eq!(decomp.ll_width, 2);
assert_eq!(decomp.ll_height, 2);
assert_eq!(decomp.levels.len(), 1);
}
#[test]
fn test_forward_dwt_97_multi_level() {
let samples: Vec<f32> = (0..64).map(|x| x as f32).collect();
let decomp = forward_dwt(&samples, 8, 8, 3, false);
assert_eq!(decomp.levels.len(), 3);
assert_eq!(decomp.ll_width, 1);
assert_eq!(decomp.ll_height, 1);
}
#[test]
fn test_odd_dimensions() {
let samples: Vec<f32> = (0..15).map(|x| x as f32).collect();
let decomp = forward_dwt(&samples, 5, 3, 1, true);
assert_eq!(decomp.ll_width, 3);
assert_eq!(decomp.ll_height, 2);
assert_eq!(decomp.levels[0].high_width, 2);
assert_eq!(decomp.levels[0].high_height, 1);
}
fn inverse_lift_53(data: &mut [f32]) {
let n = data.len();
if n < 2 {
return;
}
for i in (0..n).step_by(2) {
let left = if i > 0 { data[i - 1] } else { data[1] };
let right = if i + 1 < n { data[i + 1] } else { left };
data[i] -= ((left + right) * 0.25 + 0.5).floor();
}
let last_even = if n.is_multiple_of(2) { n - 2 } else { n - 1 };
for i in (1..n).step_by(2) {
let left = data[i - 1];
let right = if i + 1 < n {
data[i + 1]
} else {
data[last_even]
};
data[i] += ((left + right) * 0.5).floor();
}
}
fn forward_lift_53_reference(data: &mut [f32]) {
let n = data.len();
if n < 2 {
return;
}
let last_even = if n.is_multiple_of(2) { n - 2 } else { n - 1 };
for i in (1..n).step_by(2) {
let left = data[i - 1];
let right = if i + 1 < n {
data[i + 1]
} else {
data[last_even]
};
data[i] -= ((left + right) * 0.5).floor();
}
for i in (0..n).step_by(2) {
let left = if i > 0 { data[i - 1] } else { data[1] };
let right = if i + 1 < n { data[i + 1] } else { left };
data[i] += ((left + right) * 0.25 + 0.5).floor();
}
}
}