Skip to main content

burn_tensor/tensor/signal/
fft.rs

1use alloc::vec;
2use burn_backend::Backend;
3
4use crate::Tensor;
5use crate::TensorPrimitive;
6use crate::check;
7use crate::check::TensorCheck;
8
9/// Computes the 1-dimensional discrete Fourier Transform of real-valued input.
10///
11/// Since the input is real, the Hermitian symmetry is exploited, and only the
12/// first non-redundant values are returned ($N/2 + 1$).
13/// For now, the autodiff is not yet supported
14///
15#[cfg_attr(
16    doc,
17    doc = r#"
18The mathematical formulation for each element $k$ in the frequency domain is:
19
20$$X\[k\] = \sum_{n=0}^{N-1} x\[n\] \left\[ \cos\left(\frac{2\pi kn}{N}\right) - i \sin\left(\frac{2\pi kn}{N}\right) \right\]$$
21
22where $N$ is the size of the signal along the specified dimension.
23"#
24)]
25#[cfg_attr(not(doc), doc = r"X\[k\] = Σ x\[n\] * exp(-i*2πkn/N)")]
26///
27/// # Arguments
28///
29/// * `signal` - The input tensor containing the real-valued signal.
30/// * `dim` - The dimension along which to take the FFT.
31/// * `n` - Optional FFT length. When `None`, the signal must be a power of two along `dim`.
32///   When `Some(n)`, `n` must also be a power of two; the signal is truncated or zero-padded
33///   to length `n`. Non-power-of-two `n` is rejected with a panic (true arbitrary-size DFT
34///   support via Bluestein's algorithm is tracked as a follow-up).
35///
36/// # Returns
37///
38/// A tuple containing:
39/// 1. The real part of the spectrum. Output length along `dim` is `n / 2 + 1` (using `n` or
40///    `signal_len` respectively).
41/// 2. The imaginary part of the spectrum (same shape).
42///
43/// # Example
44///
45/// ```rust
46/// use burn_tensor::backend::Backend;
47/// use burn_tensor::Tensor;
48///
49/// fn example<B: Backend>() {
50///     let device = B::Device::default();
51///     let signal = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device);
52///     let (real, imag) = burn_tensor::signal::rfft(signal, 0, None);
53/// }
54/// ```
55pub fn rfft<B: Backend, const D: usize>(
56    signal: Tensor<B, D>,
57    dim: usize,
58    n: Option<usize>,
59) -> (Tensor<B, D>, Tensor<B, D>) {
60    check!(TensorCheck::check_dim::<D>(dim));
61
62    match n {
63        None => check!(TensorCheck::check_is_power_of_two::<D>(
64            &signal.shape(),
65            dim
66        )),
67        Some(n) => {
68            assert!(n >= 1, "rfft: n must be >= 1, got {n}");
69            assert!(
70                n.is_power_of_two(),
71                "rfft: n must be a power of two, got {n}. True non-power-of-two \
72                 DFT support is tracked as a follow-up (Bluestein's algorithm)."
73            );
74        }
75    }
76
77    let (spectrum_re, spectrum_im) = B::rfft(signal.primitive.tensor(), dim, n);
78    (
79        Tensor::new(TensorPrimitive::Float(spectrum_re)),
80        Tensor::new(TensorPrimitive::Float(spectrum_im)),
81    )
82}
83
84/// Computes the 1-dimensional inverse discrete Fourier Transform for real-valued signals.
85///
86/// This function reconstructs the real-valued time-domain signal from the
87/// first non-redundant values ($N/2 + 1$) of the frequency-domain spectrum.
88/// For now, the autodiff is not yet supported.
89///
90#[cfg_attr(
91    doc,
92    doc = r#"
93The mathematical formulation for each element $n$ in the time domain is:
94
95$$x\[n\] = \frac{1}{N} \sum_{k=0}^{N-1} X\[k\] \left\[ \cos\left(\frac{2\pi kn}{N}\right) + i \sin\left(\frac{2\pi kn}{N}\right) \right\]$$
96
97where $N$ is the size of the reconstructed signal.
98"#
99)]
100#[cfg_attr(not(doc), doc = r"x\[n\] = (1/N) * Σ X\[k\] * exp(i*2πkn/N)")]
101///
102/// # Arguments
103///
104/// * `spectrum_re` - The real part of the spectrum.
105/// * `spectrum_im` - The imaginary part of the spectrum.
106/// * `dim` - The dimension along which to take the inverse FFT.
107/// * `n` - Optional output signal length. When `None`, the reconstructed signal length
108///   `2 * (size - 1)` must be a power of two. When `Some(n)`, `n` must also be a power of
109///   two and the output has exactly `n` samples. Non-power-of-two `n` is rejected.
110///
111/// # Returns
112///
113/// The reconstructed real-valued signal.
114///
115/// # Example
116///
117/// ```rust
118/// use burn_tensor::backend::Backend;
119/// use burn_tensor::Tensor;
120///
121/// fn example<B: Backend>() {
122///     let device = B::Device::default();
123///     let real = Tensor::<B, 1>::from_floats([10.0, -2.0, 2.0], &device);
124///     let imag = Tensor::<B, 1>::from_floats([0.0, 2.0, 0.0], &device);
125///     let signal = burn_tensor::signal::irfft(real, imag, 0, None);
126/// }
127/// ```
128pub fn irfft<B: Backend, const D: usize>(
129    spectrum_re: Tensor<B, D>,
130    spectrum_im: Tensor<B, D>,
131    dim: usize,
132    n: Option<usize>,
133) -> Tensor<B, D> {
134    check!(TensorCheck::check_dim::<D>(dim));
135
136    if let Some(n) = n {
137        assert!(n >= 1, "irfft: n must be >= 1, got {n}");
138        assert!(
139            n.is_power_of_two(),
140            "irfft: n must be a power of two, got {n}. True non-power-of-two \
141             DFT support is tracked as a follow-up (Bluestein's algorithm)."
142        );
143    }
144
145    let signal = B::irfft(
146        spectrum_re.primitive.tensor(),
147        spectrum_im.primitive.tensor(),
148        dim,
149        n,
150    );
151    Tensor::new(TensorPrimitive::Float(signal))
152}
153
154/// Computes the 1-dimensional discrete Fourier Transform of complex-valued input.
155///
156/// Internally calls [`rfft`] on the real and imaginary parts separately,
157/// extends each half-spectrum to the full `N`-bin spectrum via Hermitian
158/// symmetry.
159///
160/// Autodiff is not yet supported.
161///
162#[cfg_attr(
163    doc,
164    doc = r#"
165
166Due to the linearity of the Fourier Transform, a complex-valued signal $x\[n\] = x_{re}\[n\] + i x_{im}\[n\]$ can be transformed by applying the FFT to its real and imaginary parts separately:
167
168$$ \text{FFT}(x\[n\]) = \text{FFT}(x_{re}\[n\]) + i \text{FFT}(x_{im}\[n\]) $$
169
170Since $x_{re}\[n\]$ and $x_{im}\[n\]$ are purely real, their transforms can be computed efficiently using the real FFT ([`rfft`]). The full spectrum is then reconstructed by exploiting Hermitian symmetry.
171"#
172)]
173#[cfg_attr(not(doc), doc = r"X\[k\] = Σ x\[n\] * exp(-i*2πkn/N)")]
174///
175/// # Arguments
176///
177/// * `signal_re` - The real part of the complex input signal.
178/// * `signal_im` - The imaginary part of the complex input signal. Must have the
179///   same shape as `signal_re`.
180/// * `dim` - The dimension along which to take the FFT.
181/// * `n` - Optional FFT length. When `None`, the signal must be a power of two
182///   along `dim`. When `Some(n)`, `n` must also be a power of two; the signal is
183///   truncated or zero-padded to length `n`.
184///
185/// # Returns
186///
187/// A tuple `(re, im)` representing the full complex spectrum, each with `n`
188/// elements along `dim`.
189///
190/// # Example
191///
192/// ```rust
193/// use burn_tensor::backend::Backend;
194/// use burn_tensor::Tensor;
195///
196/// fn example<B: Backend>() {
197///     let device = B::Device::default();
198///     let re = Tensor::<B, 1>::from_floats([1.0, 0.0, -1.0, 0.0], &device);
199///     let im = Tensor::<B, 1>::from_floats([0.0, 1.0, 0.0, -1.0], &device);
200///     let (spec_re, spec_im) = burn_tensor::signal::cfft(re, im, 0, None);
201/// }
202/// ```
203pub fn cfft<B: Backend, const D: usize>(
204    signal_re: Tensor<B, D>,
205    signal_im: Tensor<B, D>,
206    dim: usize,
207    n: Option<usize>,
208) -> (Tensor<B, D>, Tensor<B, D>) {
209    assert!(
210        signal_re.shape() == signal_im.shape(),
211        "cfft: signal_re and signal_im must have the same shape, \
212         got {:?} and {:?}",
213        signal_re.shape(),
214        signal_im.shape(),
215    );
216
217    check!(TensorCheck::check_dim::<D>(dim));
218    let fft_size = n.unwrap_or(signal_re.dims()[dim]);
219
220    // rfft validates power-of-two and n constraints internally
221    let (xr, xi) = rfft(signal_re, dim, n);
222    let (yr, yi) = rfft(signal_im, dim, n);
223
224    // Extend half-spectra (N/2+1 bins) to full N-bin spectra via Hermitian symmetry
225    let (xr, xi) = hermitian_extend(xr, xi, dim, fft_size);
226    let (yr, yi) = hermitian_extend(yr, yi, dim, fft_size);
227
228    // FFT(z) = FFT(x) + i·FFT(y)
229    //        = (Xr + i·Xi) + i·(Yr + i·Yi)
230    //        = (Xr - Yi) + i·(Xi + Yr)
231    (xr - yi, xi + yr)
232}
233
234/// Extend a half-spectrum from [`rfft`] (`N/2 + 1` bins) to the full `N`-bin
235/// spectrum using Hermitian symmetry: `X[k] = conj(X[N-k])` for `k > N/2`.
236pub(super) fn hermitian_extend<B: Backend, const D: usize>(
237    half_re: Tensor<B, D>,
238    half_im: Tensor<B, D>,
239    dim: usize,
240    full_len: usize,
241) -> (Tensor<B, D>, Tensor<B, D>) {
242    let half_len = half_re.dims()[dim]; // N/2 + 1
243
244    // For N <= 2, the half-spectrum already covers all bins
245    if full_len <= half_len {
246        return (half_re, half_im);
247    }
248
249    // Mirror bins: reverse of bins 1..N/2-1 (skipping the Nyquist bin),
250    // with conjugated imaginary part. This produces X[N/2+1], X[N/2+2], ..., X[N-1]
251    let mirror_len = full_len - half_len; // N/2 - 1
252    let mirror_re = half_re
253        .clone()
254        .narrow(dim, 1, mirror_len)
255        .flip([dim as isize]);
256    let mirror_im = half_im
257        .clone()
258        .narrow(dim, 1, mirror_len)
259        .flip([dim as isize])
260        .neg();
261
262    // Full spectrum = [half_spectrum, conjugate_mirror]
263    let full_re = Tensor::cat(vec![half_re, mirror_re], dim);
264    let full_im = Tensor::cat(vec![half_im, mirror_im], dim);
265
266    (full_re, full_im)
267}