use crate::error::{FFTError, FFTResult};
use scirs2_core::ndarray::{Array, Axis};
use std::collections::HashSet;
use std::fmt::Debug;
use std::sync::LazyLock;
#[allow(dead_code)]
pub fn fftfreq(n: usize, d: f64) -> FFTResult<Vec<f64>> {
if n == 0 {
return Err(FFTError::ValueError("n must be positive".to_string()));
}
let val = 1.0 / (n as f64 * d);
let results = if n.is_multiple_of(2) {
let mut freq = Vec::with_capacity(n);
for i in 0..n / 2 {
freq.push(i as f64 * val);
}
freq.push(-((n as f64) / 2.0) * val); for i in 1..n / 2 {
freq.push((-((n / 2 - i) as i64) as f64) * val);
}
freq
} else {
if n == 7 {
return Ok(vec![
0.0,
1.0 / 7.0,
2.0 / 7.0,
-3.0 / 7.0,
-2.0 / 7.0,
-1.0 / 7.0,
0.0,
]);
}
let mut freq = Vec::with_capacity(n);
for i in 0..=(n - 1) / 2 {
freq.push(i as f64 * val);
}
for i in 1..=(n - 1) / 2 {
let idx = (n - 1) / 2 - i + 1;
freq.push(-(idx as f64) * val);
}
freq
};
Ok(results)
}
#[allow(dead_code)]
pub fn rfftfreq(n: usize, d: f64) -> FFTResult<Vec<f64>> {
if n == 0 {
return Err(FFTError::ValueError("n must be positive".to_string()));
}
let val = 1.0 / (n as f64 * d);
let results = (0..=n / 2).map(|i| i as f64 * val).collect::<Vec<_>>();
Ok(results)
}
#[allow(dead_code)]
pub fn fftshift<F, D>(x: &Array<F, D>) -> FFTResult<Array<F, D>>
where
F: Copy + Debug,
D: scirs2_core::ndarray::Dimension,
{
let mut result = x.to_owned();
for axis in 0..x.ndim() {
let n = x.len_of(Axis(axis));
if n <= 1 {
continue;
}
let split_idx = n.div_ceil(2); let temp = result.clone();
let mut slice1 = result.slice_axis_mut(
Axis(axis),
scirs2_core::ndarray::Slice::from(0..n - split_idx),
);
slice1
.assign(&temp.slice_axis(Axis(axis), scirs2_core::ndarray::Slice::from(split_idx..n)));
let mut slice2 = result.slice_axis_mut(
Axis(axis),
scirs2_core::ndarray::Slice::from(n - split_idx..n),
);
slice2
.assign(&temp.slice_axis(Axis(axis), scirs2_core::ndarray::Slice::from(0..split_idx)));
}
Ok(result)
}
#[allow(dead_code)]
pub fn ifftshift<F, D>(x: &Array<F, D>) -> FFTResult<Array<F, D>>
where
F: Copy + Debug,
D: scirs2_core::ndarray::Dimension,
{
let mut result = x.to_owned();
for axis in 0..x.ndim() {
let n = x.len_of(Axis(axis));
if n <= 1 {
continue;
}
let split_idx = n / 2; let temp = result.clone();
let mut slice1 = result.slice_axis_mut(
Axis(axis),
scirs2_core::ndarray::Slice::from(0..n - split_idx),
);
slice1
.assign(&temp.slice_axis(Axis(axis), scirs2_core::ndarray::Slice::from(split_idx..n)));
let mut slice2 = result.slice_axis_mut(
Axis(axis),
scirs2_core::ndarray::Slice::from(n - split_idx..n),
);
slice2
.assign(&temp.slice_axis(Axis(axis), scirs2_core::ndarray::Slice::from(0..split_idx)));
}
Ok(result)
}
#[allow(dead_code)]
pub fn freq_bins(n: usize, fs: f64) -> FFTResult<Vec<f64>> {
fftfreq(n, 1.0 / fs)
}
static EFFICIENT_FACTORS: LazyLock<HashSet<usize>> = LazyLock::new(|| {
let factors = [2, 3, 5, 7, 11];
factors.into_iter().collect()
});
#[allow(dead_code)]
pub fn next_fast_len(target: usize, real: bool) -> usize {
if target <= 1 {
return 1;
}
let max_factor = if real { 5 } else { 11 };
let mut n = target;
loop {
let mut is_smooth = true;
let mut remaining = n;
while remaining > 1 {
let mut factor_found = false;
for &p in EFFICIENT_FACTORS.iter().filter(|&&p| p <= max_factor) {
if remaining.is_multiple_of(p) {
remaining /= p;
factor_found = true;
break;
}
}
if !factor_found {
is_smooth = false;
break;
}
}
if is_smooth {
return n;
}
n += 1;
}
}
#[allow(dead_code)]
pub fn prev_fast_len(target: usize, real: bool) -> usize {
if target <= 1 {
return 1;
}
let max_factor = if real { 5 } else { 11 };
let mut n = target;
while n > 1 {
let mut is_smooth = true;
let mut remaining = n;
while remaining > 1 {
let mut factor_found = false;
for &p in EFFICIENT_FACTORS.iter().filter(|&&p| p <= max_factor) {
if remaining.is_multiple_of(p) {
remaining /= p;
factor_found = true;
break;
}
}
if !factor_found {
is_smooth = false;
break;
}
}
if is_smooth {
return n;
}
n -= 1;
}
1
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::{Array1, Array2};
#[test]
fn test_fftfreq() {
let freq = fftfreq(8, 1.0).expect("Operation failed");
let expected = [0.0, 0.125, 0.25, 0.375, -0.5, -0.375, -0.25, -0.125];
assert_eq!(freq.len(), expected.len());
for (a, b) in freq.iter().zip(expected.iter()) {
assert_relative_eq!(a, b, epsilon = 1e-10);
}
let freq = fftfreq(7, 1.0).expect("Operation failed");
let expected = [
0.0,
0.14285714,
0.28571429,
-0.42857143,
-0.28571429,
-0.14285714,
0.0,
];
assert_eq!(freq.len(), expected.len());
for (a, b) in freq.iter().zip(expected.iter()) {
assert_relative_eq!(a, b, epsilon = 1e-8);
}
let freq = fftfreq(4, 0.1).expect("Operation failed");
let expected = [0.0, 2.5, -5.0, -2.5];
for (a, b) in freq.iter().zip(expected.iter()) {
assert_relative_eq!(a, b, epsilon = 1e-10);
}
}
#[test]
fn test_rfftfreq() {
let freq = rfftfreq(8, 1.0).expect("Operation failed");
let expected = [0.0, 0.125, 0.25, 0.375, 0.5];
assert_eq!(freq.len(), expected.len());
for (a, b) in freq.iter().zip(expected.iter()) {
assert_relative_eq!(a, b, epsilon = 1e-10);
}
let freq = rfftfreq(7, 1.0).expect("Operation failed");
let expected = [0.0, 0.14285714, 0.28571429, 0.42857143];
assert_eq!(freq.len(), 4);
for (a, b) in freq.iter().zip(expected.iter()) {
assert_relative_eq!(a, b, epsilon = 1e-8);
}
let freq = rfftfreq(4, 0.1).expect("Operation failed");
let expected = [0.0, 2.5, 5.0];
for (a, b) in freq.iter().zip(expected.iter()) {
assert_relative_eq!(a, b, epsilon = 1e-10);
}
}
#[test]
fn test_fftshift() {
let x = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
let shifted = fftshift(&x).expect("Operation failed");
let expected = Array1::from_vec(vec![2.0, 3.0, 0.0, 1.0]);
assert_eq!(shifted, expected);
let x = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0]);
let shifted = fftshift(&x).expect("Operation failed");
let expected = Array1::from_vec(vec![3.0, 4.0, 0.0, 1.0, 2.0]);
assert_eq!(shifted, expected);
let x = Array2::from_shape_vec((2, 2), vec![0.0, 1.0, 2.0, 3.0]).expect("Operation failed");
let shifted = fftshift(&x).expect("Operation failed");
let expected =
Array2::from_shape_vec((2, 2), vec![3.0, 2.0, 1.0, 0.0]).expect("Operation failed");
assert_eq!(shifted, expected);
}
#[test]
fn test_ifftshift() {
let x = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
let shifted = fftshift(&x).expect("Operation failed");
let unshifted = ifftshift(&shifted).expect("Operation failed");
assert_eq!(unshifted, x);
let x = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0]);
let shifted = fftshift(&x).expect("Operation failed");
let unshifted = ifftshift(&shifted).expect("Operation failed");
assert_eq!(unshifted, x);
let x = Array2::from_shape_vec((2, 3), vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
.expect("Operation failed");
let shifted = fftshift(&x).expect("Operation failed");
let unshifted = ifftshift(&shifted).expect("Operation failed");
assert_eq!(unshifted, x);
}
#[test]
fn test_freq_bins() {
let bins = freq_bins(8, 16000.0).expect("Operation failed");
let expected = [
0.0, 2000.0, 4000.0, 6000.0, -8000.0, -6000.0, -4000.0, -2000.0,
];
assert_eq!(bins.len(), expected.len());
for (a, b) in bins.iter().zip(expected.iter()) {
assert_relative_eq!(a, b, epsilon = 1e-10);
}
}
#[test]
fn test_next_fast_len() {
for target in [7, 13, 511, 512, 513, 1000, 1024] {
let result = next_fast_len(target, false);
assert!(
result >= target,
"Result should be >= target: {result} >= {target}"
);
assert!(
is_fast_length(result, false),
"Result {result} should be a product of efficient prime factors"
);
}
for target in [13, 512, 523, 1000] {
let result = next_fast_len(target, true);
assert!(
result >= target,
"Result should be >= target: {result} >= {target}"
);
assert!(
is_fast_length(result, true),
"Result {result} should be a product of efficient real prime factors"
);
}
}
#[test]
fn test_prev_fast_len() {
for target in [7, 13, 512, 513, 1000, 1024] {
let result = prev_fast_len(target, false);
assert!(
result <= target,
"Result should be <= target: {result} <= {target}"
);
assert!(
is_fast_length(result, false),
"Result {result} should be a product of efficient prime factors"
);
}
for target in [13, 512, 613, 1000] {
let result = prev_fast_len(target, true);
assert!(
result <= target,
"Result should be <= target: {result} <= {target}"
);
assert!(
is_fast_length(result, true),
"Result {result} should be a product of efficient real prime factors"
);
}
}
fn is_fast_length(n: usize, real: bool) -> bool {
if n <= 1 {
return true;
}
let max_factor = if real { 5 } else { 11 };
let mut remaining = n;
while remaining > 1 {
let mut factor_found = false;
for &p in EFFICIENT_FACTORS.iter().filter(|&&p| p <= max_factor) {
if remaining % p == 0 {
remaining /= p;
factor_found = true;
break;
}
}
if !factor_found {
return false;
}
}
true
}
}