use ferray_core::Array;
use ferray_core::dimension::{Dimension, IxDyn};
use ferray_core::dtype::Element;
use ferray_core::error::{FerrayError, FerrayResult};
pub fn fftshift<T: Element, D: Dimension>(
a: &Array<T, D>,
axes: Option<&[usize]>,
) -> FerrayResult<Array<T, IxDyn>> {
let shape = a.shape();
let ndim = shape.len();
let axes = resolve_axes(ndim, axes)?;
let shifts: Vec<isize> = axes.iter().map(|&ax| (shape[ax] / 2) as isize).collect();
roll_along_axes(a, &axes, &shifts)
}
pub fn ifftshift<T: Element, D: Dimension>(
a: &Array<T, D>,
axes: Option<&[usize]>,
) -> FerrayResult<Array<T, IxDyn>> {
let shape = a.shape();
let ndim = shape.len();
let axes = resolve_axes(ndim, axes)?;
let shifts: Vec<isize> = axes.iter().map(|&ax| -((shape[ax] / 2) as isize)).collect();
roll_along_axes(a, &axes, &shifts)
}
fn resolve_axes(ndim: usize, axes: Option<&[usize]>) -> FerrayResult<Vec<usize>> {
match axes {
Some(ax) => {
for &a in ax {
if a >= ndim {
return Err(FerrayError::axis_out_of_bounds(a, ndim));
}
}
Ok(ax.to_vec())
}
None => Ok((0..ndim).collect()),
}
}
fn roll_along_axes<T: Element, D: Dimension>(
a: &Array<T, D>,
axes: &[usize],
shifts: &[isize],
) -> FerrayResult<Array<T, IxDyn>> {
let shape = a.shape();
let ndim = shape.len();
let total: usize = shape.iter().product();
if total == 0 {
let data: Vec<T> = Vec::new();
return Array::from_vec(IxDyn::new(shape), data);
}
let strides = compute_strides(shape);
let mut axis_shifts = vec![0isize; ndim];
for (&ax, &sh) in axes.iter().zip(shifts.iter()) {
let n = shape[ax] as isize;
if n > 0 {
axis_shifts[ax] = ((sh % n) + n) % n;
}
}
let input: Vec<T> = a.iter().cloned().collect();
let mut output = Vec::with_capacity(total);
for out_flat in 0..total {
let mut src_flat = 0usize;
let mut remaining = out_flat;
for d in 0..ndim {
let idx = remaining / strides[d];
remaining %= strides[d];
let n = shape[d] as isize;
let src_idx = ((idx as isize - axis_shifts[d]) % n + n) % n;
src_flat += src_idx as usize * strides[d];
}
output.push(input[src_flat].clone());
}
Array::from_vec(IxDyn::new(shape), output)
}
fn compute_strides(shape: &[usize]) -> Vec<usize> {
let ndim = shape.len();
let mut strides = vec![0usize; ndim];
if ndim == 0 {
return strides;
}
strides[ndim - 1] = 1;
for i in (0..ndim - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
strides
}
#[cfg(test)]
mod tests {
use super::*;
use ferray_core::dimension::Ix1;
use ferray_core::dimension::Ix2;
#[test]
fn fftshift_even() {
let a = Array::<f64, Ix1>::from_vec(
Ix1::new([8]),
vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
)
.unwrap();
let shifted = fftshift(&a, None).unwrap();
let data: Vec<f64> = shifted.iter().copied().collect();
assert_eq!(data, vec![4.0, 5.0, 6.0, 7.0, 0.0, 1.0, 2.0, 3.0]);
}
#[test]
fn fftshift_odd() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![0.0, 1.0, 2.0, 3.0, 4.0]).unwrap();
let shifted = fftshift(&a, None).unwrap();
let data: Vec<f64> = shifted.iter().copied().collect();
assert_eq!(data, vec![3.0, 4.0, 0.0, 1.0, 2.0]);
}
#[test]
fn ifftshift_even() {
let a = Array::<f64, Ix1>::from_vec(
Ix1::new([8]),
vec![4.0, 5.0, 6.0, 7.0, 0.0, 1.0, 2.0, 3.0],
)
.unwrap();
let unshifted = ifftshift(&a, None).unwrap();
let data: Vec<f64> = unshifted.iter().copied().collect();
assert_eq!(data, vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]);
}
#[test]
fn ifftshift_odd() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![3.0, 4.0, 0.0, 1.0, 2.0]).unwrap();
let unshifted = ifftshift(&a, None).unwrap();
let data: Vec<f64> = unshifted.iter().copied().collect();
assert_eq!(data, vec![0.0, 1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn fftshift_ifftshift_roundtrip_even() {
let a = Array::<f64, Ix1>::from_vec(
Ix1::new([8]),
vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
)
.unwrap();
let shifted = fftshift(&a, None).unwrap();
let recovered = ifftshift(&shifted, None).unwrap();
let data: Vec<f64> = recovered.iter().copied().collect();
assert_eq!(data, vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]);
}
#[test]
fn fftshift_ifftshift_roundtrip_odd() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![0.0, 1.0, 2.0, 3.0, 4.0]).unwrap();
let shifted = fftshift(&a, None).unwrap();
let recovered = ifftshift(&shifted, None).unwrap();
let data: Vec<f64> = recovered.iter().copied().collect();
assert_eq!(data, vec![0.0, 1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn fftshift_2d() {
let a = Array::<f64, Ix2>::from_vec(
Ix2::new([2, 4]),
vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
)
.unwrap();
let shifted = fftshift(&a, None).unwrap();
let data: Vec<f64> = shifted.iter().copied().collect();
assert_eq!(data, vec![6.0, 7.0, 4.0, 5.0, 2.0, 3.0, 0.0, 1.0]);
}
#[test]
fn fftshift_specific_axis() {
let a = Array::<f64, Ix2>::from_vec(
Ix2::new([2, 4]),
vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
)
.unwrap();
let shifted = fftshift(&a, Some(&[1])).unwrap();
let data: Vec<f64> = shifted.iter().copied().collect();
assert_eq!(data, vec![2.0, 3.0, 0.0, 1.0, 6.0, 7.0, 4.0, 5.0]);
}
#[test]
fn fftshift_axis_out_of_bounds() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![0.0; 4]).unwrap();
assert!(fftshift(&a, Some(&[1])).is_err());
}
}