use ferray_core::Array as FerrayArray;
use ferray_core::IxDyn as FerrayIxDyn;
use ferray_fft::FftNorm;
use rustfft::FftPlanner;
use rustfft::num_complex::Complex;
use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::numeric_cast::cast;
use crate::storage::TensorStorage;
use crate::tensor::Tensor;
#[inline]
fn is_f32<T: Float>() -> bool {
std::mem::size_of::<T>() == 4
}
#[inline]
fn is_f64<T: Float>() -> bool {
std::mem::size_of::<T>() == 8
}
fn complex_to_pairs<T: Float>(data: &[Complex<f64>]) -> FerrotorchResult<Vec<T>> {
let mut out = Vec::with_capacity(data.len() * 2);
for c in data {
out.push(cast(c.re)?);
out.push(cast(c.im)?);
}
Ok(out)
}
fn fft_1d_last_axis(data: &mut [Complex<f64>], batch_shape: &[usize], n: usize, inverse: bool) {
let mut planner = FftPlanner::<f64>::new();
let fft = if inverse {
planner.plan_fft_inverse(n)
} else {
planner.plan_fft_forward(n)
};
let batch_size: usize = if batch_shape.is_empty() {
1
} else {
batch_shape.iter().product()
};
for b in 0..batch_size {
let offset = b * n;
fft.process(&mut data[offset..offset + n]);
}
if inverse {
let scale = 1.0 / n as f64;
for v in data.iter_mut() {
v.re *= scale;
v.im *= scale;
}
}
}
pub fn fft<T: Float>(input: &Tensor<T>, n: Option<usize>) -> FerrotorchResult<Tensor<T>> {
let shape = input.shape();
if shape.is_empty() || *shape.last().unwrap() != 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"fft: input must have trailing dimension 2 (complex), got shape {:?}",
shape
),
});
}
let ndim = shape.len();
if ndim < 2 {
return Err(FerrotorchError::InvalidArgument {
message: "fft: input must have at least 2 dimensions ([..., n, 2])".into(),
});
}
let input_n = shape[ndim - 2];
let fft_n = n.unwrap_or(input_n);
if fft_n == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "fft: n must be > 0".into(),
});
}
let batch_shape = &shape[..ndim - 2];
let batch_size: usize = batch_shape.iter().product::<usize>().max(1);
if input.is_cuda() && (is_f32::<T>() || is_f64::<T>()) {
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let buf = input.gpu_handle()?;
let (transformed_handle, owned);
let buf_for_fft: &crate::gpu_dispatch::GpuBufferHandle = if fft_n == input_n {
buf
} else if is_f32::<T>() {
owned = backend.pad_truncate_complex_f32(buf, batch_size, input_n, fft_n)?;
transformed_handle = &owned;
transformed_handle
} else {
owned = backend.pad_truncate_complex_f64(buf, batch_size, input_n, fft_n)?;
transformed_handle = &owned;
transformed_handle
};
let h = if is_f32::<T>() {
backend.fft_c2c_f32(buf_for_fft, batch_size, fft_n, false)?
} else {
backend.fft_c2c_f64(buf_for_fft, batch_size, fft_n, false)?
};
let mut out_shape = batch_shape.to_vec();
out_shape.push(fft_n);
out_shape.push(2);
return Tensor::from_storage(TensorStorage::gpu(h), out_shape, false);
}
let data = input.data_vec()?;
let mut complex_data = Vec::with_capacity(batch_size * fft_n);
for b in 0..batch_size {
let src_offset = b * input_n * 2;
let copy_len = input_n.min(fft_n);
for i in 0..copy_len {
let re = data[src_offset + i * 2].to_f64().unwrap();
let im = data[src_offset + i * 2 + 1].to_f64().unwrap();
complex_data.push(Complex::new(re, im));
}
for _ in copy_len..fft_n {
complex_data.push(Complex::new(0.0, 0.0));
}
}
fft_1d_last_axis(&mut complex_data, batch_shape, fft_n, false);
let result_data = complex_to_pairs::<T>(&complex_data)?;
let mut out_shape = batch_shape.to_vec();
out_shape.push(fft_n);
out_shape.push(2);
Tensor::from_storage(TensorStorage::cpu(result_data), out_shape, false)
}
pub fn ifft<T: Float>(input: &Tensor<T>, n: Option<usize>) -> FerrotorchResult<Tensor<T>> {
let shape = input.shape();
if shape.is_empty() || *shape.last().unwrap() != 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"ifft: input must have trailing dimension 2 (complex), got shape {:?}",
shape
),
});
}
let ndim = shape.len();
if ndim < 2 {
return Err(FerrotorchError::InvalidArgument {
message: "ifft: input must have at least 2 dimensions ([..., n, 2])".into(),
});
}
let input_n = shape[ndim - 2];
let fft_n = n.unwrap_or(input_n);
if fft_n == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "ifft: n must be > 0".into(),
});
}
let batch_shape = &shape[..ndim - 2];
let batch_size: usize = batch_shape.iter().product::<usize>().max(1);
if input.is_cuda() && (is_f32::<T>() || is_f64::<T>()) {
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let buf = input.gpu_handle()?;
let (transformed_handle, owned);
let buf_for_fft: &crate::gpu_dispatch::GpuBufferHandle = if fft_n == input_n {
buf
} else if is_f32::<T>() {
owned = backend.pad_truncate_complex_f32(buf, batch_size, input_n, fft_n)?;
transformed_handle = &owned;
transformed_handle
} else {
owned = backend.pad_truncate_complex_f64(buf, batch_size, input_n, fft_n)?;
transformed_handle = &owned;
transformed_handle
};
let h = if is_f32::<T>() {
backend.fft_c2c_f32(buf_for_fft, batch_size, fft_n, true)?
} else {
backend.fft_c2c_f64(buf_for_fft, batch_size, fft_n, true)?
};
let mut out_shape = batch_shape.to_vec();
out_shape.push(fft_n);
out_shape.push(2);
return Tensor::from_storage(TensorStorage::gpu(h), out_shape, false);
}
let data = input.data_vec()?;
let mut complex_data = Vec::with_capacity(batch_size * fft_n);
for b in 0..batch_size {
let src_offset = b * input_n * 2;
let copy_len = input_n.min(fft_n);
for i in 0..copy_len {
let re = data[src_offset + i * 2].to_f64().unwrap();
let im = data[src_offset + i * 2 + 1].to_f64().unwrap();
complex_data.push(Complex::new(re, im));
}
for _ in copy_len..fft_n {
complex_data.push(Complex::new(0.0, 0.0));
}
}
fft_1d_last_axis(&mut complex_data, batch_shape, fft_n, true);
let result_data = complex_to_pairs::<T>(&complex_data)?;
let mut out_shape = batch_shape.to_vec();
out_shape.push(fft_n);
out_shape.push(2);
Tensor::from_storage(TensorStorage::cpu(result_data), out_shape, false)
}
pub fn rfft<T: Float>(input: &Tensor<T>, n: Option<usize>) -> FerrotorchResult<Tensor<T>> {
let shape = input.shape();
if shape.is_empty() {
return Err(FerrotorchError::InvalidArgument {
message: "rfft: input must have at least 1 dimension".into(),
});
}
let ndim = shape.len();
let input_n = shape[ndim - 1];
let fft_n = n.unwrap_or(input_n);
if fft_n == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "rfft: n must be > 0".into(),
});
}
let batch_shape = &shape[..ndim - 1];
let batch_size: usize = batch_shape.iter().product::<usize>().max(1);
if input.is_cuda() && fft_n == input_n {
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let buf = input.gpu_handle()?;
let h = if is_f32::<T>() {
backend.rfft_r2c_f32(buf, batch_size, fft_n)?
} else if is_f64::<T>() {
backend.rfft_r2c_f64(buf, batch_size, fft_n)?
} else {
return Err(FerrotorchError::InvalidArgument {
message: "rfft requires f32 or f64".into(),
});
};
let half_n = fft_n / 2 + 1;
let mut out_shape = batch_shape.to_vec();
out_shape.push(half_n);
out_shape.push(2);
return Tensor::from_storage(TensorStorage::gpu(h), out_shape, false);
}
let data = input.data_vec()?;
let mut complex_data = Vec::with_capacity(batch_size * fft_n);
for b in 0..batch_size {
let src_offset = b * input_n;
let copy_len = input_n.min(fft_n);
for i in 0..copy_len {
complex_data.push(Complex::new(data[src_offset + i].to_f64().unwrap(), 0.0));
}
for _ in copy_len..fft_n {
complex_data.push(Complex::new(0.0, 0.0));
}
}
fft_1d_last_axis(&mut complex_data, batch_shape, fft_n, false);
let half_n = fft_n / 2 + 1;
let mut result_data = Vec::with_capacity(batch_size * half_n * 2);
for b in 0..batch_size {
let offset = b * fft_n;
for i in 0..half_n {
let c = complex_data[offset + i];
result_data.push(cast(c.re)?);
result_data.push(cast(c.im)?);
}
}
let mut out_shape = batch_shape.to_vec();
out_shape.push(half_n);
out_shape.push(2);
Tensor::from_storage(TensorStorage::cpu(result_data), out_shape, false)
}
pub fn irfft<T: Float>(input: &Tensor<T>, n: Option<usize>) -> FerrotorchResult<Tensor<T>> {
let shape = input.shape();
if shape.is_empty() || *shape.last().unwrap() != 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"irfft: input must have trailing dimension 2 (complex), got shape {:?}",
shape
),
});
}
let ndim = shape.len();
if ndim < 2 {
return Err(FerrotorchError::InvalidArgument {
message: "irfft: input must have at least 2 dimensions ([..., n/2+1, 2])".into(),
});
}
let half_n = shape[ndim - 2];
let output_n = n.unwrap_or(2 * (half_n - 1));
if output_n == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "irfft: output length must be > 0".into(),
});
}
let batch_shape = &shape[..ndim - 2];
let batch_size: usize = batch_shape.iter().product::<usize>().max(1);
if input.is_cuda() && half_n == output_n / 2 + 1 {
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let buf = input.gpu_handle()?;
let h = if is_f32::<T>() {
backend.irfft_c2r_f32(buf, batch_size, output_n)?
} else if is_f64::<T>() {
backend.irfft_c2r_f64(buf, batch_size, output_n)?
} else {
return Err(FerrotorchError::InvalidArgument {
message: "irfft requires f32 or f64".into(),
});
};
let mut out_shape = batch_shape.to_vec();
out_shape.push(output_n);
return Tensor::from_storage(TensorStorage::gpu(h), out_shape, false);
}
let data = input.data_vec()?;
let mut complex_data = Vec::with_capacity(batch_size * output_n);
for b in 0..batch_size {
let src_offset = b * half_n * 2;
for i in 0..half_n.min(output_n) {
let re = data[src_offset + i * 2].to_f64().unwrap();
let im = data[src_offset + i * 2 + 1].to_f64().unwrap();
complex_data.push(Complex::new(re, im));
}
for k in half_n..output_n {
let mirror = output_n - k;
if mirror < half_n {
let re = data[src_offset + mirror * 2].to_f64().unwrap();
let im = data[src_offset + mirror * 2 + 1].to_f64().unwrap();
complex_data.push(Complex::new(re, -im));
} else {
complex_data.push(Complex::new(0.0, 0.0));
}
}
}
fft_1d_last_axis(&mut complex_data, batch_shape, output_n, true);
let result_data: Vec<T> = complex_data
.iter()
.map(|c| cast(c.re))
.collect::<FerrotorchResult<_>>()?;
let mut out_shape = batch_shape.to_vec();
out_shape.push(output_n);
Tensor::from_storage(TensorStorage::cpu(result_data), out_shape, false)
}
pub fn fft2<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let shape = input.shape();
if shape.is_empty() || *shape.last().unwrap() != 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"fft2: input must have trailing dimension 2 (complex), got shape {:?}",
shape
),
});
}
let ndim = shape.len();
if ndim < 3 {
return Err(FerrotorchError::InvalidArgument {
message: "fft2: input must have at least 3 dimensions ([..., rows, cols, 2])".into(),
});
}
let rows = shape[ndim - 3];
let cols = shape[ndim - 2];
let batch_dims: usize = shape[..ndim - 3].iter().product::<usize>().max(1);
if input.is_cuda() && batch_dims == 1 && (is_f32::<T>() || is_f64::<T>()) {
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let h = if is_f32::<T>() {
backend.fft2_c2c_f32(input.gpu_handle()?, rows, cols, false)?
} else {
backend.fft2_c2c_f64(input.gpu_handle()?, rows, cols, false)?
};
return Tensor::from_storage(TensorStorage::gpu(h), shape.to_vec(), false);
}
let after_cols = fft(input, Some(cols))?;
fft_2d_row_pass(&after_cols, rows, cols, false)
}
pub fn ifft2<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let shape = input.shape();
if shape.is_empty() || *shape.last().unwrap() != 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"ifft2: input must have trailing dimension 2 (complex), got shape {:?}",
shape
),
});
}
let ndim = shape.len();
if ndim < 3 {
return Err(FerrotorchError::InvalidArgument {
message: "ifft2: input must have at least 3 dimensions ([..., rows, cols, 2])".into(),
});
}
let rows = shape[ndim - 3];
let cols = shape[ndim - 2];
let batch_dims: usize = shape[..ndim - 3].iter().product::<usize>().max(1);
if input.is_cuda() && batch_dims == 1 && (is_f32::<T>() || is_f64::<T>()) {
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let h = if is_f32::<T>() {
backend.fft2_c2c_f32(input.gpu_handle()?, rows, cols, true)?
} else {
backend.fft2_c2c_f64(input.gpu_handle()?, rows, cols, true)?
};
return Tensor::from_storage(TensorStorage::gpu(h), shape.to_vec(), false);
}
let after_cols = ifft(input, Some(cols))?;
fft_2d_row_pass(&after_cols, rows, cols, true)
}
fn fft_2d_row_pass<T: Float>(
input: &Tensor<T>,
rows: usize,
cols: usize,
inverse: bool,
) -> FerrotorchResult<Tensor<T>> {
let shape = input.shape();
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op: "fft2" });
}
let ndim = shape.len();
let batch_shape = &shape[..ndim - 3];
let batch_size: usize = batch_shape.iter().product::<usize>().max(1);
let data = input.data_vec()?;
let mut transposed = vec![<T as num_traits::Zero>::zero(); data.len()];
for b in 0..batch_size {
let base = b * rows * cols * 2;
for r in 0..rows {
for c in 0..cols {
let src = base + r * cols * 2 + c * 2;
let dst = base + c * rows * 2 + r * 2;
transposed[dst] = data[src];
transposed[dst + 1] = data[src + 1];
}
}
}
let mut trans_shape = batch_shape.to_vec();
trans_shape.push(cols);
trans_shape.push(rows);
trans_shape.push(2);
let trans_tensor = Tensor::from_storage(TensorStorage::cpu(transposed), trans_shape, false)?;
let transformed = if inverse {
ifft(&trans_tensor, Some(rows))?
} else {
fft(&trans_tensor, Some(rows))?
};
let t_data = transformed.data_vec()?;
let mut result = vec![<T as num_traits::Zero>::zero(); t_data.len()];
for b in 0..batch_size {
let base = b * rows * cols * 2;
for c in 0..cols {
for r in 0..rows {
let src = base + c * rows * 2 + r * 2;
let dst = base + r * cols * 2 + c * 2;
result[dst] = t_data[src];
result[dst + 1] = t_data[src + 1];
}
}
}
let mut out_shape = batch_shape.to_vec();
out_shape.push(rows);
out_shape.push(cols);
out_shape.push(2);
Tensor::from_storage(TensorStorage::cpu(result), out_shape, false)
}
fn tensor_to_complex_array<T: Float>(
input: &Tensor<T>,
op: &'static str,
) -> FerrotorchResult<FerrayArray<Complex<f64>, FerrayIxDyn>> {
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op });
}
let shape = input.shape();
if shape.is_empty() || *shape.last().unwrap() != 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"{op}: input must have trailing dimension 2 (complex), got shape {shape:?}"
),
});
}
let data = input.data_vec()?;
let total_complex = data.len() / 2;
let mut complex_data = Vec::with_capacity(total_complex);
for i in 0..total_complex {
let re = data[i * 2].to_f64().unwrap();
let im = data[i * 2 + 1].to_f64().unwrap();
complex_data.push(Complex::new(re, im));
}
let inner_shape: Vec<usize> = shape[..shape.len() - 1].to_vec();
FerrayArray::from_vec(FerrayIxDyn::new(&inner_shape), complex_data).map_err(|e| {
FerrotorchError::InvalidArgument {
message: format!("{op}: failed to build ferray array: {e}"),
}
})
}
fn tensor_to_real_array<T: Float>(
input: &Tensor<T>,
op: &'static str,
) -> FerrotorchResult<FerrayArray<f64, FerrayIxDyn>> {
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op });
}
let data = input.data_vec()?;
let real_data: Vec<f64> = data.iter().map(|v| v.to_f64().unwrap()).collect();
FerrayArray::from_vec(FerrayIxDyn::new(input.shape()), real_data).map_err(|e| {
FerrotorchError::InvalidArgument {
message: format!("{op}: failed to build ferray array: {e}"),
}
})
}
fn complex_array_to_tensor<T: Float>(
arr: &FerrayArray<Complex<f64>, FerrayIxDyn>,
) -> FerrotorchResult<Tensor<T>> {
let shape = arr.shape().to_vec();
let total: usize = shape.iter().product();
let mut out_data: Vec<T> = Vec::with_capacity(total * 2);
for c in arr.iter() {
out_data.push(cast(c.re)?);
out_data.push(cast(c.im)?);
}
let mut out_shape = shape;
out_shape.push(2);
Tensor::from_storage(TensorStorage::cpu(out_data), out_shape, false)
}
fn real_array_to_tensor<T: Float>(
arr: &FerrayArray<f64, FerrayIxDyn>,
) -> FerrotorchResult<Tensor<T>> {
let shape = arr.shape().to_vec();
let out_data: Vec<T> = arr
.iter()
.map(|&v| cast(v))
.collect::<FerrotorchResult<_>>()?;
Tensor::from_storage(TensorStorage::cpu(out_data), shape, false)
}
pub fn fftn<T: Float>(
input: &Tensor<T>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
) -> FerrotorchResult<Tensor<T>> {
let arr = tensor_to_complex_array(input, "fftn")?;
let result = ferray_fft::fftn(&arr, s, axes, FftNorm::Backward).map_err(|e| {
FerrotorchError::InvalidArgument {
message: format!("fftn: {e}"),
}
})?;
complex_array_to_tensor(&result)
}
pub fn ifftn<T: Float>(
input: &Tensor<T>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
) -> FerrotorchResult<Tensor<T>> {
let arr = tensor_to_complex_array(input, "ifftn")?;
let result = ferray_fft::ifftn(&arr, s, axes, FftNorm::Backward).map_err(|e| {
FerrotorchError::InvalidArgument {
message: format!("ifftn: {e}"),
}
})?;
complex_array_to_tensor(&result)
}
pub fn rfftn<T: Float>(
input: &Tensor<T>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
) -> FerrotorchResult<Tensor<T>> {
let arr = tensor_to_real_array(input, "rfftn")?;
let result = ferray_fft::rfftn(&arr, s, axes, FftNorm::Backward).map_err(|e| {
FerrotorchError::InvalidArgument {
message: format!("rfftn: {e}"),
}
})?;
complex_array_to_tensor(&result)
}
pub fn irfftn<T: Float>(
input: &Tensor<T>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
) -> FerrotorchResult<Tensor<T>> {
let arr = tensor_to_complex_array(input, "irfftn")?;
let result = ferray_fft::irfftn(&arr, s, axes, FftNorm::Backward).map_err(|e| {
FerrotorchError::InvalidArgument {
message: format!("irfftn: {e}"),
}
})?;
real_array_to_tensor(&result)
}
pub fn hfft<T: Float>(input: &Tensor<T>, n: Option<usize>) -> FerrotorchResult<Tensor<T>> {
let arr = tensor_to_complex_array(input, "hfft")?;
let result = ferray_fft::hfft(&arr, n, None, FftNorm::Backward).map_err(|e| {
FerrotorchError::InvalidArgument {
message: format!("hfft: {e}"),
}
})?;
real_array_to_tensor(&result)
}
pub fn ihfft<T: Float>(input: &Tensor<T>, n: Option<usize>) -> FerrotorchResult<Tensor<T>> {
let arr = tensor_to_real_array(input, "ihfft")?;
let result = ferray_fft::ihfft(&arr, n, None, FftNorm::Backward).map_err(|e| {
FerrotorchError::InvalidArgument {
message: format!("ihfft: {e}"),
}
})?;
complex_array_to_tensor(&result)
}
pub fn fftfreq(n: usize, d: f64) -> FerrotorchResult<Tensor<f64>> {
let arr = ferray_fft::fftfreq(n, d).map_err(|e| FerrotorchError::InvalidArgument {
message: format!("fftfreq: {e}"),
})?;
let data: Vec<f64> = arr.iter().copied().collect();
Tensor::from_storage(TensorStorage::cpu(data), vec![n], false)
}
pub fn rfftfreq(n: usize, d: f64) -> FerrotorchResult<Tensor<f64>> {
let arr = ferray_fft::rfftfreq(n, d).map_err(|e| FerrotorchError::InvalidArgument {
message: format!("rfftfreq: {e}"),
})?;
let len = arr.shape()[0];
let data: Vec<f64> = arr.iter().copied().collect();
Tensor::from_storage(TensorStorage::cpu(data), vec![len], false)
}
pub fn fftshift<T: Float>(
input: &Tensor<T>,
axes: Option<&[isize]>,
) -> FerrotorchResult<Tensor<T>> {
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op: "fftshift" });
}
let arr = tensor_to_real_array(input, "fftshift")?;
let shifted =
ferray_fft::fftshift(&arr, axes).map_err(|e| FerrotorchError::InvalidArgument {
message: format!("fftshift: {e}"),
})?;
real_array_to_tensor(&shifted)
}
pub fn ifftshift<T: Float>(
input: &Tensor<T>,
axes: Option<&[isize]>,
) -> FerrotorchResult<Tensor<T>> {
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op: "ifftshift" });
}
let arr = tensor_to_real_array(input, "ifftshift")?;
let shifted =
ferray_fft::ifftshift(&arr, axes).map_err(|e| FerrotorchError::InvalidArgument {
message: format!("ifftshift: {e}"),
})?;
real_array_to_tensor(&shifted)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::TensorStorage;
fn t(data: &[f64], shape: &[usize]) -> Tensor<f64> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), false).unwrap()
}
fn assert_close(a: &[f64], b: &[f64], tol: f64) {
assert_eq!(
a.len(),
b.len(),
"length mismatch: {} vs {}",
a.len(),
b.len()
);
for (i, (&x, &y)) in a.iter().zip(b.iter()).enumerate() {
assert!(
(x - y).abs() < tol,
"index {i}: {x} vs {y} (diff {})",
(x - y).abs()
);
}
}
fn complex_tensor(pairs: &[(f64, f64)]) -> Tensor<f64> {
let mut data = Vec::with_capacity(pairs.len() * 2);
for &(re, im) in pairs {
data.push(re);
data.push(im);
}
t(&data, &[pairs.len(), 2])
}
#[test]
fn fft_of_zeros() {
let input = complex_tensor(&[(0.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]);
let result = fft(&input, None).unwrap();
assert_eq!(result.shape(), &[4, 2]);
let d = result.data().unwrap();
for &v in d {
assert!(v.abs() < 1e-12, "expected 0, got {v}");
}
}
#[test]
fn fft_of_ones() {
let n = 8;
let pairs: Vec<(f64, f64)> = vec![(1.0, 0.0); n];
let input = complex_tensor(&pairs);
let result = fft(&input, None).unwrap();
assert_eq!(result.shape(), &[n, 2]);
let d = result.data().unwrap();
assert!(
(d[0] - n as f64).abs() < 1e-10,
"DC re = {}, expected {n}",
d[0]
);
assert!(d[1].abs() < 1e-10, "DC im = {}", d[1]);
for i in 1..n {
assert!(d[i * 2].abs() < 1e-10, "bin {i} re = {}", d[i * 2]);
assert!(d[i * 2 + 1].abs() < 1e-10, "bin {i} im = {}", d[i * 2 + 1]);
}
}
#[test]
fn fft_pure_cosine() {
let n = 16;
let k = 3; let pi = std::f64::consts::PI;
let pairs: Vec<(f64, f64)> = (0..n)
.map(|i| ((2.0 * pi * k as f64 * i as f64 / n as f64).cos(), 0.0))
.collect();
let input = complex_tensor(&pairs);
let result = fft(&input, None).unwrap();
let d = result.data().unwrap();
for i in 0..n {
let mag = (d[i * 2] * d[i * 2] + d[i * 2 + 1] * d[i * 2 + 1]).sqrt();
if i == k || i == n - k {
assert!(
(mag - n as f64 / 2.0).abs() < 1e-8,
"bin {i}: magnitude {mag}, expected {}",
n as f64 / 2.0
);
} else {
assert!(mag < 1e-8, "bin {i}: magnitude {mag}, expected ~0");
}
}
}
#[test]
fn fft_ifft_roundtrip() {
let pairs = vec![
(1.0, 2.0),
(-1.0, 0.5),
(3.0, -1.0),
(0.0, 0.0),
(-2.5, 1.5),
(0.7, -0.3),
];
let input = complex_tensor(&pairs);
let spectrum = fft(&input, None).unwrap();
let recovered = ifft(&spectrum, None).unwrap();
let d = recovered.data().unwrap();
for (i, &(re, im)) in pairs.iter().enumerate() {
assert!(
(d[i * 2] - re).abs() < 1e-10,
"re at {i}: {} vs {re}",
d[i * 2]
);
assert!(
(d[i * 2 + 1] - im).abs() < 1e-10,
"im at {i}: {} vs {im}",
d[i * 2 + 1]
);
}
}
#[test]
fn rfft_irfft_roundtrip() {
let original = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let n = original.len();
let input = t(&original, &[n]);
let spectrum = rfft(&input, None).unwrap();
assert_eq!(spectrum.shape(), &[5, 2]);
let recovered = irfft(&spectrum, Some(n)).unwrap();
assert_eq!(recovered.shape(), &[n]);
let d = recovered.data().unwrap();
assert_close(d, &original, 1e-10);
}
#[test]
fn rfft_output_shape() {
let input = t(&[0.0; 8], &[8]);
let result = rfft(&input, None).unwrap();
assert_eq!(result.shape(), &[5, 2]);
let input = t(&[0.0; 7], &[7]);
let result = rfft(&input, None).unwrap();
assert_eq!(result.shape(), &[4, 2]); }
#[test]
fn rfft_irfft_roundtrip_odd() {
let original = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let n = original.len();
let input = t(&original, &[n]);
let spectrum = rfft(&input, None).unwrap();
assert_eq!(spectrum.shape(), &[3, 2]);
let recovered = irfft(&spectrum, Some(n)).unwrap();
assert_eq!(recovered.shape(), &[n]);
assert_close(recovered.data().unwrap(), &original, 1e-10);
}
#[test]
fn fft_with_padding() {
let input = complex_tensor(&[(1.0, 0.0), (1.0, 0.0)]);
let result = fft(&input, Some(4)).unwrap();
assert_eq!(result.shape(), &[4, 2]);
let d = result.data().unwrap();
assert!((d[0] - 2.0).abs() < 1e-10);
}
#[test]
fn fft_with_truncation() {
let input = complex_tensor(&[(1.0, 0.0), (2.0, 0.0), (3.0, 0.0), (4.0, 0.0)]);
let result = fft(&input, Some(2)).unwrap();
assert_eq!(result.shape(), &[2, 2]);
let d = result.data().unwrap();
assert!((d[0] - 3.0).abs() < 1e-10);
assert!(d[1].abs() < 1e-10);
assert!((d[2] - (-1.0)).abs() < 1e-10);
assert!(d[3].abs() < 1e-10);
}
#[test]
fn fft2_ifft2_roundtrip() {
let pairs = vec![
(1.0, 0.0),
(2.0, 0.0),
(3.0, 0.0),
(4.0, 0.0),
(5.0, 0.0),
(6.0, 0.0),
];
let mut data = Vec::new();
for &(re, im) in &pairs {
data.push(re);
data.push(im);
}
let input = t(&data, &[2, 3, 2]);
let spectrum = fft2(&input).unwrap();
assert_eq!(spectrum.shape(), &[2, 3, 2]);
let recovered = ifft2(&spectrum).unwrap();
assert_eq!(recovered.shape(), &[2, 3, 2]);
let d = recovered.data().unwrap();
for (i, &(re, im)) in pairs.iter().enumerate() {
assert!(
(d[i * 2] - re).abs() < 1e-9,
"re at {i}: {} vs {re}",
d[i * 2]
);
assert!(
(d[i * 2 + 1] - im).abs() < 1e-9,
"im at {i}: {} vs {im}",
d[i * 2 + 1]
);
}
}
#[test]
fn fft_batched() {
let data = vec![
1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0,
];
let input = t(&data, &[2, 4, 2]);
let result = fft(&input, None).unwrap();
assert_eq!(result.shape(), &[2, 4, 2]);
let d = result.data().unwrap();
for i in 0..4 {
assert!((d[i * 2] - 1.0).abs() < 1e-10, "batch0 bin {i} re");
assert!(d[i * 2 + 1].abs() < 1e-10, "batch0 bin {i} im");
}
let off = 4 * 2;
assert!((d[off] - 4.0).abs() < 1e-10, "batch1 DC re");
assert!(d[off + 1].abs() < 1e-10, "batch1 DC im");
for i in 1..4 {
assert!(d[off + i * 2].abs() < 1e-10, "batch1 bin {i} re");
assert!(d[off + i * 2 + 1].abs() < 1e-10, "batch1 bin {i} im");
}
}
#[test]
fn fft_f32() {
let data: Vec<f32> = vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
let input = Tensor::from_storage(TensorStorage::cpu(data), vec![4, 2], false).unwrap();
let result = fft(&input, None).unwrap();
assert_eq!(result.shape(), &[4, 2]);
let d = result.data().unwrap();
for i in 0..4 {
assert!((d[i * 2] - 1.0).abs() < 1e-5, "bin {i} re = {}", d[i * 2]);
assert!(d[i * 2 + 1].abs() < 1e-5, "bin {i} im = {}", d[i * 2 + 1]);
}
}
#[test]
fn fftn_matches_fft_1d() {
let pairs = vec![(1.0, 2.0), (3.0, -1.0), (-2.0, 0.5), (0.0, 1.0)];
let input = complex_tensor(&pairs);
let by_fft = fft(&input, None).unwrap();
let by_fftn = fftn(&input, None, None).unwrap();
assert_close(by_fft.data().unwrap(), by_fftn.data().unwrap(), 1e-9);
}
#[test]
fn fftn_ifftn_roundtrip_2d() {
let mut data = Vec::with_capacity(24);
for i in 0..12 {
data.push(i as f64);
data.push((i as f64) * 0.5);
}
let input = t(&data, &[3, 4, 2]);
let spectrum = fftn(&input, None, None).unwrap();
assert_eq!(spectrum.shape(), &[3, 4, 2]);
let recovered = ifftn(&spectrum, None, None).unwrap();
assert_eq!(recovered.shape(), &[3, 4, 2]);
assert_close(recovered.data().unwrap(), input.data().unwrap(), 1e-9);
}
#[test]
fn fftn_ifftn_roundtrip_3d() {
let mut data = Vec::with_capacity(2 * 2 * 3 * 2);
for i in 0..(2 * 2 * 3) {
data.push(i as f64 + 1.0);
data.push((i as f64) * 0.3);
}
let input = t(&data, &[2, 2, 3, 2]);
let spectrum = fftn(&input, None, None).unwrap();
assert_eq!(spectrum.shape(), &[2, 2, 3, 2]);
let recovered = ifftn(&spectrum, None, None).unwrap();
assert_close(recovered.data().unwrap(), input.data().unwrap(), 1e-9);
}
#[test]
fn rfftn_irfftn_roundtrip_2d() {
let original: Vec<f64> = (1..=12).map(|x| x as f64).collect();
let input = t(&original, &[3, 4]);
let spectrum = rfftn(&input, None, None).unwrap();
assert_eq!(spectrum.shape(), &[3, 3, 2]);
let recovered = irfftn(&spectrum, Some(&[3, 4]), None).unwrap();
assert_eq!(recovered.shape(), &[3, 4]);
assert_close(recovered.data().unwrap(), &original, 1e-9);
}
#[test]
fn hfft_ihfft_roundtrip() {
let original = vec![1.0, 2.5, -1.5, 0.5, 3.0, -2.0, 0.0, 1.0];
let n = original.len();
let input = t(&original, &[n]);
let spectrum = ihfft(&input, None).unwrap();
assert_eq!(spectrum.shape(), &[n / 2 + 1, 2]);
let recovered = hfft(&spectrum, Some(n)).unwrap();
assert_eq!(recovered.shape(), &[n]);
assert_close(recovered.data().unwrap(), &original, 1e-9);
}
#[test]
fn fftfreq_known_values() {
let f = fftfreq(8, 1.0).unwrap();
let expected = [0.0, 0.125, 0.25, 0.375, -0.5, -0.375, -0.25, -0.125];
assert_close(f.data().unwrap(), &expected, 1e-12);
}
#[test]
fn rfftfreq_known_values() {
let f = rfftfreq(8, 1.0).unwrap();
let expected = [0.0, 0.125, 0.25, 0.375, 0.5];
assert_close(f.data().unwrap(), &expected, 1e-12);
}
#[test]
fn fftfreq_with_sample_spacing() {
let f = fftfreq(8, 0.1).unwrap();
let d = f.data().unwrap();
assert!((d[1] - 1.25).abs() < 1e-10);
}
#[test]
fn fftshift_basic_even() {
let input = t(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], &[8]);
let shifted = fftshift(&input, None).unwrap();
let d = shifted.data().unwrap();
assert_close(d, &[4.0, 5.0, 6.0, 7.0, 0.0, 1.0, 2.0, 3.0], 1e-12);
}
#[test]
fn fftshift_ifftshift_even_inverse() {
let input = t(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], &[8]);
let shifted = fftshift(&input, None).unwrap();
let unshifted = ifftshift(&shifted, None).unwrap();
assert_close(unshifted.data().unwrap(), input.data().unwrap(), 1e-12);
}
#[test]
fn fftshift_ifftshift_odd_inverse() {
let input = t(&[0.0, 1.0, 2.0, 3.0, 4.0], &[5]);
let shifted = fftshift(&input, None).unwrap();
let unshifted = ifftshift(&shifted, None).unwrap();
assert_close(unshifted.data().unwrap(), input.data().unwrap(), 1e-12);
}
#[test]
fn fftshift_axes_arg() {
let input = t(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], &[2, 4]);
let shifted = fftshift(&input, Some(&[-1])).unwrap();
assert_close(
shifted.data().unwrap(),
&[2.0, 3.0, 0.0, 1.0, 6.0, 7.0, 4.0, 5.0],
1e-12,
);
}
#[test]
fn fftn_agrees_with_fft2_for_2d() {
let mut data = Vec::with_capacity(2 * 3 * 2);
for i in 0..6 {
data.push((i as f64) - 3.0);
data.push((i as f64) * 0.7);
}
let input = t(&data, &[2, 3, 2]);
let by_fft2 = fft2(&input).unwrap();
let by_fftn = fftn(&input, None, None).unwrap();
assert_close(by_fft2.data().unwrap(), by_fftn.data().unwrap(), 1e-9);
}
}