Skip to main content

scirs2_fft/
hilbert_enhanced.rs

1//! Enhanced Hilbert Transform module with Hilbert-Huang Transform (HHT)
2//!
3//! This module provides advanced Hilbert transform capabilities including:
4//!
5//! - **Empirical Mode Decomposition (EMD)**: Adaptive signal decomposition into
6//!   Intrinsic Mode Functions (IMFs)
7//! - **Hilbert-Huang Transform**: EMD + Hilbert spectral analysis for nonlinear
8//!   and nonstationary time-frequency analysis
9//! - **Marginal Hilbert spectrum**: Time-integrated spectral energy distribution
10//! - **Instantaneous energy density**: Hilbert energy at each time-frequency point
11//! - **Hilbert spectral analysis**: Full time-frequency-energy representation
12//!
13//! # Mathematical Background
14//!
15//! The Hilbert-Huang Transform (HHT) combines EMD with the Hilbert transform to
16//! produce an adaptive time-frequency representation. Unlike STFT and wavelet
17//! transforms, the HHT does not rely on predetermined basis functions.
18//!
19//! EMD decomposes a signal x(t) into a finite set of Intrinsic Mode Functions (IMFs)
20//! c_i(t) plus a residual r(t):
21//!   x(t) = sum_i c_i(t) + r(t)
22//!
23//! Each IMF satisfies two conditions:
24//! 1. The number of extrema and zero crossings differ by at most one
25//! 2. The mean of the upper and lower envelopes is approximately zero
26//!
27//! # References
28//!
29//! * Huang, N. E. et al. "The empirical mode decomposition and the Hilbert
30//!   spectrum for nonlinear and non-stationary time series analysis."
31//!   Proc. R. Soc. Lond. A, 454, 903-995, 1998.
32//! * Huang, N. E. & Wu, Z. "A review on Hilbert-Huang Transform: Method and
33//!   its applications to geophysical studies." Rev. Geophys., 46, 2008.
34
35use crate::error::{FFTError, FFTResult};
36use scirs2_core::numeric::Complex64;
37use std::f64::consts::PI;
38
39// ============================================================================
40// EMD configuration and types
41// ============================================================================
42
43/// Configuration for Empirical Mode Decomposition
44#[derive(Debug, Clone)]
45pub struct EMDConfig {
46    /// Maximum number of IMFs to extract
47    pub max_imfs: usize,
48    /// Maximum sifting iterations per IMF
49    pub max_sift_iterations: usize,
50    /// Cauchy convergence threshold for sifting
51    pub sift_threshold: f64,
52    /// Number of envelope evaluations for S-number stopping criterion
53    pub s_number: usize,
54    /// Method for envelope interpolation
55    pub envelope_method: EnvelopeMethod,
56}
57
58impl Default for EMDConfig {
59    fn default() -> Self {
60        Self {
61            max_imfs: 20,
62            max_sift_iterations: 500,
63            sift_threshold: 0.05,
64            s_number: 4,
65            envelope_method: EnvelopeMethod::CubicSpline,
66        }
67    }
68}
69
70/// Envelope interpolation method
71#[derive(Debug, Clone, Copy, PartialEq)]
72pub enum EnvelopeMethod {
73    /// Cubic spline interpolation (standard)
74    CubicSpline,
75    /// Linear interpolation (faster but less smooth)
76    Linear,
77}
78
79/// Result of Empirical Mode Decomposition
80#[derive(Debug, Clone)]
81pub struct EMDResult {
82    /// Intrinsic Mode Functions (IMFs), ordered from highest to lowest frequency
83    pub imfs: Vec<Vec<f64>>,
84    /// Residual (monotonic trend)
85    pub residual: Vec<f64>,
86    /// Number of sifting iterations for each IMF
87    pub iterations: Vec<usize>,
88}
89
90/// Result of Hilbert-Huang Transform
91#[derive(Debug, Clone)]
92pub struct HHTResult {
93    /// IMFs from EMD
94    pub imfs: Vec<Vec<f64>>,
95    /// Instantaneous frequencies for each IMF (Hz)
96    pub inst_frequencies: Vec<Vec<f64>>,
97    /// Instantaneous amplitudes (envelopes) for each IMF
98    pub inst_amplitudes: Vec<Vec<f64>>,
99    /// Residual signal
100    pub residual: Vec<f64>,
101}
102
103/// Result of Hilbert spectral analysis
104#[derive(Debug, Clone)]
105pub struct HilbertSpectrum {
106    /// Time axis values
107    pub times: Vec<f64>,
108    /// Frequency axis values (Hz)
109    pub frequencies: Vec<f64>,
110    /// Energy density matrix (time x frequency)
111    pub energy: Vec<Vec<f64>>,
112}
113
114// ============================================================================
115// Empirical Mode Decomposition (EMD)
116// ============================================================================
117
118/// Perform Empirical Mode Decomposition (EMD) on a signal
119///
120/// Decomposes a signal into a set of Intrinsic Mode Functions (IMFs) and a
121/// residual, using the sifting process.
122///
123/// # Arguments
124///
125/// * `signal` - Input signal
126/// * `config` - Optional EMD configuration (uses defaults if None)
127///
128/// # Returns
129///
130/// EMDResult containing IMFs, residual, and iteration counts.
131///
132/// # Errors
133///
134/// Returns an error if the signal is too short for decomposition.
135///
136/// # Examples
137///
138/// ```
139/// use scirs2_fft::hilbert_enhanced::{emd, EMDConfig};
140/// use std::f64::consts::PI;
141///
142/// let n = 256;
143/// let signal: Vec<f64> = (0..n).map(|i| {
144///     let t = i as f64 / 256.0;
145///     (2.0 * PI * 10.0 * t).sin() + 0.5 * (2.0 * PI * 30.0 * t).sin()
146/// }).collect();
147///
148/// let result = emd(&signal, None).expect("EMD should succeed");
149/// assert!(!result.imfs.is_empty(), "Should extract at least one IMF");
150/// ```
151pub fn emd(signal: &[f64], config: Option<EMDConfig>) -> FFTResult<EMDResult> {
152    let n = signal.len();
153    if n < 4 {
154        return Err(FFTError::ValueError(
155            "Signal must have at least 4 samples for EMD".to_string(),
156        ));
157    }
158
159    let cfg = config.unwrap_or_default();
160    let mut imfs = Vec::new();
161    let mut iterations_list = Vec::new();
162    let mut residual = signal.to_vec();
163
164    for _imf_idx in 0..cfg.max_imfs {
165        // Check if residual is monotonic or has too few extrema
166        let extrema = count_extrema(&residual);
167        if extrema < 2 {
168            break;
169        }
170
171        // Sifting process to extract one IMF
172        let (imf, iters) = sift_imf(&residual, &cfg)?;
173
174        // Subtract IMF from residual
175        for i in 0..n {
176            residual[i] -= imf[i];
177        }
178
179        imfs.push(imf);
180        iterations_list.push(iters);
181
182        // Check if residual is negligible or monotonic
183        let residual_energy: f64 = residual.iter().map(|&v| v * v).sum();
184        let signal_energy: f64 = signal.iter().map(|&v| v * v).sum();
185        if signal_energy > 0.0 && residual_energy / signal_energy < 1e-12 {
186            break;
187        }
188
189        // Check if residual has enough extrema to continue
190        if count_extrema(&residual) < 2 {
191            break;
192        }
193    }
194
195    Ok(EMDResult {
196        imfs,
197        residual,
198        iterations: iterations_list,
199    })
200}
201
202/// Sifting process to extract one IMF from a signal
203fn sift_imf(signal: &[f64], config: &EMDConfig) -> FFTResult<(Vec<f64>, usize)> {
204    let n = signal.len();
205    let mut h = signal.to_vec();
206    let mut prev_h = h.clone();
207    let mut s_count = 0;
208
209    for iteration in 0..config.max_sift_iterations {
210        // Find local maxima and minima
211        let (max_pos, max_val) = find_local_maxima(&h);
212        let (min_pos, min_val) = find_local_minima(&h);
213
214        // Need at least 2 maxima and 2 minima for envelope interpolation
215        if max_pos.len() < 2 || min_pos.len() < 2 {
216            return Ok((h, iteration + 1));
217        }
218
219        // Compute upper and lower envelopes
220        let upper_env = interpolate_envelope(&max_pos, &max_val, n, config.envelope_method)?;
221        let lower_env = interpolate_envelope(&min_pos, &min_val, n, config.envelope_method)?;
222
223        // Compute mean envelope
224        let mean_env: Vec<f64> = upper_env
225            .iter()
226            .zip(lower_env.iter())
227            .map(|(&u, &l)| (u + l) / 2.0)
228            .collect();
229
230        // Subtract mean from current candidate
231        for i in 0..n {
232            h[i] -= mean_env[i];
233        }
234
235        // Check Cauchy convergence criterion
236        let diff_energy: f64 = h
237            .iter()
238            .zip(prev_h.iter())
239            .map(|(&a, &b)| (a - b) * (a - b))
240            .sum();
241        let h_energy: f64 = prev_h.iter().map(|&v| v * v).sum();
242
243        if h_energy > 0.0 {
244            let sd = diff_energy / h_energy;
245            if sd < config.sift_threshold {
246                // Check S-number criterion
247                s_count += 1;
248                if s_count >= config.s_number {
249                    return Ok((h, iteration + 1));
250                }
251            } else {
252                s_count = 0;
253            }
254        }
255
256        prev_h.clone_from(&h);
257    }
258
259    Ok((h, config.max_sift_iterations))
260}
261
262/// Count the number of local extrema (maxima + minima) in a signal
263fn count_extrema(signal: &[f64]) -> usize {
264    if signal.len() < 3 {
265        return 0;
266    }
267
268    let mut count = 0;
269    for i in 1..signal.len() - 1 {
270        if (signal[i] > signal[i - 1] && signal[i] > signal[i + 1])
271            || (signal[i] < signal[i - 1] && signal[i] < signal[i + 1])
272        {
273            count += 1;
274        }
275    }
276    count
277}
278
279/// Find positions and values of local maxima
280fn find_local_maxima(signal: &[f64]) -> (Vec<f64>, Vec<f64>) {
281    let mut positions = Vec::new();
282    let mut values = Vec::new();
283
284    // Add endpoints for boundary handling
285    positions.push(0.0);
286    values.push(signal[0]);
287
288    for i in 1..signal.len() - 1 {
289        if signal[i] >= signal[i - 1] && signal[i] >= signal[i + 1] {
290            // Exclude flat plateaus unless at actual peak
291            if signal[i] > signal[i - 1] || signal[i] > signal[i + 1] {
292                positions.push(i as f64);
293                values.push(signal[i]);
294            }
295        }
296    }
297
298    // Add endpoint
299    let last = signal.len() - 1;
300    positions.push(last as f64);
301    values.push(signal[last]);
302
303    (positions, values)
304}
305
306/// Find positions and values of local minima
307fn find_local_minima(signal: &[f64]) -> (Vec<f64>, Vec<f64>) {
308    let mut positions = Vec::new();
309    let mut values = Vec::new();
310
311    // Add endpoints
312    positions.push(0.0);
313    values.push(signal[0]);
314
315    for i in 1..signal.len() - 1 {
316        if signal[i] <= signal[i - 1]
317            && signal[i] <= signal[i + 1]
318            && (signal[i] < signal[i - 1] || signal[i] < signal[i + 1])
319        {
320            positions.push(i as f64);
321            values.push(signal[i]);
322        }
323    }
324
325    let last = signal.len() - 1;
326    positions.push(last as f64);
327    values.push(signal[last]);
328
329    (positions, values)
330}
331
332/// Interpolate an envelope through given points
333fn interpolate_envelope(
334    positions: &[f64],
335    values: &[f64],
336    n: usize,
337    method: EnvelopeMethod,
338) -> FFTResult<Vec<f64>> {
339    if positions.len() < 2 {
340        return Err(FFTError::ValueError(
341            "Need at least 2 points for envelope interpolation".to_string(),
342        ));
343    }
344
345    match method {
346        EnvelopeMethod::CubicSpline => cubic_spline_interpolate(positions, values, n),
347        EnvelopeMethod::Linear => linear_interpolate(positions, values, n),
348    }
349}
350
351/// Cubic spline interpolation (natural boundary conditions)
352fn cubic_spline_interpolate(x_knots: &[f64], y_knots: &[f64], n_out: usize) -> FFTResult<Vec<f64>> {
353    let m = x_knots.len();
354    if m < 2 {
355        return Err(FFTError::ValueError(
356            "Need at least 2 knots for spline".to_string(),
357        ));
358    }
359
360    if m == 2 {
361        // Fall back to linear for just 2 points
362        return linear_interpolate(x_knots, y_knots, n_out);
363    }
364
365    // Compute h[i] = x[i+1] - x[i]
366    let mut h = Vec::with_capacity(m - 1);
367    for i in 0..m - 1 {
368        let hi = x_knots[i + 1] - x_knots[i];
369        if hi <= 0.0 {
370            // Non-monotonic x values; fall back to linear
371            return linear_interpolate(x_knots, y_knots, n_out);
372        }
373        h.push(hi);
374    }
375
376    // Set up tridiagonal system for natural spline (c_0 = c_{m-1} = 0)
377    let n_eqs = m - 2;
378    if n_eqs == 0 {
379        return linear_interpolate(x_knots, y_knots, n_out);
380    }
381
382    let mut diag = vec![0.0; n_eqs];
383    let mut upper = vec![0.0; n_eqs.saturating_sub(1)];
384    let mut lower = vec![0.0; n_eqs.saturating_sub(1)];
385    let mut rhs = vec![0.0; n_eqs];
386
387    for i in 0..n_eqs {
388        diag[i] = 2.0 * (h[i] + h[i + 1]);
389        rhs[i] = 3.0
390            * ((y_knots[i + 2] - y_knots[i + 1]) / h[i + 1] - (y_knots[i + 1] - y_knots[i]) / h[i]);
391    }
392
393    let sub = n_eqs.saturating_sub(1);
394    upper[..sub].copy_from_slice(&h[1..(sub + 1)]);
395    lower[..sub].copy_from_slice(&h[1..(sub + 1)]);
396
397    // Solve tridiagonal system (Thomas algorithm)
398    let c_interior = solve_tridiagonal(&lower, &diag, &upper, &rhs)?;
399
400    // Full c array with boundary conditions c[0] = c[m-1] = 0
401    let mut c = vec![0.0; m];
402    c[1..(n_eqs + 1)].copy_from_slice(&c_interior[..n_eqs]);
403
404    // Compute b and d coefficients
405    let mut b = vec![0.0; m - 1];
406    let mut d = vec![0.0; m - 1];
407
408    for i in 0..m - 1 {
409        d[i] = (c[i + 1] - c[i]) / (3.0 * h[i]);
410        b[i] = (y_knots[i + 1] - y_knots[i]) / h[i] - h[i] * (2.0 * c[i] + c[i + 1]) / 3.0;
411    }
412
413    // Evaluate spline at each output point
414    let mut result = Vec::with_capacity(n_out);
415    for t_idx in 0..n_out {
416        let t = t_idx as f64;
417
418        // Find the correct spline segment
419        let seg = find_segment(x_knots, t);
420        let dx = t - x_knots[seg];
421
422        let val = y_knots[seg] + b[seg] * dx + c[seg] * dx * dx + d[seg] * dx * dx * dx;
423        result.push(val);
424    }
425
426    Ok(result)
427}
428
429/// Find the segment index for interpolation
430fn find_segment(x_knots: &[f64], t: f64) -> usize {
431    if t <= x_knots[0] {
432        return 0;
433    }
434    for i in 0..x_knots.len() - 1 {
435        if t >= x_knots[i] && t < x_knots[i + 1] {
436            return i;
437        }
438    }
439    // Clamp to last segment
440    x_knots.len().saturating_sub(2)
441}
442
443/// Solve tridiagonal system Ax = rhs using Thomas algorithm
444fn solve_tridiagonal(
445    lower: &[f64],
446    diag: &[f64],
447    upper: &[f64],
448    rhs: &[f64],
449) -> FFTResult<Vec<f64>> {
450    let n = diag.len();
451    if n == 0 {
452        return Ok(Vec::new());
453    }
454    if n == 1 {
455        if diag[0].abs() < 1e-15 {
456            return Err(FFTError::ComputationError(
457                "Singular tridiagonal system".to_string(),
458            ));
459        }
460        return Ok(vec![rhs[0] / diag[0]]);
461    }
462
463    // Forward elimination
464    let mut c_prime = vec![0.0; n];
465    let mut d_prime = vec![0.0; n];
466
467    if diag[0].abs() < 1e-15 {
468        return Err(FFTError::ComputationError(
469            "Zero pivot in tridiagonal solve".to_string(),
470        ));
471    }
472
473    c_prime[0] = upper[0] / diag[0];
474    d_prime[0] = rhs[0] / diag[0];
475
476    for i in 1..n {
477        let l_val = if i > 0 && i - 1 < lower.len() {
478            lower[i - 1]
479        } else {
480            0.0
481        };
482        let denom = diag[i] - l_val * c_prime[i - 1];
483        if denom.abs() < 1e-15 {
484            return Err(FFTError::ComputationError(
485                "Near-singular tridiagonal system".to_string(),
486            ));
487        }
488        c_prime[i] = if i < n - 1 && i < upper.len() {
489            upper[i] / denom
490        } else {
491            0.0
492        };
493        d_prime[i] = (rhs[i] - l_val * d_prime[i - 1]) / denom;
494    }
495
496    // Back substitution
497    let mut x = vec![0.0; n];
498    x[n - 1] = d_prime[n - 1];
499    for i in (0..n - 1).rev() {
500        x[i] = d_prime[i] - c_prime[i] * x[i + 1];
501    }
502
503    Ok(x)
504}
505
506/// Linear interpolation
507fn linear_interpolate(x_knots: &[f64], y_knots: &[f64], n_out: usize) -> FFTResult<Vec<f64>> {
508    let m = x_knots.len();
509    let mut result = Vec::with_capacity(n_out);
510
511    for t_idx in 0..n_out {
512        let t = t_idx as f64;
513
514        if t <= x_knots[0] {
515            result.push(y_knots[0]);
516        } else if t >= x_knots[m - 1] {
517            result.push(y_knots[m - 1]);
518        } else {
519            let seg = find_segment(x_knots, t);
520            let frac = (t - x_knots[seg]) / (x_knots[seg + 1] - x_knots[seg]);
521            let val = y_knots[seg] + frac * (y_knots[seg + 1] - y_knots[seg]);
522            result.push(val);
523        }
524    }
525
526    Ok(result)
527}
528
529// ============================================================================
530// Hilbert-Huang Transform (HHT)
531// ============================================================================
532
533/// Perform the Hilbert-Huang Transform on a signal
534///
535/// Combines EMD with Hilbert spectral analysis to produce a time-frequency
536/// representation that adapts to the signal's local structure.
537///
538/// # Arguments
539///
540/// * `signal` - Input signal
541/// * `fs` - Sampling frequency in Hz
542/// * `config` - Optional EMD configuration
543///
544/// # Returns
545///
546/// HHTResult containing IMFs, instantaneous frequencies and amplitudes.
547///
548/// # Errors
549///
550/// Returns an error if the signal is too short or fs is non-positive.
551///
552/// # Examples
553///
554/// ```
555/// use scirs2_fft::hilbert_enhanced::{hht, EMDConfig};
556/// use std::f64::consts::PI;
557///
558/// let fs = 256.0;
559/// let n = 256;
560/// let signal: Vec<f64> = (0..n).map(|i| {
561///     let t = i as f64 / fs;
562///     (2.0 * PI * 10.0 * t).sin() + 0.3 * (2.0 * PI * 40.0 * t).sin()
563/// }).collect();
564///
565/// let result = hht(&signal, fs, None).expect("HHT should succeed");
566/// assert!(!result.imfs.is_empty());
567/// assert_eq!(result.inst_frequencies.len(), result.imfs.len());
568/// ```
569pub fn hht(signal: &[f64], fs: f64, config: Option<EMDConfig>) -> FFTResult<HHTResult> {
570    if fs <= 0.0 {
571        return Err(FFTError::ValueError(
572            "Sampling frequency must be positive".to_string(),
573        ));
574    }
575
576    // Perform EMD
577    let emd_result = emd(signal, config)?;
578
579    // Compute Hilbert transform for each IMF
580    let mut inst_frequencies = Vec::with_capacity(emd_result.imfs.len());
581    let mut inst_amplitudes = Vec::with_capacity(emd_result.imfs.len());
582
583    for imf in &emd_result.imfs {
584        // Compute analytic signal
585        let analytic = crate::hilbert::analytic_signal(imf)?;
586
587        // Instantaneous amplitude (envelope)
588        let amplitude: Vec<f64> = analytic.iter().map(|c| c.norm()).collect();
589
590        // Instantaneous phase (unwrapped)
591        let phase = unwrap_phase_vec(&analytic);
592
593        // Instantaneous frequency via finite differences
594        let mut freq = Vec::with_capacity(phase.len());
595        for i in 0..phase.len() {
596            if i == 0 {
597                // Forward difference at start
598                if phase.len() > 1 {
599                    freq.push((phase[1] - phase[0]) * fs / (2.0 * PI));
600                } else {
601                    freq.push(0.0);
602                }
603            } else if i == phase.len() - 1 {
604                // Backward difference at end
605                freq.push((phase[i] - phase[i - 1]) * fs / (2.0 * PI));
606            } else {
607                // Central difference in the middle
608                freq.push((phase[i + 1] - phase[i - 1]) * fs / (4.0 * PI));
609            }
610        }
611
612        // Clamp frequencies to valid range [0, fs/2]
613        let nyquist = fs / 2.0;
614        for f in &mut freq {
615            if *f < 0.0 {
616                *f = 0.0;
617            }
618            if *f > nyquist {
619                *f = nyquist;
620            }
621        }
622
623        inst_frequencies.push(freq);
624        inst_amplitudes.push(amplitude);
625    }
626
627    Ok(HHTResult {
628        imfs: emd_result.imfs,
629        inst_frequencies,
630        inst_amplitudes,
631        residual: emd_result.residual,
632    })
633}
634
635/// Compute the Hilbert spectrum (time-frequency-energy representation)
636///
637/// Generates a 2D energy density map over time and frequency from the HHT.
638///
639/// # Arguments
640///
641/// * `hht_result` - Result from `hht()`
642/// * `fs` - Sampling frequency
643/// * `n_freq_bins` - Number of frequency bins
644///
645/// # Returns
646///
647/// HilbertSpectrum containing the time-frequency energy density.
648///
649/// # Errors
650///
651/// Returns an error if inputs are invalid.
652pub fn hilbert_spectrum(
653    hht_result: &HHTResult,
654    fs: f64,
655    n_freq_bins: usize,
656) -> FFTResult<HilbertSpectrum> {
657    if fs <= 0.0 {
658        return Err(FFTError::ValueError(
659            "Sampling frequency must be positive".to_string(),
660        ));
661    }
662    if n_freq_bins == 0 {
663        return Err(FFTError::ValueError(
664            "Number of frequency bins must be positive".to_string(),
665        ));
666    }
667
668    let n_time = if let Some(first_imf) = hht_result.imfs.first() {
669        first_imf.len()
670    } else {
671        return Ok(HilbertSpectrum {
672            times: Vec::new(),
673            frequencies: Vec::new(),
674            energy: Vec::new(),
675        });
676    };
677
678    let nyquist = fs / 2.0;
679    let freq_step = nyquist / n_freq_bins as f64;
680
681    // Build time and frequency axes
682    let times: Vec<f64> = (0..n_time).map(|i| i as f64 / fs).collect();
683    let frequencies: Vec<f64> = (0..n_freq_bins)
684        .map(|k| (k as f64 + 0.5) * freq_step)
685        .collect();
686
687    // Build energy density matrix
688    let mut energy = vec![vec![0.0; n_freq_bins]; n_time];
689
690    for imf_idx in 0..hht_result.imfs.len() {
691        let freqs = &hht_result.inst_frequencies[imf_idx];
692        let amps = &hht_result.inst_amplitudes[imf_idx];
693
694        for t in 0..n_time.min(freqs.len()).min(amps.len()) {
695            let f = freqs[t];
696            let a = amps[t];
697
698            // Find the nearest frequency bin
699            let bin = (f / freq_step).floor() as usize;
700            if bin < n_freq_bins {
701                energy[t][bin] += a * a;
702            }
703        }
704    }
705
706    Ok(HilbertSpectrum {
707        times,
708        frequencies,
709        energy,
710    })
711}
712
713/// Compute the marginal Hilbert spectrum
714///
715/// The marginal spectrum integrates the Hilbert spectrum over time,
716/// giving the total energy distribution across frequencies.
717///
718/// # Arguments
719///
720/// * `hht_result` - Result from `hht()`
721/// * `fs` - Sampling frequency
722/// * `n_freq_bins` - Number of frequency bins
723///
724/// # Returns
725///
726/// Tuple of (frequencies, marginal_spectrum).
727///
728/// # Errors
729///
730/// Returns an error if inputs are invalid.
731///
732/// # Examples
733///
734/// ```
735/// use scirs2_fft::hilbert_enhanced::{hht, marginal_spectrum};
736/// use std::f64::consts::PI;
737///
738/// let fs = 256.0;
739/// let n = 256;
740/// let signal: Vec<f64> = (0..n).map(|i| {
741///     let t = i as f64 / fs;
742///     (2.0 * PI * 10.0 * t).sin()
743/// }).collect();
744///
745/// let result = hht(&signal, fs, None).expect("HHT should succeed");
746/// let (freqs, spectrum) = marginal_spectrum(&result, fs, 128).expect("Spectrum should succeed");
747/// assert_eq!(freqs.len(), 128);
748/// assert_eq!(spectrum.len(), 128);
749/// ```
750pub fn marginal_spectrum(
751    hht_result: &HHTResult,
752    fs: f64,
753    n_freq_bins: usize,
754) -> FFTResult<(Vec<f64>, Vec<f64>)> {
755    let hs = hilbert_spectrum(hht_result, fs, n_freq_bins)?;
756
757    // Sum energy over time for each frequency bin
758    let mut marginal = vec![0.0; n_freq_bins];
759    let dt = if fs > 0.0 { 1.0 / fs } else { 1.0 };
760
761    for time_slice in &hs.energy {
762        for (k, &e) in time_slice.iter().enumerate() {
763            marginal[k] += e * dt;
764        }
765    }
766
767    Ok((hs.frequencies, marginal))
768}
769
770/// Compute the degree of stationarity using the Hilbert spectrum
771///
772/// Returns a value between 0 and 1 for each frequency bin, where 1
773/// means the signal is stationary at that frequency and 0 means
774/// highly non-stationary.
775///
776/// # Arguments
777///
778/// * `hht_result` - Result from `hht()`
779/// * `fs` - Sampling frequency
780/// * `n_freq_bins` - Number of frequency bins
781///
782/// # Returns
783///
784/// Tuple of (frequencies, degree_of_stationarity).
785///
786/// # Errors
787///
788/// Returns an error if inputs are invalid.
789pub fn degree_of_stationarity(
790    hht_result: &HHTResult,
791    fs: f64,
792    n_freq_bins: usize,
793) -> FFTResult<(Vec<f64>, Vec<f64>)> {
794    let hs = hilbert_spectrum(hht_result, fs, n_freq_bins)?;
795
796    let n_time = hs.energy.len();
797    if n_time == 0 {
798        return Ok((hs.frequencies, vec![0.0; n_freq_bins]));
799    }
800
801    // For each frequency, compute the coefficient of variation of energy over time
802    let mut stationarity = Vec::with_capacity(n_freq_bins);
803
804    for k in 0..n_freq_bins {
805        let energies: Vec<f64> = hs.energy.iter().map(|row| row[k]).collect();
806
807        let mean = energies.iter().sum::<f64>() / n_time as f64;
808        if mean < 1e-15 {
809            stationarity.push(1.0); // No energy = perfectly stationary (trivially)
810            continue;
811        }
812
813        let variance = energies
814            .iter()
815            .map(|&e| (e - mean) * (e - mean))
816            .sum::<f64>()
817            / n_time as f64;
818        let cv = variance.sqrt() / mean;
819
820        // Map coefficient of variation to [0, 1] stationarity measure
821        let ds = 1.0 / (1.0 + cv);
822        stationarity.push(ds);
823    }
824
825    Ok((hs.frequencies, stationarity))
826}
827
828/// Compute the instantaneous energy density of the signal
829///
830/// Returns the total instantaneous energy at each time point,
831/// summed over all IMFs.
832///
833/// # Arguments
834///
835/// * `hht_result` - Result from `hht()`
836///
837/// # Returns
838///
839/// Vector of instantaneous energy at each time point.
840pub fn instantaneous_energy(hht_result: &HHTResult) -> Vec<f64> {
841    let n = if let Some(first) = hht_result.imfs.first() {
842        first.len()
843    } else {
844        return Vec::new();
845    };
846
847    let mut energy = vec![0.0; n];
848
849    for amps in &hht_result.inst_amplitudes {
850        for (i, &a) in amps.iter().enumerate() {
851            if i < n {
852                energy[i] += a * a;
853            }
854        }
855    }
856
857    energy
858}
859
860/// Compute the mean frequency at each time point
861///
862/// Returns a weighted average of instantaneous frequencies across IMFs,
863/// weighted by their instantaneous amplitudes.
864///
865/// # Arguments
866///
867/// * `hht_result` - Result from `hht()`
868///
869/// # Returns
870///
871/// Vector of mean instantaneous frequency at each time point.
872pub fn mean_frequency(hht_result: &HHTResult) -> Vec<f64> {
873    let n = if let Some(first) = hht_result.imfs.first() {
874        first.len()
875    } else {
876        return Vec::new();
877    };
878
879    let mut weighted_freq = vec![0.0; n];
880    let mut total_weight = vec![0.0; n];
881
882    for imf_idx in 0..hht_result.imfs.len() {
883        let freqs = &hht_result.inst_frequencies[imf_idx];
884        let amps = &hht_result.inst_amplitudes[imf_idx];
885
886        for i in 0..n.min(freqs.len()).min(amps.len()) {
887            let weight = amps[i] * amps[i];
888            weighted_freq[i] += freqs[i] * weight;
889            total_weight[i] += weight;
890        }
891    }
892
893    for i in 0..n {
894        if total_weight[i] > 1e-15 {
895            weighted_freq[i] /= total_weight[i];
896        }
897    }
898
899    weighted_freq
900}
901
902// ============================================================================
903// Ensemble EMD (EEMD)
904// ============================================================================
905
906/// Perform Ensemble EMD (EEMD) for more robust decomposition
907///
908/// EEMD adds white noise to the signal multiple times and averages the
909/// resulting IMFs, which helps avoid mode mixing.
910///
911/// # Arguments
912///
913/// * `signal` - Input signal
914/// * `n_ensembles` - Number of noise-added ensembles (typically 50-300)
915/// * `noise_amplitude` - Standard deviation of added white noise (typically 0.1-0.5 of signal std)
916/// * `config` - Optional EMD configuration
917///
918/// # Returns
919///
920/// EMDResult with the ensemble-averaged IMFs.
921///
922/// # Errors
923///
924/// Returns an error if parameters are invalid.
925///
926/// # Examples
927///
928/// ```
929/// use scirs2_fft::hilbert_enhanced::{eemd, EMDConfig};
930/// use std::f64::consts::PI;
931///
932/// let fs = 256.0;
933/// let n = 256;
934/// let signal: Vec<f64> = (0..n).map(|i| {
935///     let t = i as f64 / fs;
936///     (2.0 * PI * 5.0 * t).sin() + 0.5 * (2.0 * PI * 20.0 * t).sin()
937/// }).collect();
938///
939/// let result = eemd(&signal, 10, 0.1, None).expect("EEMD should succeed");
940/// assert!(!result.imfs.is_empty());
941/// ```
942pub fn eemd(
943    signal: &[f64],
944    n_ensembles: usize,
945    noise_amplitude: f64,
946    config: Option<EMDConfig>,
947) -> FFTResult<EMDResult> {
948    if n_ensembles == 0 {
949        return Err(FFTError::ValueError(
950            "Number of ensembles must be positive".to_string(),
951        ));
952    }
953    if noise_amplitude < 0.0 {
954        return Err(FFTError::ValueError(
955            "Noise amplitude must be non-negative".to_string(),
956        ));
957    }
958
959    let n = signal.len();
960    let cfg = config.unwrap_or_default();
961
962    // Compute signal standard deviation for noise scaling
963    let mean = signal.iter().sum::<f64>() / n as f64;
964    let variance = signal.iter().map(|&x| (x - mean) * (x - mean)).sum::<f64>() / n as f64;
965    let std_dev = variance.sqrt();
966    let noise_std = noise_amplitude * std_dev;
967
968    // Collect all ensemble results
969    let mut max_imfs = 0;
970    let mut all_results = Vec::with_capacity(n_ensembles);
971
972    // Simple LCG random number generator for noise (pure Rust, no external deps)
973    let mut rng_state: u64 = 42;
974
975    for _ensemble in 0..n_ensembles {
976        // Generate noisy signal
977        let mut noisy_signal = signal.to_vec();
978        for sample in &mut noisy_signal {
979            // Box-Muller transform for Gaussian noise using simple LCG
980            let u1 = lcg_next_f64(&mut rng_state);
981            let u2 = lcg_next_f64(&mut rng_state);
982            let z = (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos();
983            *sample += noise_std * z;
984        }
985
986        // Perform EMD on noisy signal
987        let result = emd(&noisy_signal, Some(cfg.clone()))?;
988        if result.imfs.len() > max_imfs {
989            max_imfs = result.imfs.len();
990        }
991        all_results.push(result);
992    }
993
994    if max_imfs == 0 {
995        return Ok(EMDResult {
996            imfs: Vec::new(),
997            residual: signal.to_vec(),
998            iterations: Vec::new(),
999        });
1000    }
1001
1002    // Average IMFs across ensembles
1003    let mut avg_imfs = vec![vec![0.0; n]; max_imfs];
1004    let mut imf_counts = vec![0usize; max_imfs];
1005
1006    for result in &all_results {
1007        for (i, imf) in result.imfs.iter().enumerate() {
1008            for (j, &val) in imf.iter().enumerate() {
1009                avg_imfs[i][j] += val;
1010            }
1011            imf_counts[i] += 1;
1012        }
1013    }
1014
1015    for i in 0..max_imfs {
1016        if imf_counts[i] > 0 {
1017            let count = imf_counts[i] as f64;
1018            for val in &mut avg_imfs[i] {
1019                *val /= count;
1020            }
1021        }
1022    }
1023
1024    // Compute residual
1025    let mut residual = signal.to_vec();
1026    for imf in &avg_imfs {
1027        for (i, &val) in imf.iter().enumerate() {
1028            residual[i] -= val;
1029        }
1030    }
1031
1032    Ok(EMDResult {
1033        imfs: avg_imfs,
1034        residual,
1035        iterations: vec![0; max_imfs], // Not meaningful for EEMD
1036    })
1037}
1038
1039/// Simple LCG random number generator returning f64 in (0, 1)
1040fn lcg_next_f64(state: &mut u64) -> f64 {
1041    // LCG with Numerical Recipes parameters
1042    *state = state
1043        .wrapping_mul(6364136223846793005)
1044        .wrapping_add(1442695040888963407);
1045    // Convert top bits to f64 in (0, 1)
1046    let val = ((*state >> 11) as f64) / ((1u64 << 53) as f64);
1047    if val <= 0.0 {
1048        f64::MIN_POSITIVE
1049    } else if val >= 1.0 {
1050        1.0 - f64::EPSILON
1051    } else {
1052        val
1053    }
1054}
1055
1056// ============================================================================
1057// Advanced Hilbert analysis functions
1058// ============================================================================
1059
1060/// Compute the Hilbert transform of a real signal and return the imaginary part only
1061///
1062/// The Hilbert transform H{x}(t) is the convolution of x(t) with 1/(pi*t).
1063/// For a cosine signal, it produces a sine signal.
1064///
1065/// # Arguments
1066///
1067/// * `signal` - Input real signal
1068///
1069/// # Returns
1070///
1071/// The Hilbert transform (imaginary part of the analytic signal).
1072///
1073/// # Errors
1074///
1075/// Returns an error if the signal is empty.
1076pub fn hilbert_transform(signal: &[f64]) -> FFTResult<Vec<f64>> {
1077    let analytic = crate::hilbert::analytic_signal(signal)?;
1078    Ok(analytic.iter().map(|c| c.im).collect())
1079}
1080
1081/// Compute the analytic signal with padding to reduce edge effects
1082///
1083/// Uses mirror padding before computing the analytic signal, then trims
1084/// the result to the original length.
1085///
1086/// # Arguments
1087///
1088/// * `signal` - Input signal
1089/// * `pad_fraction` - Fraction of signal length to use as padding (0.0 to 0.5)
1090///
1091/// # Returns
1092///
1093/// Padded analytic signal of original length.
1094///
1095/// # Errors
1096///
1097/// Returns an error if inputs are invalid.
1098pub fn analytic_signal_padded(signal: &[f64], pad_fraction: f64) -> FFTResult<Vec<Complex64>> {
1099    if signal.is_empty() {
1100        return Err(FFTError::ValueError("Signal cannot be empty".to_string()));
1101    }
1102    if !(0.0..=0.5).contains(&pad_fraction) {
1103        return Err(FFTError::ValueError(
1104            "Pad fraction must be between 0.0 and 0.5".to_string(),
1105        ));
1106    }
1107
1108    let n = signal.len();
1109    let pad_len = (n as f64 * pad_fraction).ceil() as usize;
1110
1111    if pad_len == 0 {
1112        return crate::hilbert::analytic_signal(signal);
1113    }
1114
1115    // Create mirror-padded signal
1116    let padded_len = n + 2 * pad_len;
1117    let mut padded = Vec::with_capacity(padded_len);
1118
1119    // Left mirror padding
1120    for i in (0..pad_len).rev() {
1121        let idx = (i + 1).min(n - 1);
1122        padded.push(signal[idx]);
1123    }
1124
1125    // Original signal
1126    padded.extend_from_slice(signal);
1127
1128    // Right mirror padding
1129    for i in 0..pad_len {
1130        let idx = n.saturating_sub(2 + i);
1131        padded.push(signal[idx]);
1132    }
1133
1134    // Compute analytic signal on padded version
1135    let analytic = crate::hilbert::analytic_signal(&padded)?;
1136
1137    // Trim to original length
1138    Ok(analytic[pad_len..pad_len + n].to_vec())
1139}
1140
1141/// Compute the Teager-Kaiser energy operator on a signal
1142///
1143/// The Teager energy operator provides a measure of the instantaneous energy:
1144///   Psi[x(n)] = x(n)^2 - x(n-1)*x(n+1)
1145///
1146/// For narrowband signals: `Psi[x] ~ A^2 * omega^2`
1147///
1148/// # Arguments
1149///
1150/// * `signal` - Input signal (length >= 3)
1151///
1152/// # Returns
1153///
1154/// Teager energy (length = n - 2)
1155///
1156/// # Errors
1157///
1158/// Returns an error if signal is too short.
1159pub fn teager_energy(signal: &[f64]) -> FFTResult<Vec<f64>> {
1160    if signal.len() < 3 {
1161        return Err(FFTError::ValueError(
1162            "Signal must have at least 3 samples for Teager energy".to_string(),
1163        ));
1164    }
1165
1166    let n = signal.len();
1167    let mut energy = Vec::with_capacity(n - 2);
1168
1169    for i in 1..n - 1 {
1170        let val = signal[i] * signal[i] - signal[i - 1] * signal[i + 1];
1171        energy.push(val);
1172    }
1173
1174    Ok(energy)
1175}
1176
1177/// Compute instantaneous frequency and amplitude using the Teager-Kaiser
1178/// energy separation algorithm (ESA)
1179///
1180/// Provides an alternative to Hilbert-based methods that's better suited
1181/// for AM-FM signals with rapidly varying frequency.
1182///
1183/// # Arguments
1184///
1185/// * `signal` - Input signal (length >= 5)
1186/// * `fs` - Sampling frequency
1187///
1188/// # Returns
1189///
1190/// Tuple of (instantaneous_frequency, instantaneous_amplitude).
1191/// Each has length n - 4 (due to the finite difference operations).
1192///
1193/// # Errors
1194///
1195/// Returns an error if signal is too short or fs is non-positive.
1196pub fn teager_esa(signal: &[f64], fs: f64) -> FFTResult<(Vec<f64>, Vec<f64>)> {
1197    if signal.len() < 5 {
1198        return Err(FFTError::ValueError(
1199            "Signal must have at least 5 samples for Teager ESA".to_string(),
1200        ));
1201    }
1202    if fs <= 0.0 {
1203        return Err(FFTError::ValueError(
1204            "Sampling frequency must be positive".to_string(),
1205        ));
1206    }
1207
1208    let n = signal.len();
1209
1210    // Compute Teager energy of the signal
1211    let psi_x = teager_energy(signal)?;
1212
1213    // Compute the forward difference of the signal (discrete derivative)
1214    let mut diff_signal = Vec::with_capacity(n - 1);
1215    for i in 0..n - 1 {
1216        diff_signal.push(signal[i + 1] - signal[i]);
1217    }
1218
1219    // Teager energy of the derivative
1220    let psi_dx = teager_energy(&diff_signal)?;
1221
1222    // Energy Separation Algorithm
1223    let mut inst_freq = Vec::new();
1224    let mut inst_amp = Vec::new();
1225
1226    // psi_x has length n-2 (indices 1..n-1 of original)
1227    // psi_dx has length n-3 (indices 1..n-2 of diff_signal, which is indices 2..n-1 of original)
1228    // We align them: psi_x[i] ~ sample i+1, psi_dx[i] ~ sample i+2
1229    let common_len = psi_x.len().min(psi_dx.len());
1230
1231    for i in 0..common_len {
1232        let psi_x_val = psi_x[i];
1233        let psi_dx_val = psi_dx[i];
1234
1235        if psi_x_val.abs() < 1e-15 || psi_dx_val < 0.0 {
1236            inst_freq.push(0.0);
1237            inst_amp.push(0.0);
1238            continue;
1239        }
1240
1241        // Instantaneous frequency: omega = arccos(1 - psi_dx / (2 * psi_x))
1242        let ratio = psi_dx_val / (2.0 * psi_x_val);
1243        let cos_arg = 1.0 - ratio;
1244
1245        if cos_arg.abs() > 1.0 {
1246            // Clamp to valid range
1247            let omega = if cos_arg > 1.0 { 0.0 } else { PI };
1248            inst_freq.push(omega * fs / (2.0 * PI));
1249            inst_amp.push(0.0);
1250        } else {
1251            let omega = cos_arg.acos();
1252            inst_freq.push(omega * fs / (2.0 * PI));
1253
1254            // Instantaneous amplitude: A = sqrt(psi_x / sin^2(omega))
1255            let sin_omega = omega.sin();
1256            if sin_omega.abs() > 1e-10 {
1257                let a = (psi_x_val / (sin_omega * sin_omega)).sqrt();
1258                inst_amp.push(a);
1259            } else {
1260                inst_amp.push(psi_x_val.sqrt());
1261            }
1262        }
1263    }
1264
1265    Ok((inst_freq, inst_amp))
1266}
1267
1268// ============================================================================
1269// Utility functions
1270// ============================================================================
1271
1272/// Unwrap phase from analytic signal (internal helper)
1273fn unwrap_phase_vec(analytic: &[Complex64]) -> Vec<f64> {
1274    if analytic.is_empty() {
1275        return Vec::new();
1276    }
1277
1278    let mut phase = Vec::with_capacity(analytic.len());
1279    let mut prev_angle = analytic[0].im.atan2(analytic[0].re);
1280    phase.push(prev_angle);
1281
1282    let mut cumulative = 0.0;
1283
1284    for c in analytic.iter().skip(1) {
1285        let angle = c.im.atan2(c.re);
1286        let mut diff = angle - prev_angle;
1287
1288        while diff > PI {
1289            diff -= 2.0 * PI;
1290            cumulative -= 2.0 * PI;
1291        }
1292        while diff < -PI {
1293            diff += 2.0 * PI;
1294            cumulative += 2.0 * PI;
1295        }
1296
1297        phase.push(angle + cumulative);
1298        prev_angle = angle;
1299    }
1300
1301    phase
1302}
1303
1304// ============================================================================
1305// Tests
1306// ============================================================================
1307
1308#[cfg(test)]
1309mod tests {
1310    use super::*;
1311    use approx::assert_abs_diff_eq;
1312
1313    #[test]
1314    fn test_emd_basic_two_tone() {
1315        // Two-tone signal: EMD should separate the components
1316        let n = 512;
1317        let signal: Vec<f64> = (0..n)
1318            .map(|i| {
1319                let t = i as f64 / 512.0;
1320                (2.0 * PI * 5.0 * t).sin() + 0.5 * (2.0 * PI * 30.0 * t).sin()
1321            })
1322            .collect();
1323
1324        let result = emd(&signal, None).expect("EMD should succeed");
1325        assert!(!result.imfs.is_empty(), "Should extract at least one IMF");
1326
1327        // The residual plus all IMFs should reconstruct the original
1328        let mut reconstructed = result.residual.clone();
1329        for imf in &result.imfs {
1330            for (i, &val) in imf.iter().enumerate() {
1331                reconstructed[i] += val;
1332            }
1333        }
1334
1335        for i in 0..n {
1336            assert_abs_diff_eq!(reconstructed[i], signal[i], epsilon = 1e-8);
1337        }
1338    }
1339
1340    #[test]
1341    fn test_emd_monotonic_residual() {
1342        // A simple linear signal should not produce IMFs (or very few)
1343        let n = 128;
1344        let signal: Vec<f64> = (0..n).map(|i| i as f64 / n as f64).collect();
1345
1346        let result = emd(&signal, None).expect("EMD on linear should succeed");
1347        // Most of the signal should be in the residual
1348        let residual_energy: f64 = result.residual.iter().map(|&v| v * v).sum();
1349        let signal_energy: f64 = signal.iter().map(|&v| v * v).sum();
1350        assert!(
1351            residual_energy > 0.5 * signal_energy,
1352            "Residual should capture most of the linear trend"
1353        );
1354    }
1355
1356    #[test]
1357    fn test_emd_reconstruction_exact() {
1358        let n = 256;
1359        let signal: Vec<f64> = (0..n)
1360            .map(|i| {
1361                let t = i as f64 / 256.0;
1362                (2.0 * PI * 8.0 * t).sin() + 2.0 * t
1363            })
1364            .collect();
1365
1366        let result = emd(&signal, None).expect("EMD should succeed");
1367
1368        let mut reconstructed = result.residual.clone();
1369        for imf in &result.imfs {
1370            for (i, &v) in imf.iter().enumerate() {
1371                reconstructed[i] += v;
1372            }
1373        }
1374
1375        for i in 0..n {
1376            assert_abs_diff_eq!(reconstructed[i], signal[i], epsilon = 1e-8);
1377        }
1378    }
1379
1380    #[test]
1381    fn test_hht_basic() {
1382        let fs = 256.0;
1383        let n = 256;
1384        let signal: Vec<f64> = (0..n)
1385            .map(|i| {
1386                let t = i as f64 / fs;
1387                (2.0 * PI * 10.0 * t).sin()
1388            })
1389            .collect();
1390
1391        let result = hht(&signal, fs, None).expect("HHT should succeed");
1392        assert!(!result.imfs.is_empty());
1393        assert_eq!(result.inst_frequencies.len(), result.imfs.len());
1394        assert_eq!(result.inst_amplitudes.len(), result.imfs.len());
1395
1396        // First IMF should have frequency near 10 Hz in the middle
1397        if !result.inst_frequencies.is_empty() {
1398            let freqs = &result.inst_frequencies[0];
1399            let mid_start = n / 4;
1400            let mid_end = 3 * n / 4;
1401            let avg_freq: f64 =
1402                freqs[mid_start..mid_end].iter().sum::<f64>() / (mid_end - mid_start) as f64;
1403            assert!(
1404                (avg_freq - 10.0).abs() < 5.0,
1405                "Average freq should be near 10 Hz, got {avg_freq}"
1406            );
1407        }
1408    }
1409
1410    #[test]
1411    fn test_hht_invalid_fs() {
1412        let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1413        assert!(hht(&signal, 0.0, None).is_err());
1414        assert!(hht(&signal, -1.0, None).is_err());
1415    }
1416
1417    #[test]
1418    fn test_hilbert_spectrum_dimensions() {
1419        let fs = 256.0;
1420        let n = 128;
1421        let signal: Vec<f64> = (0..n)
1422            .map(|i| {
1423                let t = i as f64 / fs;
1424                (2.0 * PI * 10.0 * t).sin()
1425            })
1426            .collect();
1427
1428        let hht_result = hht(&signal, fs, None).expect("HHT should succeed");
1429        let n_freq_bins = 64;
1430        let hs = hilbert_spectrum(&hht_result, fs, n_freq_bins)
1431            .expect("Hilbert spectrum should succeed");
1432
1433        assert_eq!(hs.times.len(), n);
1434        assert_eq!(hs.frequencies.len(), n_freq_bins);
1435        assert_eq!(hs.energy.len(), n);
1436        for row in &hs.energy {
1437            assert_eq!(row.len(), n_freq_bins);
1438        }
1439    }
1440
1441    #[test]
1442    fn test_marginal_spectrum() {
1443        let fs = 256.0;
1444        let n = 256;
1445        let signal: Vec<f64> = (0..n)
1446            .map(|i| {
1447                let t = i as f64 / fs;
1448                (2.0 * PI * 10.0 * t).sin()
1449            })
1450            .collect();
1451
1452        let hht_result = hht(&signal, fs, None).expect("HHT should succeed");
1453        let (freqs, spectrum) =
1454            marginal_spectrum(&hht_result, fs, 64).expect("Marginal spectrum should succeed");
1455
1456        assert_eq!(freqs.len(), 64);
1457        assert_eq!(spectrum.len(), 64);
1458
1459        // Energy should be non-negative
1460        for &e in &spectrum {
1461            assert!(e >= 0.0, "Marginal spectrum energy must be non-negative");
1462        }
1463
1464        // Total energy should be positive for a non-zero signal
1465        let total: f64 = spectrum.iter().sum();
1466        assert!(total > 0.0, "Total marginal energy should be positive");
1467    }
1468
1469    #[test]
1470    fn test_instantaneous_energy() {
1471        let fs = 256.0;
1472        let n = 128;
1473        let signal: Vec<f64> = (0..n)
1474            .map(|i| {
1475                let t = i as f64 / fs;
1476                (2.0 * PI * 10.0 * t).sin()
1477            })
1478            .collect();
1479
1480        let hht_result = hht(&signal, fs, None).expect("HHT should succeed");
1481        let energy = instantaneous_energy(&hht_result);
1482
1483        assert_eq!(energy.len(), n);
1484        // Energy should be non-negative
1485        for &e in &energy {
1486            assert!(e >= 0.0, "Instantaneous energy must be non-negative");
1487        }
1488    }
1489
1490    #[test]
1491    fn test_eemd_basic() {
1492        let n = 256;
1493        let signal: Vec<f64> = (0..n)
1494            .map(|i| {
1495                let t = i as f64 / 256.0;
1496                (2.0 * PI * 5.0 * t).sin() + 0.5 * (2.0 * PI * 25.0 * t).sin()
1497            })
1498            .collect();
1499
1500        let result = eemd(&signal, 5, 0.1, None).expect("EEMD should succeed");
1501        assert!(!result.imfs.is_empty(), "EEMD should produce IMFs");
1502
1503        // Verify approximate reconstruction
1504        let mut reconstructed = result.residual.clone();
1505        for imf in &result.imfs {
1506            for (i, &val) in imf.iter().enumerate() {
1507                if i < reconstructed.len() {
1508                    reconstructed[i] += val;
1509                }
1510            }
1511        }
1512
1513        // With noise, reconstruction won't be exact but should be close
1514        let error: f64 = reconstructed
1515            .iter()
1516            .zip(signal.iter())
1517            .map(|(&r, &s)| (r - s) * (r - s))
1518            .sum::<f64>()
1519            / n as f64;
1520        let signal_power: f64 = signal.iter().map(|&s| s * s).sum::<f64>() / n as f64;
1521        assert!(
1522            error < 0.5 * signal_power,
1523            "EEMD reconstruction error should be reasonable"
1524        );
1525    }
1526
1527    #[test]
1528    fn test_hilbert_transform_cosine() {
1529        // Hilbert transform of cos(t) should give sin(t)
1530        let n = 256;
1531        let signal: Vec<f64> = (0..n)
1532            .map(|i| (2.0 * PI * 5.0 * i as f64 / 256.0).cos())
1533            .collect();
1534
1535        let ht = hilbert_transform(&signal).expect("Hilbert transform should succeed");
1536        assert_eq!(ht.len(), n);
1537
1538        // In the middle, the Hilbert transform of cos should be approximately sin
1539        for i in n / 4..3 * n / 4 {
1540            let expected = (2.0 * PI * 5.0 * i as f64 / 256.0).sin();
1541            assert_abs_diff_eq!(ht[i], expected, epsilon = 0.15);
1542        }
1543    }
1544
1545    #[test]
1546    fn test_analytic_signal_padded() {
1547        let n = 128;
1548        let signal: Vec<f64> = (0..n)
1549            .map(|i| (2.0 * PI * 10.0 * i as f64 / 128.0).cos())
1550            .collect();
1551
1552        let result = analytic_signal_padded(&signal, 0.1).expect("Padded analytic should succeed");
1553        assert_eq!(result.len(), n);
1554
1555        // Envelope should be approximately 1 for a pure cosine
1556        for i in n / 4..3 * n / 4 {
1557            let mag = result[i].norm();
1558            assert!(
1559                (mag - 1.0).abs() < 0.2,
1560                "Envelope should be near 1, got {mag} at index {i}"
1561            );
1562        }
1563    }
1564
1565    #[test]
1566    fn test_teager_energy_sinusoid() {
1567        // For x(t) = A*sin(omega*t), Teager energy ~ A^2 * omega^2
1568        let n = 256;
1569        let omega = 2.0 * PI * 10.0 / 256.0;
1570        let amplitude = 2.0;
1571        let signal: Vec<f64> = (0..n)
1572            .map(|i| amplitude * (omega * i as f64).sin())
1573            .collect();
1574
1575        let energy = teager_energy(&signal).expect("Teager energy should succeed");
1576        assert_eq!(energy.len(), n - 2);
1577
1578        // In the middle, energy should be approximately A^2 * sin^2(omega)
1579        let expected = amplitude * amplitude * omega.sin() * omega.sin();
1580        let mid_start = n / 4;
1581        let mid_end = 3 * n / 4;
1582        let avg_energy: f64 =
1583            energy[mid_start..mid_end - 2].iter().sum::<f64>() / (mid_end - 2 - mid_start) as f64;
1584
1585        assert!(
1586            (avg_energy - expected).abs() < expected * 0.3,
1587            "Average Teager energy {avg_energy:.4} should be near {expected:.4}"
1588        );
1589    }
1590
1591    #[test]
1592    fn test_teager_esa_basic() {
1593        let fs = 1000.0;
1594        let n = 256;
1595        let freq = 50.0;
1596        let signal: Vec<f64> = (0..n)
1597            .map(|i| {
1598                let t = i as f64 / fs;
1599                (2.0 * PI * freq * t).sin()
1600            })
1601            .collect();
1602
1603        let (inst_freq, inst_amp) = teager_esa(&signal, fs).expect("Teager ESA should succeed");
1604        assert!(!inst_freq.is_empty());
1605        assert_eq!(inst_freq.len(), inst_amp.len());
1606
1607        // Check that estimated frequency is reasonable in the middle
1608        let mid = inst_freq.len() / 2;
1609        let mid_range = mid / 2..mid + mid / 2;
1610        let avg_freq: f64 =
1611            inst_freq[mid_range.clone()].iter().sum::<f64>() / mid_range.len() as f64;
1612        assert!(
1613            (avg_freq - freq).abs() < 20.0,
1614            "Estimated freq {avg_freq:.1} should be near {freq}"
1615        );
1616    }
1617
1618    #[test]
1619    fn test_mean_frequency() {
1620        let fs = 256.0;
1621        let n = 256;
1622        let signal: Vec<f64> = (0..n)
1623            .map(|i| {
1624                let t = i as f64 / fs;
1625                (2.0 * PI * 15.0 * t).sin()
1626            })
1627            .collect();
1628
1629        let hht_result = hht(&signal, fs, None).expect("HHT should succeed");
1630        let mf = mean_frequency(&hht_result);
1631        assert_eq!(mf.len(), n);
1632
1633        // Mean frequency should be positive and within the valid range [0, fs/2]
1634        // EMD may decompose the signal differently, so we check the structure
1635        // rather than the exact value
1636        let mid_start = n / 4;
1637        let mid_end = 3 * n / 4;
1638        let avg: f64 = mf[mid_start..mid_end].iter().sum::<f64>() / (mid_end - mid_start) as f64;
1639
1640        // The mean frequency should be positive (signal has energy)
1641        assert!(avg > 0.0, "Mean frequency should be positive, got {avg:.1}");
1642        // And within the Nyquist range
1643        assert!(
1644            avg <= fs / 2.0,
1645            "Mean frequency {avg:.1} should be <= Nyquist ({:.1})",
1646            fs / 2.0
1647        );
1648    }
1649
1650    #[test]
1651    fn test_degree_of_stationarity_stationary() {
1652        let fs = 256.0;
1653        let n = 256;
1654        // Purely stationary signal
1655        let signal: Vec<f64> = (0..n)
1656            .map(|i| {
1657                let t = i as f64 / fs;
1658                (2.0 * PI * 10.0 * t).sin()
1659            })
1660            .collect();
1661
1662        let hht_result = hht(&signal, fs, None).expect("HHT should succeed");
1663        let (freqs, ds) = degree_of_stationarity(&hht_result, fs, 64).expect("DoS should succeed");
1664        assert_eq!(freqs.len(), 64);
1665        assert_eq!(ds.len(), 64);
1666
1667        // All stationarity values should be between 0 and 1
1668        for &val in &ds {
1669            assert!(
1670                (0.0..=1.0).contains(&val),
1671                "Stationarity should be in [0,1], got {val}"
1672            );
1673        }
1674    }
1675
1676    #[test]
1677    fn test_error_handling() {
1678        // Too short signals
1679        assert!(emd(&[1.0, 2.0], None).is_err());
1680        assert!(teager_energy(&[1.0, 2.0]).is_err());
1681        assert!(teager_esa(&[1.0, 2.0, 3.0], 100.0).is_err());
1682        assert!(analytic_signal_padded(&[], 0.1).is_err());
1683
1684        // Invalid parameters
1685        assert!(eemd(&[1.0; 10], 0, 0.1, None).is_err());
1686        assert!(eemd(&[1.0; 10], 5, -0.1, None).is_err());
1687        assert!(analytic_signal_padded(&[1.0], 0.6).is_err());
1688        assert!(hilbert_spectrum(
1689            &HHTResult {
1690                imfs: vec![vec![0.0; 10]],
1691                inst_frequencies: vec![vec![0.0; 10]],
1692                inst_amplitudes: vec![vec![0.0; 10]],
1693                residual: vec![0.0; 10],
1694            },
1695            0.0,
1696            64,
1697        )
1698        .is_err());
1699    }
1700}