use alloc::vec;
use burn_backend::Backend;
use crate::Tensor;
use crate::TensorPrimitive;
use crate::check;
use crate::check::TensorCheck;
#[cfg_attr(
doc,
doc = r#"
The mathematical formulation for each element $k$ in the frequency domain is:
$$X\[k\] = \sum_{n=0}^{N-1} x\[n\] \left\[ \cos\left(\frac{2\pi kn}{N}\right) - i \sin\left(\frac{2\pi kn}{N}\right) \right\]$$
where $N$ is the size of the signal along the specified dimension.
"#
)]
#[cfg_attr(not(doc), doc = r"X\[k\] = Σ x\[n\] * exp(-i*2πkn/N)")]
pub fn rfft<B: Backend, const D: usize>(
signal: Tensor<B, D>,
dim: usize,
n: Option<usize>,
) -> (Tensor<B, D>, Tensor<B, D>) {
check!(TensorCheck::check_dim::<D>(dim));
match n {
None => check!(TensorCheck::check_is_power_of_two::<D>(
&signal.shape(),
dim
)),
Some(n) => {
assert!(n >= 1, "rfft: n must be >= 1, got {n}");
assert!(
n.is_power_of_two(),
"rfft: n must be a power of two, got {n}. True non-power-of-two \
DFT support is tracked as a follow-up (Bluestein's algorithm)."
);
}
}
let (spectrum_re, spectrum_im) = B::rfft(signal.primitive.tensor(), dim, n);
(
Tensor::new(TensorPrimitive::Float(spectrum_re)),
Tensor::new(TensorPrimitive::Float(spectrum_im)),
)
}
#[cfg_attr(
doc,
doc = r#"
The mathematical formulation for each element $n$ in the time domain is:
$$x\[n\] = \frac{1}{N} \sum_{k=0}^{N-1} X\[k\] \left\[ \cos\left(\frac{2\pi kn}{N}\right) + i \sin\left(\frac{2\pi kn}{N}\right) \right\]$$
where $N$ is the size of the reconstructed signal.
"#
)]
#[cfg_attr(not(doc), doc = r"x\[n\] = (1/N) * Σ X\[k\] * exp(i*2πkn/N)")]
pub fn irfft<B: Backend, const D: usize>(
spectrum_re: Tensor<B, D>,
spectrum_im: Tensor<B, D>,
dim: usize,
n: Option<usize>,
) -> Tensor<B, D> {
check!(TensorCheck::check_dim::<D>(dim));
if let Some(n) = n {
assert!(n >= 1, "irfft: n must be >= 1, got {n}");
assert!(
n.is_power_of_two(),
"irfft: n must be a power of two, got {n}. True non-power-of-two \
DFT support is tracked as a follow-up (Bluestein's algorithm)."
);
}
let signal = B::irfft(
spectrum_re.primitive.tensor(),
spectrum_im.primitive.tensor(),
dim,
n,
);
Tensor::new(TensorPrimitive::Float(signal))
}
#[cfg_attr(
doc,
doc = r#"
Due to the linearity of the Fourier Transform, a complex-valued signal $x\[n\] = x_{re}\[n\] + i x_{im}\[n\]$ can be transformed by applying the FFT to its real and imaginary parts separately:
$$ \text{FFT}(x\[n\]) = \text{FFT}(x_{re}\[n\]) + i \text{FFT}(x_{im}\[n\]) $$
Since $x_{re}\[n\]$ and $x_{im}\[n\]$ are purely real, their transforms can be computed efficiently using the real FFT ([`rfft`]). The full spectrum is then reconstructed by exploiting Hermitian symmetry.
"#
)]
#[cfg_attr(not(doc), doc = r"X\[k\] = Σ x\[n\] * exp(-i*2πkn/N)")]
pub fn cfft<B: Backend, const D: usize>(
signal_re: Tensor<B, D>,
signal_im: Tensor<B, D>,
dim: usize,
n: Option<usize>,
) -> (Tensor<B, D>, Tensor<B, D>) {
assert!(
signal_re.shape() == signal_im.shape(),
"cfft: signal_re and signal_im must have the same shape, \
got {:?} and {:?}",
signal_re.shape(),
signal_im.shape(),
);
check!(TensorCheck::check_dim::<D>(dim));
let fft_size = n.unwrap_or(signal_re.dims()[dim]);
let (xr, xi) = rfft(signal_re, dim, n);
let (yr, yi) = rfft(signal_im, dim, n);
let (xr, xi) = hermitian_extend(xr, xi, dim, fft_size);
let (yr, yi) = hermitian_extend(yr, yi, dim, fft_size);
(xr - yi, xi + yr)
}
pub(super) fn hermitian_extend<B: Backend, const D: usize>(
half_re: Tensor<B, D>,
half_im: Tensor<B, D>,
dim: usize,
full_len: usize,
) -> (Tensor<B, D>, Tensor<B, D>) {
let half_len = half_re.dims()[dim];
if full_len <= half_len {
return (half_re, half_im);
}
let mirror_len = full_len - half_len; let mirror_re = half_re
.clone()
.narrow(dim, 1, mirror_len)
.flip([dim as isize]);
let mirror_im = half_im
.clone()
.narrow(dim, 1, mirror_len)
.flip([dim as isize])
.neg();
let full_re = Tensor::cat(vec![half_re, mirror_re], dim);
let full_im = Tensor::cat(vec![half_im, mirror_im], dim);
(full_re, full_im)
}