use ferray_core::Array;
use ferray_core::dimension::{Dimension, IxDyn};
use ferray_core::dtype::Element;
use ferray_core::error::FerrayResult;
use crate::axes::{compute_strides, resolve_axes};
pub fn fftshift<T: Element, D: Dimension>(
a: &Array<T, D>,
axes: Option<&[isize]>,
) -> FerrayResult<Array<T, IxDyn>> {
let shape = a.shape();
let ndim = shape.len();
let axes = resolve_axes(ndim, axes)?;
let shifts: Vec<isize> = axes.iter().map(|&ax| (shape[ax] / 2) as isize).collect();
roll_along_axes(a, &axes, &shifts)
}
pub fn ifftshift<T: Element, D: Dimension>(
a: &Array<T, D>,
axes: Option<&[isize]>,
) -> FerrayResult<Array<T, IxDyn>> {
let shape = a.shape();
let ndim = shape.len();
let axes = resolve_axes(ndim, axes)?;
let shifts: Vec<isize> = axes.iter().map(|&ax| -((shape[ax] / 2) as isize)).collect();
roll_along_axes(a, &axes, &shifts)
}
fn roll_along_axes<T: Element, D: Dimension>(
a: &Array<T, D>,
axes: &[usize],
shifts: &[isize],
) -> FerrayResult<Array<T, IxDyn>> {
let shape = a.shape();
let ndim = shape.len();
let total: usize = shape.iter().product();
if total == 0 {
let data: Vec<T> = Vec::new();
return Array::from_vec(IxDyn::new(shape), data);
}
let strides: Vec<usize> = compute_strides(shape).iter().map(|&s| s as usize).collect();
let mut axis_shifts = vec![0isize; ndim];
for (&ax, &sh) in axes.iter().zip(shifts.iter()) {
let n = shape[ax] as isize;
if n > 0 {
axis_shifts[ax] = ((sh % n) + n) % n;
}
}
enum InputRef<'a, T> {
Borrowed(&'a [T]),
Owned(Vec<T>),
}
impl<T> std::ops::Deref for InputRef<'_, T> {
type Target = [T];
fn deref(&self) -> &[T] {
match self {
InputRef::Borrowed(s) => s,
InputRef::Owned(v) => v,
}
}
}
let input: InputRef<'_, T> = match a.as_slice() {
Some(s) => InputRef::Borrowed(s),
None => InputRef::Owned(a.iter().cloned().collect()),
};
let mut output = Vec::with_capacity(total);
for out_flat in 0..total {
let mut src_flat = 0usize;
let mut remaining = out_flat;
for d in 0..ndim {
let idx = remaining / strides[d];
remaining %= strides[d];
let n = shape[d] as isize;
let src_idx = ((idx as isize - axis_shifts[d]) % n + n) % n;
src_flat += src_idx as usize * strides[d];
}
output.push(input[src_flat].clone());
}
Array::from_vec(IxDyn::new(shape), output)
}
#[cfg(test)]
mod tests {
use super::*;
use ferray_core::dimension::Ix1;
use ferray_core::dimension::Ix2;
#[test]
fn fftshift_even() {
let a = Array::<f64, Ix1>::from_vec(
Ix1::new([8]),
vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
)
.unwrap();
let shifted = fftshift(&a, None).unwrap();
let data: Vec<f64> = shifted.iter().copied().collect();
assert_eq!(data, vec![4.0, 5.0, 6.0, 7.0, 0.0, 1.0, 2.0, 3.0]);
}
#[test]
fn fftshift_odd() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![0.0, 1.0, 2.0, 3.0, 4.0]).unwrap();
let shifted = fftshift(&a, None).unwrap();
let data: Vec<f64> = shifted.iter().copied().collect();
assert_eq!(data, vec![3.0, 4.0, 0.0, 1.0, 2.0]);
}
#[test]
fn ifftshift_even() {
let a = Array::<f64, Ix1>::from_vec(
Ix1::new([8]),
vec![4.0, 5.0, 6.0, 7.0, 0.0, 1.0, 2.0, 3.0],
)
.unwrap();
let unshifted = ifftshift(&a, None).unwrap();
let data: Vec<f64> = unshifted.iter().copied().collect();
assert_eq!(data, vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]);
}
#[test]
fn ifftshift_odd() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![3.0, 4.0, 0.0, 1.0, 2.0]).unwrap();
let unshifted = ifftshift(&a, None).unwrap();
let data: Vec<f64> = unshifted.iter().copied().collect();
assert_eq!(data, vec![0.0, 1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn fftshift_ifftshift_roundtrip_even() {
let a = Array::<f64, Ix1>::from_vec(
Ix1::new([8]),
vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
)
.unwrap();
let shifted = fftshift(&a, None).unwrap();
let recovered = ifftshift(&shifted, None).unwrap();
let data: Vec<f64> = recovered.iter().copied().collect();
assert_eq!(data, vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]);
}
#[test]
fn fftshift_ifftshift_roundtrip_odd() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![0.0, 1.0, 2.0, 3.0, 4.0]).unwrap();
let shifted = fftshift(&a, None).unwrap();
let recovered = ifftshift(&shifted, None).unwrap();
let data: Vec<f64> = recovered.iter().copied().collect();
assert_eq!(data, vec![0.0, 1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn fftshift_2d() {
let a = Array::<f64, Ix2>::from_vec(
Ix2::new([2, 4]),
vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
)
.unwrap();
let shifted = fftshift(&a, None).unwrap();
let data: Vec<f64> = shifted.iter().copied().collect();
assert_eq!(data, vec![6.0, 7.0, 4.0, 5.0, 2.0, 3.0, 0.0, 1.0]);
}
#[test]
fn fftshift_specific_axis() {
let a = Array::<f64, Ix2>::from_vec(
Ix2::new([2, 4]),
vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
)
.unwrap();
let shifted = fftshift(&a, Some(&[1])).unwrap();
let data: Vec<f64> = shifted.iter().copied().collect();
assert_eq!(data, vec![2.0, 3.0, 0.0, 1.0, 6.0, 7.0, 4.0, 5.0]);
}
#[test]
fn fftshift_axis_out_of_bounds() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![0.0; 4]).unwrap();
assert!(fftshift(&a, Some(&[1])).is_err());
}
}