1#[allow(unused_imports)]
6use super::functions::*;
7#[cfg(test)]
8mod tests {
9 use super::*;
10 use crate::wavelet_transform::DaubechiesWavelet;
11 use crate::wavelet_transform::HaarTransform;
12 use crate::wavelet_transform::LiftingHaar;
13 use crate::wavelet_transform::MotherWavelet;
14 use crate::wavelet_transform::ThresholdMode;
15 use crate::wavelet_transform::WaveletFamily;
16 use std::f64::consts::PI;
17 pub(super) const TOL: f64 = 1e-6;
18 fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
19 (a - b).abs() < tol
20 }
21 #[test]
22 fn test_haar_forward_inverse() {
23 let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
24 let (a, d) = HaarTransform::forward(&signal);
25 let reconstructed = HaarTransform::inverse(&a, &d);
26 for (i, (&orig, &rec)) in signal.iter().zip(reconstructed.iter()).enumerate() {
27 assert!(
28 approx_eq(orig, rec, TOL),
29 "Haar mismatch at {}: {} vs {}",
30 i,
31 orig,
32 rec
33 );
34 }
35 }
36 #[test]
37 fn test_haar_constant_signal() {
38 let signal = vec![5.0; 8];
39 let (a, d) = HaarTransform::forward(&signal);
40 for &di in &d {
41 assert!(di.abs() < TOL, "Haar detail should be zero for constant");
42 }
43 assert_eq!(a.len(), 4);
44 }
45 #[test]
46 fn test_haar_multilevel() {
47 let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
48 let result = HaarTransform::forward_multilevel(&signal, 3);
49 assert_eq!(result.len(), 4);
50 assert_eq!(result[0].len(), 1);
51 }
52 #[test]
53 fn test_dwt_haar_single_level() {
54 let signal = vec![1.0, 3.0, 5.0, 7.0];
55 let level = dwt_single(&signal, WaveletFamily::Haar);
56 assert_eq!(level.approx.len(), 2);
57 assert_eq!(level.detail.len(), 2);
58 }
59 #[test]
60 fn test_dwt_multilevel_haar() {
61 let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
62 let decomp = dwt(&signal, WaveletFamily::Haar, 3);
63 assert_eq!(decomp.details.len(), 3);
64 }
65 #[test]
66 fn test_dwt_idwt_haar_roundtrip() {
67 let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
68 let decomp = dwt(&signal, WaveletFamily::Haar, 2);
69 let rec = idwt(&decomp);
70 assert_eq!(rec.len(), signal.len());
71 for (i, (&o, &r)) in signal.iter().zip(rec.iter()).enumerate() {
72 assert!(
73 approx_eq(o, r, 1e-4),
74 "Haar roundtrip mismatch at {}: {} vs {}",
75 i,
76 o,
77 r
78 );
79 }
80 }
81 #[test]
82 fn test_daubechies_db2_forward() {
83 let db = DaubechiesWavelet::new(2);
84 let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
85 let level = db.forward(&signal);
86 assert_eq!(level.approx.len(), 5);
87 assert_eq!(level.detail.len(), 5);
88 }
89 #[test]
90 fn test_daubechies_db3_forward() {
91 let db = DaubechiesWavelet::new(3);
92 let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
93 let level = db.forward(&signal);
94 assert!(!level.approx.is_empty());
95 assert!(!level.detail.is_empty());
96 }
97 #[test]
98 fn test_daubechies_db4_forward() {
99 let db = DaubechiesWavelet::new(4);
100 let signal: Vec<f64> = (0..16).map(|i| (i as f64).sin()).collect();
101 let level = db.forward(&signal);
102 assert!(!level.approx.is_empty());
103 }
104 #[test]
105 fn test_daubechies_db5_forward() {
106 let db = DaubechiesWavelet::new(5);
107 let signal: Vec<f64> = (0..16).map(|i| (i as f64 * 0.5).cos()).collect();
108 let level = db.forward(&signal);
109 assert!(!level.approx.is_empty());
110 }
111 #[test]
112 fn test_daubechies_db6_forward() {
113 let db = DaubechiesWavelet::new(6);
114 let signal: Vec<f64> = (0..32).map(|i| (i as f64 * 0.3).sin()).collect();
115 let level = db.forward(&signal);
116 assert!(!level.approx.is_empty());
117 }
118 #[test]
119 fn test_multiresolution_analysis() {
120 let signal: Vec<f64> = (0..32).map(|i| (i as f64 * PI / 8.0).sin()).collect();
121 let mra = multiresolution_analysis(&signal, WaveletFamily::Haar, 3);
122 assert_eq!(mra.approximations.len(), 4);
123 assert_eq!(mra.detail_contributions.len(), 3);
124 }
125 #[test]
126 fn test_wavelet_packet_decompose() {
127 let signal: Vec<f64> = (0..16).map(|i| (i as f64).sin()).collect();
128 let tree = wavelet_packet_decompose(&signal, WaveletFamily::Haar, 2);
129 assert!(tree.nodes.len() >= 2);
130 assert_eq!(tree.nodes[0].len(), 1);
131 }
132 #[test]
133 fn test_best_basis_selection() {
134 let signal: Vec<f64> = (0..16).map(|i| (i as f64 * 0.5).sin()).collect();
135 let tree = wavelet_packet_decompose(&signal, WaveletFamily::Haar, 2);
136 let basis = best_basis_selection(&tree);
137 assert!(!basis.is_empty());
138 }
139 #[test]
140 fn test_cwt_morlet() {
141 let signal: Vec<f64> = (0..64)
142 .map(|i| (2.0 * PI * i as f64 / 16.0).sin())
143 .collect();
144 let scales = log_scales(1.0, 8, 0.5);
145 let result = cwt(&signal, &scales, MotherWavelet::morlet(), 1.0);
146 assert_eq!(result.coefficients.len(), scales.len());
147 assert_eq!(result.coefficients[0].len(), signal.len());
148 }
149 #[test]
150 fn test_cwt_mexican_hat() {
151 let signal: Vec<f64> = (0..32).map(|i| (2.0 * PI * i as f64 / 8.0).sin()).collect();
152 let scales = linear_scales(1.0, 5.0, 5);
153 let result = cwt(&signal, &scales, MotherWavelet::mexican_hat(), 1.0);
154 assert_eq!(result.coefficients.len(), 5);
155 }
156 #[test]
157 fn test_morlet_evaluation() {
158 let morlet = MotherWavelet::morlet();
159 let v = morlet.evaluate(0.0);
160 assert!(v > 0.9, "Morlet at t=0 should be close to 1");
161 }
162 #[test]
163 fn test_mexican_hat_evaluation() {
164 let mh = MotherWavelet::mexican_hat();
165 let v0 = mh.evaluate(0.0);
166 assert!(v0 > 0.0, "Mexican hat at t=0 should be positive");
167 let v5 = mh.evaluate(5.0);
168 assert!(v5.abs() < 0.01, "Mexican hat should decay at large t");
169 }
170 #[test]
171 fn test_scalogram() {
172 let signal: Vec<f64> = (0..32).map(|i| (i as f64 * 0.5).sin()).collect();
173 let scales = log_scales(1.0, 4, 0.5);
174 let cwt_result = cwt(&signal, &scales, MotherWavelet::morlet(), 1.0);
175 let scalo = scalogram(&cwt_result);
176 assert_eq!(scalo.energy.len(), 4);
177 assert!(scalo.scale_energy.iter().all(|&e| e >= 0.0));
178 }
179 #[test]
180 fn test_global_wavelet_spectrum() {
181 let signal: Vec<f64> = (0..32).map(|i| (i as f64 * 0.3).sin()).collect();
182 let scales = log_scales(1.0, 6, 0.5);
183 let cwt_result = cwt(&signal, &scales, MotherWavelet::morlet(), 1.0);
184 let scalo = scalogram(&cwt_result);
185 let gws = global_wavelet_spectrum(&scalo);
186 assert_eq!(gws.len(), 6);
187 assert!(gws.iter().all(|&e| e >= 0.0));
188 }
189 #[test]
190 fn test_hard_thresholding() {
191 let coeffs = vec![0.1, 0.5, -0.3, 1.0, -0.05, 2.0];
192 let result = apply_threshold(&coeffs, 0.4, ThresholdMode::Hard);
193 assert_eq!(result[0], 0.0);
194 assert_eq!(result[1], 0.5);
195 assert_eq!(result[4], 0.0);
196 }
197 #[test]
198 fn test_soft_thresholding() {
199 let coeffs = vec![1.0, -2.0, 0.3];
200 let result = apply_threshold(&coeffs, 0.5, ThresholdMode::Soft);
201 assert!(approx_eq(result[0], 0.5, TOL));
202 assert!(approx_eq(result[1], -1.5, TOL));
203 assert_eq!(result[2], 0.0);
204 }
205 #[test]
206 fn test_universal_threshold() {
207 let thresh = universal_threshold(1.0, 100);
208 assert!(thresh > 0.0);
209 assert!(thresh < 5.0);
210 }
211 #[test]
212 fn test_estimate_noise_sigma() {
213 let noise: Vec<f64> = (0..100).map(|i| (i as f64 * 0.7).sin() * 0.1).collect();
214 let sigma = estimate_noise_sigma(&noise);
215 assert!(sigma > 0.0);
216 }
217 #[test]
218 fn test_wavelet_denoise_soft() {
219 let clean: Vec<f64> = (0..64)
220 .map(|i| (2.0 * PI * i as f64 / 16.0).sin())
221 .collect();
222 let noisy: Vec<f64> = clean.iter().map(|&x| x + 0.1).collect();
223 let denoised = wavelet_denoise(&noisy, WaveletFamily::Haar, 3, ThresholdMode::Soft, None);
224 assert_eq!(denoised.len(), noisy.len());
225 }
226 #[test]
227 fn test_wavelet_denoise_hard() {
228 let signal: Vec<f64> = (0..32).map(|i| (i as f64 * 0.4).cos()).collect();
229 let denoised = wavelet_denoise(
230 &signal,
231 WaveletFamily::Db2,
232 2,
233 ThresholdMode::Hard,
234 Some(0.1),
235 );
236 assert_eq!(denoised.len(), signal.len());
237 }
238 #[test]
239 fn test_level_energy() {
240 let coeffs = vec![1.0, 2.0, 3.0];
241 assert!(approx_eq(level_energy(&coeffs), 14.0, TOL));
242 }
243 #[test]
244 fn test_energy_distribution() {
245 let signal: Vec<f64> = (0..32).map(|i| (i as f64 * 0.2).sin()).collect();
246 let decomp = dwt(&signal, WaveletFamily::Haar, 3);
247 let (detail_e, approx_e) = energy_distribution(&decomp);
248 assert_eq!(detail_e.len(), 3);
249 assert!(approx_e >= 0.0);
250 }
251 #[test]
252 fn test_relative_energy() {
253 let signal: Vec<f64> = (0..16).map(|i| (i as f64 * 0.5).sin()).collect();
254 let decomp = dwt(&signal, WaveletFamily::Haar, 2);
255 let rel = relative_energy(&decomp);
256 let total: f64 = rel.iter().sum();
257 assert!(
258 approx_eq(total, 1.0, 0.01),
259 "Relative energy should sum to ~1.0"
260 );
261 }
262 #[test]
263 fn test_wavelet_entropy() {
264 let signal: Vec<f64> = (0..32).map(|i| (i as f64).sin()).collect();
265 let decomp = dwt(&signal, WaveletFamily::Haar, 3);
266 let ent = wavelet_entropy(&decomp);
267 assert!(ent >= 0.0);
268 }
269 #[test]
270 fn test_swt() {
271 let signal: Vec<f64> = (0..16).map(|i| (i as f64 * 0.3).cos()).collect();
272 let result = swt(&signal, WaveletFamily::Haar, 2);
273 assert_eq!(result.details.len(), 2);
274 assert_eq!(result.details[0].len(), 16);
275 assert_eq!(result.approx.len(), 16);
276 }
277 #[test]
278 fn test_wavelet_cross_spectrum() {
279 let x: Vec<f64> = (0..32).map(|i| (i as f64 * 0.3).sin()).collect();
280 let y: Vec<f64> = (0..32).map(|i| (i as f64 * 0.3).cos()).collect();
281 let scales = log_scales(1.0, 3, 1.0);
282 let cwt_x = cwt(&x, &scales, MotherWavelet::morlet(), 1.0);
283 let cwt_y = cwt(&y, &scales, MotherWavelet::morlet(), 1.0);
284 let cross = wavelet_cross_spectrum(&cwt_x, &cwt_y);
285 assert_eq!(cross.len(), 3);
286 }
287 #[test]
288 fn test_wavelet_coherence() {
289 let x: Vec<f64> = (0..16).map(|i| (i as f64 * 0.5).sin()).collect();
290 let y = x.clone();
291 let scales = log_scales(1.0, 3, 1.0);
292 let cwt_x = cwt(&x, &scales, MotherWavelet::morlet(), 1.0);
293 let cwt_y = cwt(&y, &scales, MotherWavelet::morlet(), 1.0);
294 let coh = wavelet_coherence(&cwt_x, &cwt_y, 2);
295 assert_eq!(coh.len(), 3);
296 for row in &coh {
297 for &c in row {
298 assert!((0.0..=1.0 + TOL).contains(&c));
299 }
300 }
301 }
302 #[test]
303 fn test_log_scales() {
304 let scales = log_scales(1.0, 5, 0.5);
305 assert_eq!(scales.len(), 5);
306 assert!(approx_eq(scales[0], 1.0, TOL));
307 for i in 1..scales.len() {
308 assert!(scales[i] > scales[i - 1]);
309 }
310 }
311 #[test]
312 fn test_linear_scales() {
313 let scales = linear_scales(1.0, 10.0, 10);
314 assert_eq!(scales.len(), 10);
315 assert!(approx_eq(scales[0], 1.0, TOL));
316 assert!(approx_eq(scales[9], 10.0, TOL));
317 }
318 #[test]
319 fn test_detect_ridges() {
320 let signal: Vec<f64> = (0..64)
321 .map(|i| (2.0 * PI * i as f64 / 16.0).sin())
322 .collect();
323 let scales = log_scales(1.0, 10, 0.25);
324 let cwt_result = cwt(&signal, &scales, MotherWavelet::morlet(), 1.0);
325 let scalo = scalogram(&cwt_result);
326 let ridges = detect_ridges(&scalo);
327 assert!(
328 ridges
329 .iter()
330 .all(|&(s, t)| s < scales.len() && t < signal.len())
331 );
332 }
333 #[test]
334 fn test_wavelet_compress() {
335 let signal: Vec<f64> = (0..32).map(|i| (i as f64 * 0.3).sin()).collect();
336 let compressed = wavelet_compress(&signal, WaveletFamily::Haar, 3, 0.5);
337 assert_eq!(compressed.len(), signal.len());
338 }
339 #[test]
340 fn test_compression_ratio() {
341 let signal: Vec<f64> = (0..32).map(|i| (i as f64 * 0.2).sin()).collect();
342 let decomp = dwt(&signal, WaveletFamily::Haar, 3);
343 let ratio = compression_ratio(&decomp);
344 assert!((0.0..=1.0).contains(&ratio));
345 }
346 #[test]
347 fn test_modwt() {
348 let signal: Vec<f64> = (0..32).map(|i| (i as f64 * 0.4).sin()).collect();
349 let result = modwt(&signal, WaveletFamily::Haar, 3);
350 assert_eq!(result.details.len(), 3);
351 assert_eq!(result.details[0].len(), 32);
352 assert_eq!(result.approx.len(), 32);
353 }
354 #[test]
355 fn test_lifting_haar_roundtrip() {
356 let original = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
357 let mut data = original.clone();
358 LiftingHaar::forward(&mut data);
359 LiftingHaar::inverse(&mut data);
360 for (i, (&o, &r)) in original.iter().zip(data.iter()).enumerate() {
361 assert!(
362 approx_eq(o, r, TOL),
363 "Lifting roundtrip mismatch at {}: {} vs {}",
364 i,
365 o,
366 r
367 );
368 }
369 }
370 #[test]
371 fn test_scale_frequency_roundtrip() {
372 let wavelet = MotherWavelet::morlet();
373 let dt = 0.01;
374 let freq = 10.0;
375 let scale = frequency_to_scale(freq, dt, wavelet);
376 let freq_back = scale_to_frequency(scale, dt, wavelet);
377 assert!(
378 approx_eq(freq, freq_back, 0.01),
379 "Scale-freq roundtrip: {} vs {}",
380 freq,
381 freq_back
382 );
383 }
384 #[test]
385 fn test_reconstruction_error() {
386 let a = vec![1.0, 2.0, 3.0];
387 let b = vec![1.0, 2.0, 3.0];
388 assert!(approx_eq(reconstruction_error(&a, &b), 0.0, TOL));
389 }
390 #[test]
391 fn test_reconstruction_snr_perfect() {
392 let a = vec![1.0, 2.0, 3.0];
393 let b = vec![1.0, 2.0, 3.0];
394 assert_eq!(reconstruction_snr(&a, &b), f64::INFINITY);
395 }
396 #[test]
397 fn test_bayes_shrink_denoise() {
398 let signal: Vec<f64> = (0..32).map(|i| (i as f64 * 0.4).sin()).collect();
399 let denoised = bayes_shrink_denoise(&signal, WaveletFamily::Haar, 3, ThresholdMode::Soft);
400 assert_eq!(denoised.len(), signal.len());
401 }
402 #[test]
403 fn test_wavelet_features() {
404 let signal: Vec<f64> = (0..32).map(|i| (i as f64 * 0.2).cos()).collect();
405 let decomp = dwt(&signal, WaveletFamily::Haar, 3);
406 let features = wavelet_features(&decomp);
407 assert_eq!(features.len(), 4);
408 }
409 #[test]
410 fn test_cone_of_influence() {
411 let scales = vec![1.0, 2.0, 4.0];
412 let coi = cone_of_influence(&scales, MotherWavelet::morlet());
413 assert_eq!(coi.len(), 3);
414 assert!(coi[0] < coi[1]);
415 assert!(coi[1] < coi[2]);
416 }
417 #[test]
418 fn test_wavelet_variance() {
419 let signal: Vec<f64> = (0..32).map(|i| (i as f64 * 0.5).sin()).collect();
420 let scales = log_scales(1.0, 4, 0.5);
421 let cwt_result = cwt(&signal, &scales, MotherWavelet::morlet(), 1.0);
422 let var = wavelet_variance(&cwt_result);
423 assert_eq!(var.len(), 4);
424 assert!(var.iter().all(|&v| v >= 0.0));
425 }
426 #[test]
427 fn test_wavelet_power() {
428 let signal: Vec<f64> = (0..16).map(|i| (i as f64 * 0.3).sin()).collect();
429 let scales = log_scales(1.0, 3, 1.0);
430 let cwt_result = cwt(&signal, &scales, MotherWavelet::morlet(), 1.0);
431 let power = wavelet_power(&cwt_result);
432 assert_eq!(power.len(), 3);
433 for row in &power {
434 assert!(row.iter().all(|&p| p >= 0.0));
435 }
436 }
437}