burn_tensor/tensor/signal/fft.rs
1use alloc::vec;
2use burn_backend::Backend;
3
4use crate::Tensor;
5use crate::TensorPrimitive;
6use crate::check;
7use crate::check::TensorCheck;
8
9/// Computes the 1-dimensional discrete Fourier Transform of real-valued input.
10///
11/// Since the input is real, the Hermitian symmetry is exploited, and only the
12/// first non-redundant values are returned ($N/2 + 1$).
13/// For now, the autodiff is not yet supported
14///
15#[cfg_attr(
16 doc,
17 doc = r#"
18The mathematical formulation for each element $k$ in the frequency domain is:
19
20$$X\[k\] = \sum_{n=0}^{N-1} x\[n\] \left\[ \cos\left(\frac{2\pi kn}{N}\right) - i \sin\left(\frac{2\pi kn}{N}\right) \right\]$$
21
22where $N$ is the size of the signal along the specified dimension.
23"#
24)]
25#[cfg_attr(not(doc), doc = r"X\[k\] = Σ x\[n\] * exp(-i*2πkn/N)")]
26///
27/// # Arguments
28///
29/// * `signal` - The input tensor containing the real-valued signal.
30/// * `dim` - The dimension along which to take the FFT.
31/// * `n` - Optional FFT length. When `None`, the signal must be a power of two along `dim`.
32/// When `Some(n)`, `n` must also be a power of two; the signal is truncated or zero-padded
33/// to length `n`. Non-power-of-two `n` is rejected with a panic (true arbitrary-size DFT
34/// support via Bluestein's algorithm is tracked as a follow-up).
35///
36/// # Returns
37///
38/// A tuple containing:
39/// 1. The real part of the spectrum. Output length along `dim` is `n / 2 + 1` (using `n` or
40/// `signal_len` respectively).
41/// 2. The imaginary part of the spectrum (same shape).
42///
43/// # Example
44///
45/// ```rust
46/// use burn_tensor::backend::Backend;
47/// use burn_tensor::Tensor;
48///
49/// fn example<B: Backend>() {
50/// let device = B::Device::default();
51/// let signal = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device);
52/// let (real, imag) = burn_tensor::signal::rfft(signal, 0, None);
53/// }
54/// ```
55pub fn rfft<B: Backend, const D: usize>(
56 signal: Tensor<B, D>,
57 dim: usize,
58 n: Option<usize>,
59) -> (Tensor<B, D>, Tensor<B, D>) {
60 check!(TensorCheck::check_dim::<D>(dim));
61
62 match n {
63 None => check!(TensorCheck::check_is_power_of_two::<D>(
64 &signal.shape(),
65 dim
66 )),
67 Some(n) => {
68 assert!(n >= 1, "rfft: n must be >= 1, got {n}");
69 assert!(
70 n.is_power_of_two(),
71 "rfft: n must be a power of two, got {n}. True non-power-of-two \
72 DFT support is tracked as a follow-up (Bluestein's algorithm)."
73 );
74 }
75 }
76
77 let (spectrum_re, spectrum_im) = B::rfft(signal.primitive.tensor(), dim, n);
78 (
79 Tensor::new(TensorPrimitive::Float(spectrum_re)),
80 Tensor::new(TensorPrimitive::Float(spectrum_im)),
81 )
82}
83
84/// Computes the 1-dimensional inverse discrete Fourier Transform for real-valued signals.
85///
86/// This function reconstructs the real-valued time-domain signal from the
87/// first non-redundant values ($N/2 + 1$) of the frequency-domain spectrum.
88/// For now, the autodiff is not yet supported.
89///
90#[cfg_attr(
91 doc,
92 doc = r#"
93The mathematical formulation for each element $n$ in the time domain is:
94
95$$x\[n\] = \frac{1}{N} \sum_{k=0}^{N-1} X\[k\] \left\[ \cos\left(\frac{2\pi kn}{N}\right) + i \sin\left(\frac{2\pi kn}{N}\right) \right\]$$
96
97where $N$ is the size of the reconstructed signal.
98"#
99)]
100#[cfg_attr(not(doc), doc = r"x\[n\] = (1/N) * Σ X\[k\] * exp(i*2πkn/N)")]
101///
102/// # Arguments
103///
104/// * `spectrum_re` - The real part of the spectrum.
105/// * `spectrum_im` - The imaginary part of the spectrum.
106/// * `dim` - The dimension along which to take the inverse FFT.
107/// * `n` - Optional output signal length. When `None`, the reconstructed signal length
108/// `2 * (size - 1)` must be a power of two. When `Some(n)`, `n` must also be a power of
109/// two and the output has exactly `n` samples. Non-power-of-two `n` is rejected.
110///
111/// # Returns
112///
113/// The reconstructed real-valued signal.
114///
115/// # Example
116///
117/// ```rust
118/// use burn_tensor::backend::Backend;
119/// use burn_tensor::Tensor;
120///
121/// fn example<B: Backend>() {
122/// let device = B::Device::default();
123/// let real = Tensor::<B, 1>::from_floats([10.0, -2.0, 2.0], &device);
124/// let imag = Tensor::<B, 1>::from_floats([0.0, 2.0, 0.0], &device);
125/// let signal = burn_tensor::signal::irfft(real, imag, 0, None);
126/// }
127/// ```
128pub fn irfft<B: Backend, const D: usize>(
129 spectrum_re: Tensor<B, D>,
130 spectrum_im: Tensor<B, D>,
131 dim: usize,
132 n: Option<usize>,
133) -> Tensor<B, D> {
134 check!(TensorCheck::check_dim::<D>(dim));
135
136 if let Some(n) = n {
137 assert!(n >= 1, "irfft: n must be >= 1, got {n}");
138 assert!(
139 n.is_power_of_two(),
140 "irfft: n must be a power of two, got {n}. True non-power-of-two \
141 DFT support is tracked as a follow-up (Bluestein's algorithm)."
142 );
143 }
144
145 let signal = B::irfft(
146 spectrum_re.primitive.tensor(),
147 spectrum_im.primitive.tensor(),
148 dim,
149 n,
150 );
151 Tensor::new(TensorPrimitive::Float(signal))
152}
153
154/// Computes the 1-dimensional discrete Fourier Transform of complex-valued input.
155///
156/// Internally calls [`rfft`] on the real and imaginary parts separately,
157/// extends each half-spectrum to the full `N`-bin spectrum via Hermitian
158/// symmetry.
159///
160/// Autodiff is not yet supported.
161///
162#[cfg_attr(
163 doc,
164 doc = r#"
165
166Due to the linearity of the Fourier Transform, a complex-valued signal $x\[n\] = x_{re}\[n\] + i x_{im}\[n\]$ can be transformed by applying the FFT to its real and imaginary parts separately:
167
168$$ \text{FFT}(x\[n\]) = \text{FFT}(x_{re}\[n\]) + i \text{FFT}(x_{im}\[n\]) $$
169
170Since $x_{re}\[n\]$ and $x_{im}\[n\]$ are purely real, their transforms can be computed efficiently using the real FFT ([`rfft`]). The full spectrum is then reconstructed by exploiting Hermitian symmetry.
171"#
172)]
173#[cfg_attr(not(doc), doc = r"X\[k\] = Σ x\[n\] * exp(-i*2πkn/N)")]
174///
175/// # Arguments
176///
177/// * `signal_re` - The real part of the complex input signal.
178/// * `signal_im` - The imaginary part of the complex input signal. Must have the
179/// same shape as `signal_re`.
180/// * `dim` - The dimension along which to take the FFT.
181/// * `n` - Optional FFT length. When `None`, the signal must be a power of two
182/// along `dim`. When `Some(n)`, `n` must also be a power of two; the signal is
183/// truncated or zero-padded to length `n`.
184///
185/// # Returns
186///
187/// A tuple `(re, im)` representing the full complex spectrum, each with `n`
188/// elements along `dim`.
189///
190/// # Example
191///
192/// ```rust
193/// use burn_tensor::backend::Backend;
194/// use burn_tensor::Tensor;
195///
196/// fn example<B: Backend>() {
197/// let device = B::Device::default();
198/// let re = Tensor::<B, 1>::from_floats([1.0, 0.0, -1.0, 0.0], &device);
199/// let im = Tensor::<B, 1>::from_floats([0.0, 1.0, 0.0, -1.0], &device);
200/// let (spec_re, spec_im) = burn_tensor::signal::cfft(re, im, 0, None);
201/// }
202/// ```
203pub fn cfft<B: Backend, const D: usize>(
204 signal_re: Tensor<B, D>,
205 signal_im: Tensor<B, D>,
206 dim: usize,
207 n: Option<usize>,
208) -> (Tensor<B, D>, Tensor<B, D>) {
209 assert!(
210 signal_re.shape() == signal_im.shape(),
211 "cfft: signal_re and signal_im must have the same shape, \
212 got {:?} and {:?}",
213 signal_re.shape(),
214 signal_im.shape(),
215 );
216
217 check!(TensorCheck::check_dim::<D>(dim));
218 let fft_size = n.unwrap_or(signal_re.dims()[dim]);
219
220 // rfft validates power-of-two and n constraints internally
221 let (xr, xi) = rfft(signal_re, dim, n);
222 let (yr, yi) = rfft(signal_im, dim, n);
223
224 // Extend half-spectra (N/2+1 bins) to full N-bin spectra via Hermitian symmetry
225 let (xr, xi) = hermitian_extend(xr, xi, dim, fft_size);
226 let (yr, yi) = hermitian_extend(yr, yi, dim, fft_size);
227
228 // FFT(z) = FFT(x) + i·FFT(y)
229 // = (Xr + i·Xi) + i·(Yr + i·Yi)
230 // = (Xr - Yi) + i·(Xi + Yr)
231 (xr - yi, xi + yr)
232}
233
234/// Extend a half-spectrum from [`rfft`] (`N/2 + 1` bins) to the full `N`-bin
235/// spectrum using Hermitian symmetry: `X[k] = conj(X[N-k])` for `k > N/2`.
236pub(super) fn hermitian_extend<B: Backend, const D: usize>(
237 half_re: Tensor<B, D>,
238 half_im: Tensor<B, D>,
239 dim: usize,
240 full_len: usize,
241) -> (Tensor<B, D>, Tensor<B, D>) {
242 let half_len = half_re.dims()[dim]; // N/2 + 1
243
244 // For N <= 2, the half-spectrum already covers all bins
245 if full_len <= half_len {
246 return (half_re, half_im);
247 }
248
249 // Mirror bins: reverse of bins 1..N/2-1 (skipping the Nyquist bin),
250 // with conjugated imaginary part. This produces X[N/2+1], X[N/2+2], ..., X[N-1]
251 let mirror_len = full_len - half_len; // N/2 - 1
252 let mirror_re = half_re
253 .clone()
254 .narrow(dim, 1, mirror_len)
255 .flip([dim as isize]);
256 let mirror_im = half_im
257 .clone()
258 .narrow(dim, 1, mirror_len)
259 .flip([dim as isize])
260 .neg();
261
262 // Full spectrum = [half_spectrum, conjugate_mirror]
263 let full_re = Tensor::cat(vec![half_re, mirror_re], dim);
264 let full_im = Tensor::cat(vec![half_im, mirror_im], dim);
265
266 (full_re, full_im)
267}