use crate::error::{FFTError, FFTResult};
use crate::fft::fft;
use crate::window::{get_window, Window};
use scirs2_core::numeric::Complex64;
use std::f64::consts::PI;
use super::types::{CyclostationaryConfig, SpectralCorrelationResult};
pub struct CyclostationaryAnalyzer;
impl CyclostationaryAnalyzer {
pub fn new() -> Self {
Self
}
pub fn compute_scd(
&self,
signal: &[f64],
fs: f64,
config: &CyclostationaryConfig,
) -> FFTResult<SpectralCorrelationResult> {
let n = signal.len();
if n < config.n_fft {
return Err(FFTError::ValueError(format!(
"Signal length {n} must be >= n_fft {}",
config.n_fft
)));
}
let alpha_vec: Vec<f64> = match &config.cyclic_freqs {
Some(alphas) => alphas.clone(),
None => {
let detected = detect_cyclic_frequencies_impl(signal, fs, config)?;
if detected.is_empty() {
build_alpha_grid(fs, config.alpha_resolution)
} else {
detected
}
}
};
let n_alphas = alpha_vec.len();
let n_fft = config.n_fft;
let hop = compute_hop(n_fft, config.overlap);
let window = build_hann_window(n_fft);
let stft_matrix = compute_stft(signal, &window, n_fft, hop)?;
let n_frames = stft_matrix.len();
if n_frames == 0 {
return Err(FFTError::ComputationError(
"Signal too short for STFT computation".to_string(),
));
}
let spectral_frequencies: Vec<f64> =
(0..n_fft).map(|k| k as f64 * fs / n_fft as f64).collect();
let mut scd_matrix: Vec<Vec<Complex64>> = Vec::with_capacity(n_alphas);
for &alpha in &alpha_vec {
let row = compute_scd_row(&stft_matrix, alpha, fs, n_fft, n_frames)?;
scd_matrix.push(row);
}
Ok(SpectralCorrelationResult {
scd: scd_matrix,
cyclic_frequencies: alpha_vec,
spectral_frequencies,
})
}
pub fn detect_cyclic_frequencies(
&self,
signal: &[f64],
fs: f64,
resolution: f64,
) -> FFTResult<Vec<f64>> {
let config = CyclostationaryConfig {
alpha_resolution: resolution,
fs,
..CyclostationaryConfig::default()
};
detect_cyclic_frequencies_impl(signal, fs, &config)
}
}
impl Default for CyclostationaryAnalyzer {
fn default() -> Self {
Self::new()
}
}
fn build_hann_window(n: usize) -> Vec<f64> {
(0..n)
.map(|i| 0.5 * (1.0 - (2.0 * PI * i as f64 / n as f64).cos()))
.collect()
}
fn compute_hop(n_fft: usize, overlap: f64) -> usize {
let overlap_clamped = overlap.clamp(0.0, 0.99);
let hop = ((1.0 - overlap_clamped) * n_fft as f64).round() as usize;
hop.max(1)
}
fn compute_stft(
signal: &[f64],
window: &[f64],
n_fft: usize,
hop: usize,
) -> FFTResult<Vec<Vec<Complex64>>> {
let n = signal.len();
if n < n_fft {
return Err(FFTError::ValueError(
"Signal must be at least n_fft samples long".to_string(),
));
}
let mut frames: Vec<Vec<Complex64>> = Vec::new();
let mut start = 0;
while start + n_fft <= n {
let segment: Vec<f64> = (0..n_fft).map(|i| signal[start + i] * window[i]).collect();
let spectrum = fft(&segment, None)?;
frames.push(spectrum);
start += hop;
}
Ok(frames)
}
fn compute_scd_row(
stft_matrix: &[Vec<Complex64>],
alpha: f64,
fs: f64,
n_fft: usize,
n_frames: usize,
) -> FFTResult<Vec<Complex64>> {
let delta_k_raw = alpha * n_fft as f64 / fs;
let delta_k = delta_k_raw.round() as i64;
let mut row: Vec<Complex64> = vec![Complex64::new(0.0, 0.0); n_fft];
for frame in stft_matrix.iter().take(n_frames) {
for f_idx in 0..n_fft {
let upper_idx = wrap_bin(f_idx as i64 + delta_k, n_fft);
let lower_idx = wrap_bin(f_idx as i64 - delta_k, n_fft);
let upper = frame[upper_idx];
let lower = frame[lower_idx];
row[f_idx] = row[f_idx] + upper * lower.conj();
}
}
let norm = 1.0 / n_frames as f64;
for val in row.iter_mut() {
*val = Complex64::new(val.re * norm, val.im * norm);
}
Ok(row)
}
#[inline]
fn wrap_bin(idx: i64, n: usize) -> usize {
let n_i = n as i64;
(((idx % n_i) + n_i) % n_i) as usize
}
fn build_alpha_grid(fs: f64, resolution: f64) -> Vec<f64> {
let n_steps = ((fs / resolution).ceil() as usize).max(1);
(0..n_steps)
.map(|i| i as f64 * resolution)
.take_while(|&a| a <= fs / 2.0)
.collect()
}
fn detect_cyclic_frequencies_impl(
signal: &[f64],
fs: f64,
config: &CyclostationaryConfig,
) -> FFTResult<Vec<f64>> {
let n = signal.len();
if n == 0 {
return Err(FFTError::ValueError("Signal must not be empty".to_string()));
}
let power: Vec<f64> = signal.iter().map(|&s| s * s).collect();
let power_spectrum = fft(&power, None)?;
let n_alpha = power_spectrum.len();
let magnitudes: Vec<f64> = power_spectrum
.iter()
.map(|c| (c.re * c.re + c.im * c.im).sqrt())
.collect();
let max_mag = magnitudes[1..].iter().cloned().fold(0.0_f64, f64::max);
if max_mag < f64::EPSILON {
return Ok(Vec::new());
}
let threshold = config.detection_threshold * max_mag;
let mut peaks: Vec<f64> = Vec::new();
for k in 1..(n_alpha / 2 + 1) {
let mag = magnitudes[k];
if mag < threshold {
continue;
}
let prev = if k > 1 { magnitudes[k - 1] } else { 0.0 };
let next = if k + 1 < n_alpha {
magnitudes[k + 1]
} else {
0.0
};
if mag >= prev && mag >= next {
let alpha = k as f64 * fs / n_alpha as f64;
if alpha <= fs / 2.0 {
peaks.push(alpha);
}
}
}
Ok(peaks)
}
pub fn compute_scd(
signal: &[f64],
fs: f64,
cyclic_freqs: Option<Vec<f64>>,
) -> FFTResult<SpectralCorrelationResult> {
let config = CyclostationaryConfig {
cyclic_freqs,
fs,
..CyclostationaryConfig::default()
};
let analyzer = CyclostationaryAnalyzer::new();
analyzer.compute_scd(signal, fs, &config)
}
pub fn detect_cyclic_frequencies(signal: &[f64], fs: f64, resolution: f64) -> FFTResult<Vec<f64>> {
let analyzer = CyclostationaryAnalyzer::new();
analyzer.detect_cyclic_frequencies(signal, fs, resolution)
}