use crate::error::FFTResult;
use crate::fft::{fft, ifft};
use scirs2_core::ndarray::Array2;
use scirs2_core::numeric::Complex64;
use std::f64::consts::PI;
#[allow(dead_code)]
pub fn frft_dft<T>(x: &[T], alpha: f64) -> FFTResult<Vec<Complex64>>
where
T: Copy + Into<f64>,
{
let n = x.len();
if n == 0 {
return Ok(vec![]);
}
let x_complex: Vec<Complex64> = x
.iter()
.map(|&val| Complex64::new(val.into(), 0.0))
.collect();
let alpha_mod = alpha.rem_euclid(4.0);
if alpha_mod.abs() < 1e-10 {
return Ok(x_complex);
} else if (alpha_mod - 1.0).abs() < 1e-10 {
return fft(&x_complex, None);
} else if (alpha_mod - 2.0).abs() < 1e-10 {
return Ok(x_complex.into_iter().rev().collect());
} else if (alpha_mod - 3.0).abs() < 1e-10 {
return ifft(&x_complex, None);
}
let _angle = alpha * PI / 2.0;
let eigenvectors = compute_dft_eigenvectors(n);
let eigenvalues = compute_dft_eigenvalues(n);
let mut coefficients = vec![Complex64::new(0.0, 0.0); n];
for k in 0..n {
for j in 0..n {
coefficients[k] += x_complex[j] * eigenvectors[(j, k)].conj();
}
}
for k in 0..n {
let fractional_eigenvalue = eigenvalues[k].powc(Complex64::new(alpha, 0.0));
coefficients[k] *= fractional_eigenvalue;
}
let mut result = vec![Complex64::new(0.0, 0.0); n];
for j in 0..n {
for k in 0..n {
result[j] += coefficients[k] * eigenvectors[(j, k)];
}
}
Ok(result)
}
#[allow(dead_code)]
fn compute_dft_eigenvectors(n: usize) -> Array2<Complex64> {
let mut eigenvectors = Array2::zeros((n, n));
let n_f64 = n as f64;
for k in 0..n {
for j in 0..n {
let x = (j as f64 - n_f64 / 2.0) / (n_f64 / 4.0).sqrt();
let hermite_value = hermite_function(k, x);
let phase = Complex64::new(0.0, -PI * j as f64 * k as f64 / n_f64).exp();
eigenvectors[(j, k)] = hermite_value * phase;
}
}
for k in 0..n {
let norm: f64 = (0..n)
.map(|j| eigenvectors[(j, k)].norm_sqr())
.sum::<f64>()
.sqrt();
if norm > 0.0 {
for j in 0..n {
eigenvectors[(j, k)] /= norm;
}
}
}
eigenvectors
}
#[allow(dead_code)]
fn compute_dft_eigenvalues(n: usize) -> Vec<Complex64> {
let mut eigenvalues = vec![Complex64::new(0.0, 0.0); n];
for (k, eigenvalue) in eigenvalues.iter_mut().enumerate().take(n) {
let eigenvalue_index = k % 4;
*eigenvalue = match eigenvalue_index {
0 => Complex64::new(1.0, 0.0),
1 => Complex64::new(0.0, -1.0),
2 => Complex64::new(-1.0, 0.0),
3 => Complex64::new(0.0, 1.0),
_ => unreachable!(),
};
}
eigenvalues
}
#[allow(dead_code)]
fn hermite_function(n: usize, x: f64) -> Complex64 {
let hermite = match n {
0 => 1.0,
1 => 2.0 * x,
2 => 4.0 * x * x - 2.0,
3 => 8.0 * x * x * x - 12.0 * x,
_ => {
let mut h_prev = 4.0 * x * x - 2.0;
let mut h_curr = 8.0 * x * x * x - 12.0 * x;
for k in 4..=n {
let h_next = 2.0 * x * h_curr - 2.0 * (k - 1) as f64 * h_prev;
h_prev = h_curr;
h_curr = h_next;
}
h_curr
}
};
let gaussian = (-x * x / 2.0).exp();
Complex64::new(hermite * gaussian, 0.0)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_dft_identity() {
let signal = vec![1.0, 2.0, 3.0, 4.0];
let result = frft_dft(&signal, 0.0).expect("Operation failed");
for (i, &val) in signal.iter().enumerate() {
assert_relative_eq!(result[i].re, val, epsilon = 1e-6);
assert_relative_eq!(result[i].im, 0.0, epsilon = 1e-6);
}
}
#[test]
fn test_dft_energy_conservation() {
let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let input_energy: f64 = signal.iter().map(|&x| x * x).sum();
for alpha in &[0.0, 2.0] {
let result = frft_dft(&signal, *alpha).expect("Operation failed");
let output_energy: f64 = result.iter().map(|c| c.norm_sqr()).sum();
assert_relative_eq!(output_energy, input_energy, epsilon = 1e-10);
}
for alpha in &[1.0, 3.0] {
let result = frft_dft(&signal, *alpha).expect("Operation failed");
let output_energy: f64 = result.iter().map(|c| c.norm_sqr()).sum();
let ratio = output_energy / input_energy;
assert!(
ratio > 0.1 && ratio < 10.0,
"Energy ratio {ratio} for alpha {alpha} is outside acceptable range"
);
}
for alpha in &[0.1, 0.5, 1.5, 2.5, 3.5] {
let result = frft_dft(&signal, *alpha).expect("Operation failed");
let output_energy: f64 = result.iter().map(|c| c.norm_sqr()).sum();
let ratio = output_energy / input_energy;
assert!(
ratio > 0.01 && ratio < 100.0,
"Energy ratio {ratio} for alpha {alpha} is completely unreasonable"
);
}
}
}