use ferray_core::Array;
use ferray_core::dimension::Ix1;
use ferray_core::error::{FerrayError, FerrayResult};
pub fn fftfreq(n: usize, d: f64) -> FerrayResult<Array<f64, Ix1>> {
if n == 0 {
return Err(FerrayError::invalid_value("fftfreq: n must be > 0"));
}
if d == 0.0 {
return Err(FerrayError::invalid_value(
"fftfreq: sample spacing d must be nonzero",
));
}
let nf = n as f64;
let val = 1.0 / (nf * d);
let mut result = Vec::with_capacity(n);
let positive_end = n.div_ceil(2);
for i in 0..positive_end {
result.push(i as f64 * val);
}
let negative_start = if n % 2 == 0 {
-(n as isize / 2)
} else {
-((n as isize - 1) / 2)
};
for i in negative_start..0 {
result.push(i as f64 * val);
}
Array::from_vec(Ix1::new([n]), result)
}
pub fn rfftfreq(n: usize, d: f64) -> FerrayResult<Array<f64, Ix1>> {
if n == 0 {
return Err(FerrayError::invalid_value("rfftfreq: n must be > 0"));
}
if d == 0.0 {
return Err(FerrayError::invalid_value(
"rfftfreq: sample spacing d must be nonzero",
));
}
let nf = n as f64;
let val = 1.0 / (nf * d);
let out_len = n / 2 + 1;
let mut result = Vec::with_capacity(out_len);
for i in 0..out_len {
result.push(i as f64 * val);
}
Array::from_vec(Ix1::new([out_len]), result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn fftfreq_8() {
let freq = fftfreq(8, 1.0).unwrap();
let expected = [0.0, 0.125, 0.25, 0.375, -0.5, -0.375, -0.25, -0.125];
let data: Vec<f64> = freq.iter().copied().collect();
assert_eq!(data.len(), 8);
for (a, b) in data.iter().zip(expected.iter()) {
assert!(
(a - b).abs() < 1e-15,
"fftfreq mismatch: got {}, expected {}",
a,
b
);
}
}
#[test]
fn fftfreq_odd() {
let freq = fftfreq(5, 1.0).unwrap();
let expected = [0.0, 0.2, 0.4, -0.4, -0.2];
let data: Vec<f64> = freq.iter().copied().collect();
for (a, b) in data.iter().zip(expected.iter()) {
assert!((a - b).abs() < 1e-15);
}
}
#[test]
fn fftfreq_with_spacing() {
let freq = fftfreq(4, 0.5).unwrap();
let expected = [0.0, 0.5, -1.0, -0.5];
let data: Vec<f64> = freq.iter().copied().collect();
for (a, b) in data.iter().zip(expected.iter()) {
assert!((a - b).abs() < 1e-15);
}
}
#[test]
fn fftfreq_zero_n_errors() {
assert!(fftfreq(0, 1.0).is_err());
}
#[test]
fn fftfreq_zero_d_errors() {
assert!(fftfreq(8, 0.0).is_err());
}
#[test]
fn rfftfreq_8() {
let freq = rfftfreq(8, 1.0).unwrap();
let expected = [0.0, 0.125, 0.25, 0.375, 0.5];
let data: Vec<f64> = freq.iter().copied().collect();
assert_eq!(data.len(), 5);
for (a, b) in data.iter().zip(expected.iter()) {
assert!((a - b).abs() < 1e-15);
}
}
#[test]
fn rfftfreq_odd() {
let freq = rfftfreq(5, 1.0).unwrap();
let expected = [0.0, 0.2, 0.4];
let data: Vec<f64> = freq.iter().copied().collect();
assert_eq!(data.len(), 3);
for (a, b) in data.iter().zip(expected.iter()) {
assert!((a - b).abs() < 1e-15);
}
}
#[test]
fn rfftfreq_zero_n_errors() {
assert!(rfftfreq(0, 1.0).is_err());
}
#[test]
fn rfftfreq_n1() {
let freq = rfftfreq(1, 1.0).unwrap();
assert_eq!(freq.shape(), &[1]);
assert!((freq.iter().next().unwrap() - 0.0).abs() < 1e-15);
}
}