torsh-series 0.1.2

Time series analysis components for ToRSh - powered by SciRS2
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
//! Wavelet decomposition for time series analysis
//!
//! Wavelets provide a time-frequency representation of signals, making them ideal
//! for analyzing non-stationary time series with features at different scales.
//!
//! NOTE: This module provides the structure for wavelet analysis.
//! Full implementation will use scirs2-signal when the wavelet API is available.

// Framework infrastructure - components designed for future use
#![allow(dead_code)]
use crate::TimeSeries;
use scirs2_core::ndarray::Array1;
// TODO: Uncomment when scirs2-signal wavelet API is available
// use scirs2_signal::wavelets::{dwt, idwt, wavelet_decompose, wavelet_reconstruct, WaveletType};
use torsh_core::error::{Result, TorshError};
use torsh_tensor::Tensor;

/// Wavelet type enumeration (placeholder until scirs2-signal provides it)
#[derive(Debug, Clone, Copy)]
pub enum WaveletType {
    Haar,
    Daubechies4,
    Symlet4,
    Morlet,
}

/// Wavelet decomposition result
#[derive(Debug, Clone)]
pub struct WaveletDecomposition {
    /// Approximation coefficients (low frequency)
    pub approximation: Tensor,
    /// Detail coefficients (high frequency) for each level
    pub details: Vec<Tensor>,
    /// Wavelet family used
    pub wavelet_family: String,
    /// Decomposition level
    pub level: usize,
}

/// Continuous Wavelet Transform (CWT) result
#[derive(Debug, Clone)]
pub struct CWTResult {
    /// Wavelet coefficients [scale x time]
    pub coefficients: Tensor,
    /// Scales used
    pub scales: Vec<f64>,
    /// Frequencies corresponding to scales
    pub frequencies: Vec<f64>,
}

/// Wavelet-based time series decomposer
pub struct WaveletDecomposer {
    wavelet_type: WaveletType,
    level: Option<usize>,
    mode: String,
}

impl WaveletDecomposer {
    /// Create a new wavelet decomposer
    ///
    /// # Arguments
    /// * `wavelet_type` - Type of wavelet to use (e.g., Haar, Daubechies, Symlet)
    pub fn new(wavelet_type: WaveletType) -> Self {
        Self {
            wavelet_type,
            level: None,
            mode: "symmetric".to_string(),
        }
    }

    /// Set decomposition level
    pub fn with_level(mut self, level: usize) -> Self {
        self.level = Some(level);
        self
    }

    /// Set boundary extension mode
    pub fn with_mode(mut self, mode: &str) -> Self {
        self.mode = mode.to_string();
        self
    }

    /// Perform Discrete Wavelet Transform (DWT) decomposition
    pub fn decompose(&self, series: &TimeSeries) -> Result<WaveletDecomposition> {
        // Convert TimeSeries to Array1
        let data = series.values.to_vec().map_err(|e| {
            TorshError::InvalidArgument(format!("Failed to convert tensor to vec: {}", e))
        })?;
        let _ts_array = Array1::from_vec(data.clone());

        // Determine decomposition level if not specified
        let level = self
            .level
            .unwrap_or_else(|| Self::max_decomposition_level(series.len()));

        // TODO: Use scirs2-signal wavelet_decompose when available
        // For now, return placeholder decomposition
        let n = series.len();
        let mut current_len = n;
        let mut details = Vec::with_capacity(level);

        // Create placeholder detail coefficients (simplified Haar-like decomposition)
        for _lev in 0..level {
            let detail_len = current_len / 2;
            let detail_data = vec![0.0f32; detail_len];
            let detail_tensor = Tensor::from_vec(detail_data, &[detail_len])?;
            details.push(detail_tensor);
            current_len = detail_len;
        }

        // Approximation coefficients (same length as last detail level)
        let approx_data = vec![0.0f32; current_len];
        let approximation = Tensor::from_vec(approx_data, &[current_len])?;

        Ok(WaveletDecomposition {
            approximation,
            details,
            wavelet_family: format!("{:?}", self.wavelet_type),
            level,
        })
    }

    /// Reconstruct time series from wavelet decomposition
    pub fn reconstruct(&self, decomposition: &WaveletDecomposition) -> Result<TimeSeries> {
        // TODO: Use scirs2-signal wavelet_reconstruct when available
        // For now, return placeholder reconstruction
        let approx_len = decomposition.approximation.shape().dims()[0];
        let total_len = approx_len * (2_usize.pow(decomposition.level as u32));

        let recon_data = vec![0.0f32; total_len];
        let tensor = Tensor::from_vec(recon_data, &[total_len])?;

        Ok(TimeSeries::new(tensor))
    }

    /// Calculate maximum useful decomposition level
    fn max_decomposition_level(n: usize) -> usize {
        // Maximum level is floor(log2(n))
        ((n as f64).log2().floor() as usize).max(1)
    }

    /// Perform single-level DWT
    pub fn single_level_dwt(&self, series: &TimeSeries) -> Result<(Tensor, Tensor)> {
        let data = series.values.to_vec()?;
        let n = data.len();

        // TODO: Use scirs2-signal dwt when available
        // Placeholder implementation (simplified Haar transform)
        let half_n = n / 2;
        let mut approx_data = vec![0.0f32; half_n];
        let mut detail_data = vec![0.0f32; half_n];

        for i in 0..half_n {
            if 2 * i + 1 < n {
                approx_data[i] = (data[2 * i] + data[2 * i + 1]) / 2.0;
                detail_data[i] = (data[2 * i] - data[2 * i + 1]) / 2.0;
            }
        }

        let approx_tensor = Tensor::from_vec(approx_data, &[half_n])?;
        let detail_tensor = Tensor::from_vec(detail_data, &[half_n])?;

        Ok((approx_tensor, detail_tensor))
    }

    /// Perform single-level inverse DWT
    pub fn single_level_idwt(&self, approx: &Tensor, detail: &Tensor) -> Result<TimeSeries> {
        let approx_data = approx.to_vec()?;
        let detail_data = detail.to_vec()?;

        // TODO: Use scirs2-signal idwt when available
        // Placeholder implementation (simplified inverse Haar transform)
        let half_n = approx_data.len();
        let mut recon_data = vec![0.0f32; half_n * 2];

        for i in 0..half_n {
            recon_data[2 * i] = approx_data[i] + detail_data[i];
            recon_data[2 * i + 1] = approx_data[i] - detail_data[i];
        }

        let tensor = Tensor::from_vec(recon_data.clone(), &[recon_data.len()])?;
        Ok(TimeSeries::new(tensor))
    }
}

/// Continuous Wavelet Transform (CWT) analyzer
pub struct CWTAnalyzer {
    wavelet_type: WaveletType,
    scales: Option<Vec<f64>>,
    sampling_period: f64,
}

impl CWTAnalyzer {
    /// Create a new CWT analyzer
    pub fn new(wavelet_type: WaveletType) -> Self {
        Self {
            wavelet_type,
            scales: None,
            sampling_period: 1.0,
        }
    }

    /// Set scales for CWT
    pub fn with_scales(mut self, scales: Vec<f64>) -> Self {
        self.scales = Some(scales);
        self
    }

    /// Set sampling period
    pub fn with_sampling_period(mut self, period: f64) -> Self {
        self.sampling_period = period;
        self
    }

    /// Perform Continuous Wavelet Transform
    pub fn analyze(&self, series: &TimeSeries) -> Result<CWTResult> {
        let _data = series.values.to_vec()?;

        // Generate scales if not provided
        let scales = self
            .scales
            .clone()
            .unwrap_or_else(|| Self::generate_scales(series.len(), 1.0, 128.0, 64));

        // TODO: Use scirs2-signal cwt when available
        // Placeholder implementation
        let n_scales = scales.len();
        let n_time = series.len();
        let coef_data = vec![0.0f32; n_scales * n_time];
        let coefficients = Tensor::from_vec(coef_data, &[n_scales, n_time])?;

        // Calculate frequencies from scales
        let frequencies = self.scales_to_frequencies(&scales);

        Ok(CWTResult {
            coefficients,
            scales,
            frequencies,
        })
    }

    /// Generate logarithmically spaced scales
    fn generate_scales(_n: usize, min_scale: f64, max_scale: f64, n_scales: usize) -> Vec<f64> {
        let log_min = min_scale.ln();
        let log_max = max_scale.ln();
        let step = (log_max - log_min) / (n_scales - 1) as f64;

        (0..n_scales)
            .map(|i| (log_min + i as f64 * step).exp())
            .collect()
    }

    /// Convert scales to frequencies
    fn scales_to_frequencies(&self, scales: &[f64]) -> Vec<f64> {
        // For Morlet wavelet, center frequency is typically ~1.0
        // Frequency = center_freq / (scale * sampling_period)
        let center_freq = 1.0;
        scales
            .iter()
            .map(|&scale| center_freq / (scale * self.sampling_period))
            .collect()
    }
}

/// Wavelet packet decomposition for more flexible time-frequency analysis
pub struct WaveletPacketDecomposer {
    wavelet_type: WaveletType,
    level: usize,
}

impl WaveletPacketDecomposer {
    /// Create a new wavelet packet decomposer
    pub fn new(wavelet_type: WaveletType, level: usize) -> Self {
        Self {
            wavelet_type,
            level,
        }
    }

    /// Decompose into wavelet packet tree
    pub fn decompose(&self, series: &TimeSeries) -> Result<WaveletPacketTree> {
        let _data = series.values.to_vec()?;

        // TODO: Use scirs2-signal wavelet_packet_decompose when available
        // Placeholder implementation
        let mut nodes = std::collections::HashMap::new();

        // Create root node
        let root_data = vec![0.0f32; series.len()];
        let root_tensor = Tensor::from_vec(root_data, &[series.len()])?;
        nodes.insert("".to_string(), root_tensor);

        Ok(WaveletPacketTree {
            nodes,
            level: self.level,
            wavelet_family: format!("{:?}", self.wavelet_type),
        })
    }
}

/// Wavelet packet tree structure
#[derive(Debug, Clone)]
pub struct WaveletPacketTree {
    /// Nodes indexed by path (e.g., "a" for approximation, "d" for detail)
    pub nodes: std::collections::HashMap<String, Tensor>,
    /// Decomposition level
    pub level: usize,
    /// Wavelet family
    pub wavelet_family: String,
}

impl WaveletPacketTree {
    /// Get node at specific path
    pub fn get_node(&self, path: &str) -> Option<&Tensor> {
        self.nodes.get(path)
    }

    /// Get all leaf nodes at the deepest level
    pub fn leaf_nodes(&self) -> Vec<&Tensor> {
        self.nodes
            .iter()
            .filter(|(path, _)| path.len() == self.level)
            .map(|(_, tensor)| tensor)
            .collect()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use torsh_tensor::Tensor;

    fn create_test_series() -> TimeSeries {
        // Create a simple time series with multiple frequencies
        let mut data = Vec::with_capacity(128);
        for i in 0..128 {
            let t = i as f32 * 0.1;
            // Low frequency + high frequency components
            let val = (t).sin() + 0.5 * (5.0 * t).sin();
            data.push(val);
        }
        let tensor = Tensor::from_vec(data, &[128]).expect("Tensor should succeed");
        TimeSeries::new(tensor)
    }

    #[test]
    fn test_wavelet_decomposer_creation() {
        let decomposer = WaveletDecomposer::new(WaveletType::Haar);
        assert_eq!(decomposer.mode, "symmetric");
    }

    #[test]
    fn test_wavelet_decomposition() {
        let series = create_test_series();
        let decomposer = WaveletDecomposer::new(WaveletType::Haar).with_level(3);
        let decomposition = decomposer
            .decompose(&series)
            .expect("decomposition should succeed");

        assert_eq!(decomposition.level, 3);
        assert_eq!(decomposition.details.len(), 3);
        assert!(decomposition.approximation.shape().dims()[0] > 0);
    }

    #[test]
    fn test_wavelet_reconstruction() {
        let series = create_test_series();
        let decomposer = WaveletDecomposer::new(WaveletType::Haar).with_level(2);
        let decomposition = decomposer
            .decompose(&series)
            .expect("decomposition should succeed");
        let reconstructed = decomposer
            .reconstruct(&decomposition)
            .expect("reconstruction should succeed");

        // Reconstructed should have similar length to original
        assert!(reconstructed.len() >= series.len() - 10); // Allow small difference due to padding
    }

    #[test]
    fn test_single_level_dwt() {
        let series = create_test_series();
        let decomposer = WaveletDecomposer::new(WaveletType::Haar);
        let (approx, detail) = decomposer
            .single_level_dwt(&series)
            .expect("single-level DWT should succeed");

        assert!(approx.shape().dims()[0] > 0);
        assert!(detail.shape().dims()[0] > 0);
    }

    #[test]
    fn test_cwt_analyzer() {
        let series = create_test_series();
        let analyzer = CWTAnalyzer::new(WaveletType::Morlet).with_sampling_period(0.1);
        let result = analyzer.analyze(&series).expect("analysis should succeed");

        assert!(result.coefficients.shape().dims()[0] > 0); // scales
        assert_eq!(result.coefficients.shape().dims()[1], series.len()); // time
        assert!(!result.scales.is_empty());
        assert!(!result.frequencies.is_empty());
    }

    #[test]
    fn test_max_decomposition_level() {
        assert_eq!(WaveletDecomposer::max_decomposition_level(128), 7);
        assert_eq!(WaveletDecomposer::max_decomposition_level(256), 8);
        assert_eq!(WaveletDecomposer::max_decomposition_level(64), 6);
    }

    #[test]
    fn test_wavelet_packet_decomposer() {
        let series = create_test_series();
        let decomposer = WaveletPacketDecomposer::new(WaveletType::Daubechies4, 2);
        let tree = decomposer
            .decompose(&series)
            .expect("decomposition should succeed");

        assert_eq!(tree.level, 2);
        assert!(!tree.nodes.is_empty());
    }
}