use burn_backend::Backend;
use crate::Tensor;
use crate::TensorPrimitive;
use crate::check;
use crate::check::TensorCheck;
#[cfg_attr(
doc,
doc = r#"
The mathematical formulation for each element $k$ in the frequency domain is:
$$X\[k\] = \sum_{n=0}^{N-1} x\[n\] \left\[ \cos\left(\frac{2\pi kn}{N}\right) - i \sin\left(\frac{2\pi kn}{N}\right) \right\]$$
where $N$ is the size of the signal along the specified dimension.
"#
)]
#[cfg_attr(not(doc), doc = r"X\[k\] = Σ x\[n\] * exp(-i*2πkn/N)")]
pub fn rfft<B: Backend, const D: usize>(
signal: Tensor<B, D>,
dim: usize,
) -> (Tensor<B, D>, Tensor<B, D>) {
check!(TensorCheck::check_dim::<D>(dim));
check!(TensorCheck::check_is_power_of_two::<D>(
&signal.shape(),
dim
));
let (spectrum_re, spectrum_im) = B::rfft(signal.primitive.tensor(), dim);
(
Tensor::new(TensorPrimitive::Float(spectrum_re)),
Tensor::new(TensorPrimitive::Float(spectrum_im)),
)
}
#[cfg_attr(
doc,
doc = r#"
The mathematical formulation for each element $n$ in the time domain is:
$$x\[n\] = \frac{1}{N} \sum_{k=0}^{N-1} X\[k\] \left\[ \cos\left(\frac{2\pi kn}{N}\right) + i \sin\left(\frac{2\pi kn}{N}\right) \right\]$$
where $N$ is the size of the reconstructed signal.
"#
)]
#[cfg_attr(not(doc), doc = r"x\[n\] = (1/N) * Σ X\[k\] * exp(i*2πkn/N)")]
pub fn irfft<B: Backend, const D: usize>(
spectrum_re: Tensor<B, D>,
spectrum_im: Tensor<B, D>,
dim: usize,
) -> Tensor<B, D> {
check!(TensorCheck::check_dim::<D>(dim));
let signal = B::irfft(
spectrum_re.primitive.tensor(),
spectrum_im.primitive.tensor(),
dim,
);
Tensor::new(TensorPrimitive::Float(signal))
}