use crate::error::{FFTError, FFTResult};
use crate::fft::{fft, ifft};
use crate::window::{get_window, Window};
use scirs2_core::numeric::Complex64;
use std::f64::consts::PI;
#[derive(Debug, Clone)]
pub struct CyclostationaryConfig {
pub fft_size: usize,
pub window: Window,
pub overlap: usize,
pub num_averages: Option<usize>,
pub fs: f64,
}
impl Default for CyclostationaryConfig {
fn default() -> Self {
Self {
fft_size: 256,
window: Window::Hann,
overlap: 128,
num_averages: None,
fs: 1.0,
}
}
}
#[derive(Debug, Clone)]
pub struct CyclostationaryResult {
pub frequencies: Vec<f64>,
pub alphas: Vec<f64>,
pub scf: Vec<Vec<f64>>,
}
fn channelise(
signal: &[f64],
fft_size: usize,
window_coeffs: &[f64],
hop: usize,
) -> FFTResult<Vec<Vec<Complex64>>> {
let n = signal.len();
if fft_size == 0 {
return Err(FFTError::ValueError(
"fft_size must be positive".to_string(),
));
}
if hop == 0 {
return Err(FFTError::ValueError("hop must be positive".to_string()));
}
if fft_size > n {
return Err(FFTError::ValueError(format!(
"fft_size ({}) exceeds signal length ({})",
fft_size, n
)));
}
let num_segments = (n - fft_size) / hop + 1;
let mut segments: Vec<Vec<Complex64>> = Vec::with_capacity(num_segments);
for seg_idx in 0..num_segments {
let start = seg_idx * hop;
let windowed: Vec<f64> = (0..fft_size)
.map(|k| signal[start + k] * window_coeffs[k])
.collect();
let spectrum = fft(&windowed, None)?;
segments.push(spectrum);
}
Ok(segments)
}
pub fn spectral_correlation_function(
signal: &[f64],
config: &CyclostationaryConfig,
) -> FFTResult<CyclostationaryResult> {
let n = signal.len();
if n == 0 {
return Err(FFTError::ValueError("Signal is empty".to_string()));
}
let fft_size = config.fft_size;
let hop = fft_size.saturating_sub(config.overlap).max(1);
let win = get_window(config.window.clone(), fft_size, true)?;
let window_coeffs: Vec<f64> = win.to_vec();
let segments = channelise(signal, fft_size, &window_coeffs, hop)?;
let num_segments = segments.len();
let max_averages = config
.num_averages
.unwrap_or(num_segments)
.min(num_segments);
if max_averages == 0 {
return Err(FFTError::ValueError(
"Not enough data for even one average".to_string(),
));
}
let freq_resolution = config.fs / fft_size as f64;
let frequencies: Vec<f64> = (0..fft_size)
.map(|k| {
let k_shifted = if k < fft_size / 2 {
k as f64
} else {
k as f64 - fft_size as f64
};
k_shifted * freq_resolution
})
.collect();
let num_alpha = fft_size;
let alphas: Vec<f64> = (0..num_alpha)
.map(|a| {
let a_shifted = if a < num_alpha / 2 {
a as f64
} else {
a as f64 - num_alpha as f64
};
a_shifted * freq_resolution
})
.collect();
let mut scf: Vec<Vec<f64>> = vec![vec![0.0; fft_size]; num_alpha];
for a_idx in 0..num_alpha {
let shift = a_idx;
let mut accum: Vec<Complex64> = vec![Complex64::new(0.0, 0.0); fft_size];
for seg in segments.iter().take(max_averages) {
for k in 0..fft_size {
let upper = (k + shift) % fft_size;
let lower = k;
accum[k] += seg[upper] * seg[lower].conj();
}
}
let inv_p = 1.0 / max_averages as f64;
for k in 0..fft_size {
scf[a_idx][k] = (accum[k] * inv_p).norm();
}
}
Ok(CyclostationaryResult {
frequencies,
alphas,
scf,
})
}
pub fn cyclic_spectral_density(
signal: &[f64],
alpha: f64,
config: &CyclostationaryConfig,
) -> FFTResult<(Vec<f64>, Vec<f64>)> {
let result = spectral_correlation_function(signal, config)?;
let best_idx = result
.alphas
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| {
let da = (*a - alpha).abs();
let db = (*b - alpha).abs();
da.partial_cmp(&db).unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i)
.unwrap_or(0);
Ok((result.frequencies, result.scf[best_idx].clone()))
}
pub fn cyclic_frequency_detection(
signal: &[f64],
config: &CyclostationaryConfig,
) -> FFTResult<(Vec<f64>, Vec<f64>)> {
let result = spectral_correlation_function(signal, config)?;
let num_freqs = result.frequencies.len();
let profile: Vec<f64> = result
.scf
.iter()
.map(|row| {
let sum: f64 = row.iter().sum();
sum / num_freqs as f64
})
.collect();
Ok((result.alphas, profile))
}
pub fn spectral_coherence_cyclic(
signal: &[f64],
config: &CyclostationaryConfig,
) -> FFTResult<CyclostationaryResult> {
let result = spectral_correlation_function(signal, config)?;
let fft_size = result.frequencies.len();
let psd = &result.scf[0];
let mut coherence: Vec<Vec<f64>> = vec![vec![0.0; fft_size]; result.alphas.len()];
for (a_idx, row) in result.scf.iter().enumerate() {
let shift = a_idx;
for k in 0..fft_size {
let upper = (k + shift) % fft_size;
let denom = (psd[upper] * psd[k]).sqrt();
coherence[a_idx][k] = if denom > 1e-30 { row[k] / denom } else { 0.0 };
}
}
Ok(CyclostationaryResult {
frequencies: result.frequencies,
alphas: result.alphas,
scf: coherence,
})
}
#[cfg(test)]
mod tests {
use super::*;
use std::f64::consts::PI;
fn sine_signal(freq: f64, fs: f64, n: usize) -> Vec<f64> {
(0..n)
.map(|i| (2.0 * PI * freq * i as f64 / fs).sin())
.collect()
}
fn pseudo_noise(n: usize, seed: u64) -> Vec<f64> {
let mut state = seed;
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(state >> 33) as f64 / (1u64 << 31) as f64 - 1.0
})
.collect()
}
#[test]
fn test_scf_sinusoid_peaks() {
let fs = 1000.0;
let f0 = 100.0;
let n = 4096;
let signal = sine_signal(f0, fs, n);
let config = CyclostationaryConfig {
fft_size: 256,
window: Window::Hann,
overlap: 128,
num_averages: None,
fs,
};
let result = spectral_correlation_function(&signal, &config)
.expect("SCF computation should succeed");
let psd_row = &result.scf[0];
let peak_idx = psd_row
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.expect("psd row should be non-empty");
let peak_freq = result.frequencies[peak_idx].abs();
assert!(
(peak_freq - f0).abs() < 20.0,
"PSD peak at {} Hz, expected near {} Hz",
peak_freq,
f0
);
let alpha_profile: Vec<f64> = result
.scf
.iter()
.map(|row| row.iter().sum::<f64>())
.collect();
let max_profile = alpha_profile
.iter()
.copied()
.fold(f64::NEG_INFINITY, f64::max);
assert!(max_profile > 0.0, "SCF should have non-zero content");
}
#[test]
fn test_scf_white_noise_flat() {
let fs = 1000.0;
let n = 4096;
let signal = pseudo_noise(n, 42);
let config = CyclostationaryConfig {
fft_size: 256,
window: Window::Hann,
overlap: 128,
num_averages: None,
fs,
};
let (alphas, profile) =
cyclic_frequency_detection(&signal, &config).expect("Detection should succeed");
let psd_energy = profile[0];
let non_zero_alphas: Vec<f64> = profile.iter().skip(1).copied().collect();
let mean_nz: f64 = non_zero_alphas.iter().sum::<f64>() / non_zero_alphas.len() as f64;
let max_nz = non_zero_alphas
.iter()
.copied()
.fold(f64::NEG_INFINITY, f64::max);
assert!(
max_nz < psd_energy,
"Noise cyclic features ({}) should be smaller than PSD ({})",
max_nz,
psd_energy
);
let var: f64 = non_zero_alphas
.iter()
.map(|&v| (v - mean_nz).powi(2))
.sum::<f64>()
/ non_zero_alphas.len() as f64;
let std_dev = var.sqrt();
if mean_nz > 1e-15 {
let cv = std_dev / mean_nz;
assert!(
cv < 2.0,
"Noise alpha profile should be relatively flat, CV = {}",
cv
);
}
}
#[test]
fn test_cyclic_spectral_density_extraction() {
let fs = 1000.0;
let f0 = 100.0;
let n = 4096;
let signal = sine_signal(f0, fs, n);
let config = CyclostationaryConfig {
fft_size: 256,
window: Window::Hann,
overlap: 128,
num_averages: None,
fs,
};
let (freqs, csd) =
cyclic_spectral_density(&signal, 0.0, &config).expect("CSD should succeed");
assert_eq!(freqs.len(), csd.len());
assert!(!csd.is_empty());
let peak_idx = csd
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.expect("CSD should have elements");
let peak_freq = freqs[peak_idx].abs();
assert!(
(peak_freq - f0).abs() < 20.0,
"CSD peak at {} Hz, expected near {} Hz",
peak_freq,
f0
);
}
#[test]
fn test_spectral_coherence() {
let fs = 1000.0;
let f0 = 100.0;
let n = 4096;
let signal = sine_signal(f0, fs, n);
let config = CyclostationaryConfig {
fft_size: 256,
window: Window::Hann,
overlap: 128,
num_averages: None,
fs,
};
let coherence =
spectral_coherence_cyclic(&signal, &config).expect("Coherence should succeed");
let alpha0_row = &coherence.scf[0];
for &c in alpha0_row {
assert!(c <= 1.0 + 1e-10, "Coherence should be <= 1, got {}", c);
assert!(c >= 0.0, "Coherence should be >= 0, got {}", c);
}
}
#[test]
fn test_scf_invalid_params() {
let signal = vec![1.0; 100];
let config = CyclostationaryConfig {
fft_size: 0,
..Default::default()
};
assert!(spectral_correlation_function(&signal, &config).is_err());
let config2 = CyclostationaryConfig {
fft_size: 200,
..Default::default()
};
assert!(spectral_correlation_function(&signal, &config2).is_err());
}
}