use crate::DType;
use super::types::Wavelet;
use numr::error::Result;
use numr::runtime::Runtime;
use numr::tensor::Tensor;
pub trait DwtAlgorithms<R: Runtime<DType = DType>> {
fn dwt(&self, x: &Tensor<R>, wavelet: &Wavelet, mode: ExtensionMode) -> Result<DwtResult<R>>;
fn idwt(
&self,
coeffs: &DwtResult<R>,
wavelet: &Wavelet,
mode: ExtensionMode,
) -> Result<Tensor<R>>;
fn wavedec(
&self,
x: &Tensor<R>,
wavelet: &Wavelet,
mode: ExtensionMode,
level: usize,
) -> Result<WavedecResult<R>>;
fn waverec(
&self,
coeffs: &WavedecResult<R>,
wavelet: &Wavelet,
mode: ExtensionMode,
) -> Result<Tensor<R>>;
fn dwt2(&self, x: &Tensor<R>, wavelet: &Wavelet, mode: ExtensionMode)
-> Result<Dwt2dResult<R>>;
fn idwt2(
&self,
coeffs: &Dwt2dResult<R>,
wavelet: &Wavelet,
mode: ExtensionMode,
) -> Result<Tensor<R>>;
}
pub trait CwtAlgorithms<R: Runtime<DType = DType>> {
fn cwt(&self, x: &Tensor<R>, scales: &Tensor<R>, wavelet: &Wavelet) -> Result<CwtResult<R>>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ExtensionMode {
#[default]
Zero,
Constant,
Symmetric,
Periodic,
Smooth,
}
#[derive(Debug, Clone)]
pub struct DwtResult<R: Runtime<DType = DType>> {
pub approx: Tensor<R>,
pub detail: Tensor<R>,
}
#[derive(Debug, Clone)]
pub struct WavedecResult<R: Runtime<DType = DType>> {
pub approx: Tensor<R>,
pub details: Vec<Tensor<R>>,
}
impl<R: Runtime<DType = DType>> WavedecResult<R> {
pub fn num_levels(&self) -> usize {
self.details.len()
}
pub fn detail(&self, level: usize) -> Option<&Tensor<R>> {
if level > 0 && level <= self.details.len() {
Some(&self.details[level - 1])
} else {
None
}
}
}
#[derive(Debug, Clone)]
pub struct Dwt2dResult<R: Runtime<DType = DType>> {
pub ll: Tensor<R>,
pub lh: Tensor<R>,
pub hl: Tensor<R>,
pub hh: Tensor<R>,
}
#[derive(Debug, Clone)]
pub struct CwtResult<R: Runtime<DType = DType>> {
pub coeffs_real: Tensor<R>,
pub coeffs_imag: Tensor<R>,
pub scales: Tensor<R>,
}
impl<R: Runtime<DType = DType>> CwtResult<R> {
pub fn magnitude(&self) -> Result<Tensor<R>> {
let re: Vec<f64> = self.coeffs_real.to_vec();
let im: Vec<f64> = self.coeffs_imag.to_vec();
let mag: Vec<f64> = re
.iter()
.zip(im.iter())
.map(|(&r, &i)| (r * r + i * i).sqrt())
.collect();
let shape = self.coeffs_real.shape().to_vec();
let device = self.coeffs_real.device();
Ok(Tensor::from_slice(&mag, &shape, device))
}
pub fn phase(&self) -> Result<Tensor<R>> {
let re: Vec<f64> = self.coeffs_real.to_vec();
let im: Vec<f64> = self.coeffs_imag.to_vec();
let phase: Vec<f64> = re
.iter()
.zip(im.iter())
.map(|(&r, &i)| i.atan2(r))
.collect();
let shape = self.coeffs_real.shape().to_vec();
let device = self.coeffs_real.device();
Ok(Tensor::from_slice(&phase, &shape, device))
}
}