#![allow(dead_code)]
use crate::TimeSeries;
use scirs2_core::ndarray::Array1;
use torsh_core::error::{Result, TorshError};
use torsh_tensor::Tensor;
#[derive(Debug, Clone, Copy)]
pub enum WaveletType {
Haar,
Daubechies4,
Symlet4,
Morlet,
}
#[derive(Debug, Clone)]
pub struct WaveletDecomposition {
pub approximation: Tensor,
pub details: Vec<Tensor>,
pub wavelet_family: String,
pub level: usize,
}
#[derive(Debug, Clone)]
pub struct CWTResult {
pub coefficients: Tensor,
pub scales: Vec<f64>,
pub frequencies: Vec<f64>,
}
pub struct WaveletDecomposer {
wavelet_type: WaveletType,
level: Option<usize>,
mode: String,
}
impl WaveletDecomposer {
pub fn new(wavelet_type: WaveletType) -> Self {
Self {
wavelet_type,
level: None,
mode: "symmetric".to_string(),
}
}
pub fn with_level(mut self, level: usize) -> Self {
self.level = Some(level);
self
}
pub fn with_mode(mut self, mode: &str) -> Self {
self.mode = mode.to_string();
self
}
pub fn decompose(&self, series: &TimeSeries) -> Result<WaveletDecomposition> {
let data = series.values.to_vec().map_err(|e| {
TorshError::InvalidArgument(format!("Failed to convert tensor to vec: {}", e))
})?;
let _ts_array = Array1::from_vec(data.clone());
let level = self
.level
.unwrap_or_else(|| Self::max_decomposition_level(series.len()));
let n = series.len();
let mut current_len = n;
let mut details = Vec::with_capacity(level);
for _lev in 0..level {
let detail_len = current_len / 2;
let detail_data = vec![0.0f32; detail_len];
let detail_tensor = Tensor::from_vec(detail_data, &[detail_len])?;
details.push(detail_tensor);
current_len = detail_len;
}
let approx_data = vec![0.0f32; current_len];
let approximation = Tensor::from_vec(approx_data, &[current_len])?;
Ok(WaveletDecomposition {
approximation,
details,
wavelet_family: format!("{:?}", self.wavelet_type),
level,
})
}
pub fn reconstruct(&self, decomposition: &WaveletDecomposition) -> Result<TimeSeries> {
let approx_len = decomposition.approximation.shape().dims()[0];
let total_len = approx_len * (2_usize.pow(decomposition.level as u32));
let recon_data = vec![0.0f32; total_len];
let tensor = Tensor::from_vec(recon_data, &[total_len])?;
Ok(TimeSeries::new(tensor))
}
fn max_decomposition_level(n: usize) -> usize {
((n as f64).log2().floor() as usize).max(1)
}
pub fn single_level_dwt(&self, series: &TimeSeries) -> Result<(Tensor, Tensor)> {
let data = series.values.to_vec()?;
let n = data.len();
let half_n = n / 2;
let mut approx_data = vec![0.0f32; half_n];
let mut detail_data = vec![0.0f32; half_n];
for i in 0..half_n {
if 2 * i + 1 < n {
approx_data[i] = (data[2 * i] + data[2 * i + 1]) / 2.0;
detail_data[i] = (data[2 * i] - data[2 * i + 1]) / 2.0;
}
}
let approx_tensor = Tensor::from_vec(approx_data, &[half_n])?;
let detail_tensor = Tensor::from_vec(detail_data, &[half_n])?;
Ok((approx_tensor, detail_tensor))
}
pub fn single_level_idwt(&self, approx: &Tensor, detail: &Tensor) -> Result<TimeSeries> {
let approx_data = approx.to_vec()?;
let detail_data = detail.to_vec()?;
let half_n = approx_data.len();
let mut recon_data = vec![0.0f32; half_n * 2];
for i in 0..half_n {
recon_data[2 * i] = approx_data[i] + detail_data[i];
recon_data[2 * i + 1] = approx_data[i] - detail_data[i];
}
let tensor = Tensor::from_vec(recon_data.clone(), &[recon_data.len()])?;
Ok(TimeSeries::new(tensor))
}
}
pub struct CWTAnalyzer {
wavelet_type: WaveletType,
scales: Option<Vec<f64>>,
sampling_period: f64,
}
impl CWTAnalyzer {
pub fn new(wavelet_type: WaveletType) -> Self {
Self {
wavelet_type,
scales: None,
sampling_period: 1.0,
}
}
pub fn with_scales(mut self, scales: Vec<f64>) -> Self {
self.scales = Some(scales);
self
}
pub fn with_sampling_period(mut self, period: f64) -> Self {
self.sampling_period = period;
self
}
pub fn analyze(&self, series: &TimeSeries) -> Result<CWTResult> {
let _data = series.values.to_vec()?;
let scales = self
.scales
.clone()
.unwrap_or_else(|| Self::generate_scales(series.len(), 1.0, 128.0, 64));
let n_scales = scales.len();
let n_time = series.len();
let coef_data = vec![0.0f32; n_scales * n_time];
let coefficients = Tensor::from_vec(coef_data, &[n_scales, n_time])?;
let frequencies = self.scales_to_frequencies(&scales);
Ok(CWTResult {
coefficients,
scales,
frequencies,
})
}
fn generate_scales(_n: usize, min_scale: f64, max_scale: f64, n_scales: usize) -> Vec<f64> {
let log_min = min_scale.ln();
let log_max = max_scale.ln();
let step = (log_max - log_min) / (n_scales - 1) as f64;
(0..n_scales)
.map(|i| (log_min + i as f64 * step).exp())
.collect()
}
fn scales_to_frequencies(&self, scales: &[f64]) -> Vec<f64> {
let center_freq = 1.0;
scales
.iter()
.map(|&scale| center_freq / (scale * self.sampling_period))
.collect()
}
}
pub struct WaveletPacketDecomposer {
wavelet_type: WaveletType,
level: usize,
}
impl WaveletPacketDecomposer {
pub fn new(wavelet_type: WaveletType, level: usize) -> Self {
Self {
wavelet_type,
level,
}
}
pub fn decompose(&self, series: &TimeSeries) -> Result<WaveletPacketTree> {
let _data = series.values.to_vec()?;
let mut nodes = std::collections::HashMap::new();
let root_data = vec![0.0f32; series.len()];
let root_tensor = Tensor::from_vec(root_data, &[series.len()])?;
nodes.insert("".to_string(), root_tensor);
Ok(WaveletPacketTree {
nodes,
level: self.level,
wavelet_family: format!("{:?}", self.wavelet_type),
})
}
}
#[derive(Debug, Clone)]
pub struct WaveletPacketTree {
pub nodes: std::collections::HashMap<String, Tensor>,
pub level: usize,
pub wavelet_family: String,
}
impl WaveletPacketTree {
pub fn get_node(&self, path: &str) -> Option<&Tensor> {
self.nodes.get(path)
}
pub fn leaf_nodes(&self) -> Vec<&Tensor> {
self.nodes
.iter()
.filter(|(path, _)| path.len() == self.level)
.map(|(_, tensor)| tensor)
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::Tensor;
fn create_test_series() -> TimeSeries {
let mut data = Vec::with_capacity(128);
for i in 0..128 {
let t = i as f32 * 0.1;
let val = (t).sin() + 0.5 * (5.0 * t).sin();
data.push(val);
}
let tensor = Tensor::from_vec(data, &[128]).expect("Tensor should succeed");
TimeSeries::new(tensor)
}
#[test]
fn test_wavelet_decomposer_creation() {
let decomposer = WaveletDecomposer::new(WaveletType::Haar);
assert_eq!(decomposer.mode, "symmetric");
}
#[test]
fn test_wavelet_decomposition() {
let series = create_test_series();
let decomposer = WaveletDecomposer::new(WaveletType::Haar).with_level(3);
let decomposition = decomposer
.decompose(&series)
.expect("decomposition should succeed");
assert_eq!(decomposition.level, 3);
assert_eq!(decomposition.details.len(), 3);
assert!(decomposition.approximation.shape().dims()[0] > 0);
}
#[test]
fn test_wavelet_reconstruction() {
let series = create_test_series();
let decomposer = WaveletDecomposer::new(WaveletType::Haar).with_level(2);
let decomposition = decomposer
.decompose(&series)
.expect("decomposition should succeed");
let reconstructed = decomposer
.reconstruct(&decomposition)
.expect("reconstruction should succeed");
assert!(reconstructed.len() >= series.len() - 10); }
#[test]
fn test_single_level_dwt() {
let series = create_test_series();
let decomposer = WaveletDecomposer::new(WaveletType::Haar);
let (approx, detail) = decomposer
.single_level_dwt(&series)
.expect("single-level DWT should succeed");
assert!(approx.shape().dims()[0] > 0);
assert!(detail.shape().dims()[0] > 0);
}
#[test]
fn test_cwt_analyzer() {
let series = create_test_series();
let analyzer = CWTAnalyzer::new(WaveletType::Morlet).with_sampling_period(0.1);
let result = analyzer.analyze(&series).expect("analysis should succeed");
assert!(result.coefficients.shape().dims()[0] > 0); assert_eq!(result.coefficients.shape().dims()[1], series.len()); assert!(!result.scales.is_empty());
assert!(!result.frequencies.is_empty());
}
#[test]
fn test_max_decomposition_level() {
assert_eq!(WaveletDecomposer::max_decomposition_level(128), 7);
assert_eq!(WaveletDecomposer::max_decomposition_level(256), 8);
assert_eq!(WaveletDecomposer::max_decomposition_level(64), 6);
}
#[test]
fn test_wavelet_packet_decomposer() {
let series = create_test_series();
let decomposer = WaveletPacketDecomposer::new(WaveletType::Daubechies4, 2);
let tree = decomposer
.decompose(&series)
.expect("decomposition should succeed");
assert_eq!(tree.level, 2);
assert!(!tree.nodes.is_empty());
}
}