use burn_backend::Backend;
use crate::Tensor;
use crate::TensorPrimitive;
use crate::check;
use crate::check::TensorCheck;
#[cfg_attr(
doc,
doc = r#"
The mathematical formulation for each element $k$ in the frequency domain is:
$$X\[k\] = \sum_{n=0}^{N-1} x\[n\] \left\[ \cos\left(\frac{2\pi kn}{N}\right) - i \sin\left(\frac{2\pi kn}{N}\right) \right\]$$
where $N$ is the size of the signal along the specified dimension.
"#
)]
#[cfg_attr(not(doc), doc = r"X\[k\] = Σ x\[n\] * exp(-i*2πkn/N)")]
pub fn rfft<B: Backend, const D: usize>(
signal: Tensor<B, D>,
dim: usize,
n: Option<usize>,
) -> (Tensor<B, D>, Tensor<B, D>) {
check!(TensorCheck::check_dim::<D>(dim));
match n {
None => check!(TensorCheck::check_is_power_of_two::<D>(
&signal.shape(),
dim
)),
Some(n) => {
assert!(n >= 1, "rfft: n must be >= 1, got {n}");
assert!(
n.is_power_of_two(),
"rfft: n must be a power of two, got {n}. True non-power-of-two \
DFT support is tracked as a follow-up (Bluestein's algorithm)."
);
}
}
let (spectrum_re, spectrum_im) = B::rfft(signal.primitive.tensor(), dim, n);
(
Tensor::new(TensorPrimitive::Float(spectrum_re)),
Tensor::new(TensorPrimitive::Float(spectrum_im)),
)
}
#[cfg_attr(
doc,
doc = r#"
The mathematical formulation for each element $n$ in the time domain is:
$$x\[n\] = \frac{1}{N} \sum_{k=0}^{N-1} X\[k\] \left\[ \cos\left(\frac{2\pi kn}{N}\right) + i \sin\left(\frac{2\pi kn}{N}\right) \right\]$$
where $N$ is the size of the reconstructed signal.
"#
)]
#[cfg_attr(not(doc), doc = r"x\[n\] = (1/N) * Σ X\[k\] * exp(i*2πkn/N)")]
pub fn irfft<B: Backend, const D: usize>(
spectrum_re: Tensor<B, D>,
spectrum_im: Tensor<B, D>,
dim: usize,
n: Option<usize>,
) -> Tensor<B, D> {
check!(TensorCheck::check_dim::<D>(dim));
if let Some(n) = n {
assert!(n >= 1, "irfft: n must be >= 1, got {n}");
assert!(
n.is_power_of_two(),
"irfft: n must be a power of two, got {n}. True non-power-of-two \
DFT support is tracked as a follow-up (Bluestein's algorithm)."
);
}
let signal = B::irfft(
spectrum_re.primitive.tensor(),
spectrum_im.primitive.tensor(),
dim,
n,
);
Tensor::new(TensorPrimitive::Float(signal))
}