use crate::tensor::TensorStorage;
use crate::{Result, Tensor, TensorError};
use num_complex::Complex;
use oxifft::{Direction, Flags, Plan};
use scirs2_core::ndarray::{ArrayD, IxDyn};
use scirs2_core::numeric::{Float, FromPrimitive, Signed, Zero};
use std::fmt::Debug;
#[inline]
fn to_oxifft_complex<T: oxifft::Float>(data: &[Complex<T>]) -> &[oxifft::kernel::Complex<T>] {
unsafe {
std::slice::from_raw_parts(
data.as_ptr() as *const oxifft::kernel::Complex<T>,
data.len(),
)
}
}
#[inline]
fn to_oxifft_complex_mut<T: oxifft::Float>(
data: &mut [Complex<T>],
) -> &mut [oxifft::kernel::Complex<T>] {
unsafe {
std::slice::from_raw_parts_mut(
data.as_mut_ptr() as *mut oxifft::kernel::Complex<T>,
data.len(),
)
}
}
pub fn fft<T>(input: &Tensor<T>) -> Result<Tensor<Complex<T>>>
where
T: Float
+ Send
+ Sync
+ 'static
+ FromPrimitive
+ Signed
+ Debug
+ Default
+ bytemuck::Pod
+ bytemuck::Zeroable
+ oxifft::Float,
Complex<T>: Default,
{
match &input.storage {
TensorStorage::Cpu(arr) => {
let shape = arr.shape();
let ndim = shape.len();
if ndim == 0 {
return Err(TensorError::InvalidShape {
operation: "fft".to_string(),
reason: "FFT requires at least 1D input".to_string(),
shape: Some(shape.to_vec()),
context: None,
});
}
let n = shape[ndim - 1];
let plan = Plan::dft_1d(n, Direction::Forward, Flags::ESTIMATE).ok_or_else(|| {
TensorError::InvalidShape {
operation: "fft".to_string(),
reason: "Failed to create FFT plan".to_string(),
shape: Some(shape.to_vec()),
context: None,
}
})?;
let total_elements: usize = shape.iter().product();
let num_ffts = total_elements / n;
let mut output_data = vec![Complex::zero(); total_elements];
if let Some(input_slice) = arr.as_slice() {
for i in 0..num_ffts {
let start_idx = i * n;
let end_idx = (i + 1) * n;
let mut input_buffer: Vec<Complex<T>> = input_slice[start_idx..end_idx]
.iter()
.map(|&x| Complex::new(x, T::zero()))
.collect();
let mut output_buffer = vec![Complex::zero(); n];
plan.execute(
to_oxifft_complex(&input_buffer),
to_oxifft_complex_mut(&mut output_buffer),
);
output_data[start_idx..end_idx].copy_from_slice(&output_buffer);
}
let output_array =
ArrayD::from_shape_vec(IxDyn(shape), output_data).map_err(|e| {
TensorError::InvalidShape {
operation: "fft".to_string(),
reason: e.to_string(),
shape: None,
context: None,
}
})?;
Ok(Tensor::from_array(output_array))
} else {
Err(TensorError::unsupported_operation_simple(
"Cannot get slice from input array".to_string(),
))
}
}
#[cfg(feature = "gpu")]
TensorStorage::Gpu(_gpu_buffer) => {
let cpu_tensor = input.to_cpu()?;
fft(&cpu_tensor)
}
}
}
pub fn ifft<T>(input: &Tensor<Complex<T>>) -> Result<Tensor<Complex<T>>>
where
T: Float
+ Send
+ Sync
+ 'static
+ FromPrimitive
+ Signed
+ Debug
+ Default
+ bytemuck::Pod
+ bytemuck::Zeroable
+ oxifft::Float,
Complex<T>: Default,
{
match &input.storage {
TensorStorage::Cpu(arr) => {
let shape = arr.shape();
let ndim = shape.len();
if ndim == 0 {
return Err(TensorError::InvalidShape {
operation: "ifft".to_string(),
reason: "IFFT requires at least 1D input".to_string(),
shape: Some(shape.to_vec()),
context: None,
});
}
let n = shape[ndim - 1];
let plan = Plan::dft_1d(n, Direction::Backward, Flags::ESTIMATE).ok_or_else(|| {
TensorError::InvalidShape {
operation: "ifft".to_string(),
reason: "Failed to create IFFT plan".to_string(),
shape: Some(shape.to_vec()),
context: None,
}
})?;
let total_elements: usize = shape.iter().product();
let num_iffts = total_elements / n;
let mut output_data = vec![Complex::zero(); total_elements];
if let Some(input_slice) = arr.as_slice() {
for i in 0..num_iffts {
let start_idx = i * n;
let end_idx = (i + 1) * n;
let mut input_buffer: Vec<Complex<T>> =
input_slice[start_idx..end_idx].to_vec();
let mut output_buffer = vec![Complex::zero(); n];
plan.execute(
to_oxifft_complex(&input_buffer),
to_oxifft_complex_mut(&mut output_buffer),
);
let n_t = T::from(n).expect("n must be convertible to float type");
for val in &mut output_buffer {
*val /= n_t;
}
output_data[start_idx..end_idx].copy_from_slice(&output_buffer);
}
let output_array =
ArrayD::from_shape_vec(IxDyn(shape), output_data).map_err(|e| {
TensorError::InvalidShape {
operation: "fft".to_string(),
reason: e.to_string(),
shape: None,
context: None,
}
})?;
Ok(Tensor::from_array(output_array))
} else {
Err(TensorError::unsupported_operation_simple(
"Cannot get slice from input array".to_string(),
))
}
}
#[cfg(feature = "gpu")]
TensorStorage::Gpu(_gpu_buffer) => {
Err(TensorError::unsupported_operation_simple(
"GPU IFFT not yet implemented".to_string(),
))
}
}
}
pub fn rfft<T>(input: &Tensor<T>) -> Result<Tensor<Complex<T>>>
where
T: Float
+ Send
+ Sync
+ 'static
+ FromPrimitive
+ Signed
+ Debug
+ Default
+ bytemuck::Pod
+ bytemuck::Zeroable
+ oxifft::Float,
Complex<T>: Default,
{
match &input.storage {
TensorStorage::Cpu(arr) => {
let shape = arr.shape();
let ndim = shape.len();
if ndim == 0 {
return Err(TensorError::InvalidShape {
operation: "rfft".to_string(),
reason: "RFFT requires at least 1D input".to_string(),
shape: Some(shape.to_vec()),
context: None,
});
}
let n = shape[ndim - 1];
let output_len = n / 2 + 1;
let plan = Plan::dft_1d(n, Direction::Forward, Flags::ESTIMATE).ok_or_else(|| {
TensorError::InvalidShape {
operation: "rfft".to_string(),
reason: "Failed to create RFFT plan".to_string(),
shape: Some(shape.to_vec()),
context: None,
}
})?;
let mut output_shape = shape.to_vec();
output_shape[ndim - 1] = output_len;
let input_total: usize = shape.iter().product();
let output_total: usize = output_shape.iter().product();
let num_ffts = input_total / n;
let mut output_data = vec![Complex::zero(); output_total];
if let Some(input_slice) = arr.as_slice() {
for i in 0..num_ffts {
let input_start = i * n;
let input_end = (i + 1) * n;
let output_start = i * output_len;
let mut input_buffer: Vec<Complex<T>> = input_slice[input_start..input_end]
.iter()
.map(|&x| Complex::new(x, T::zero()))
.collect();
let mut full_output = vec![Complex::zero(); n];
plan.execute(
to_oxifft_complex(&input_buffer),
to_oxifft_complex_mut(&mut full_output),
);
output_data[output_start..output_start + output_len]
.copy_from_slice(&full_output[..output_len]);
}
let output_array = ArrayD::from_shape_vec(IxDyn(&output_shape), output_data)
.map_err(|e| TensorError::InvalidShape {
operation: "fft".to_string(),
reason: e.to_string(),
shape: None,
context: None,
})?;
Ok(Tensor::from_array(output_array))
} else {
Err(TensorError::unsupported_operation_simple(
"Cannot get slice from input array".to_string(),
))
}
}
#[cfg(feature = "gpu")]
TensorStorage::Gpu(_gpu_buffer) => {
let cpu_tensor = input.to_cpu()?;
rfft(&cpu_tensor)
}
}
}