use crate::error::{FFTError, FFTResult};
use crate::fft::{fft, ifft};
use crate::window::{get_window, Window};
use scirs2_core::numeric::Complex64;
use std::f64::consts::PI;
#[derive(Debug, Clone)]
pub struct FrftConfig {
pub order: f64,
}
impl Default for FrftConfig {
fn default() -> Self {
Self { order: 1.0 }
}
}
#[derive(Debug, Clone)]
pub struct StfrftConfig {
pub order: f64,
pub segment_len: usize,
pub overlap: usize,
pub window: Window,
pub fs: f64,
}
impl Default for StfrftConfig {
fn default() -> Self {
Self {
order: 1.0,
segment_len: 256,
overlap: 128,
window: Window::Hann,
fs: 1.0,
}
}
}
#[derive(Debug, Clone)]
pub struct StfrftResult {
pub times: Vec<f64>,
pub freq_indices: Vec<f64>,
pub matrix: Vec<Vec<Complex64>>,
pub order: f64,
}
pub fn frft(signal: &[f64], order: f64) -> FFTResult<Vec<Complex64>> {
if signal.is_empty() {
return Err(FFTError::ValueError("Input signal is empty".to_string()));
}
let x: Vec<Complex64> = signal.iter().map(|&v| Complex64::new(v, 0.0)).collect();
frft_complex(&x, order)
}
pub fn frft_complex(x: &[Complex64], order: f64) -> FFTResult<Vec<Complex64>> {
let n = x.len();
if n == 0 {
return Err(FFTError::ValueError("Input signal is empty".to_string()));
}
let a = ((order % 4.0) + 4.0) % 4.0;
if a.abs() < 1e-12 || (a - 4.0).abs() < 1e-12 {
return Ok(x.to_vec());
}
if (a - 2.0).abs() < 1e-12 {
let mut out = x.to_vec();
out.reverse();
return Ok(out);
}
if (a - 1.0).abs() < 1e-12 {
let spectrum = fft_of_complex(x)?;
let scale = 1.0 / (n as f64).sqrt();
return Ok(spectrum.iter().map(|&c| c * scale).collect());
}
if (a - 3.0).abs() < 1e-12 {
let result = ifft_of_complex(x)?;
let scale = (n as f64).sqrt();
return Ok(result.iter().map(|&c| c * scale).collect());
}
ozaktas_frft(x, a)
}
pub fn ifrft(signal: &[Complex64], order: f64) -> FFTResult<Vec<Complex64>> {
frft_complex(signal, -order)
}
fn ozaktas_frft(x: &[Complex64], a: f64) -> FFTResult<Vec<Complex64>> {
let n = x.len();
let phi = a * PI / 2.0;
let sin_phi = phi.sin();
let cos_phi = phi.cos();
let cot_phi = cos_phi / sin_phi;
let csc_phi = 1.0 / sin_phi;
let norm_factor = (1.0 + cot_phi * cot_phi).sqrt().sqrt() / (n as f64).sqrt();
let n_f = n as f64;
let pre_chirp: Vec<Complex64> = (0..n)
.map(|k| {
let kc = k as f64 - n_f / 2.0;
Complex64::from_polar(1.0, -PI * cot_phi * kc * kc / n_f)
})
.collect();
let modulated: Vec<Complex64> = x
.iter()
.zip(pre_chirp.iter())
.map(|(&xi, &pc)| xi * pc)
.collect();
let pad_len = 2 * n;
let mut kernel = vec![Complex64::new(0.0, 0.0); pad_len];
for k in 0..pad_len {
let kc = if k < n {
k as f64 - n_f / 2.0
} else {
k as f64 - n_f / 2.0 - pad_len as f64 + n_f };
kernel[k] = Complex64::from_polar(1.0, PI * csc_phi * kc * kc / n_f);
}
let mut mod_padded = vec![Complex64::new(0.0, 0.0); pad_len];
for (i, &v) in modulated.iter().enumerate() {
mod_padded[i] = v;
}
let mod_fft = fft_of_complex(&mod_padded)?;
let kern_fft = fft_of_complex(&kernel)?;
let product: Vec<Complex64> = mod_fft
.iter()
.zip(kern_fft.iter())
.map(|(&m, &k)| m * k)
.collect();
let conv_result = ifft_of_complex(&product)?;
let result: Vec<Complex64> = (0..n)
.map(|k| {
let kc = k as f64 - n_f / 2.0;
let post_chirp = Complex64::from_polar(1.0, -PI * cot_phi * kc * kc / n_f);
conv_result[k] * post_chirp * norm_factor
})
.collect();
Ok(result)
}
pub fn stfrft(signal: &[f64], config: &StfrftConfig) -> FFTResult<StfrftResult> {
let n = signal.len();
if n == 0 {
return Err(FFTError::ValueError("Signal is empty".to_string()));
}
let seg_len = config.segment_len;
if seg_len == 0 {
return Err(FFTError::ValueError(
"segment_len must be positive".to_string(),
));
}
if seg_len > n {
return Err(FFTError::ValueError(format!(
"segment_len ({}) exceeds signal length ({})",
seg_len, n
)));
}
let hop = seg_len.saturating_sub(config.overlap).max(1);
let num_segments = (n - seg_len) / hop + 1;
let win = get_window(config.window.clone(), seg_len, true)?;
let win_coeffs: Vec<f64> = win.to_vec();
let times: Vec<f64> = (0..num_segments)
.map(|s| {
let centre_sample = s * hop + seg_len / 2;
centre_sample as f64 / config.fs
})
.collect();
let freq_indices: Vec<f64> = (0..seg_len).map(|k| k as f64).collect();
let mut matrix: Vec<Vec<Complex64>> = Vec::with_capacity(num_segments);
for seg_idx in 0..num_segments {
let start = seg_idx * hop;
let windowed: Vec<Complex64> = (0..seg_len)
.map(|k| Complex64::new(signal[start + k] * win_coeffs[k], 0.0))
.collect();
let frft_result = frft_complex(&windowed, config.order)?;
matrix.push(frft_result);
}
Ok(StfrftResult {
times,
freq_indices,
matrix,
order: config.order,
})
}
fn fft_of_complex(x: &[Complex64]) -> FFTResult<Vec<Complex64>> {
fft(x, None)
}
fn ifft_of_complex(x: &[Complex64]) -> FFTResult<Vec<Complex64>> {
ifft(x, None)
}
#[cfg(test)]
mod tests {
use super::*;
use std::f64::consts::PI;
#[test]
fn test_frft_order_zero_is_identity() {
let signal: Vec<f64> = (0..64)
.map(|i| (2.0 * PI * 5.0 * i as f64 / 64.0).sin())
.collect();
let result = frft(&signal, 0.0).expect("FrFT order=0 should succeed");
for (i, (&orig, &res)) in signal.iter().zip(result.iter()).enumerate() {
assert!(
(res.re - orig).abs() < 1e-10,
"Mismatch at index {}: expected {}, got {}",
i,
orig,
res.re
);
assert!(
res.im.abs() < 1e-10,
"Imaginary part should be zero at index {}",
i
);
}
}
#[test]
fn test_frft_order_one_matches_fft() {
let signal: Vec<f64> = (0..64)
.map(|i| (2.0 * PI * 3.0 * i as f64 / 64.0).cos())
.collect();
let frft_result = frft(&signal, 1.0).expect("FrFT order=1 should succeed");
let fft_result = fft(&signal, None).expect("FFT should succeed");
let scale = 1.0 / (signal.len() as f64).sqrt();
for (i, (&fr, &ff)) in frft_result.iter().zip(fft_result.iter()).enumerate() {
let expected = ff * scale;
assert!(
(fr.re - expected.re).abs() < 1e-8,
"Real mismatch at {}: frft={}, fft_scaled={}",
i,
fr.re,
expected.re
);
assert!(
(fr.im - expected.im).abs() < 1e-8,
"Imag mismatch at {}: frft={}, fft_scaled={}",
i,
fr.im,
expected.im
);
}
}
#[test]
fn test_frft_energy_preservation() {
let signal: Vec<f64> = (0..128)
.map(|i| (2.0 * PI * 7.0 * i as f64 / 128.0).sin() + 0.5)
.collect();
let input_energy: f64 = signal.iter().map(|&v| v * v).sum();
for &order in &[0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5] {
let result = frft(&signal, order).expect("FrFT should succeed");
let output_energy: f64 = result.iter().map(|c| c.norm_sqr()).sum();
let ratio = output_energy / input_energy;
assert!(
ratio > 0.1 && ratio < 10.0,
"Energy ratio for order {}: {} (input={}, output={})",
order,
ratio,
input_energy,
output_energy
);
}
}
#[test]
fn test_frft_inverse_roundtrip() {
let signal: Vec<f64> = (0..64)
.map(|i| (2.0 * PI * 4.0 * i as f64 / 64.0).sin())
.collect();
let x: Vec<Complex64> = signal.iter().map(|&v| Complex64::new(v, 0.0)).collect();
let forward = frft_complex(&x, 1.0).expect("Forward FrFT");
let recovered = ifrft(&forward, 1.0).expect("Inverse FrFT");
for (i, (&orig, &rec)) in x.iter().zip(recovered.iter()).enumerate() {
assert!(
(orig.re - rec.re).abs() < 1e-8,
"Round-trip mismatch at {}: orig={}, recovered={}",
i,
orig.re,
rec.re
);
}
}
#[test]
fn test_stfrft_basic() {
let fs = 1000.0;
let n = 1024;
let signal: Vec<f64> = (0..n)
.map(|i| (2.0 * PI * 50.0 * i as f64 / fs).sin())
.collect();
let config = StfrftConfig {
order: 1.0,
segment_len: 128,
overlap: 64,
window: Window::Hann,
fs,
};
let result = stfrft(&signal, &config).expect("STFRFT should succeed");
let expected_segments = (n - 128) / 64 + 1;
assert_eq!(result.times.len(), expected_segments);
assert_eq!(result.matrix.len(), expected_segments);
assert_eq!(result.matrix[0].len(), 128);
assert_eq!(result.freq_indices.len(), 128);
for row in &result.matrix {
let energy: f64 = row.iter().map(|c| c.norm_sqr()).sum();
assert!(energy > 0.0, "Segment should have non-zero energy");
}
}
#[test]
fn test_stfrft_chirp_localization() {
let fs = 1000.0;
let n = 2048;
let signal: Vec<f64> = (0..n)
.map(|i| {
let t = i as f64 / fs;
let freq = 50.0 + 150.0 * t * fs / (n as f64);
(2.0 * PI * freq * t).sin()
})
.collect();
let config = StfrftConfig {
order: 1.0,
segment_len: 128,
overlap: 96,
window: Window::Hann,
fs,
};
let result = stfrft(&signal, &config).expect("STFRFT should succeed");
let peak_indices: Vec<usize> = result
.matrix
.iter()
.map(|row| {
row.iter()
.enumerate()
.max_by(|(_, a), (_, b)| {
a.norm_sqr()
.partial_cmp(&b.norm_sqr())
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(idx, _)| idx)
.unwrap_or(0)
})
.collect();
let n_seg = peak_indices.len();
if n_seg >= 4 {
let first_q: f64 = peak_indices[..n_seg / 4]
.iter()
.map(|&v| v as f64)
.sum::<f64>()
/ (n_seg / 4) as f64;
let last_q: f64 = peak_indices[3 * n_seg / 4..]
.iter()
.map(|&v| v as f64)
.sum::<f64>()
/ (n_seg - 3 * n_seg / 4) as f64;
assert!(
last_q != first_q || n_seg < 8,
"Chirp should show frequency variation over time"
);
}
}
#[test]
fn test_stfrft_invalid_params() {
let signal = vec![1.0; 100];
let config = StfrftConfig {
segment_len: 0,
..Default::default()
};
assert!(stfrft(&signal, &config).is_err());
let config2 = StfrftConfig {
segment_len: 200,
..Default::default()
};
assert!(stfrft(&signal, &config2).is_err());
}
}