#![deny(missing_docs)]
pub mod boundarys;
pub mod dwt;
pub mod iter;
pub mod lwt;
pub mod simd;
pub mod utils;
use ndwt_macros::{generate_wavelet_enum, generate_wavelet_match_arms};
use num_traits::{FromPrimitive, MulAdd, Num, NumAssignOps, NumOps};
use std::{fmt::Debug, ops::Neg};
macro_rules! gen_wavelet_struct {
(
$( ($name:ident, $width:expr) ),* $(,)?
) => {
$(
/// Zero-size marker struct representing a specific wavelet.
///
/// `WIDTH` is the filter length (number of coefficients). Use [`crate::Wavelet`] to
pub struct $name;
impl $name{
pub const WIDTH: usize = $width;
pub fn new() -> Self{ Self{}}
}
impl Default for $name {
fn default() -> Self { Self::new() }
}
)*
};
}
pub mod daubechies {
gen_wavelet_struct!(
(Daubechies1, 2),
(Daubechies2, 4),
(Daubechies3, 6),
(Daubechies4, 8),
(Daubechies5, 10),
(Daubechies6, 12),
(Daubechies7, 14),
(Daubechies8, 16),
(Daubechies9, 18),
(Daubechies10, 20),
);
}
pub mod symlet {
gen_wavelet_struct!((Symlet4, 8), (Symlet5, 10), (Symlet6, 12),);
}
pub mod coiflet {
gen_wavelet_struct!((Coiflet1, 6), (Coiflet2, 12), (Coiflet3, 18),);
}
pub mod bior {
gen_wavelet_struct!((Bior1_3, 6));
gen_wavelet_struct!((Bior1_5, 10));
gen_wavelet_struct!((Bior2_2, 6));
gen_wavelet_struct!((Bior2_4, 10));
gen_wavelet_struct!((Bior2_6, 14));
gen_wavelet_struct!((Bior2_8, 18));
gen_wavelet_struct!((Bior3_1, 4));
gen_wavelet_struct!((Bior3_3, 8));
gen_wavelet_struct!((Bior3_5, 12));
gen_wavelet_struct!((Bior3_7, 16));
gen_wavelet_struct!((Bior3_9, 20));
gen_wavelet_struct!((Bior4_2, 8));
gen_wavelet_struct!((Bior4_4, 12));
gen_wavelet_struct!((Bior4_6, 16));
gen_wavelet_struct!((Bior5_5, 14));
gen_wavelet_struct!((Bior6_8, 22));
gen_wavelet_struct!((CDF5_3, 6));
gen_wavelet_struct!((CDF9_7, 10));
}
#[inline]
pub fn max_level(width: usize, n: usize) -> usize {
if width == 0 {
return 0;
}
if n < width - 1 {
return 0;
}
let mut lvl = 0;
let mut n = n;
while n >= 2 * (width - 1) {
lvl += 1;
n = n.div_ceil(2);
}
lvl
}
#[inline]
pub fn max_level_nd(width: usize, shape: &[usize], axes: &[usize]) -> usize {
axes.iter().enumerate().for_each(|(i, ax)| {
assert!(
*ax < shape.len(),
"Requested axis[{i}]={ax} is beyond the dimensionality of shape: {}",
shape.len()
)
});
axes.iter()
.map(|&ax| max_level(width, shape[ax]))
.min()
.unwrap_or(0)
}
generate_wavelet_enum!(
Wavelet,
(Clone, Copy, Debug, PartialEq, Eq, Hash),
{
}
);
impl Wavelet {
pub fn max_level(&self, n: usize) -> usize {
max_level(self.width(), n)
}
pub fn width(&self) -> usize {
use bior::*;
use coiflet::*;
use daubechies::*;
use symlet::*;
generate_wavelet_match_arms! {Self, self, { #wvlt::WIDTH,}}
}
}
pub trait MulScalarAdd<A = Self, B = Self> {
type Output;
fn mul_add(self, a: A, b: B) -> Self::Output;
}
impl<T: num_traits::MulAdd<T, T, Output = T>> MulScalarAdd<T, T> for T {
type Output = T;
#[inline(always)]
fn mul_add(self, a: Self, b: Self) -> Self::Output {
<Self as num_traits::MulAdd>::mul_add(self, a, b)
}
}
pub trait Transformable:
NumOps
+ NumOps<Self::Scalar>
+ Clone
+ Neg<Output = Self>
+ NumAssignOps
+ NumAssignOps<Self::Scalar>
+ MulScalarAdd<Self::Scalar, Self, Output = Self>
{
type Scalar: FromPrimitive + Copy + NumOps + std::fmt::Debug;
#[inline(always)]
fn mul_add_op(self, b: Self::Scalar, c: Self) -> Self {
self.mul_add(b, c)
}
#[inline(always)]
fn neg_mul_add_op(self, b: Self::Scalar, c: Self) -> Self {
(-self).mul_add(b, c)
}
#[inline(always)]
fn scalar_type_from_isize(x: isize) -> Self::Scalar {
Self::Scalar::from_isize(x).unwrap()
}
#[inline(always)]
fn scalar_type_from_f64(x: f64) -> Self::Scalar {
Self::Scalar::from_f64(x).unwrap()
}
}
macro_rules! impl_transformable {
($T:ty) => {
impl Transformable for $T {
type Scalar = Self;
}
};
}
impl_transformable!(i8);
impl_transformable!(i16);
impl_transformable!(i32);
impl_transformable!(i64);
impl_transformable!(i128);
impl_transformable!(isize);
impl_transformable!(f32);
impl_transformable!(f64);
impl<T: MulAdd<Output = T> + Clone> MulScalarAdd<T, num_complex::Complex<T>>
for num_complex::Complex<T>
{
type Output = Self;
#[inline(always)]
fn mul_add(self, a: T, b: Self) -> Self::Output {
Self::Output {
re: T::mul_add(self.re, a.clone(), b.re),
im: T::mul_add(self.im, a, b.im),
}
}
}
impl<T: Num + Copy + Debug + FromPrimitive + MulAdd<Output = T> + Neg<Output = T> + NumAssignOps>
Transformable for num_complex::Complex<T>
{
type Scalar = T;
}
const N_BITS: usize = 256;
pub trait ChunkWidth<T, const N: usize> {}
macro_rules! impl_chunk_size {
($name:tt, $t:ty) => {
const $name: usize = N_BITS / <$t>::BITS as usize;
impl ChunkWidth<$t, $name> for $t {}
};
($name:tt, $t:ty, $bits:tt) => {
const $name: usize = N_BITS / $bits;
impl ChunkWidth<$t, $name> for $t {}
};
}
impl_chunk_size! {N_I8, i8}
impl_chunk_size! {N_I16, i16}
impl_chunk_size! {N_I32, i32}
impl_chunk_size! {N_I64, i64}
impl_chunk_size! {N_I128, i128}
impl_chunk_size! {N_ISIZE, isize}
impl_chunk_size! {N_F32, f32, 32}
impl_chunk_size! {N_F64, f64, 64}
impl_chunk_size! {N_C32, num_complex::Complex32, 64}
impl_chunk_size! {N_C64, num_complex::Complex64, 128}
#[doc(hidden)]
pub mod tests {
#[track_caller]
pub fn test_approx_equal<T>(actual: &[T], desired: &[T], rtol: T, atol: T)
where
T: num_traits::Float + std::fmt::Debug,
{
let n_a = actual.len();
let n_d = desired.len();
assert_eq!(
n_a, n_d,
"Slice length mismatch:\n actual: {n_a}\n desired: {n_d}"
);
let mut mismatch = None;
let mut max_adiff = None;
let mut max_rdiff = None;
actual.iter().zip(desired.iter()).for_each(|(a, d)| {
let abs_diff = (*a - *d).abs();
if abs_diff > rtol * d.abs() + atol {
mismatch = Some(mismatch.unwrap_or(0) + 1);
max_adiff = Some(max_adiff.unwrap_or(T::zero()).max(abs_diff));
let r_diff = if d.abs() == T::zero() {
T::infinity()
} else {
abs_diff / d.abs()
};
max_rdiff = Some(max_rdiff.unwrap_or(T::zero()).max(r_diff));
}
});
if let (Some(mismatch), Some(max_adiff), Some(max_rdiff)) = (mismatch, max_adiff, max_rdiff)
{
panic!(
"{}/{} mismatched elements:\n Maximum differences: absolute={:?}, relative={:?}\n actual:\n{:?}\n desired:\n{:?}",
mismatch, n_a, max_adiff, max_rdiff, actual, desired
);
}
}
#[track_caller]
pub fn test_approx_adjoint<F, FA, T>(f: F, f_adj: FA, u: &[T], v: &[T], rtol: T, atol: T)
where
F: Fn(&[T], &mut [T]),
FA: Fn(&[T], &mut [T]),
T: num_traits::Float + std::fmt::Debug,
{
let n_u = u.len();
let n_v = v.len();
let mut f_u = vec![T::zero(); n_v];
let mut f_adj_v = vec![T::zero(); n_u];
f(u, &mut f_u);
let v1 = std::iter::zip(f_u, v.iter().cloned()).fold(T::zero(), |acc, (x, y)| acc + x * y);
f_adj(v, &mut f_adj_v);
let v2 =
std::iter::zip(f_adj_v, u.iter().cloned()).fold(T::zero(), |acc, (x, y)| acc + x * y);
let abs_diff = (v1 - v2).abs();
let thresh = rtol * v1.abs() + atol;
assert!(
abs_diff <= thresh,
"{v1:?} and {v2:?} are not equal to tolerance rtol={rtol:?}, atol={atol:?}
Absolute difference: {:?}
Relative difference: {:?}
",
abs_diff,
abs_diff / v1.abs()
);
}
}