use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::runtime::Runtime;
use crate::tensor::Tensor;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FftDirection {
Forward,
Inverse,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum FftNormalization {
None,
#[default]
Backward,
Ortho,
Forward,
}
impl FftNormalization {
#[inline]
pub fn factor(self, direction: FftDirection, n: usize) -> f64 {
let n_f = n as f64;
match (self, direction) {
(Self::None, _) => 1.0,
(Self::Backward, FftDirection::Forward) => 1.0,
(Self::Backward, FftDirection::Inverse) => 1.0 / n_f,
(Self::Ortho, _) => 1.0 / n_f.sqrt(),
(Self::Forward, FftDirection::Forward) => 1.0 / n_f,
(Self::Forward, FftDirection::Inverse) => 1.0,
}
}
}
pub trait FftAlgorithms<R: Runtime> {
fn fft(
&self,
input: &Tensor<R>,
direction: FftDirection,
norm: FftNormalization,
) -> Result<Tensor<R>> {
let _ = (input, direction, norm);
Err(Error::NotImplemented {
feature: "FftAlgorithms::fft",
})
}
fn fft_dim(
&self,
input: &Tensor<R>,
dim: isize,
direction: FftDirection,
norm: FftNormalization,
) -> Result<Tensor<R>> {
let _ = (input, dim, direction, norm);
Err(Error::NotImplemented {
feature: "FftAlgorithms::fft_dim",
})
}
fn rfft(&self, input: &Tensor<R>, norm: FftNormalization) -> Result<Tensor<R>> {
let _ = (input, norm);
Err(Error::NotImplemented {
feature: "FftAlgorithms::rfft",
})
}
fn irfft(
&self,
input: &Tensor<R>,
n: Option<usize>,
norm: FftNormalization,
) -> Result<Tensor<R>> {
let _ = (input, n, norm);
Err(Error::NotImplemented {
feature: "FftAlgorithms::irfft",
})
}
fn fft2(
&self,
input: &Tensor<R>,
direction: FftDirection,
norm: FftNormalization,
) -> Result<Tensor<R>> {
let _ = (input, direction, norm);
Err(Error::NotImplemented {
feature: "FftAlgorithms::fft2",
})
}
fn rfft2(&self, input: &Tensor<R>, norm: FftNormalization) -> Result<Tensor<R>> {
let _ = (input, norm);
Err(Error::NotImplemented {
feature: "FftAlgorithms::rfft2",
})
}
fn irfft2(
&self,
input: &Tensor<R>,
s: Option<(usize, usize)>,
norm: FftNormalization,
) -> Result<Tensor<R>> {
let _ = (input, s, norm);
Err(Error::NotImplemented {
feature: "FftAlgorithms::irfft2",
})
}
fn fftshift(&self, input: &Tensor<R>) -> Result<Tensor<R>> {
let _ = input;
Err(Error::NotImplemented {
feature: "FftAlgorithms::fftshift",
})
}
fn ifftshift(&self, input: &Tensor<R>) -> Result<Tensor<R>> {
let _ = input;
Err(Error::NotImplemented {
feature: "FftAlgorithms::ifftshift",
})
}
fn fftfreq(&self, n: usize, d: f64, dtype: DType, device: &R::Device) -> Result<Tensor<R>> {
let _ = (n, d, dtype, device);
Err(Error::NotImplemented {
feature: "FftAlgorithms::fftfreq",
})
}
fn rfftfreq(&self, n: usize, d: f64, dtype: DType, device: &R::Device) -> Result<Tensor<R>> {
let _ = (n, d, dtype, device);
Err(Error::NotImplemented {
feature: "FftAlgorithms::rfftfreq",
})
}
}
#[inline]
pub fn is_power_of_two(n: usize) -> bool {
n > 0 && (n & (n - 1)) == 0
}
pub fn validate_fft_size(n: usize, op: &'static str) -> Result<()> {
if !is_power_of_two(n) {
let next = n.next_power_of_two();
let prev = next / 2;
let suggestion = if prev >= 2 && prev < n {
format!(
"{} requires power-of-2 size, got {}. \
Consider truncating to {} or padding to {}.",
op, n, prev, next
)
} else {
format!(
"{} requires power-of-2 size, got {}. \
Consider padding to {}.",
op, n, next
)
};
return Err(Error::InvalidArgument {
arg: "n",
reason: suggestion,
});
}
Ok(())
}
pub fn validate_fft_complex_dtype(dtype: DType, op: &'static str) -> Result<()> {
if !dtype.is_complex() {
return Err(Error::UnsupportedDType { dtype, op });
}
Ok(())
}
pub fn validate_rfft_real_dtype(dtype: DType, op: &'static str) -> Result<()> {
match dtype {
DType::F32 | DType::F64 => Ok(()),
_ => Err(Error::UnsupportedDType { dtype, op }),
}
}
pub fn complex_dtype_for_real(real_dtype: DType) -> Result<DType> {
match real_dtype {
DType::F32 => Ok(DType::Complex64),
DType::F64 => Ok(DType::Complex128),
_ => Err(Error::UnsupportedDType {
dtype: real_dtype,
op: "rfft",
}),
}
}
pub fn real_dtype_for_complex(complex_dtype: DType) -> Result<DType> {
match complex_dtype {
DType::Complex64 => Ok(DType::F32),
DType::Complex128 => Ok(DType::F64),
_ => Err(Error::UnsupportedDType {
dtype: complex_dtype,
op: "irfft",
}),
}
}
use std::f64::consts::PI;
pub fn generate_twiddles_c64(n: usize, inverse: bool) -> Vec<crate::dtype::Complex64> {
let sign = if inverse { 1.0 } else { -1.0 };
let n_f = n as f64;
(0..n / 2)
.map(|k| {
let theta = sign * 2.0 * PI * (k as f64) / n_f;
crate::dtype::Complex64::new(theta.cos() as f32, theta.sin() as f32)
})
.collect()
}
pub fn generate_twiddles_c128(n: usize, inverse: bool) -> Vec<crate::dtype::Complex128> {
let sign = if inverse { 1.0 } else { -1.0 };
let n_f = n as f64;
(0..n / 2)
.map(|k| {
let theta = sign * 2.0 * PI * (k as f64) / n_f;
crate::dtype::Complex128::new(theta.cos(), theta.sin())
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_power_of_two() {
assert!(is_power_of_two(1));
assert!(is_power_of_two(2));
assert!(is_power_of_two(4));
assert!(is_power_of_two(1024));
assert!(!is_power_of_two(0));
assert!(!is_power_of_two(3));
assert!(!is_power_of_two(7));
}
#[test]
fn test_validate_fft_size() {
assert!(validate_fft_size(4, "fft").is_ok());
assert!(validate_fft_size(1024, "fft").is_ok());
assert!(validate_fft_size(7, "fft").is_err());
}
#[test]
fn test_normalization_factor() {
let n = 8;
assert_eq!(
FftNormalization::Backward.factor(FftDirection::Forward, n),
1.0
);
assert_eq!(
FftNormalization::Backward.factor(FftDirection::Inverse, n),
0.125
);
let sqrt_inv = 1.0 / (n as f64).sqrt();
assert!(
(FftNormalization::Ortho.factor(FftDirection::Forward, n) - sqrt_inv).abs() < 1e-10
);
assert!(
(FftNormalization::Ortho.factor(FftDirection::Inverse, n) - sqrt_inv).abs() < 1e-10
);
}
#[test]
fn test_twiddle_generation() {
let twiddles = generate_twiddles_c64(8, false);
assert_eq!(twiddles.len(), 4);
assert!((twiddles[0].re - 1.0).abs() < 1e-6);
assert!(twiddles[0].im.abs() < 1e-6);
assert!(twiddles[2].re.abs() < 1e-6);
assert!((twiddles[2].im - (-1.0)).abs() < 1e-6);
}
#[test]
fn test_complex_dtype_conversion() {
assert_eq!(
complex_dtype_for_real(DType::F32).unwrap(),
DType::Complex64
);
assert_eq!(
complex_dtype_for_real(DType::F64).unwrap(),
DType::Complex128
);
assert!(complex_dtype_for_real(DType::I32).is_err());
}
}
#[allow(dead_code)]
fn _verify_cpu_fft_impl() {
fn assert_fft_impl<T: FftAlgorithms<crate::runtime::cpu::CpuRuntime>>() {}
assert_fft_impl::<crate::runtime::cpu::CpuClient>();
}
#[cfg(feature = "cuda")]
#[allow(dead_code)]
fn _verify_cuda_fft_impl() {
fn assert_fft_impl<T: FftAlgorithms<crate::runtime::cuda::CudaRuntime>>() {}
assert_fft_impl::<crate::runtime::cuda::CudaClient>();
}
#[cfg(feature = "wgpu")]
#[allow(dead_code)]
fn _verify_wgpu_fft_impl() {
fn assert_fft_impl<T: FftAlgorithms<crate::runtime::wgpu::WgpuRuntime>>() {}
assert_fft_impl::<crate::runtime::wgpu::WgpuClient>();
}