use crate::tensor::{complex::Complex, Tensor};
use num_traits::Float;
impl<
T: Float
+ 'static
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive
+ Default
+ Clone
+ std::fmt::Debug,
> Tensor<T>
{
pub fn fft(
&self,
n: Option<usize>,
dim: Option<isize>,
norm: Option<&str>,
) -> Result<(Self, Self), String> {
let actual_dim = self.resolve_dim(dim.unwrap_or(-1))?;
let input_len = self.shape()[actual_dim];
let fft_len = n.unwrap_or(input_len);
if self.shape().len() == 1 && actual_dim == 0 {
return self.fft_1d_basic(fft_len, norm, false);
}
Err("Multi-dimensional FFT not implemented in this version".to_string())
}
pub fn ifft(
&self,
real_part: &Self,
imag_part: &Self,
n: Option<usize>,
dim: Option<isize>,
norm: Option<&str>,
) -> Result<(Self, Self), String> {
if real_part.shape() != imag_part.shape() {
return Err("Real and imaginary parts must have the same shape".to_string());
}
let actual_dim = self.resolve_dim(dim.unwrap_or(-1))?;
let input_len = real_part.shape()[actual_dim];
let fft_len = n.unwrap_or(input_len);
if real_part.shape().len() == 1 && actual_dim == 0 {
return self.ifft_1d_basic(real_part, imag_part, fft_len, norm);
}
Err("Multi-dimensional IFFT not implemented in this version".to_string())
}
pub fn rfft(
&self,
n: Option<usize>,
dim: Option<isize>,
norm: Option<&str>,
) -> Result<(Self, Self), String> {
let (real_full, imag_full) = self.fft(n, dim, norm)?;
let full_len = real_full.shape()[0];
let rfft_len = full_len / 2 + 1;
let real_data = real_full.as_slice().unwrap();
let imag_data = imag_full.as_slice().unwrap();
let real_rfft: Vec<T> = real_data[0..rfft_len].to_vec();
let imag_rfft: Vec<T> = imag_data[0..rfft_len].to_vec();
Ok((
Tensor::from_vec(real_rfft, vec![rfft_len]),
Tensor::from_vec(imag_rfft, vec![rfft_len]),
))
}
pub fn fftshift(&self, dim: Option<&[isize]>) -> Result<Self, String> {
let shape = self.shape();
let ndim = shape.len() as isize;
let dims_to_shift: Vec<usize> = if let Some(dims) = dim {
dims.iter()
.map(|&d| {
let adjusted = if d < 0 {
(ndim + d) as usize
} else {
d as usize
};
if adjusted >= shape.len() {
return Err(format!("Dimension {} is out of bounds", d));
}
Ok(adjusted)
})
.collect::<Result<Vec<_>, _>>()?
} else {
(0..shape.len()).collect()
};
let mut result = self.clone();
for &dim_idx in &dims_to_shift {
let dim_size = shape[dim_idx];
let shift_amount = dim_size / 2;
if shift_amount == 0 {
continue; }
result = Self::_shift_along_axis(&result, dim_idx, shift_amount)?;
}
Ok(result)
}
pub fn ifftshift(&self, dim: Option<&[isize]>) -> Result<Self, String> {
let shape = self.shape();
let ndim = shape.len() as isize;
let dims_to_shift: Vec<usize> = if let Some(dims) = dim {
dims.iter()
.map(|&d| {
let adjusted = if d < 0 {
(ndim + d) as usize
} else {
d as usize
};
if adjusted >= shape.len() {
return Err(format!("Dimension {} is out of bounds", d));
}
Ok(adjusted)
})
.collect::<Result<Vec<_>, _>>()?
} else {
(0..shape.len()).collect()
};
let mut result = self.clone();
for &dim_idx in &dims_to_shift {
let dim_size = shape[dim_idx];
let shift_amount = (dim_size + 1) / 2;
if shift_amount == 0 || shift_amount == dim_size {
continue;
}
result = Self::_shift_along_axis(&result, dim_idx, shift_amount)?;
}
Ok(result)
}
fn _shift_along_axis(tensor: &Self, axis: usize, shift: usize) -> Result<Self, String> {
let shape = tensor.shape();
let axis_size = shape[axis];
if shift == 0 || shift >= axis_size {
return Ok(tensor.clone());
}
let data = tensor.as_slice().unwrap();
let mut result_data = Vec::with_capacity(data.len());
let mut strides = vec![1; shape.len()];
for i in (0..shape.len() - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
let axis_stride = strides[axis];
let outer_size = if axis == 0 {
1
} else {
shape[..axis].iter().product()
};
let inner_size = if axis == shape.len() - 1 {
1
} else {
shape[axis + 1..].iter().product()
};
for outer in 0..outer_size {
for inner in 0..inner_size {
for i in 0..axis_size {
let src_idx = (i + shift) % axis_size;
let linear_idx =
outer * (axis_size * inner_size) + src_idx * inner_size + inner;
result_data.push(data[linear_idx]);
}
}
}
Ok(Tensor::from_vec(result_data, shape.to_vec()))
}
pub fn apply_window(
&self,
window_type: WindowType,
axis: Option<usize>,
) -> Result<Self, String> {
let target_axis = axis.unwrap_or(self.shape().len() - 1);
if target_axis >= self.shape().len() {
return Err(format!("Axis {} is out of bounds", target_axis));
}
let window_size = self.shape()[target_axis];
let window = Self::create_window(window_type, window_size)?;
self.apply_window_along_axis(&window, target_axis)
}
fn create_window(window_type: WindowType, size: usize) -> Result<Self, String> {
let mut window_data = Vec::with_capacity(size);
match window_type {
WindowType::Hann => {
for i in 0..size {
let val = T::from(0.5).unwrap()
* (T::one()
- T::cos(
T::from(2.0 * std::f64::consts::PI * i as f64 / (size - 1) as f64)
.unwrap(),
));
window_data.push(val);
}
}
WindowType::Hamming => {
for i in 0..size {
let val = T::from(0.54).unwrap()
- T::from(0.46).unwrap()
* T::cos(
T::from(2.0 * std::f64::consts::PI * i as f64 / (size - 1) as f64)
.unwrap(),
);
window_data.push(val);
}
}
WindowType::Blackman => {
for i in 0..size {
let cos1 = T::cos(
T::from(2.0 * std::f64::consts::PI * i as f64 / (size - 1) as f64).unwrap(),
);
let cos2 = T::cos(
T::from(4.0 * std::f64::consts::PI * i as f64 / (size - 1) as f64).unwrap(),
);
let val = T::from(0.42).unwrap() - T::from(0.5).unwrap() * cos1
+ T::from(0.08).unwrap() * cos2;
window_data.push(val);
}
}
WindowType::Rectangular => {
for _ in 0..size {
window_data.push(T::one());
}
}
}
Ok(Tensor::from_vec(window_data, vec![size]))
}
fn apply_window_along_axis(&self, window: &Self, axis: usize) -> Result<Self, String> {
let shape = self.shape();
let window_size = window.numel();
if shape[axis] != window_size {
return Err(format!(
"Window size {} doesn't match axis size {}",
window_size, shape[axis]
));
}
let data = self.as_slice().unwrap();
let window_data = window.as_slice().unwrap();
let mut result_data = Vec::with_capacity(data.len());
let mut strides = vec![1; shape.len()];
for i in (0..shape.len() - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
let axis_stride = strides[axis];
let outer_size = if axis == 0 {
1
} else {
shape[..axis].iter().product()
};
let inner_size = if axis == shape.len() - 1 {
1
} else {
shape[axis + 1..].iter().product()
};
for outer in 0..outer_size {
for i in 0..shape[axis] {
for inner in 0..inner_size {
let linear_idx = outer * (shape[axis] * inner_size) + i * inner_size + inner;
result_data.push(data[linear_idx] * window_data[i]);
}
}
}
Ok(Tensor::from_vec(result_data, shape.to_vec()))
}
fn fft_1d_basic(
&self,
n: usize,
norm: Option<&str>,
inverse: bool,
) -> Result<(Self, Self), String> {
let input_data = self.as_slice().unwrap();
let input_len = input_data.len();
let real_input: Vec<T> = if n != input_len {
if n > input_len {
let mut padded = input_data.to_vec();
padded.resize(n, T::zero());
padded
} else {
input_data[0..n].to_vec()
}
} else {
input_data.to_vec()
};
let complex_data: Vec<Complex<T>> = real_input
.iter()
.map(|&x| Complex::new(x, T::zero()))
.collect();
let mut result = Vec::with_capacity(n);
for k in 0..n {
let mut sum = Complex::new(T::zero(), T::zero());
for (j, &x) in complex_data.iter().enumerate() {
let angle =
T::from(-2.0 * std::f64::consts::PI * j as f64 * k as f64 / n as f64).unwrap();
let angle = if inverse { -angle } else { angle };
let twiddle = Complex::new(angle.cos(), angle.sin());
sum = sum + x * twiddle;
}
if inverse {
sum = Complex::new(sum.re / T::from(n).unwrap(), sum.im / T::from(n).unwrap());
}
result.push(sum);
}
let normalized_result = match norm {
Some("ortho") => {
let scale = T::from(1.0 / (n as f64).sqrt()).unwrap();
result
.iter()
.map(|c| Complex::new(c.re * scale, c.im * scale))
.collect()
}
Some("forward") => {
if !inverse {
let scale = T::from(1.0 / n as f64).unwrap();
result
.iter()
.map(|c| Complex::new(c.re * scale, c.im * scale))
.collect()
} else {
result
}
}
_ => result, };
let real_part: Vec<T> = normalized_result.iter().map(|c| c.re).collect();
let imag_part: Vec<T> = normalized_result.iter().map(|c| c.im).collect();
Ok((
Tensor::from_vec(real_part, vec![n]),
Tensor::from_vec(imag_part, vec![n]),
))
}
fn ifft_1d_basic(
&self,
real_part: &Self,
imag_part: &Self,
n: usize,
norm: Option<&str>,
) -> Result<(Self, Self), String> {
let real_data = real_part.as_slice().unwrap();
let imag_data = imag_part.as_slice().unwrap();
let input_len = real_data.len();
let complex_input: Vec<T> = real_data.to_vec();
let temp_tensor = Tensor::from_vec(complex_input, vec![input_len]);
temp_tensor.fft_1d_basic(n, norm, true)
}
}
#[derive(Debug, Clone, Copy)]
pub enum WindowType {
Hann,
Hamming,
Blackman,
Rectangular,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fft_placeholder() {
let tensor = Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], vec![4]);
let (real, imag) = tensor.fft(None, None, None).unwrap();
assert_eq!(real.shape(), tensor.shape());
assert_eq!(imag.shape(), tensor.shape());
}
#[test]
fn test_rfft_placeholder() {
let tensor = Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], vec![4]);
let (real, imag) = tensor.rfft(None, None, None).unwrap();
assert_eq!(real.shape(), &[3]); assert_eq!(imag.shape(), &[3]);
}
#[test]
fn test_fftshift() {
let tensor = Tensor::from_vec(vec![0.0f32, 1.0, 2.0, 3.0], vec![4]);
let shifted = tensor.fftshift(None).unwrap();
assert_eq!(shifted.shape(), tensor.shape());
assert_eq!(shifted.as_slice().unwrap(), &[2.0f32, 3.0, 0.0, 1.0]);
}
#[test]
fn test_ifftshift() {
let tensor = Tensor::from_vec(vec![0.0f32, 1.0, 2.0, 3.0], vec![4]);
let shifted = tensor.ifftshift(None).unwrap();
assert_eq!(shifted.shape(), tensor.shape());
assert_eq!(shifted.as_slice().unwrap(), &[2.0f32, 3.0, 0.0, 1.0]);
}
#[test]
fn test_window_functions() {
let tensor = Tensor::from_vec(vec![1.0f32; 8], vec![8]);
let windowed = tensor.apply_window(WindowType::Hann, None).unwrap();
assert_eq!(windowed.shape(), tensor.shape());
let windowed_data = windowed.as_slice().unwrap();
assert!(windowed_data[0].abs() < 1e-6);
assert!(windowed_data[7].abs() < 1e-6);
let rectangular = tensor.apply_window(WindowType::Rectangular, None).unwrap();
assert_eq!(rectangular.as_slice().unwrap(), tensor.as_slice().unwrap());
}
#[test]
fn test_create_window() {
let hann = Tensor::<f32>::create_window(WindowType::Hann, 4).unwrap();
assert_eq!(hann.shape(), &[4]);
let hann_data = hann.as_slice().unwrap();
assert!(hann_data[0].abs() < 1e-6); assert!(hann_data[3].abs() < 1e-6);
let rect = Tensor::<f32>::create_window(WindowType::Rectangular, 5).unwrap();
let rect_data = rect.as_slice().unwrap();
for &val in rect_data {
assert_eq!(val, 1.0);
}
}
#[test]
fn test_multidimensional_fftshift() {
let tensor = Tensor::from_vec((0..8).map(|x| x as f32).collect(), vec![2, 4]);
let shifted = tensor.fftshift(Some(&[1])).unwrap();
assert_eq!(shifted.shape(), tensor.shape());
}
#[test]
fn test_fftshift_round_trip() {
let tensor = Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]);
let shifted = tensor.fftshift(None).unwrap();
let back = shifted.ifftshift(None).unwrap();
let original_data = tensor.as_slice().unwrap();
let back_data = back.as_slice().unwrap();
for (i, (&orig, &restored)) in original_data.iter().zip(back_data.iter()).enumerate() {
assert!(
(orig - restored).abs() < 1e-6,
"Mismatch at index {}: {} vs {}",
i,
orig,
restored
);
}
}
}