use ferray_core::Array as FerrayArray;
use ferray_core::IxDyn as FerrayIxDyn;
pub use ferray_fft::FftNorm;
use rustfft::num_complex::Complex;
use crate::dtype::{DType, Element, Float};
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::numeric_cast::cast;
use crate::storage::TensorStorage;
use crate::tensor::Tensor;
pub fn fft_norm_from_str(norm: Option<&str>, op: &'static str) -> FerrotorchResult<FftNorm> {
match norm {
None | Some("backward") => Ok(FftNorm::Backward),
Some("forward") => Ok(FftNorm::Forward),
Some("ortho") => Ok(FftNorm::Ortho),
Some(other) => Err(FerrotorchError::InvalidArgument {
message: format!("{op}: invalid normalization mode: \"{other}\""),
}),
}
}
#[inline]
fn is_f32<T: Float>() -> bool {
std::mem::size_of::<T>() == 4
}
#[inline]
fn is_f64<T: Float>() -> bool {
std::mem::size_of::<T>() == 8
}
#[inline]
fn reject_half_cpu_fft<T: Float>(op: &'static str) -> FerrotorchResult<()> {
match <T as Element>::dtype() {
DType::F16 | DType::BF16 => Err(FerrotorchError::InvalidArgument {
message: format!(
"{op}: Unsupported dtype {:?} — torch.fft.* does not support \
half/bfloat16 on CPU (half is CUDA-only as a native complex-half \
transform; see SpectralOps.cpp:88-90)",
<T as Element>::dtype(),
),
}),
_ => Ok(()),
}
}
pub fn fft<T: Float>(input: &Tensor<T>, n: Option<usize>) -> FerrotorchResult<Tensor<T>> {
fft_norm(input, n, None, FftNorm::Backward)
}
pub fn fft_norm<T: Float>(
input: &Tensor<T>,
n: Option<usize>,
dim: Option<isize>,
norm: FftNorm,
) -> FerrotorchResult<Tensor<T>> {
reject_half_cpu_fft::<T>("fft")?;
let shape = input.shape();
if shape.is_empty() || *shape.last().unwrap() != 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"fft: input must have trailing dimension 2 (complex), got shape {shape:?}"
),
});
}
let ndim = shape.len();
if ndim < 2 {
return Err(FerrotorchError::InvalidArgument {
message: "fft: input must have at least 2 dimensions ([..., n, 2])".into(),
});
}
if input.is_cuda()
&& (is_f32::<T>() || is_f64::<T>())
&& norm == FftNorm::Backward
&& is_last_signal_axis(dim, ndim - 1)
{
let input_n = shape[ndim - 2];
let fft_n = n.unwrap_or(input_n);
if fft_n == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "fft: n must be > 0".into(),
});
}
let batch_shape = &shape[..ndim - 2];
let batch_size: usize = batch_shape.iter().product::<usize>().max(1);
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let buf = input.gpu_handle()?;
let (transformed_handle, owned);
let buf_for_fft: &crate::gpu_dispatch::GpuBufferHandle = if fft_n == input_n {
buf
} else if is_f32::<T>() {
owned = backend.pad_truncate_complex_f32(buf, batch_size, input_n, fft_n)?;
transformed_handle = &owned;
transformed_handle
} else {
owned = backend.pad_truncate_complex_f64(buf, batch_size, input_n, fft_n)?;
transformed_handle = &owned;
transformed_handle
};
let h = if is_f32::<T>() {
backend.fft_c2c_f32(buf_for_fft, batch_size, fft_n, false)?
} else {
backend.fft_c2c_f64(buf_for_fft, batch_size, fft_n, false)?
};
let mut out_shape = batch_shape.to_vec();
out_shape.push(fft_n);
out_shape.push(2);
return Tensor::from_storage(TensorStorage::gpu(h), out_shape, false);
}
let arr = tensor_to_complex_array(input, "fft")?;
let result =
ferray_fft::fft(&arr, n, dim, norm).map_err(|e| FerrotorchError::InvalidArgument {
message: format!("fft: {e}"),
})?;
complex_array_to_tensor(&result)
}
#[inline]
fn is_last_signal_axis(dim: Option<isize>, signal_ndim: usize) -> bool {
match dim {
None => true,
Some(d) => {
let resolved = if d < 0 { signal_ndim as isize + d } else { d };
resolved == signal_ndim as isize - 1
}
}
}
pub fn ifft<T: Float>(input: &Tensor<T>, n: Option<usize>) -> FerrotorchResult<Tensor<T>> {
ifft_norm(input, n, None, FftNorm::Backward)
}
pub fn ifft_norm<T: Float>(
input: &Tensor<T>,
n: Option<usize>,
dim: Option<isize>,
norm: FftNorm,
) -> FerrotorchResult<Tensor<T>> {
reject_half_cpu_fft::<T>("ifft")?;
let shape = input.shape();
if shape.is_empty() || *shape.last().unwrap() != 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"ifft: input must have trailing dimension 2 (complex), got shape {shape:?}"
),
});
}
let ndim = shape.len();
if ndim < 2 {
return Err(FerrotorchError::InvalidArgument {
message: "ifft: input must have at least 2 dimensions ([..., n, 2])".into(),
});
}
if input.is_cuda()
&& (is_f32::<T>() || is_f64::<T>())
&& norm == FftNorm::Backward
&& is_last_signal_axis(dim, ndim - 1)
{
let input_n = shape[ndim - 2];
let fft_n = n.unwrap_or(input_n);
if fft_n == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "ifft: n must be > 0".into(),
});
}
let batch_shape = &shape[..ndim - 2];
let batch_size: usize = batch_shape.iter().product::<usize>().max(1);
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let buf = input.gpu_handle()?;
let (transformed_handle, owned);
let buf_for_fft: &crate::gpu_dispatch::GpuBufferHandle = if fft_n == input_n {
buf
} else if is_f32::<T>() {
owned = backend.pad_truncate_complex_f32(buf, batch_size, input_n, fft_n)?;
transformed_handle = &owned;
transformed_handle
} else {
owned = backend.pad_truncate_complex_f64(buf, batch_size, input_n, fft_n)?;
transformed_handle = &owned;
transformed_handle
};
let h = if is_f32::<T>() {
backend.fft_c2c_f32(buf_for_fft, batch_size, fft_n, true)?
} else {
backend.fft_c2c_f64(buf_for_fft, batch_size, fft_n, true)?
};
let mut out_shape = batch_shape.to_vec();
out_shape.push(fft_n);
out_shape.push(2);
return Tensor::from_storage(TensorStorage::gpu(h), out_shape, false);
}
let arr = tensor_to_complex_array(input, "ifft")?;
let result =
ferray_fft::ifft(&arr, n, dim, norm).map_err(|e| FerrotorchError::InvalidArgument {
message: format!("ifft: {e}"),
})?;
complex_array_to_tensor(&result)
}
pub fn rfft<T: Float>(input: &Tensor<T>, n: Option<usize>) -> FerrotorchResult<Tensor<T>> {
rfft_norm(input, n, None, FftNorm::Backward)
}
pub fn rfft_norm<T: Float>(
input: &Tensor<T>,
n: Option<usize>,
dim: Option<isize>,
norm: FftNorm,
) -> FerrotorchResult<Tensor<T>> {
reject_half_cpu_fft::<T>("rfft")?;
let shape = input.shape();
if shape.is_empty() {
return Err(FerrotorchError::InvalidArgument {
message: "rfft: input must have at least 1 dimension".into(),
});
}
let ndim = shape.len();
let input_n = shape[ndim - 1];
if input.is_cuda()
&& norm == FftNorm::Backward
&& is_last_signal_axis(dim, ndim)
&& n.unwrap_or(input_n) == input_n
{
let fft_n = input_n;
let batch_shape = &shape[..ndim - 1];
let batch_size: usize = batch_shape.iter().product::<usize>().max(1);
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let buf = input.gpu_handle()?;
let h = if is_f32::<T>() {
backend.rfft_r2c_f32(buf, batch_size, fft_n)?
} else if is_f64::<T>() {
backend.rfft_r2c_f64(buf, batch_size, fft_n)?
} else {
return Err(FerrotorchError::InvalidArgument {
message: "rfft requires f32 or f64".into(),
});
};
let half_n = fft_n / 2 + 1;
let mut out_shape = batch_shape.to_vec();
out_shape.push(half_n);
out_shape.push(2);
return Tensor::from_storage(TensorStorage::gpu(h), out_shape, false);
}
let arr = tensor_to_real_array(input, "rfft")?;
let result =
ferray_fft::rfft(&arr, n, dim, norm).map_err(|e| FerrotorchError::InvalidArgument {
message: format!("rfft: {e}"),
})?;
complex_array_to_tensor(&result)
}
pub fn irfft<T: Float>(input: &Tensor<T>, n: Option<usize>) -> FerrotorchResult<Tensor<T>> {
irfft_norm(input, n, None, FftNorm::Backward)
}
pub fn irfft_norm<T: Float>(
input: &Tensor<T>,
n: Option<usize>,
dim: Option<isize>,
norm: FftNorm,
) -> FerrotorchResult<Tensor<T>> {
reject_half_cpu_fft::<T>("irfft")?;
let shape = input.shape();
if shape.is_empty() || *shape.last().unwrap() != 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"irfft: input must have trailing dimension 2 (complex), got shape {shape:?}"
),
});
}
let ndim = shape.len();
if ndim < 2 {
return Err(FerrotorchError::InvalidArgument {
message: "irfft: input must have at least 2 dimensions ([..., n/2+1, 2])".into(),
});
}
let half_n = shape[ndim - 2];
let output_n = n.unwrap_or(2 * (half_n - 1));
if input.is_cuda()
&& norm == FftNorm::Backward
&& is_last_signal_axis(dim, ndim - 1)
&& half_n == output_n / 2 + 1
{
if output_n == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "irfft: output length must be > 0".into(),
});
}
let batch_shape = &shape[..ndim - 2];
let batch_size: usize = batch_shape.iter().product::<usize>().max(1);
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let buf = input.gpu_handle()?;
let h = if is_f32::<T>() {
backend.irfft_c2r_f32(buf, batch_size, output_n)?
} else if is_f64::<T>() {
backend.irfft_c2r_f64(buf, batch_size, output_n)?
} else {
return Err(FerrotorchError::InvalidArgument {
message: "irfft requires f32 or f64".into(),
});
};
let mut out_shape = batch_shape.to_vec();
out_shape.push(output_n);
return Tensor::from_storage(TensorStorage::gpu(h), out_shape, false);
}
let arr = tensor_to_complex_array(input, "irfft")?;
let result =
ferray_fft::irfft(&arr, n, dim, norm).map_err(|e| FerrotorchError::InvalidArgument {
message: format!("irfft: {e}"),
})?;
real_array_to_tensor(&result)
}
pub fn fft2<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
fft2_norm(input, None, None, FftNorm::Backward)
}
pub fn fft2_norm<T: Float>(
input: &Tensor<T>,
s: Option<&[usize]>,
dim: Option<&[isize]>,
norm: FftNorm,
) -> FerrotorchResult<Tensor<T>> {
reject_half_cpu_fft::<T>("fft2")?;
let shape = input.shape();
if shape.is_empty() || *shape.last().unwrap() != 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"fft2: input must have trailing dimension 2 (complex), got shape {shape:?}"
),
});
}
let ndim = shape.len();
if ndim < 3 {
return Err(FerrotorchError::InvalidArgument {
message: "fft2: input must have at least 3 dimensions ([..., rows, cols, 2])".into(),
});
}
let rows = shape[ndim - 3];
let cols = shape[ndim - 2];
let batch_dims: usize = shape[..ndim - 3].iter().product::<usize>().max(1);
if input.is_cuda()
&& batch_dims == 1
&& (is_f32::<T>() || is_f64::<T>())
&& norm == FftNorm::Backward
&& dim.is_none()
&& s.is_none()
{
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let h = if is_f32::<T>() {
backend.fft2_c2c_f32(input.gpu_handle()?, rows, cols, false)?
} else {
backend.fft2_c2c_f64(input.gpu_handle()?, rows, cols, false)?
};
return Tensor::from_storage(TensorStorage::gpu(h), shape.to_vec(), false);
}
let arr = tensor_to_complex_array(input, "fft2")?;
let result =
ferray_fft::fft2(&arr, s, dim, norm).map_err(|e| FerrotorchError::InvalidArgument {
message: format!("fft2: {e}"),
})?;
complex_array_to_tensor(&result)
}
pub fn ifft2<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
ifft2_norm(input, None, None, FftNorm::Backward)
}
pub fn ifft2_norm<T: Float>(
input: &Tensor<T>,
s: Option<&[usize]>,
dim: Option<&[isize]>,
norm: FftNorm,
) -> FerrotorchResult<Tensor<T>> {
reject_half_cpu_fft::<T>("ifft2")?;
let shape = input.shape();
if shape.is_empty() || *shape.last().unwrap() != 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"ifft2: input must have trailing dimension 2 (complex), got shape {shape:?}"
),
});
}
let ndim = shape.len();
if ndim < 3 {
return Err(FerrotorchError::InvalidArgument {
message: "ifft2: input must have at least 3 dimensions ([..., rows, cols, 2])".into(),
});
}
let rows = shape[ndim - 3];
let cols = shape[ndim - 2];
let batch_dims: usize = shape[..ndim - 3].iter().product::<usize>().max(1);
if input.is_cuda()
&& batch_dims == 1
&& (is_f32::<T>() || is_f64::<T>())
&& norm == FftNorm::Backward
&& dim.is_none()
&& s.is_none()
{
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let h = if is_f32::<T>() {
backend.fft2_c2c_f32(input.gpu_handle()?, rows, cols, true)?
} else {
backend.fft2_c2c_f64(input.gpu_handle()?, rows, cols, true)?
};
return Tensor::from_storage(TensorStorage::gpu(h), shape.to_vec(), false);
}
let arr = tensor_to_complex_array(input, "ifft2")?;
let result =
ferray_fft::ifft2(&arr, s, dim, norm).map_err(|e| FerrotorchError::InvalidArgument {
message: format!("ifft2: {e}"),
})?;
complex_array_to_tensor(&result)
}
fn tensor_to_complex_array<T: Float>(
input: &Tensor<T>,
op: &'static str,
) -> FerrotorchResult<FerrayArray<Complex<f64>, FerrayIxDyn>> {
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op });
}
let shape = input.shape();
if shape.is_empty() || *shape.last().unwrap() != 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"{op}: input must have trailing dimension 2 (complex), got shape {shape:?}"
),
});
}
let data = input.data_vec()?;
let total_complex = data.len() / 2;
let mut complex_data = Vec::with_capacity(total_complex);
for i in 0..total_complex {
let re = data[i * 2].to_f64().unwrap();
let im = data[i * 2 + 1].to_f64().unwrap();
complex_data.push(Complex::new(re, im));
}
let inner_shape: Vec<usize> = shape[..shape.len() - 1].to_vec();
FerrayArray::from_vec(FerrayIxDyn::new(&inner_shape), complex_data).map_err(|e| {
FerrotorchError::InvalidArgument {
message: format!("{op}: failed to build ferray array: {e}"),
}
})
}
fn tensor_to_real_array<T: Float>(
input: &Tensor<T>,
op: &'static str,
) -> FerrotorchResult<FerrayArray<f64, FerrayIxDyn>> {
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op });
}
let data = input.data_vec()?;
let real_data: Vec<f64> = data.iter().map(|v| v.to_f64().unwrap()).collect();
FerrayArray::from_vec(FerrayIxDyn::new(input.shape()), real_data).map_err(|e| {
FerrotorchError::InvalidArgument {
message: format!("{op}: failed to build ferray array: {e}"),
}
})
}
fn complex_array_to_tensor<T: Float>(
arr: &FerrayArray<Complex<f64>, FerrayIxDyn>,
) -> FerrotorchResult<Tensor<T>> {
let shape = arr.shape().to_vec();
let total: usize = shape.iter().product();
let mut out_data: Vec<T> = Vec::with_capacity(total * 2);
for c in arr.iter() {
out_data.push(cast(c.re)?);
out_data.push(cast(c.im)?);
}
let mut out_shape = shape;
out_shape.push(2);
Tensor::from_storage(TensorStorage::cpu(out_data), out_shape, false)
}
fn real_array_to_tensor<T: Float>(
arr: &FerrayArray<f64, FerrayIxDyn>,
) -> FerrotorchResult<Tensor<T>> {
let shape = arr.shape().to_vec();
let out_data: Vec<T> = arr
.iter()
.map(|&v| cast(v))
.collect::<FerrotorchResult<_>>()?;
Tensor::from_storage(TensorStorage::cpu(out_data), shape, false)
}
pub fn fftn<T: Float>(
input: &Tensor<T>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
) -> FerrotorchResult<Tensor<T>> {
fftn_norm(input, s, axes, FftNorm::Backward)
}
pub fn fftn_norm<T: Float>(
input: &Tensor<T>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
norm: FftNorm,
) -> FerrotorchResult<Tensor<T>> {
reject_half_cpu_fft::<T>("fftn")?;
if input.is_cuda()
&& (is_f32::<T>() || is_f64::<T>())
&& s.is_none()
&& norm == FftNorm::Backward
{
let shape = input.shape();
let ndim = shape.len();
if ndim >= 2 && shape[ndim - 1] == 2 {
let spatial_ndim = ndim - 1; let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
if axes.is_none() {
if spatial_ndim == 2 {
let h = shape[0];
let w = shape[1];
let h_out = if is_f32::<T>() {
backend.fftn2d_c2c_f32(input.gpu_handle()?, h, w, false)?
} else {
backend.fftn2d_c2c_f64(input.gpu_handle()?, h, w, false)?
};
return Tensor::from_storage(TensorStorage::gpu(h_out), shape.to_vec(), false);
}
if spatial_ndim == 3 {
let d = shape[0];
let h = shape[1];
let w = shape[2];
let h_out = if is_f32::<T>() {
backend.fftn3d_c2c_f32(input.gpu_handle()?, d, h, w, false)?
} else {
backend.fftn3d_c2c_f64(input.gpu_handle()?, d, h, w, false)?
};
return Tensor::from_storage(TensorStorage::gpu(h_out), shape.to_vec(), false);
}
} else if let Some(ax) = axes {
let norm_axes: Vec<usize> = ax
.iter()
.map(|&a| {
if a < 0 {
(spatial_ndim as isize + a) as usize
} else {
a as usize
}
})
.collect();
let r = norm_axes.len();
let innermost_set: std::collections::HashSet<usize> =
(spatial_ndim - r..spatial_ndim).collect();
let axes_set: std::collections::HashSet<usize> =
norm_axes.iter().copied().collect();
if norm_axes.iter().all(|&a| a < spatial_ndim) && axes_set == innermost_set {
let mut sorted_axes = norm_axes.clone();
sorted_axes.sort_unstable();
let spatial_shape = &shape[..spatial_ndim];
let h_out = if is_f32::<T>() {
backend.fftn_axes_c2c_f32(
input.gpu_handle()?,
spatial_shape,
&sorted_axes,
false,
)?
} else {
backend.fftn_axes_c2c_f64(
input.gpu_handle()?,
spatial_shape,
&sorted_axes,
false,
)?
};
return Tensor::from_storage(TensorStorage::gpu(h_out), shape.to_vec(), false);
}
}
}
}
let arr = tensor_to_complex_array(input, "fftn")?;
let result =
ferray_fft::fftn(&arr, s, axes, norm).map_err(|e| FerrotorchError::InvalidArgument {
message: format!("fftn: {e}"),
})?;
complex_array_to_tensor(&result)
}
pub fn ifftn<T: Float>(
input: &Tensor<T>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
) -> FerrotorchResult<Tensor<T>> {
ifftn_norm(input, s, axes, FftNorm::Backward)
}
pub fn ifftn_norm<T: Float>(
input: &Tensor<T>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
norm: FftNorm,
) -> FerrotorchResult<Tensor<T>> {
reject_half_cpu_fft::<T>("ifftn")?;
if input.is_cuda()
&& (is_f32::<T>() || is_f64::<T>())
&& s.is_none()
&& norm == FftNorm::Backward
{
let shape = input.shape();
let ndim = shape.len();
if ndim >= 2 && shape[ndim - 1] == 2 {
let spatial_ndim = ndim - 1;
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
if axes.is_none() {
if spatial_ndim == 2 {
let h = shape[0];
let w = shape[1];
let h_out = if is_f32::<T>() {
backend.fftn2d_c2c_f32(input.gpu_handle()?, h, w, true)?
} else {
backend.fftn2d_c2c_f64(input.gpu_handle()?, h, w, true)?
};
return Tensor::from_storage(TensorStorage::gpu(h_out), shape.to_vec(), false);
}
if spatial_ndim == 3 {
let d = shape[0];
let h = shape[1];
let w = shape[2];
let h_out = if is_f32::<T>() {
backend.fftn3d_c2c_f32(input.gpu_handle()?, d, h, w, true)?
} else {
backend.fftn3d_c2c_f64(input.gpu_handle()?, d, h, w, true)?
};
return Tensor::from_storage(TensorStorage::gpu(h_out), shape.to_vec(), false);
}
} else if let Some(ax) = axes {
let norm_axes: Vec<usize> = ax
.iter()
.map(|&a| {
if a < 0 {
(spatial_ndim as isize + a) as usize
} else {
a as usize
}
})
.collect();
let r = norm_axes.len();
let innermost_set: std::collections::HashSet<usize> =
(spatial_ndim - r..spatial_ndim).collect();
let axes_set: std::collections::HashSet<usize> =
norm_axes.iter().copied().collect();
if norm_axes.iter().all(|&a| a < spatial_ndim) && axes_set == innermost_set {
let mut sorted_axes = norm_axes.clone();
sorted_axes.sort_unstable();
let spatial_shape = &shape[..spatial_ndim];
let h_out = if is_f32::<T>() {
backend.fftn_axes_c2c_f32(
input.gpu_handle()?,
spatial_shape,
&sorted_axes,
true,
)?
} else {
backend.fftn_axes_c2c_f64(
input.gpu_handle()?,
spatial_shape,
&sorted_axes,
true,
)?
};
return Tensor::from_storage(TensorStorage::gpu(h_out), shape.to_vec(), false);
}
}
}
}
let arr = tensor_to_complex_array(input, "ifftn")?;
let result =
ferray_fft::ifftn(&arr, s, axes, norm).map_err(|e| FerrotorchError::InvalidArgument {
message: format!("ifftn: {e}"),
})?;
complex_array_to_tensor(&result)
}
pub fn rfftn<T: Float>(
input: &Tensor<T>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
) -> FerrotorchResult<Tensor<T>> {
rfftn_norm(input, s, axes, FftNorm::Backward)
}
pub fn rfftn_norm<T: Float>(
input: &Tensor<T>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
norm: FftNorm,
) -> FerrotorchResult<Tensor<T>> {
reject_half_cpu_fft::<T>("rfftn")?;
let arr = tensor_to_real_array(input, "rfftn")?;
let result =
ferray_fft::rfftn(&arr, s, axes, norm).map_err(|e| FerrotorchError::InvalidArgument {
message: format!("rfftn: {e}"),
})?;
complex_array_to_tensor(&result)
}
pub fn irfftn<T: Float>(
input: &Tensor<T>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
) -> FerrotorchResult<Tensor<T>> {
irfftn_norm(input, s, axes, FftNorm::Backward)
}
pub fn irfftn_norm<T: Float>(
input: &Tensor<T>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
norm: FftNorm,
) -> FerrotorchResult<Tensor<T>> {
reject_half_cpu_fft::<T>("irfftn")?;
let arr = tensor_to_complex_array(input, "irfftn")?;
let result =
ferray_fft::irfftn(&arr, s, axes, norm).map_err(|e| FerrotorchError::InvalidArgument {
message: format!("irfftn: {e}"),
})?;
real_array_to_tensor(&result)
}
pub fn rfft2<T: Float>(
input: &Tensor<T>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
) -> FerrotorchResult<Tensor<T>> {
rfft2_norm(input, s, axes, FftNorm::Backward)
}
pub fn rfft2_norm<T: Float>(
input: &Tensor<T>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
norm: FftNorm,
) -> FerrotorchResult<Tensor<T>> {
reject_half_cpu_fft::<T>("rfft2")?;
let arr = tensor_to_real_array(input, "rfft2")?;
let result =
ferray_fft::rfft2(&arr, s, axes, norm).map_err(|e| FerrotorchError::InvalidArgument {
message: format!("rfft2: {e}"),
})?;
complex_array_to_tensor(&result)
}
pub fn irfft2<T: Float>(
input: &Tensor<T>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
) -> FerrotorchResult<Tensor<T>> {
irfft2_norm(input, s, axes, FftNorm::Backward)
}
pub fn irfft2_norm<T: Float>(
input: &Tensor<T>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
norm: FftNorm,
) -> FerrotorchResult<Tensor<T>> {
reject_half_cpu_fft::<T>("irfft2")?;
let arr = tensor_to_complex_array(input, "irfft2")?;
let result =
ferray_fft::irfft2(&arr, s, axes, norm).map_err(|e| FerrotorchError::InvalidArgument {
message: format!("irfft2: {e}"),
})?;
real_array_to_tensor(&result)
}
pub fn hfft<T: Float>(input: &Tensor<T>, n: Option<usize>) -> FerrotorchResult<Tensor<T>> {
hfft_norm(input, n, None, FftNorm::Backward)
}
pub fn hfft_norm<T: Float>(
input: &Tensor<T>,
n: Option<usize>,
dim: Option<isize>,
norm: FftNorm,
) -> FerrotorchResult<Tensor<T>> {
reject_half_cpu_fft::<T>("hfft")?;
if input.is_cuda()
&& (is_f32::<T>() || is_f64::<T>())
&& norm == FftNorm::Backward
&& is_last_signal_axis(dim, input.ndim().saturating_sub(1))
{
let shape = input.shape();
if shape.len() >= 2 && *shape.last().unwrap() == 2 {
let ndim = shape.len();
let half_in = shape[ndim - 2];
let n_out = n.unwrap_or(2 * (half_in - 1));
if half_in == n_out / 2 + 1 {
let batch_shape = &shape[..ndim - 2];
let batch_size: usize = batch_shape.iter().product::<usize>().max(1);
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let h_out = if is_f32::<T>() {
backend.hfft_f32(input.gpu_handle()?, batch_size, half_in, n_out)?
} else {
backend.hfft_f64(input.gpu_handle()?, batch_size, half_in, n_out)?
};
let mut out_shape = batch_shape.to_vec();
out_shape.push(n_out);
return Tensor::from_storage(TensorStorage::gpu(h_out), out_shape, false);
}
}
}
let arr = tensor_to_complex_array(input, "hfft")?;
let result =
ferray_fft::hfft(&arr, n, dim, norm).map_err(|e| FerrotorchError::InvalidArgument {
message: format!("hfft: {e}"),
})?;
real_array_to_tensor(&result)
}
pub fn ihfft<T: Float>(input: &Tensor<T>, n: Option<usize>) -> FerrotorchResult<Tensor<T>> {
ihfft_norm(input, n, None, FftNorm::Backward)
}
pub fn ihfft_norm<T: Float>(
input: &Tensor<T>,
n: Option<usize>,
dim: Option<isize>,
norm: FftNorm,
) -> FerrotorchResult<Tensor<T>> {
reject_half_cpu_fft::<T>("ihfft")?;
if input.is_cuda()
&& (is_f32::<T>() || is_f64::<T>())
&& norm == FftNorm::Backward
&& is_last_signal_axis(dim, input.ndim())
{
let shape = input.shape();
if !shape.is_empty() {
let ndim = shape.len();
let input_n = shape[ndim - 1];
let fft_n = n.unwrap_or(input_n);
if fft_n == input_n {
let batch_shape = &shape[..ndim - 1];
let batch_size: usize = batch_shape.iter().product::<usize>().max(1);
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let h_out = if is_f32::<T>() {
backend.ihfft_f32(input.gpu_handle()?, batch_size, fft_n)?
} else {
backend.ihfft_f64(input.gpu_handle()?, batch_size, fft_n)?
};
let half_n = fft_n / 2 + 1;
let mut out_shape = batch_shape.to_vec();
out_shape.push(half_n);
out_shape.push(2);
return Tensor::from_storage(TensorStorage::gpu(h_out), out_shape, false);
}
}
}
let arr = tensor_to_real_array(input, "ihfft")?;
let result =
ferray_fft::ihfft(&arr, n, dim, norm).map_err(|e| FerrotorchError::InvalidArgument {
message: format!("ihfft: {e}"),
})?;
complex_array_to_tensor(&result)
}
pub fn hfft2<T: Float>(
input: &Tensor<T>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
) -> FerrotorchResult<Tensor<T>> {
hfft2_norm(input, s, axes, FftNorm::Backward)
}
pub fn hfft2_norm<T: Float>(
input: &Tensor<T>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
norm: FftNorm,
) -> FerrotorchResult<Tensor<T>> {
reject_half_cpu_fft::<T>("hfft2")?;
let arr = tensor_to_complex_array(input, "hfft2")?;
let result =
ferray_fft::hfft2(&arr, s, axes, norm).map_err(|e| FerrotorchError::InvalidArgument {
message: format!("hfft2: {e}"),
})?;
real_array_to_tensor(&result)
}
pub fn ihfft2<T: Float>(
input: &Tensor<T>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
) -> FerrotorchResult<Tensor<T>> {
ihfft2_norm(input, s, axes, FftNorm::Backward)
}
pub fn ihfft2_norm<T: Float>(
input: &Tensor<T>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
norm: FftNorm,
) -> FerrotorchResult<Tensor<T>> {
reject_half_cpu_fft::<T>("ihfft2")?;
let arr = tensor_to_real_array(input, "ihfft2")?;
let result =
ferray_fft::ihfft2(&arr, s, axes, norm).map_err(|e| FerrotorchError::InvalidArgument {
message: format!("ihfft2: {e}"),
})?;
complex_array_to_tensor(&result)
}
pub fn hfftn<T: Float>(
input: &Tensor<T>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
) -> FerrotorchResult<Tensor<T>> {
hfftn_norm(input, s, axes, FftNorm::Backward)
}
pub fn hfftn_norm<T: Float>(
input: &Tensor<T>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
norm: FftNorm,
) -> FerrotorchResult<Tensor<T>> {
reject_half_cpu_fft::<T>("hfftn")?;
let arr = tensor_to_complex_array(input, "hfftn")?;
let result =
ferray_fft::hfftn(&arr, s, axes, norm).map_err(|e| FerrotorchError::InvalidArgument {
message: format!("hfftn: {e}"),
})?;
real_array_to_tensor(&result)
}
pub fn ihfftn<T: Float>(
input: &Tensor<T>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
) -> FerrotorchResult<Tensor<T>> {
ihfftn_norm(input, s, axes, FftNorm::Backward)
}
pub fn ihfftn_norm<T: Float>(
input: &Tensor<T>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
norm: FftNorm,
) -> FerrotorchResult<Tensor<T>> {
reject_half_cpu_fft::<T>("ihfftn")?;
let arr = tensor_to_real_array(input, "ihfftn")?;
let result =
ferray_fft::ihfftn(&arr, s, axes, norm).map_err(|e| FerrotorchError::InvalidArgument {
message: format!("ihfftn: {e}"),
})?;
complex_array_to_tensor(&result)
}
pub fn fftfreq(n: usize, d: f64) -> FerrotorchResult<Tensor<f64>> {
let arr = ferray_fft::fftfreq(n, d).map_err(|e| FerrotorchError::InvalidArgument {
message: format!("fftfreq: {e}"),
})?;
let data: Vec<f64> = arr.iter().copied().collect();
Tensor::from_storage(TensorStorage::cpu(data), vec![n], false)
}
pub fn rfftfreq(n: usize, d: f64) -> FerrotorchResult<Tensor<f64>> {
let arr = ferray_fft::rfftfreq(n, d).map_err(|e| FerrotorchError::InvalidArgument {
message: format!("rfftfreq: {e}"),
})?;
let len = arr.shape()[0];
let data: Vec<f64> = arr.iter().copied().collect();
Tensor::from_storage(TensorStorage::cpu(data), vec![len], false)
}
pub fn fftshift<T: Float>(
input: &Tensor<T>,
axes: Option<&[isize]>,
) -> FerrotorchResult<Tensor<T>> {
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op: "fftshift" });
}
let arr = tensor_to_real_array(input, "fftshift")?;
let shifted =
ferray_fft::fftshift(&arr, axes).map_err(|e| FerrotorchError::InvalidArgument {
message: format!("fftshift: {e}"),
})?;
real_array_to_tensor(&shifted)
}
pub fn ifftshift<T: Float>(
input: &Tensor<T>,
axes: Option<&[isize]>,
) -> FerrotorchResult<Tensor<T>> {
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op: "ifftshift" });
}
let arr = tensor_to_real_array(input, "ifftshift")?;
let shifted =
ferray_fft::ifftshift(&arr, axes).map_err(|e| FerrotorchError::InvalidArgument {
message: format!("ifftshift: {e}"),
})?;
real_array_to_tensor(&shifted)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::TensorStorage;
fn t(data: &[f64], shape: &[usize]) -> Tensor<f64> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), false).unwrap()
}
fn assert_close(a: &[f64], b: &[f64], tol: f64) {
assert_eq!(
a.len(),
b.len(),
"length mismatch: {} vs {}",
a.len(),
b.len()
);
for (i, (&x, &y)) in a.iter().zip(b.iter()).enumerate() {
assert!(
(x - y).abs() < tol,
"index {i}: {x} vs {y} (diff {})",
(x - y).abs()
);
}
}
fn complex_tensor(pairs: &[(f64, f64)]) -> Tensor<f64> {
let mut data = Vec::with_capacity(pairs.len() * 2);
for &(re, im) in pairs {
data.push(re);
data.push(im);
}
t(&data, &[pairs.len(), 2])
}
#[test]
fn fft_of_zeros() {
let input = complex_tensor(&[(0.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]);
let result = fft(&input, None).unwrap();
assert_eq!(result.shape(), &[4, 2]);
let d = result.data().unwrap();
for &v in d {
assert!(v.abs() < 1e-12, "expected 0, got {v}");
}
}
#[test]
fn fft_of_ones() {
let n = 8;
let pairs: Vec<(f64, f64)> = vec![(1.0, 0.0); n];
let input = complex_tensor(&pairs);
let result = fft(&input, None).unwrap();
assert_eq!(result.shape(), &[n, 2]);
let d = result.data().unwrap();
assert!(
(d[0] - n as f64).abs() < 1e-10,
"DC re = {}, expected {n}",
d[0]
);
assert!(d[1].abs() < 1e-10, "DC im = {}", d[1]);
for i in 1..n {
assert!(d[i * 2].abs() < 1e-10, "bin {i} re = {}", d[i * 2]);
assert!(d[i * 2 + 1].abs() < 1e-10, "bin {i} im = {}", d[i * 2 + 1]);
}
}
#[test]
fn fft_pure_cosine() {
let n = 16;
let k = 3; let pi = std::f64::consts::PI;
let pairs: Vec<(f64, f64)> = (0..n)
.map(|i| ((2.0 * pi * k as f64 * i as f64 / n as f64).cos(), 0.0))
.collect();
let input = complex_tensor(&pairs);
let result = fft(&input, None).unwrap();
let d = result.data().unwrap();
for i in 0..n {
let mag = (d[i * 2] * d[i * 2] + d[i * 2 + 1] * d[i * 2 + 1]).sqrt();
if i == k || i == n - k {
assert!(
(mag - n as f64 / 2.0).abs() < 1e-8,
"bin {i}: magnitude {mag}, expected {}",
n as f64 / 2.0
);
} else {
assert!(mag < 1e-8, "bin {i}: magnitude {mag}, expected ~0");
}
}
}
#[test]
fn fft_ifft_roundtrip() {
let pairs = vec![
(1.0, 2.0),
(-1.0, 0.5),
(3.0, -1.0),
(0.0, 0.0),
(-2.5, 1.5),
(0.7, -0.3),
];
let input = complex_tensor(&pairs);
let spectrum = fft(&input, None).unwrap();
let recovered = ifft(&spectrum, None).unwrap();
let d = recovered.data().unwrap();
for (i, &(re, im)) in pairs.iter().enumerate() {
assert!(
(d[i * 2] - re).abs() < 1e-10,
"re at {i}: {} vs {re}",
d[i * 2]
);
assert!(
(d[i * 2 + 1] - im).abs() < 1e-10,
"im at {i}: {} vs {im}",
d[i * 2 + 1]
);
}
}
#[test]
fn rfft_irfft_roundtrip() {
let original = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let n = original.len();
let input = t(&original, &[n]);
let spectrum = rfft(&input, None).unwrap();
assert_eq!(spectrum.shape(), &[5, 2]);
let recovered = irfft(&spectrum, Some(n)).unwrap();
assert_eq!(recovered.shape(), &[n]);
let d = recovered.data().unwrap();
assert_close(d, &original, 1e-10);
}
#[test]
fn rfft_output_shape() {
let input = t(&[0.0; 8], &[8]);
let result = rfft(&input, None).unwrap();
assert_eq!(result.shape(), &[5, 2]);
let input = t(&[0.0; 7], &[7]);
let result = rfft(&input, None).unwrap();
assert_eq!(result.shape(), &[4, 2]); }
#[test]
fn rfft_irfft_roundtrip_odd() {
let original = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let n = original.len();
let input = t(&original, &[n]);
let spectrum = rfft(&input, None).unwrap();
assert_eq!(spectrum.shape(), &[3, 2]);
let recovered = irfft(&spectrum, Some(n)).unwrap();
assert_eq!(recovered.shape(), &[n]);
assert_close(recovered.data().unwrap(), &original, 1e-10);
}
#[test]
fn fft_with_padding() {
let input = complex_tensor(&[(1.0, 0.0), (1.0, 0.0)]);
let result = fft(&input, Some(4)).unwrap();
assert_eq!(result.shape(), &[4, 2]);
let d = result.data().unwrap();
assert!((d[0] - 2.0).abs() < 1e-10);
}
#[test]
fn fft_with_truncation() {
let input = complex_tensor(&[(1.0, 0.0), (2.0, 0.0), (3.0, 0.0), (4.0, 0.0)]);
let result = fft(&input, Some(2)).unwrap();
assert_eq!(result.shape(), &[2, 2]);
let d = result.data().unwrap();
assert!((d[0] - 3.0).abs() < 1e-10);
assert!(d[1].abs() < 1e-10);
assert!((d[2] - (-1.0)).abs() < 1e-10);
assert!(d[3].abs() < 1e-10);
}
#[test]
fn fft2_ifft2_roundtrip() {
let pairs = vec![
(1.0, 0.0),
(2.0, 0.0),
(3.0, 0.0),
(4.0, 0.0),
(5.0, 0.0),
(6.0, 0.0),
];
let mut data = Vec::new();
for &(re, im) in &pairs {
data.push(re);
data.push(im);
}
let input = t(&data, &[2, 3, 2]);
let spectrum = fft2(&input).unwrap();
assert_eq!(spectrum.shape(), &[2, 3, 2]);
let recovered = ifft2(&spectrum).unwrap();
assert_eq!(recovered.shape(), &[2, 3, 2]);
let d = recovered.data().unwrap();
for (i, &(re, im)) in pairs.iter().enumerate() {
assert!(
(d[i * 2] - re).abs() < 1e-9,
"re at {i}: {} vs {re}",
d[i * 2]
);
assert!(
(d[i * 2 + 1] - im).abs() < 1e-9,
"im at {i}: {} vs {im}",
d[i * 2 + 1]
);
}
}
#[test]
fn fft_batched() {
let data = vec![
1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0,
];
let input = t(&data, &[2, 4, 2]);
let result = fft(&input, None).unwrap();
assert_eq!(result.shape(), &[2, 4, 2]);
let d = result.data().unwrap();
for i in 0..4 {
assert!((d[i * 2] - 1.0).abs() < 1e-10, "batch0 bin {i} re");
assert!(d[i * 2 + 1].abs() < 1e-10, "batch0 bin {i} im");
}
let off = 4 * 2;
assert!((d[off] - 4.0).abs() < 1e-10, "batch1 DC re");
assert!(d[off + 1].abs() < 1e-10, "batch1 DC im");
for i in 1..4 {
assert!(d[off + i * 2].abs() < 1e-10, "batch1 bin {i} re");
assert!(d[off + i * 2 + 1].abs() < 1e-10, "batch1 bin {i} im");
}
}
#[test]
fn fft_f32() {
let data: Vec<f32> = vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
let input = Tensor::from_storage(TensorStorage::cpu(data), vec![4, 2], false).unwrap();
let result = fft(&input, None).unwrap();
assert_eq!(result.shape(), &[4, 2]);
let d = result.data().unwrap();
for i in 0..4 {
assert!((d[i * 2] - 1.0).abs() < 1e-5, "bin {i} re = {}", d[i * 2]);
assert!(d[i * 2 + 1].abs() < 1e-5, "bin {i} im = {}", d[i * 2 + 1]);
}
}
#[test]
fn fft_f16_cpu_rejects_matching_torch_unsupported_dtype() {
use half::f16;
let data: Vec<f16> = vec![
f16::from_f32(1.0),
f16::from_f32(0.0),
f16::from_f32(2.0),
f16::from_f32(0.0),
];
let input = Tensor::from_storage(TensorStorage::cpu(data), vec![2, 2], false).unwrap();
let r = fft(&input, None);
assert!(
r.is_err(),
"torch.fft.fft(half) raises RuntimeError: Unsupported dtype Half on CPU; \
ferrotorch must Err too, not silently upcast"
);
let msg = format!("{}", r.unwrap_err());
assert!(
msg.contains("Unsupported dtype"),
"expected 'Unsupported dtype' (mirrors SpectralOps.cpp:90), got: {msg}"
);
}
#[test]
fn fft_bf16_cpu_rejects_matching_torch_unsupported_dtype() {
use half::bf16;
let data: Vec<bf16> = vec![
bf16::from_f32(1.0),
bf16::from_f32(0.0),
bf16::from_f32(2.0),
bf16::from_f32(0.0),
];
let input = Tensor::from_storage(TensorStorage::cpu(data), vec![2, 2], false).unwrap();
let r = fft(&input, None);
assert!(
r.is_err(),
"torch.fft.fft(bfloat16) raises RuntimeError: Unsupported dtype BFloat16 on CPU"
);
assert!(format!("{}", r.unwrap_err()).contains("Unsupported dtype"));
}
#[test]
fn rfft_f16_and_bf16_cpu_reject() {
use half::{bf16, f16};
let f16_real: Vec<f16> = vec![
f16::from_f32(1.0),
f16::from_f32(2.0),
f16::from_f32(3.0),
f16::from_f32(4.0),
];
let f16_in = Tensor::from_storage(TensorStorage::cpu(f16_real), vec![4], false).unwrap();
assert!(
rfft(&f16_in, None).is_err(),
"torch.fft.rfft(half) raises Unsupported dtype Half on CPU"
);
let bf16_real: Vec<bf16> = vec![
bf16::from_f32(1.0),
bf16::from_f32(2.0),
bf16::from_f32(3.0),
bf16::from_f32(4.0),
];
let bf16_in = Tensor::from_storage(TensorStorage::cpu(bf16_real), vec![4], false).unwrap();
assert!(
rfft(&bf16_in, None).is_err(),
"torch.fft.rfft(bfloat16) raises Unsupported dtype BFloat16 on CPU"
);
}
#[test]
fn nd_and_hermitian_transforms_reject_half() {
use half::f16;
let cdata: Vec<f16> = (0..8).map(|i| f16::from_f32(i as f32)).collect();
let c_in = Tensor::from_storage(TensorStorage::cpu(cdata), vec![2, 2, 2], false).unwrap();
assert!(fft2(&c_in).is_err(), "fft2 must reject f16");
assert!(fftn(&c_in, None, None).is_err(), "fftn must reject f16");
assert!(ifftn(&c_in, None, None).is_err(), "ifftn must reject f16");
let rdata: Vec<f16> = (1..=4).map(|i| f16::from_f32(i as f32)).collect();
let r_in = Tensor::from_storage(TensorStorage::cpu(rdata), vec![4], false).unwrap();
assert!(ihfft(&r_in, None).is_err(), "ihfft must reject f16");
assert!(rfftn(&r_in, None, None).is_err(), "rfftn must reject f16");
}
#[test]
fn fftshift_stays_dtype_permissive_for_half() {
use half::f16;
let data: Vec<f16> = (0..8).map(|i| f16::from_f32(i as f32)).collect();
let input = Tensor::from_storage(TensorStorage::cpu(data), vec![8], false).unwrap();
let shifted = fftshift(&input, None).expect("fftshift(half) must succeed like torch");
let got: Vec<f32> = shifted.data().unwrap().iter().map(|v| v.to_f32()).collect();
assert_eq!(got, vec![4.0, 5.0, 6.0, 7.0, 0.0, 1.0, 2.0, 3.0]);
}
#[test]
fn fftn_matches_fft_1d() {
let pairs = vec![(1.0, 2.0), (3.0, -1.0), (-2.0, 0.5), (0.0, 1.0)];
let input = complex_tensor(&pairs);
let by_fft = fft(&input, None).unwrap();
let by_fftn = fftn(&input, None, None).unwrap();
assert_close(by_fft.data().unwrap(), by_fftn.data().unwrap(), 1e-9);
}
#[test]
fn fftn_ifftn_roundtrip_2d() {
let mut data = Vec::with_capacity(24);
for i in 0..12 {
data.push(i as f64);
data.push((i as f64) * 0.5);
}
let input = t(&data, &[3, 4, 2]);
let spectrum = fftn(&input, None, None).unwrap();
assert_eq!(spectrum.shape(), &[3, 4, 2]);
let recovered = ifftn(&spectrum, None, None).unwrap();
assert_eq!(recovered.shape(), &[3, 4, 2]);
assert_close(recovered.data().unwrap(), input.data().unwrap(), 1e-9);
}
#[test]
fn fftn_ifftn_roundtrip_3d() {
let mut data = Vec::with_capacity(2 * 2 * 3 * 2);
for i in 0..(2 * 2 * 3) {
data.push(i as f64 + 1.0);
data.push((i as f64) * 0.3);
}
let input = t(&data, &[2, 2, 3, 2]);
let spectrum = fftn(&input, None, None).unwrap();
assert_eq!(spectrum.shape(), &[2, 2, 3, 2]);
let recovered = ifftn(&spectrum, None, None).unwrap();
assert_close(recovered.data().unwrap(), input.data().unwrap(), 1e-9);
}
#[test]
fn rfftn_irfftn_roundtrip_2d() {
let original: Vec<f64> = (1..=12).map(|x| x as f64).collect();
let input = t(&original, &[3, 4]);
let spectrum = rfftn(&input, None, None).unwrap();
assert_eq!(spectrum.shape(), &[3, 3, 2]);
let recovered = irfftn(&spectrum, Some(&[3, 4]), None).unwrap();
assert_eq!(recovered.shape(), &[3, 4]);
assert_close(recovered.data().unwrap(), &original, 1e-9);
}
#[test]
fn hfft_ihfft_roundtrip() {
let original = vec![1.0, 2.5, -1.5, 0.5, 3.0, -2.0, 0.0, 1.0];
let n = original.len();
let input = t(&original, &[n]);
let spectrum = ihfft(&input, None).unwrap();
assert_eq!(spectrum.shape(), &[n / 2 + 1, 2]);
let recovered = hfft(&spectrum, Some(n)).unwrap();
assert_eq!(recovered.shape(), &[n]);
assert_close(recovered.data().unwrap(), &original, 1e-9);
}
#[test]
fn fftfreq_known_values() {
let f = fftfreq(8, 1.0).unwrap();
let expected = [0.0, 0.125, 0.25, 0.375, -0.5, -0.375, -0.25, -0.125];
assert_close(f.data().unwrap(), &expected, 1e-12);
}
#[test]
fn rfftfreq_known_values() {
let f = rfftfreq(8, 1.0).unwrap();
let expected = [0.0, 0.125, 0.25, 0.375, 0.5];
assert_close(f.data().unwrap(), &expected, 1e-12);
}
#[test]
fn fftfreq_with_sample_spacing() {
let f = fftfreq(8, 0.1).unwrap();
let d = f.data().unwrap();
assert!((d[1] - 1.25).abs() < 1e-10);
}
#[test]
fn fftshift_basic_even() {
let input = t(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], &[8]);
let shifted = fftshift(&input, None).unwrap();
let d = shifted.data().unwrap();
assert_close(d, &[4.0, 5.0, 6.0, 7.0, 0.0, 1.0, 2.0, 3.0], 1e-12);
}
#[test]
fn fftshift_ifftshift_even_inverse() {
let input = t(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], &[8]);
let shifted = fftshift(&input, None).unwrap();
let unshifted = ifftshift(&shifted, None).unwrap();
assert_close(unshifted.data().unwrap(), input.data().unwrap(), 1e-12);
}
#[test]
fn fftshift_ifftshift_odd_inverse() {
let input = t(&[0.0, 1.0, 2.0, 3.0, 4.0], &[5]);
let shifted = fftshift(&input, None).unwrap();
let unshifted = ifftshift(&shifted, None).unwrap();
assert_close(unshifted.data().unwrap(), input.data().unwrap(), 1e-12);
}
#[test]
fn fftshift_axes_arg() {
let input = t(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], &[2, 4]);
let shifted = fftshift(&input, Some(&[-1])).unwrap();
assert_close(
shifted.data().unwrap(),
&[2.0, 3.0, 0.0, 1.0, 6.0, 7.0, 4.0, 5.0],
1e-12,
);
}
#[test]
fn fftn_agrees_with_fft2_for_2d() {
let mut data = Vec::with_capacity(2 * 3 * 2);
for i in 0..6 {
data.push((i as f64) - 3.0);
data.push((i as f64) * 0.7);
}
let input = t(&data, &[2, 3, 2]);
let by_fft2 = fft2(&input).unwrap();
let by_fftn = fftn(&input, None, None).unwrap();
assert_close(by_fft2.data().unwrap(), by_fftn.data().unwrap(), 1e-9);
}
#[test]
fn rfft2_output_shape_and_irfft2_roundtrip() {
let original: Vec<f64> = (1..=12).map(|x| x as f64).collect();
let input = t(&original, &[3, 4]);
let spectrum = rfft2(&input, None, None).unwrap();
assert_eq!(spectrum.shape(), &[3, 3, 2]);
let recovered = irfft2(&spectrum, Some(&[3, 4]), None).unwrap();
assert_eq!(recovered.shape(), &[3, 4]);
assert_close(recovered.data().unwrap(), &original, 1e-9);
}
#[test]
fn rfft2_matches_rfftn_over_last_two_axes() {
let original: Vec<f64> = (1..=24).map(|x| x as f64 * 0.5).collect();
let input = t(&original, &[2, 3, 4]);
let by_rfft2 = rfft2(&input, None, None).unwrap();
let by_rfftn = rfftn(&input, None, Some(&[-2, -1])).unwrap();
assert_eq!(by_rfft2.shape(), by_rfftn.shape());
assert_close(by_rfft2.data().unwrap(), by_rfftn.data().unwrap(), 1e-9);
}
#[test]
fn ihfft2_hfft2_roundtrip() {
let original: Vec<f64> = (1..=16).map(|x| x as f64).collect();
let input = t(&original, &[4, 4]);
let spectrum = ihfft2(&input, None, None).unwrap();
assert_eq!(spectrum.shape(), &[4, 3, 2]);
let recovered = hfft2(&spectrum, Some(&[4, 4]), None).unwrap();
assert_eq!(recovered.shape(), &[4, 4]);
assert_close(recovered.data().unwrap(), &original, 1e-9);
}
#[test]
fn ihfftn_hfftn_roundtrip_3d() {
let original: Vec<f64> = (1..=16).map(|x| x as f64 * 0.25).collect();
let input = t(&original, &[2, 2, 4]);
let spectrum = ihfftn(&input, None, None).unwrap();
assert_eq!(spectrum.shape(), &[2, 2, 3, 2]);
let recovered = hfftn(&spectrum, Some(&[2, 2, 4]), None).unwrap();
assert_eq!(recovered.shape(), &[2, 2, 4]);
assert_close(recovered.data().unwrap(), &original, 1e-9);
}
#[test]
fn fft_norm_from_str_maps_modes() {
assert_eq!(fft_norm_from_str(None, "fft").unwrap(), FftNorm::Backward);
assert_eq!(
fft_norm_from_str(Some("backward"), "fft").unwrap(),
FftNorm::Backward
);
assert_eq!(
fft_norm_from_str(Some("forward"), "fft").unwrap(),
FftNorm::Forward
);
assert_eq!(
fft_norm_from_str(Some("ortho"), "fft").unwrap(),
FftNorm::Ortho
);
assert!(fft_norm_from_str(Some("bogus"), "fft").is_err());
}
#[test]
fn fft_ortho_norm_scales_dc_by_sqrt_n() {
let n = 8usize;
let pairs: Vec<(f64, f64)> = vec![(1.0, 0.0); n];
let input = complex_tensor(&pairs);
let backward = fft_norm(&input, None, None, FftNorm::Backward).unwrap();
assert!((backward.data().unwrap()[0] - n as f64).abs() < 1e-10);
let ortho = fft_norm(&input, None, None, FftNorm::Ortho).unwrap();
assert!((ortho.data().unwrap()[0] - (n as f64).sqrt()).abs() < 1e-10);
let forward = fft_norm(&input, None, None, FftNorm::Forward).unwrap();
assert!((forward.data().unwrap()[0] - 1.0).abs() < 1e-10);
}
#[test]
fn fft_ortho_is_unitary_roundtrip() {
let pairs = vec![(1.0, 2.0), (-1.0, 0.5), (3.0, -1.0), (0.0, 0.0)];
let input = complex_tensor(&pairs);
let spectrum = fft_norm(&input, None, None, FftNorm::Ortho).unwrap();
let recovered = ifft_norm(&spectrum, None, None, FftNorm::Ortho).unwrap();
let d = recovered.data().unwrap();
for (i, &(re, im)) in pairs.iter().enumerate() {
assert!((d[i * 2] - re).abs() < 1e-10, "re {i}");
assert!((d[i * 2 + 1] - im).abs() < 1e-10, "im {i}");
}
}
#[test]
fn fft_dim_transforms_named_axis() {
let mut data = Vec::new();
for r in 0..2 {
for _c in 0..4 {
data.push((r + 1) as f64);
data.push(0.0);
}
}
let input = t(&data, &[2, 4, 2]);
let out = fft_norm(&input, None, Some(-2), FftNorm::Backward).unwrap();
assert_eq!(out.shape(), &[2, 4, 2]);
let d = out.data().unwrap();
for c in 0..4 {
let bin0 = d[c * 2]; let bin1 = d[(4 + c) * 2]; assert!((bin0 - 3.0).abs() < 1e-10, "col {c} bin0 = {bin0}");
assert!((bin1 - (-1.0)).abs() < 1e-10, "col {c} bin1 = {bin1}");
}
}
#[test]
fn rfft_dim_transforms_named_axis() {
let original: Vec<f64> = (1..=12).map(|x| x as f64).collect();
let input = t(&original, &[4, 3]);
let out = rfft_norm(&input, None, Some(-2), FftNorm::Backward).unwrap();
assert_eq!(out.shape(), &[3, 3, 2]);
let back = irfft_norm(&out, Some(4), Some(-2), FftNorm::Backward).unwrap();
assert_eq!(back.shape(), &[4, 3]);
assert_close(back.data().unwrap(), &original, 1e-9);
}
#[test]
fn fftn_s_resizes_named_axes() {
let mut data = Vec::with_capacity(3 * 4 * 2);
for i in 0..12 {
data.push(i as f64);
data.push(0.0);
}
let input = t(&data, &[3, 4, 2]);
let out = fftn_norm(&input, Some(&[2, 8]), Some(&[0, 1]), FftNorm::Backward).unwrap();
assert_eq!(out.shape(), &[2, 8, 2]);
}
#[test]
fn hfft2_matches_hfftn_over_last_two_axes() {
let original: Vec<f64> = (1..=24).map(|x| x as f64 * 0.3).collect();
let real_in = t(&original, &[2, 3, 4]);
let spectrum = ihfftn(&real_in, Some(&[3, 4]), Some(&[-2, -1])).unwrap();
let by_hfft2 = hfft2(&spectrum, Some(&[3, 4]), None).unwrap();
let by_hfftn = hfftn(&spectrum, Some(&[3, 4]), Some(&[-2, -1])).unwrap();
assert_eq!(by_hfft2.shape(), by_hfftn.shape());
assert_close(by_hfft2.data().unwrap(), by_hfftn.data().unwrap(), 1e-9);
}
}