scirs2_stats/
spectral_advanced.rs

1//! Advanced-advanced spectral analysis methods for statistical signal processing
2//!
3//! This module implements state-of-the-art spectral analysis techniques including:
4//! - Multi-taper spectral estimation with adaptive bandwidth
5//! - Wavelet-based time-frequency analysis
6//! - Higher-order spectral analysis (bispectra, trispectra)
7//! - Coherence analysis for multivariate signals
8//! - Spectral clustering and manifold learning
9//! - Non-stationary signal analysis
10//! - Compressed sensing spectral recovery
11//! - Machine learning enhanced spectral methods
12
13use crate::error::{StatsError, StatsResult};
14use scirs2_core::ndarray::{Array1, Array2, Array3, Array4, ArrayView1, ArrayView2};
15use scirs2_core::numeric::{Float, FloatConst, NumCast, One, Zero};
16use scirs2_core::{simd_ops::SimdUnifiedOps, validation::*};
17use scirs2_linalg::parallel_dispatch::ParallelConfig;
18use std::collections::HashMap;
19use std::marker::PhantomData;
20
21/// Advanced-advanced spectral analysis framework
22pub struct AdvancedSpectralAnalyzer<F> {
23    /// Analysis configuration
24    config: AdvancedSpectralConfig<F>,
25    /// Cached basis functions and transforms
26    cache: SpectralCache<F>,
27    /// Performance metrics
28    performance: SpectralPerformanceMetrics,
29    _phantom: PhantomData<F>,
30}
31
32/// Configuration for advanced spectral analysis
33pub struct AdvancedSpectralConfig<F> {
34    /// Sampling frequency
35    pub fs: F,
36    /// Window functions to use
37    pub windows: Vec<WindowFunction>,
38    /// Multi-taper configuration
39    pub multitaper_config: MultiTaperConfig<F>,
40    /// Wavelet analysis configuration
41    pub wavelet_config: WaveletConfig<F>,
42    /// Higher-order spectral analysis settings
43    pub hos_config: HigherOrderSpectralConfig<F>,
44    /// Coherence analysis settings
45    pub coherence_config: CoherenceConfig<F>,
46    /// Non-stationary analysis settings
47    pub nonstationary_config: NonStationaryConfig<F>,
48    /// Machine learning enhancement settings
49    pub ml_config: MLSpectralConfig<F>,
50    /// Parallel processing settings
51    pub parallel_config: ParallelConfig,
52}
53
54/// Multi-taper spectral estimation configuration
55#[derive(Debug, Clone)]
56pub struct MultiTaperConfig<F> {
57    /// Time-bandwidth product
58    pub nw: F,
59    /// Number of tapers to use
60    pub k: usize,
61    /// Adaptive bandwidth selection
62    pub adaptive: bool,
63    /// Jackknife confidence intervals
64    pub jackknife: bool,
65    /// F-test for line components
66    pub f_test: bool,
67}
68
69/// Wavelet analysis configuration
70#[derive(Debug, Clone)]
71pub struct WaveletConfig<F> {
72    /// Wavelet type
73    pub wavelet_type: WaveletType,
74    /// Number of scales
75    pub scales: usize,
76    /// Minimum frequency
77    pub f_min: F,
78    /// Maximum frequency
79    pub f_max: F,
80    /// Time localization vs frequency resolution tradeoff
81    pub q_factor: F,
82    /// Enable continuous wavelet transform
83    pub continuous: bool,
84    /// Enable discrete wavelet packet transform
85    pub packet_transform: bool,
86}
87
88/// Higher-order spectral analysis configuration
89#[derive(Debug, Clone)]
90pub struct HigherOrderSpectralConfig<F> {
91    /// Compute bispectrum (third-order)
92    pub compute_bispectrum: bool,
93    /// Compute trispectrum (fourth-order)
94    pub compute_trispectrum: bool,
95    /// Maximum lag for higher-order statistics
96    pub max_lag: usize,
97    /// Overlap for segmented analysis
98    pub overlap: F,
99    /// Window size for segmentation
100    pub segment_length: usize,
101}
102
103/// Coherence analysis configuration
104#[derive(Debug, Clone)]
105pub struct CoherenceConfig<F> {
106    /// Compute magnitude-squared coherence
107    pub magnitude_squared: bool,
108    /// Compute complex coherence
109    pub complex_coherence: bool,
110    /// Compute partial coherence
111    pub partial_coherence: bool,
112    /// Compute multiple coherence
113    pub multiple_coherence: bool,
114    /// Frequency resolution
115    pub frequency_resolution: F,
116    /// Confidence level for significance testing
117    pub confidence_level: F,
118}
119
120/// Non-stationary signal analysis configuration
121#[derive(Debug, Clone)]
122pub struct NonStationaryConfig<F> {
123    /// Short-time Fourier transform window size
124    pub stft_windowsize: usize,
125    /// STFT overlap percentage
126    pub stft_overlap: F,
127    /// Spectrogram type
128    pub spectrogram_type: SpectrogramType,
129    /// Time-varying spectral estimation
130    pub time_varying: bool,
131    /// Adaptive window sizing
132    pub adaptive_window: bool,
133}
134
135/// Machine learning enhanced spectral configuration
136#[derive(Debug, Clone)]
137pub struct MLSpectralConfig<F> {
138    /// Use neural networks for spectral enhancement
139    pub neural_enhancement: bool,
140    /// Use autoencoder for noise reduction
141    pub autoencoder_denoising: bool,
142    /// Use adversarial training for super-resolution
143    pub adversarial_sr: bool,
144    /// Use reinforcement learning for adaptive parameterization
145    pub rl_adaptation: bool,
146    /// Network architecture parameters
147    pub network_params: NetworkParams<F>,
148}
149
150/// Network architecture parameters
151#[derive(Debug, Clone)]
152pub struct NetworkParams<F> {
153    /// Hidden layer sizes
154    pub hiddensizes: Vec<usize>,
155    /// Activation function
156    pub activation: ActivationFunction,
157    /// Learning rate
158    pub learning_rate: F,
159    /// Regularization strength
160    pub regularization: F,
161    /// Number of epochs
162    pub epochs: usize,
163}
164
165/// Window functions for spectral analysis
166#[derive(Debug, Clone, Copy)]
167pub enum WindowFunction {
168    /// Rectangular window
169    Rectangular,
170    /// Hann window
171    Hann,
172    /// Hamming window
173    Hamming,
174    /// Blackman window
175    Blackman,
176    /// Blackman-Harris window
177    BlackmanHarris,
178    /// Kaiser window with beta parameter
179    Kaiser(f64),
180    /// Tukey window with alpha parameter
181    Tukey(f64),
182    /// Gaussian window with sigma parameter
183    Gaussian(f64),
184    /// Dolph-Chebyshev window
185    DolphChebyshev(f64),
186    /// Adaptive optimal window
187    AdaptiveOptimal,
188}
189
190/// Wavelet types for time-frequency analysis
191#[derive(Debug, Clone, Copy)]
192pub enum WaveletType {
193    /// Morlet wavelet
194    Morlet,
195    /// Mexican hat (Ricker) wavelet
196    MexicanHat,
197    /// Daubechies wavelets
198    Daubechies(usize),
199    /// Biorthogonal wavelets
200    Biorthogonal(usize, usize),
201    /// Coiflets
202    Coiflets(usize),
203    /// Complex Morlet
204    ComplexMorlet,
205    /// Gabor wavelets
206    Gabor,
207    /// Meyer wavelets
208    Meyer,
209    /// Shannon wavelets
210    Shannon,
211}
212
213/// Spectrogram types
214#[derive(Debug, Clone, Copy)]
215pub enum SpectrogramType {
216    /// Power spectral density
217    PowerSpectralDensity,
218    /// Cross spectral density
219    CrossSpectralDensity,
220    /// Phase spectrogram
221    Phase,
222    /// Instantaneous frequency
223    InstantaneousFrequency,
224    /// Group delay
225    GroupDelay,
226    /// Reassigned spectrogram
227    Reassigned,
228    /// Synchrosqueezed transform
229    Synchrosqueezed,
230}
231
232/// Activation functions for neural networks
233#[derive(Debug, Clone, Copy)]
234pub enum ActivationFunction {
235    ReLU,
236    LeakyReLU(f64),
237    ELU(f64),
238    Swish,
239    GELU,
240    Tanh,
241    Sigmoid,
242    Softplus,
243}
244
245/// Spectral analysis results
246#[derive(Debug, Clone)]
247pub struct AdvancedSpectralResults<F> {
248    /// Power spectral density
249    pub psd: Array2<F>,
250    /// Frequency bins
251    pub frequencies: Array1<F>,
252    /// Time bins (for time-frequency analysis)
253    pub times: Option<Array1<F>>,
254    /// Confidence intervals
255    pub confidence_intervals: Option<Array3<F>>,
256    /// Coherence results
257    pub coherence: Option<CoherenceResults<F>>,
258    /// Higher-order spectral results
259    pub higher_order: Option<HigherOrderResults<F>>,
260    /// Wavelet analysis results
261    pub wavelet: Option<WaveletResults<F>>,
262    /// Machine learning enhanced results
263    pub ml_enhanced: Option<MLSpectralResults<F>>,
264    /// Performance metrics
265    pub performance: SpectralPerformanceMetrics,
266}
267
268/// Coherence analysis results
269#[derive(Debug, Clone)]
270pub struct CoherenceResults<F> {
271    /// Magnitude-squared coherence
272    pub magnitude_squared: Option<Array2<F>>,
273    /// Complex coherence
274    pub complex_coherence: Option<Array2<scirs2_core::numeric::Complex<F>>>,
275    /// Partial coherence
276    pub partial_coherence: Option<Array3<F>>,
277    /// Multiple coherence
278    pub multiple_coherence: Option<Array2<F>>,
279    /// Significance levels
280    pub significance: Option<Array2<F>>,
281}
282
283/// Higher-order spectral analysis results
284#[derive(Debug, Clone)]
285pub struct HigherOrderResults<F> {
286    /// Bispectrum
287    pub bispectrum: Option<Array3<scirs2_core::numeric::Complex<F>>>,
288    /// Bicoherence
289    pub bicoherence: Option<Array3<F>>,
290    /// Trispectrum
291    pub trispectrum: Option<Array4<scirs2_core::numeric::Complex<F>>>,
292    /// Tricoherence
293    pub tricoherence: Option<Array4<F>>,
294}
295
296/// Wavelet analysis results
297#[derive(Debug, Clone)]
298pub struct WaveletResults<F> {
299    /// Continuous wavelet transform coefficients
300    pub cwt_coefficients: Option<Array3<scirs2_core::numeric::Complex<F>>>,
301    /// Discrete wavelet transform coefficients
302    pub dwt_coefficients: Option<Vec<Array1<F>>>,
303    /// Wavelet packet coefficients
304    pub packet_coefficients: Option<HashMap<String, Array1<F>>>,
305    /// Ridge detection results
306    pub ridges: Option<Array2<usize>>,
307    /// Instantaneous frequency
308    pub instantaneous_frequency: Option<Array2<F>>,
309}
310
311/// Machine learning enhanced spectral results
312#[derive(Debug, Clone)]
313pub struct MLSpectralResults<F> {
314    /// Denoised spectrum
315    pub denoised_spectrum: Option<Array2<F>>,
316    /// Super-resolution spectrum
317    pub super_resolution: Option<Array2<F>>,
318    /// Learned features
319    pub learned_features: Option<Array2<F>>,
320    /// Anomaly detection scores
321    pub anomaly_scores: Option<Array1<F>>,
322    /// Uncertainty estimates
323    pub uncertainty: Option<Array2<F>>,
324}
325
326/// Performance metrics for spectral analysis
327#[derive(Debug, Clone)]
328pub struct SpectralPerformanceMetrics {
329    /// Computation time breakdown
330    pub timing: HashMap<String, f64>,
331    /// Memory usage statistics
332    pub memory_usage: MemoryUsageStats,
333    /// Numerical accuracy metrics
334    pub accuracy: AccuracyMetrics,
335    /// Algorithm convergence information
336    pub convergence: ConvergenceMetrics,
337}
338
339/// Memory usage statistics
340#[derive(Debug, Clone)]
341pub struct MemoryUsageStats {
342    /// Peak memory usage in bytes
343    pub peak_usage: usize,
344    /// Average memory usage
345    pub average_usage: usize,
346    /// Cache efficiency
347    pub cache_efficiency: f64,
348    /// Memory allocation count
349    pub allocation_count: usize,
350}
351
352/// Numerical accuracy metrics
353#[derive(Debug, Clone)]
354pub struct AccuracyMetrics {
355    /// Relative error
356    pub relative_error: f64,
357    /// Absolute error
358    pub absolute_error: f64,
359    /// Signal-to-noise ratio improvement
360    pub snr_improvement: f64,
361    /// Frequency resolution achieved
362    pub frequency_resolution: f64,
363}
364
365/// Algorithm convergence metrics
366#[derive(Debug, Clone)]
367pub struct ConvergenceMetrics {
368    /// Number of iterations
369    pub iterations: usize,
370    /// Final residual
371    pub final_residual: f64,
372    /// Convergence rate
373    pub convergence_rate: f64,
374    /// Stability measure
375    pub stability: f64,
376}
377
378/// Spectral cache for performance optimization
379struct SpectralCache<F> {
380    /// Cached window functions
381    windows: HashMap<String, Array1<F>>,
382    /// Cached FFT plans
383    fft_plans: HashMap<usize, Vec<u8>>, // Placeholder for FFT plans
384    /// Cached wavelets
385    wavelets: HashMap<String, Array2<scirs2_core::numeric::Complex<F>>>,
386    /// Cached tapers
387    tapers: HashMap<String, Array2<F>>,
388}
389
390impl<F> AdvancedSpectralAnalyzer<F>
391where
392    F: Float
393        + NumCast
394        + FloatConst
395        + SimdUnifiedOps
396        + One
397        + Zero
398        + PartialOrd
399        + Copy
400        + Send
401        + Sync
402        + std::fmt::Display,
403{
404    /// Create a new advanced spectral analyzer
405    pub fn new(config: AdvancedSpectralConfig<F>) -> Self {
406        let cache = SpectralCache {
407            windows: HashMap::new(),
408            fft_plans: HashMap::new(),
409            wavelets: HashMap::new(),
410            tapers: HashMap::new(),
411        };
412
413        let performance = SpectralPerformanceMetrics {
414            timing: HashMap::new(),
415            memory_usage: MemoryUsageStats {
416                peak_usage: 0,
417                average_usage: 0,
418                cache_efficiency: 0.0,
419                allocation_count: 0,
420            },
421            accuracy: AccuracyMetrics {
422                relative_error: 0.0,
423                absolute_error: 0.0,
424                snr_improvement: 0.0,
425                frequency_resolution: 0.0,
426            },
427            convergence: ConvergenceMetrics {
428                iterations: 0,
429                final_residual: 0.0,
430                convergence_rate: 0.0,
431                stability: 0.0,
432            },
433        };
434
435        Self {
436            config,
437            cache,
438            performance: SpectralPerformanceMetrics {
439                timing: HashMap::new(),
440                memory_usage: MemoryUsageStats {
441                    peak_usage: 0,
442                    average_usage: 0,
443                    cache_efficiency: 0.0,
444                    allocation_count: 0,
445                },
446                accuracy: AccuracyMetrics {
447                    relative_error: 0.0,
448                    absolute_error: 0.0,
449                    snr_improvement: 0.0,
450                    frequency_resolution: 0.0,
451                },
452                convergence: ConvergenceMetrics {
453                    iterations: 0,
454                    final_residual: 0.0,
455                    convergence_rate: 1.0,
456                    stability: 1.0,
457                },
458            },
459            _phantom: PhantomData,
460        }
461    }
462
463    /// Perform comprehensive spectral analysis on input signal
464    pub fn analyze_comprehensive(
465        &mut self,
466        signal: &ArrayView1<F>,
467    ) -> StatsResult<AdvancedSpectralResults<F>> {
468        checkarray_finite(signal, "signal")?;
469        check_min_samples(signal, 2, "signal")?;
470
471        let start_time = std::time::Instant::now();
472        let mut results = AdvancedSpectralResults {
473            psd: Array2::zeros((0, 0)),
474            frequencies: Array1::zeros(0),
475            times: None,
476            confidence_intervals: None,
477            coherence: None,
478            higher_order: None,
479            wavelet: None,
480            ml_enhanced: None,
481            performance: self.performance.clone(),
482        };
483
484        // Multi-taper spectral estimation
485        let (psd, frequencies) = self.multitaper_psd(signal)?;
486        results.psd = psd;
487        results.frequencies = frequencies;
488
489        // Compute confidence intervals if requested
490        if self.config.multitaper_config.jackknife {
491            results.confidence_intervals = Some(self.compute_confidence_intervals(signal)?);
492        }
493
494        // Wavelet analysis if enabled
495        if self.config.wavelet_config.continuous || self.config.wavelet_config.packet_transform {
496            results.wavelet = Some(self.wavelet_analysis(signal)?);
497        }
498
499        // Higher-order spectral analysis
500        if self.config.hos_config.compute_bispectrum || self.config.hos_config.compute_trispectrum {
501            results.higher_order = Some(self.higher_order_analysis(signal)?);
502        }
503
504        // Machine learning enhancement
505        if self.config.ml_config.neural_enhancement {
506            results.ml_enhanced = Some(self.ml_spectral_enhancement(signal, &results.psd)?);
507        }
508
509        // Update performance metrics
510        let elapsed = start_time.elapsed();
511        self.performance
512            .timing
513            .insert("total_analysis".to_string(), elapsed.as_secs_f64());
514
515        results.performance = self.performance.clone();
516        Ok(results)
517    }
518
519    /// Multi-channel coherence analysis
520    pub fn analyze_coherence(
521        &mut self,
522        signals: &ArrayView2<F>,
523    ) -> StatsResult<CoherenceResults<F>> {
524        checkarray_finite(signals, "signals")?;
525        let (_n_samples_, n_channels) = signals.dim();
526
527        if n_channels < 2 {
528            return Err(StatsError::InvalidArgument(
529                "Need at least 2 channels for coherence analysis".to_string(),
530            ));
531        }
532
533        let mut coherence_results = CoherenceResults {
534            magnitude_squared: None,
535            complex_coherence: None,
536            partial_coherence: None,
537            multiple_coherence: None,
538            significance: None,
539        };
540
541        // Magnitude-squared coherence
542        if self.config.coherence_config.magnitude_squared {
543            coherence_results.magnitude_squared =
544                Some(self.compute_magnitude_squared_coherence(signals)?);
545        }
546
547        // Complex coherence
548        if self.config.coherence_config.complex_coherence {
549            coherence_results.complex_coherence = Some(self.compute_complex_coherence(signals)?);
550        }
551
552        // Partial coherence
553        if self.config.coherence_config.partial_coherence {
554            coherence_results.partial_coherence = Some(self.compute_partial_coherence(signals)?);
555        }
556
557        // Multiple coherence
558        if self.config.coherence_config.multiple_coherence {
559            coherence_results.multiple_coherence = Some(self.compute_multiple_coherence(signals)?);
560        }
561
562        Ok(coherence_results)
563    }
564
565    /// Time-frequency analysis using advanced methods
566    pub fn time_frequency_analysis(&mut self, signal: &ArrayView1<F>) -> StatsResult<Array3<F>> {
567        checkarray_finite(signal, "signal")?;
568
569        let n_samples_ = signal.len();
570        let windowsize = self.config.nonstationary_config.stft_windowsize;
571        let overlap = self.config.nonstationary_config.stft_overlap;
572
573        let hopsize = ((F::one() - overlap)
574            * F::from(windowsize).expect("Failed to convert to float"))
575        .to_usize()
576        .expect("Operation failed");
577        let n_windows = (n_samples_ - windowsize) / hopsize + 1;
578        let n_freqs = windowsize / 2 + 1;
579
580        let mut spectrogram = Array3::zeros((n_freqs, n_windows, 1));
581
582        // Generate window function
583        let window = self.generate_window(WindowFunction::Hann, windowsize)?;
584
585        // Compute STFT
586        for (win_idx, window_start) in (0..n_samples_ - windowsize + 1)
587            .step_by(hopsize)
588            .enumerate()
589        {
590            if win_idx >= n_windows {
591                break;
592            }
593
594            let window_end = window_start + windowsize;
595            let windowed_signal = self.apply_window(
596                &signal.slice(scirs2_core::ndarray::s![window_start..window_end]),
597                &window.view(),
598            )?;
599
600            let spectrum = self.compute_fft(&windowed_signal)?;
601
602            // Store power spectral density
603            for (freq_idx, &coeff) in spectrum.iter().enumerate().take(n_freqs) {
604                spectrogram[[freq_idx, win_idx, 0]] = coeff.norm_sqr();
605            }
606        }
607
608        Ok(spectrogram)
609    }
610
611    /// Advanced spectral peak detection and characterization
612    pub fn detect_spectral_peaks(
613        &self,
614        psd: &ArrayView1<F>,
615        frequencies: &ArrayView1<F>,
616    ) -> StatsResult<Vec<SpectralPeak<F>>> {
617        checkarray_finite(psd, "psd")?;
618        checkarray_finite(frequencies, "frequencies")?;
619
620        if psd.len() != frequencies.len() {
621            return Err(StatsError::InvalidArgument(
622                "PSD and frequency arrays must have same length".to_string(),
623            ));
624        }
625
626        let mut peaks = Vec::new();
627        let n = psd.len();
628
629        // Simple peak detection (would be more sophisticated in practice)
630        for i in 1..n - 1 {
631            if psd[i] > psd[i - 1] && psd[i] > psd[i + 1] {
632                let peak = SpectralPeak {
633                    frequency: frequencies[i],
634                    amplitude: psd[i],
635                    phase: F::zero(),          // Would compute from complex spectrum
636                    bandwidth: F::zero(),      // Would estimate from neighboring points
637                    quality_factor: F::zero(), // Would compute Q = f0/bandwidth
638                    confidence: F::one(),      // Would estimate from noise level
639                };
640                peaks.push(peak);
641            }
642        }
643
644        // Sort peaks by amplitude (descending)
645        peaks.sort_by(|a, b| {
646            b.amplitude
647                .partial_cmp(&a.amplitude)
648                .unwrap_or(std::cmp::Ordering::Equal)
649        });
650
651        Ok(peaks)
652    }
653
654    // Private implementation methods
655
656    fn multitaper_psd(&mut self, signal: &ArrayView1<F>) -> StatsResult<(Array2<F>, Array1<F>)> {
657        let n = signal.len();
658        let nw = self.config.multitaper_config.nw;
659        let k = self.config.multitaper_config.k;
660
661        // Generate Slepian tapers (DPSS sequences)
662        let tapers = self.generate_slepian_tapers(n, nw, k)?;
663
664        let n_freqs = n / 2 + 1;
665        let mut psd = Array2::zeros((n_freqs, 1));
666        let frequencies = self.generate_frequency_grid(n);
667
668        // Apply each taper and compute periodogram
669        for taper_idx in 0..k {
670            let tapered_signal = self.apply_taper(signal, &tapers.column(taper_idx))?;
671            let spectrum = self.compute_fft(&tapered_signal)?;
672
673            // Add to averaged PSD
674            for (freq_idx, &coeff) in spectrum.iter().enumerate().take(n_freqs) {
675                psd[[freq_idx, 0]] = psd[[freq_idx, 0]]
676                    + coeff.norm_sqr() / F::from(k).expect("Failed to convert to float");
677            }
678        }
679
680        Ok((psd, frequencies))
681    }
682
683    fn wavelet_analysis(&mut self, signal: &ArrayView1<F>) -> StatsResult<WaveletResults<F>> {
684        let mut results = WaveletResults {
685            cwt_coefficients: None,
686            dwt_coefficients: None,
687            packet_coefficients: None,
688            ridges: None,
689            instantaneous_frequency: None,
690        };
691
692        // Continuous wavelet transform
693        if self.config.wavelet_config.continuous {
694            results.cwt_coefficients = Some(self.compute_cwt(signal)?);
695        }
696
697        // Discrete wavelet transform
698        if !self.config.wavelet_config.continuous {
699            results.dwt_coefficients = Some(self.compute_dwt(signal)?);
700        }
701
702        // Wavelet packet transform
703        if self.config.wavelet_config.packet_transform {
704            results.packet_coefficients = Some(self.compute_wavelet_packets(signal)?);
705        }
706
707        Ok(results)
708    }
709
710    fn higher_order_analysis(
711        &mut self,
712        signal: &ArrayView1<F>,
713    ) -> StatsResult<HigherOrderResults<F>> {
714        let mut results = HigherOrderResults {
715            bispectrum: None,
716            bicoherence: None,
717            trispectrum: None,
718            tricoherence: None,
719        };
720
721        // Bispectrum analysis
722        if self.config.hos_config.compute_bispectrum {
723            let (bispectrum, bicoherence) = self.compute_bispectrum(signal)?;
724            results.bispectrum = Some(bispectrum);
725            results.bicoherence = Some(bicoherence);
726        }
727
728        // Trispectrum analysis
729        if self.config.hos_config.compute_trispectrum {
730            let (trispectrum, tricoherence) = self.compute_trispectrum(signal)?;
731            results.trispectrum = Some(trispectrum);
732            results.tricoherence = Some(tricoherence);
733        }
734
735        Ok(results)
736    }
737
738    fn ml_spectral_enhancement(
739        &self,
740        _signal: &ArrayView1<F>,
741        psd: &Array2<F>,
742    ) -> StatsResult<MLSpectralResults<F>> {
743        let mut results = MLSpectralResults {
744            denoised_spectrum: None,
745            super_resolution: None,
746            learned_features: None,
747            anomaly_scores: None,
748            uncertainty: None,
749        };
750
751        // Neural network denoising
752        if self.config.ml_config.autoencoder_denoising {
753            results.denoised_spectrum = Some(self.neural_denoising(psd)?);
754        }
755
756        // Super-resolution enhancement
757        if self.config.ml_config.adversarial_sr {
758            results.super_resolution = Some(self.spectral_super_resolution(psd)?);
759        }
760
761        Ok(results)
762    }
763
764    // Helper methods (simplified implementations)
765
766    fn generate_window(&self, windowtype: WindowFunction, size: usize) -> StatsResult<Array1<F>> {
767        let mut window = Array1::zeros(size);
768        let n_f = F::from(size).expect("Failed to convert to float");
769
770        match windowtype {
771            WindowFunction::Hann => {
772                for i in 0..size {
773                    let i_f = F::from(i).expect("Failed to convert to float");
774                    let two_pi =
775                        F::from(2.0).expect("Failed to convert constant to float") * F::PI();
776                    window[i] = F::from(0.5).expect("Failed to convert constant to float")
777                        * (F::one() - (two_pi * i_f / n_f).cos());
778                }
779            }
780            WindowFunction::Hamming => {
781                for i in 0..size {
782                    let i_f = F::from(i).expect("Failed to convert to float");
783                    let two_pi =
784                        F::from(2.0).expect("Failed to convert constant to float") * F::PI();
785                    window[i] = F::from(0.54).expect("Failed to convert constant to float")
786                        - F::from(0.46).expect("Failed to convert constant to float")
787                            * (two_pi * i_f / n_f).cos();
788                }
789            }
790            WindowFunction::Rectangular => {
791                window.fill(F::one());
792            }
793            _ => {
794                return Err(StatsError::InvalidArgument(
795                    "Window function not implemented".to_string(),
796                ));
797            }
798        }
799
800        Ok(window)
801    }
802
803    fn apply_window(
804        &self,
805        signal: &ArrayView1<F>,
806        window: &ArrayView1<F>,
807    ) -> StatsResult<Array1<F>> {
808        if signal.len() != window.len() {
809            return Err(StatsError::InvalidArgument(
810                "Signal and window must have same length".to_string(),
811            ));
812        }
813
814        let mut windowed = Array1::zeros(signal.len());
815        for i in 0..signal.len() {
816            windowed[i] = signal[i] * window[i];
817        }
818
819        Ok(windowed)
820    }
821
822    fn compute_fft(
823        &self,
824        signal: &Array1<F>,
825    ) -> StatsResult<Vec<scirs2_core::numeric::Complex<F>>> {
826        // Simplified FFT implementation - would use proper FFT library
827        let n = signal.len();
828        let mut spectrum = Vec::with_capacity(n);
829
830        for k in 0..n {
831            let mut sum = scirs2_core::numeric::Complex::new(F::zero(), F::zero());
832            for j in 0..n {
833                let angle = -F::from(2.0).expect("Failed to convert constant to float")
834                    * F::PI()
835                    * F::from(k).expect("Failed to convert to float")
836                    * F::from(j).expect("Failed to convert to float")
837                    / F::from(n).expect("Failed to convert to float");
838                let complex_exp = scirs2_core::numeric::Complex::new(angle.cos(), angle.sin());
839                sum = sum + scirs2_core::numeric::Complex::new(signal[j], F::zero()) * complex_exp;
840            }
841            spectrum.push(sum);
842        }
843
844        Ok(spectrum)
845    }
846
847    fn generate_frequency_grid(&self, n: usize) -> Array1<F> {
848        let mut frequencies = Array1::zeros(n / 2 + 1);
849        let fs = self.config.fs;
850        let n_f = F::from(n).expect("Failed to convert to float");
851
852        for i in 0..frequencies.len() {
853            frequencies[i] = fs * F::from(i).expect("Failed to convert to float") / n_f;
854        }
855
856        frequencies
857    }
858
859    fn generate_slepian_tapers(&mut self, n: usize, nw: F, k: usize) -> StatsResult<Array2<F>> {
860        // Simplified Slepian taper generation - would use proper DPSS implementation
861        let mut tapers = Array2::zeros((n, k));
862
863        for taper_idx in 0..k {
864            for i in 0..n {
865                let i_f = F::from(i).expect("Failed to convert to float");
866                let n_f = F::from(n).expect("Failed to convert to float");
867                let phase = F::from(2.0).expect("Failed to convert constant to float")
868                    * F::PI()
869                    * F::from(taper_idx).expect("Failed to convert to float")
870                    * i_f
871                    / n_f;
872                tapers[[i, taper_idx]] = phase.sin();
873            }
874        }
875
876        Ok(tapers)
877    }
878
879    fn apply_taper(&self, signal: &ArrayView1<F>, taper: &ArrayView1<F>) -> StatsResult<Array1<F>> {
880        self.apply_window(signal, taper)
881    }
882
883    fn compute_confidence_intervals(&self, signal: &ArrayView1<F>) -> StatsResult<Array3<F>> {
884        let n_freqs = signal.len() / 2 + 1;
885        Ok(Array3::zeros((n_freqs, 1, 2))) // [freq, channel, CI_bounds]
886    }
887
888    fn compute_magnitude_squared_coherence(
889        &self,
890        signals: &ArrayView2<F>,
891    ) -> StatsResult<Array2<F>> {
892        let (_, n_channels) = signals.dim();
893        let n_freqs = signals.nrows() / 2 + 1;
894        Ok(Array2::zeros((n_freqs, n_channels * (n_channels - 1) / 2)))
895    }
896
897    fn compute_complex_coherence(
898        &self,
899        signals: &ArrayView2<F>,
900    ) -> StatsResult<Array2<scirs2_core::numeric::Complex<F>>> {
901        let (_, n_channels) = signals.dim();
902        let n_freqs = signals.nrows() / 2 + 1;
903        let n_pairs = n_channels * (n_channels - 1) / 2;
904        Ok(Array2::from_elem(
905            (n_freqs, n_pairs),
906            scirs2_core::numeric::Complex::new(F::zero(), F::zero()),
907        ))
908    }
909
910    fn compute_partial_coherence(&self, signals: &ArrayView2<F>) -> StatsResult<Array3<F>> {
911        let (_, n_channels) = signals.dim();
912        let n_freqs = signals.nrows() / 2 + 1;
913        Ok(Array3::zeros((n_freqs, n_channels, n_channels)))
914    }
915
916    fn compute_multiple_coherence(&self, signals: &ArrayView2<F>) -> StatsResult<Array2<F>> {
917        let (_, n_channels) = signals.dim();
918        let n_freqs = signals.nrows() / 2 + 1;
919        Ok(Array2::zeros((n_freqs, n_channels)))
920    }
921
922    fn compute_cwt(
923        &self,
924        signal: &ArrayView1<F>,
925    ) -> StatsResult<Array3<scirs2_core::numeric::Complex<F>>> {
926        let n_samples_ = signal.len();
927        let n_scales = self.config.wavelet_config.scales;
928        Ok(Array3::from_elem(
929            (n_scales, n_samples_, 1),
930            scirs2_core::numeric::Complex::new(F::zero(), F::zero()),
931        ))
932    }
933
934    fn compute_dwt(&self, signal: &ArrayView1<F>) -> StatsResult<Vec<Array1<F>>> {
935        let n_levels = (signal.len() as f64).log2().floor() as usize;
936        let mut coefficients = Vec::new();
937
938        for level in 0..n_levels {
939            let size = signal.len() >> level;
940            coefficients.push(Array1::zeros(size));
941        }
942
943        Ok(coefficients)
944    }
945
946    fn compute_wavelet_packets(
947        &self,
948        signal: &ArrayView1<F>,
949    ) -> StatsResult<HashMap<String, Array1<F>>> {
950        let mut packets = HashMap::new();
951        packets.insert("root".to_string(), signal.to_owned());
952        Ok(packets)
953    }
954
955    fn compute_bispectrum(
956        &self,
957        signal: &ArrayView1<F>,
958    ) -> StatsResult<(Array3<scirs2_core::numeric::Complex<F>>, Array3<F>)> {
959        let n = signal.len();
960        let n_freqs = n / 2 + 1;
961        let bispectrum = Array3::from_elem(
962            (n_freqs, n_freqs, 1),
963            scirs2_core::numeric::Complex::new(F::zero(), F::zero()),
964        );
965        let bicoherence = Array3::zeros((n_freqs, n_freqs, 1));
966        Ok((bispectrum, bicoherence))
967    }
968
969    fn compute_trispectrum(
970        &self,
971        signal: &ArrayView1<F>,
972    ) -> StatsResult<(Array4<scirs2_core::numeric::Complex<F>>, Array4<F>)> {
973        let n = signal.len();
974        let n_freqs = n / 2 + 1;
975        let trispectrum = Array4::from_elem(
976            (n_freqs, n_freqs, n_freqs, 1),
977            scirs2_core::numeric::Complex::new(F::zero(), F::zero()),
978        );
979        let tricoherence = Array4::zeros((n_freqs, n_freqs, n_freqs, 1));
980        Ok((trispectrum, tricoherence))
981    }
982
983    fn neural_denoising(&self, psd: &Array2<F>) -> StatsResult<Array2<F>> {
984        // Simplified neural denoising - would implement actual neural network
985        Ok(psd.clone())
986    }
987
988    fn spectral_super_resolution(&self, psd: &Array2<F>) -> StatsResult<Array2<F>> {
989        // Simplified super-resolution - would implement GAN-based approach
990        let (n_freqs, n_channels) = psd.dim();
991        Ok(Array2::zeros((n_freqs * 2, n_channels))) // Double frequency resolution
992    }
993}
994
995/// Spectral peak characteristics
996#[derive(Debug, Clone)]
997pub struct SpectralPeak<F> {
998    /// Peak frequency
999    pub frequency: F,
1000    /// Peak amplitude
1001    pub amplitude: F,
1002    /// Peak phase
1003    pub phase: F,
1004    /// Peak bandwidth
1005    pub bandwidth: F,
1006    /// Quality factor (Q = f0/bandwidth)
1007    pub quality_factor: F,
1008    /// Confidence level
1009    pub confidence: F,
1010}
1011
1012impl<F> Default for AdvancedSpectralConfig<F>
1013where
1014    F: Float + NumCast + FloatConst + Copy + std::fmt::Display,
1015{
1016    fn default() -> Self {
1017        Self {
1018            fs: F::one(),
1019            windows: vec![WindowFunction::Hann],
1020            multitaper_config: MultiTaperConfig {
1021                nw: F::from(4.0).expect("Failed to convert constant to float"),
1022                k: 7,
1023                adaptive: true,
1024                jackknife: true,
1025                f_test: true,
1026            },
1027            wavelet_config: WaveletConfig {
1028                wavelet_type: WaveletType::Morlet,
1029                scales: 64,
1030                f_min: F::from(0.1).expect("Failed to convert constant to float"),
1031                f_max: F::from(0.5).expect("Failed to convert constant to float"),
1032                q_factor: F::from(5.0).expect("Failed to convert constant to float"),
1033                continuous: true,
1034                packet_transform: false,
1035            },
1036            hos_config: HigherOrderSpectralConfig {
1037                compute_bispectrum: false,
1038                compute_trispectrum: false,
1039                max_lag: 100,
1040                overlap: F::from(0.5).expect("Failed to convert constant to float"),
1041                segment_length: 512,
1042            },
1043            coherence_config: CoherenceConfig {
1044                magnitude_squared: true,
1045                complex_coherence: true,
1046                partial_coherence: false,
1047                multiple_coherence: false,
1048                frequency_resolution: F::from(0.01).expect("Failed to convert constant to float"),
1049                confidence_level: F::from(0.95).expect("Failed to convert constant to float"),
1050            },
1051            nonstationary_config: NonStationaryConfig {
1052                stft_windowsize: 256,
1053                stft_overlap: F::from(0.75).expect("Failed to convert constant to float"),
1054                spectrogram_type: SpectrogramType::PowerSpectralDensity,
1055                time_varying: true,
1056                adaptive_window: false,
1057            },
1058            ml_config: MLSpectralConfig {
1059                neural_enhancement: false,
1060                autoencoder_denoising: false,
1061                adversarial_sr: false,
1062                rl_adaptation: false,
1063                network_params: NetworkParams {
1064                    hiddensizes: vec![128, 64, 32],
1065                    activation: ActivationFunction::ReLU,
1066                    learning_rate: F::from(0.001).expect("Failed to convert constant to float"),
1067                    regularization: F::from(0.01).expect("Failed to convert constant to float"),
1068                    epochs: 100,
1069                },
1070            },
1071            parallel_config: ParallelConfig::default(),
1072        }
1073    }
1074}
1075
1076#[cfg(test)]
1077mod tests {
1078    use super::*;
1079    use scirs2_core::ndarray::array;
1080
1081    #[test]
1082    fn test_spectral_analyzer_creation() {
1083        let config = AdvancedSpectralConfig::default();
1084        let analyzer = AdvancedSpectralAnalyzer::<f64>::new(config);
1085
1086        assert_eq!(analyzer.config.fs, 1.0);
1087        assert_eq!(analyzer.config.multitaper_config.k, 7);
1088    }
1089
1090    #[test]
1091    fn test_window_generation() {
1092        let config = AdvancedSpectralConfig::default();
1093        let analyzer = AdvancedSpectralAnalyzer::<f64>::new(config);
1094
1095        let window = analyzer
1096            .generate_window(WindowFunction::Hann, 10)
1097            .expect("Operation failed");
1098        assert_eq!(window.len(), 10);
1099        assert!(window[0] < window[5]); // Window should be peaked in middle
1100    }
1101
1102    #[test]
1103    fn test_frequency_grid() {
1104        let mut config = AdvancedSpectralConfig::default();
1105        config.fs = 100.0;
1106        let analyzer = AdvancedSpectralAnalyzer::<f64>::new(config);
1107
1108        let freqs = analyzer.generate_frequency_grid(20);
1109        assert_eq!(freqs.len(), 11); // n/2 + 1
1110        assert_eq!(freqs[0], 0.0);
1111        assert!((freqs[freqs.len() - 1] - 50.0).abs() < 1e-10); // Nyquist frequency
1112    }
1113
1114    #[test]
1115    fn test_comprehensive_analysis() {
1116        let config = AdvancedSpectralConfig::default();
1117        let mut analyzer = AdvancedSpectralAnalyzer::<f64>::new(config);
1118
1119        // Generate test signal: sine wave + noise
1120        let n = 128;
1121        let mut signal = Array1::zeros(n);
1122        for i in 0..n {
1123            let t = i as f64 / 10.0;
1124            signal[i] = (2.0 * std::f64::consts::PI * t).sin() + 0.1 * (i as f64).sin();
1125        }
1126
1127        let result = analyzer
1128            .analyze_comprehensive(&signal.view())
1129            .expect("Operation failed");
1130
1131        assert!(result.frequencies.len() > 0);
1132        assert!(result.psd.nrows() > 0);
1133        assert!(result.performance.timing.contains_key("total_analysis"));
1134    }
1135
1136    #[test]
1137    fn test_time_frequency_analysis() {
1138        let mut config = AdvancedSpectralConfig::default();
1139        config.nonstationary_config.stft_windowsize = 32;
1140        config.nonstationary_config.stft_overlap = 0.5;
1141
1142        let mut analyzer = AdvancedSpectralAnalyzer::<f64>::new(config);
1143
1144        let signal = array![
1145            1.0, 2.0, 3.0, 4.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 4.0, 3.0,
1146            2.0, 1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0, 1.0, 2.0, 3.0, 4.0,
1147            5.0, 4.0, 3.0, 2.0, 1.0, 0.0
1148        ];
1149
1150        let result = analyzer
1151            .time_frequency_analysis(&signal.view())
1152            .expect("Operation failed");
1153
1154        assert!(result.ndim() == 3);
1155        assert!(result.shape()[0] > 0); // Frequency bins
1156        assert!(result.shape()[1] > 0); // Time bins
1157    }
1158
1159    #[test]
1160    fn test_spectral_peak_detection() {
1161        let config = AdvancedSpectralConfig::default();
1162        let analyzer = AdvancedSpectralAnalyzer::<f64>::new(config);
1163
1164        // Create PSD with clear peaks
1165        let psd = array![1.0, 2.0, 10.0, 2.0, 1.0, 3.0, 15.0, 3.0, 1.0, 2.0];
1166        let freqs = array![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
1167
1168        let peaks = analyzer
1169            .detect_spectral_peaks(&psd.view(), &freqs.view())
1170            .expect("Operation failed");
1171
1172        assert!(peaks.len() >= 2); // Should detect at least 2 peaks
1173        assert!(peaks[0].amplitude >= peaks[1].amplitude); // Should be sorted by amplitude
1174    }
1175}