use crate::error::FFTResult;
use crate::fft::{fft, ifft};
use scirs2_core::numeric::Complex64;
use scirs2_core::numeric::Zero;
use std::f64::consts::PI;
#[allow(dead_code)]
pub fn frft_ozaktas<T>(x: &[T], alpha: f64) -> FFTResult<Vec<Complex64>>
where
T: Copy + Into<f64>,
{
let n = x.len();
if n == 0 {
return Ok(vec![]);
}
if (alpha % 4.0).abs() < 1e-10 {
return handle_special_cases(x, alpha);
}
let x_complex: Vec<Complex64> = x
.iter()
.map(|&val| Complex64::new(val.into(), 0.0))
.collect();
let phi = alpha * PI / 2.0;
if (phi % PI).abs() < 1e-10 {
return handle_near_special_angles(&x_complex, phi);
}
let sin_phi = phi.sin();
let _cos_phi = phi.cos();
let tan_phi_2 = (phi / 2.0).tan();
let scale = (1.0 - sin_phi).abs().sqrt();
let pre_chirp = compute_stable_chirp(n, tan_phi_2);
let mut x_chirped: Vec<Complex64> = x_complex
.iter()
.zip(pre_chirp.iter())
.map(|(&x, &chirp)| x * chirp)
.collect();
let padded_len = 2 * n;
x_chirped.resize(padded_len, Complex64::zero());
apply_tukey_window(&mut x_chirped, n);
let x_fft = fft(&x_chirped, None)?;
let post_chirp = compute_stable_chirp(padded_len, tan_phi_2);
let x_post: Vec<Complex64> = x_fft
.iter()
.zip(post_chirp.iter())
.map(|(&x, &chirp)| x * chirp)
.collect();
let x_ifft = ifft(&x_post, None)?;
let final_chirp = compute_stable_chirp(n, tan_phi_2);
let mut result: Vec<Complex64> = x_ifft
.iter()
.take(n)
.zip(final_chirp.iter())
.map(|(&x, &chirp)| x * chirp * scale)
.collect();
post_process_result(&mut result, alpha);
Ok(result)
}
#[allow(dead_code)]
fn compute_stable_chirp(n: usize, param: f64) -> Vec<Complex64> {
let mut chirp = Vec::with_capacity(n);
let n_f64 = n as f64;
for k in 0..n {
let k_centered = k as f64 - n_f64 / 2.0;
let arg = PI * param * k_centered * k_centered / n_f64;
chirp.push(Complex64::from_polar(1.0, arg));
}
chirp
}
#[allow(dead_code)]
fn apply_tukey_window(x: &mut [Complex64], originallen: usize) {
let alpha = 0.1; let taper_len = (alpha * originallen as f64) as usize;
for i in 0..taper_len {
let ratio = i as f64 / taper_len as f64;
let window = 0.5 * (1.0 - (PI * ratio).cos());
x[i] *= window;
x[originallen - 1 - i] *= window;
}
}
#[allow(dead_code)]
fn handle_special_cases<T>(x: &[T], alpha: f64) -> FFTResult<Vec<Complex64>>
where
T: Copy + Into<f64>,
{
let k = (alpha % 4.0 + 4.0) % 4.0;
if k.abs() < 1e-10 {
Ok(x.iter()
.map(|&val| Complex64::new(val.into(), 0.0))
.collect())
} else if (k - 1.0).abs() < 1e-10 {
let complex_x: Vec<Complex64> = x
.iter()
.map(|&val| Complex64::new(val.into(), 0.0))
.collect();
fft(&complex_x, None)
} else if (k - 2.0).abs() < 1e-10 {
Ok(x.iter()
.rev()
.map(|&val| Complex64::new(val.into(), 0.0))
.collect())
} else {
let complex_x: Vec<Complex64> = x
.iter()
.map(|&val| Complex64::new(val.into(), 0.0))
.collect();
ifft(&complex_x, None)
}
}
#[allow(dead_code)]
fn handle_near_special_angles(x: &[Complex64], phi: f64) -> FFTResult<Vec<Complex64>> {
let k = (phi / PI).round() as i32;
if k % 2 == 0 {
if k % 4 == 0 {
Ok(x.to_vec())
} else {
Ok(x.iter().rev().copied().collect())
}
} else {
if (k - 1) % 4 == 0 {
fft(x, None)
} else {
ifft(x, None)
}
}
}
#[allow(dead_code)]
fn post_process_result(result: &mut [Complex64], alpha: f64) {
let phase = alpha * PI / 4.0;
let phase_correction = Complex64::from_polar(1.0, phase);
for val in result.iter_mut() {
*val *= phase_correction;
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_ozaktas_identity() {
let signal = vec![1.0, 2.0, 3.0, 4.0];
let result = frft_ozaktas(&signal, 0.0).expect("Operation failed");
for (i, &val) in signal.iter().enumerate() {
assert_relative_eq!(result[i].re, val, epsilon = 1e-10);
assert_relative_eq!(result[i].im, 0.0, epsilon = 1e-10);
}
}
#[test]
fn test_ozaktas_fourier() {
let signal = vec![1.0, 0.0, -1.0, 0.0];
let frft_result = frft_ozaktas(&signal, 1.0).expect("Operation failed");
let fft_result = fft(
&signal
.iter()
.map(|&x| Complex64::new(x, 0.0))
.collect::<Vec<_>>(),
None,
)
.expect("Operation failed");
let frft_norm: f64 = frft_result.iter().map(|c| c.norm_sqr()).sum();
let fft_norm: f64 = fft_result.iter().map(|c| c.norm_sqr()).sum();
if frft_norm > 0.0 && fft_norm > 0.0 {
let scale = (fft_norm / frft_norm).sqrt();
for (&frft_val, &fft_val) in frft_result.iter().zip(fft_result.iter()) {
assert_relative_eq!(frft_val.re * scale, fft_val.re, epsilon = 1e-4);
assert_relative_eq!(frft_val.im * scale, fft_val.im, epsilon = 1e-4);
}
}
}
#[test]
fn test_ozaktas_additivity() {
let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let alpha1 = 0.3;
let alpha2 = 0.4;
let direct = frft_ozaktas(&signal, alpha1 + alpha2).expect("Operation failed");
let intermediate = frft_ozaktas(&signal, alpha1).expect("Operation failed");
let sequential = frft_ozaktas(
&intermediate.iter().map(|&c| c.re).collect::<Vec<_>>(),
alpha2,
)
.expect("Operation failed");
let direct_energy: f64 = direct.iter().map(|c| c.norm_sqr()).sum();
let sequential_energy: f64 = sequential.iter().map(|c| c.norm_sqr()).sum();
let energy_ratio = direct_energy / sequential_energy;
assert!(
energy_ratio > 0.01 && energy_ratio < 100.0,
"Energy ratio {energy_ratio} is outside acceptable range"
);
}
}