Skip to main content

rlx_ir/ops/
fft_ops.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! gpu-fft-shaped graph helpers: real-input FFT, spectrum split/merge, PSD, STFT.
17//!
18//! These methods compose primitive ops (`Op::Fft`, `narrow`, `concat`, …) into
19//! NumPy/JAX-style signal-processing building blocks. Backends that cannot lower
20//! the full subgraph (for example Metal MPSGraph on `Op::Fft`) still execute
21//! the underlying `Op::Fft` nodes via thunks or host fallback.
22
23use crate::infer::GraphExt as _;
24use crate::{DType, Graph, NodeId, Op, Shape, fft::FftNorm};
25
26impl Graph {
27    /// Zero-pad the last axis to the next power of two (no-op when already pow2).
28    pub fn pad_last_axis_to_pow2(&mut self, x: NodeId) -> NodeId {
29        let shape = self.shape(x).clone();
30        let rank = shape.rank();
31        let last = rank - 1;
32        let n = shape.dim(last).unwrap_static();
33        let n_pad = crate::fft::next_pow2(n);
34        if n_pad == n {
35            return x;
36        }
37        let pad_len = n_pad - n;
38        let mut pad_dims: Vec<usize> = shape.dims().iter().map(|d| d.unwrap_static()).collect();
39        pad_dims[last] = pad_len;
40        let pad_shape = Shape::new(&pad_dims, shape.dtype());
41        let zeros = self.zeros_tensor(&pad_shape);
42        self.concat_(vec![x, zeros], last)
43    }
44
45    /// Split a 2N real-block spectrum into separate real / imag tensors.
46    pub fn split_spectrum(&mut self, spectrum: NodeId) -> (NodeId, NodeId) {
47        let shape = self.shape(spectrum).clone();
48        let meta = crate::fft::fft_meta(&shape);
49        let last = shape.rank() - 1;
50        let n = meta.n_complex;
51        let re = self.narrow_(spectrum, last, 0, n);
52        let im = self.narrow_(spectrum, last, n, n);
53        (re, im)
54    }
55
56    /// Real-input FFT (gpu-fft `fft`): auto zero-pads to pow2, returns `(re, im)`.
57    pub fn fft_real(&mut self, x: NodeId, norm: FftNorm) -> (NodeId, NodeId) {
58        assert_eq!(
59            self.shape(x).dtype(),
60            DType::F32,
61            "fft_real: requires F32 real input"
62        );
63        let padded = self.pad_last_axis_to_pow2(x);
64        let shape = self.shape(padded).clone();
65        let rank = shape.rank();
66        let last = rank - 1;
67        let n = shape.dim(last).unwrap_static();
68        let mut im_dims: Vec<usize> = shape.dims().iter().map(|d| d.unwrap_static()).collect();
69        im_dims[last] = n;
70        let im_shape = Shape::new(&im_dims, DType::F32);
71        let zero_im = self.zeros_tensor(&im_shape);
72        let block = self.concat_(vec![padded, zero_im], last);
73        let spectrum = self.fft_norm(block, false, norm);
74        self.split_spectrum(spectrum)
75    }
76
77    /// Batched real-input FFT — same as `fft_real` when the last axis is signal
78    /// length; leading axes are independent batch dimensions.
79    pub fn fft_batch_real(&mut self, x: NodeId, norm: FftNorm) -> (NodeId, NodeId) {
80        self.fft_real(x, norm)
81    }
82
83    /// Real-input FFT with half-spectrum output (`n_pad/2 + 1` complex bins).
84    ///
85    /// The input is zero-padded to the next power of two along the last axis
86    /// before the transform, matching NumPy `rfft` padding semantics.
87    pub fn rfft(&mut self, x: NodeId, norm: FftNorm) -> (NodeId, NodeId) {
88        let (re, im) = self.fft_real(x, norm);
89        let rank = self.shape(re).rank();
90        let last = rank - 1;
91        let n = self.shape(re).dim(last).unwrap_static();
92        let half = n / 2 + 1;
93        (
94            self.narrow_(re, last, 0, half),
95            self.narrow_(im, last, 0, half),
96        )
97    }
98
99    /// Inverse real FFT from half-spectrum `(re, im)` with Hermitian symmetry.
100    ///
101    /// Mirrors the conjugate half of the spectrum (excluding DC and Nyquist) before
102    /// calling [`Self::ifft_spectrum`], then truncates to length `n`.
103    pub fn irfft(&mut self, re_half: NodeId, im_half: NodeId, n: usize, norm: FftNorm) -> NodeId {
104        assert_eq!(
105            *self.shape(re_half),
106            *self.shape(im_half),
107            "irfft: re/im shape mismatch"
108        );
109        let n_pad = crate::fft::next_pow2(n);
110        let half = n_pad / 2 + 1;
111        let rank = self.shape(re_half).rank();
112        let last = rank - 1;
113        assert_eq!(
114            self.shape(re_half).dim(last).unwrap_static(),
115            half,
116            "irfft: expected half-spectrum length {half}, got {}",
117            self.shape(re_half).dim(last).unwrap_static()
118        );
119        let (re_full, im_full) = if half > 2 {
120            let mirror_len = half - 2;
121            let mirror_re = self.narrow_(re_half, last, 1, mirror_len);
122            let mirror_im = self.narrow_(im_half, last, 1, mirror_len);
123            let mirror_re_rev = self.reverse_last_axis(mirror_re);
124            let mirror_im_rev = self.reverse_last_axis(mirror_im);
125            let neg = self.scalar_f32(-1.0);
126            let mirror_im_neg = self.mul(mirror_im_rev, neg);
127            (
128                self.concat_(vec![re_half, mirror_re_rev], last),
129                self.concat_(vec![im_half, mirror_im_neg], last),
130            )
131        } else {
132            (re_half, im_half)
133        };
134        let recovered = self.ifft_spectrum(re_full, im_full, norm);
135        self.narrow_(recovered, last, 0, n)
136    }
137
138    /// Short-time Fourier transform: `[..., T]` → `[frames, ..., 2·half]` (re/im block per frame).
139    ///
140    /// Each frame is `rfft`'d with length `frame_len` and hop `hop` along the last axis.
141    pub fn stft(&mut self, x: NodeId, frame_len: usize, hop: usize, norm: FftNorm) -> NodeId {
142        assert!(
143            frame_len > 0 && hop > 0,
144            "stft: frame_len and hop must be positive"
145        );
146        let shape = self.shape(x).clone();
147        let rank = shape.rank();
148        let last = rank - 1;
149        let t = shape.dim(last).unwrap_static();
150        assert!(
151            t >= frame_len,
152            "stft: signal length {t} < frame_len {frame_len}"
153        );
154        let n_frames = 1 + (t - frame_len) / hop;
155        let mut frames = Vec::with_capacity(n_frames);
156        for f in 0..n_frames {
157            let start = f * hop;
158            let frame = self.narrow_(x, last, start, frame_len);
159            let (re, im) = self.rfft(frame, norm);
160            let block = self.concat_(vec![re, im], last);
161            frames.push(block);
162        }
163        if frames.len() == 1 {
164            let f = frames[0];
165            let mut dims: Vec<i64> = self
166                .shape(f)
167                .dims()
168                .iter()
169                .map(|d| d.unwrap_static() as i64)
170                .collect();
171            dims.insert(0, 1);
172            return self.reshape_(f, dims);
173        }
174        let mut rows = Vec::new();
175        for f in frames {
176            let mut dims: Vec<i64> = self
177                .shape(f)
178                .dims()
179                .iter()
180                .map(|d| d.unwrap_static() as i64)
181                .collect();
182            dims.insert(0, 1);
183            rows.push(self.reshape_(f, dims));
184        }
185        self.concat_(rows, 0)
186    }
187
188    /// 1D convolution via the convolution theorem (`rfft` → complex multiply → `irfft`).
189    ///
190    /// Both inputs are zero-padded to at least `n_fft` (or the next power of two covering
191    /// `len(a) + len(b) - 1` when `n_fft` is small).
192    pub fn fft_conv1d(&mut self, a: NodeId, b: NodeId, n_fft: usize, norm: FftNorm) -> NodeId {
193        let n_fft = n_fft.max(crate::fft::next_pow2(
194            self.shape(a).dim(self.shape(a).rank() - 1).unwrap_static()
195                + self.shape(b).dim(self.shape(b).rank() - 1).unwrap_static()
196                - 1,
197        ));
198        let pad_a = self.pad_axis_to_len(a, n_fft);
199        let pad_b = self.pad_axis_to_len(b, n_fft);
200        let (a_re, a_im) = self.rfft(pad_a, norm);
201        let (b_re, b_im) = self.rfft(pad_b, norm);
202        let ar_br = self.mul(a_re, b_re);
203        let ai_bi = self.mul(a_im, b_im);
204        let prod_re = self.sub(ar_br, ai_bi);
205        let ar_bi = self.mul(a_re, b_im);
206        let ai_br = self.mul(a_im, b_re);
207        let prod_im = self.add(ar_bi, ai_br);
208        let out_len = self.shape(a).dim(self.shape(a).rank() - 1).unwrap_static()
209            + self.shape(b).dim(self.shape(b).rank() - 1).unwrap_static()
210            - 1;
211        self.irfft(prod_re, prod_im, out_len.max(1), norm)
212    }
213
214    /// Constant tensor of FFT sample frequencies (length `n`, f64).
215    pub fn fftfreq_tensor(&mut self, n: usize) -> NodeId {
216        let xs = crate::fft::fftfreq(n);
217        let mut bytes = Vec::with_capacity(n * 8);
218        for x in &xs {
219            bytes.extend_from_slice(&x.to_le_bytes());
220        }
221        self.add_node(
222            Op::Constant { data: bytes },
223            vec![],
224            Shape::new(&[n], DType::F64),
225        )
226    }
227
228    /// Constant tensor of rFFT sample frequencies (length `n/2 + 1`, f64).
229    pub fn rfftfreq_tensor(&mut self, n: usize) -> NodeId {
230        let xs = crate::fft::rfftfreq(n);
231        let half = xs.len();
232        let mut bytes = Vec::with_capacity(half * 8);
233        for x in &xs {
234            bytes.extend_from_slice(&x.to_le_bytes());
235        }
236        self.add_node(
237            Op::Constant { data: bytes },
238            vec![],
239            Shape::new(&[half], DType::F64),
240        )
241    }
242
243    /// Power spectral density from real input: `rfft` → `psd`.
244    pub fn psd_real(&mut self, x: NodeId, norm: FftNorm) -> NodeId {
245        let (re, im) = self.rfft(x, norm);
246        self.psd(re, im)
247    }
248
249    /// Inverse FFT from separate real / imag spectra (gpu-fft `ifft` real part).
250    pub fn ifft_spectrum(&mut self, re: NodeId, im: NodeId, norm: FftNorm) -> NodeId {
251        let re_shape = self.shape(re).clone();
252        assert_eq!(
253            re_shape,
254            *self.shape(im),
255            "ifft_spectrum: re/im shape mismatch"
256        );
257        let rank = re_shape.rank();
258        let last = rank - 1;
259        let n = re_shape.dim(last).unwrap_static();
260        let block = self.concat_(vec![re, im], last);
261        let full = self.fft_norm(block, true, norm);
262        self.narrow_(full, last, 0, n)
263    }
264
265    /// Power spectral density: `(re² + im²) / N` (gpu-fft `psd::psd`).
266    pub fn psd(&mut self, re: NodeId, im: NodeId) -> NodeId {
267        let n = self
268            .shape(re)
269            .dim(self.shape(re).rank() - 1)
270            .unwrap_static();
271        let re2 = self.mul(re, re);
272        let im2 = self.mul(im, im);
273        let power = self.add(re2, im2);
274        let inv_n = self.scalar_f32(1.0 / n as f32);
275        self.mul(power, inv_n)
276    }
277
278    fn reverse_last_axis(&mut self, x: NodeId) -> NodeId {
279        let shape = self.shape(x).clone();
280        let rank = shape.rank();
281        let last = rank - 1;
282        let len = shape.dim(last).unwrap_static();
283        if len <= 1 {
284            return x;
285        }
286        let prefix_elems: usize = shape
287            .dims()
288            .iter()
289            .take(last)
290            .map(|d| d.unwrap_static())
291            .product();
292        let mut idx_bytes = Vec::with_capacity(prefix_elems * len * 4);
293        for _ in 0..prefix_elems.max(1) {
294            for i in (0..len).rev() {
295                idx_bytes.extend_from_slice(&(i as i32).to_le_bytes());
296            }
297        }
298        let idx_dims: Vec<usize> = shape.dims().iter().map(|d| d.unwrap_static()).collect();
299        let idx = self.add_node(
300            Op::Constant { data: idx_bytes },
301            vec![],
302            Shape::new(&idx_dims, DType::I32),
303        );
304        self.gather_(x, idx, last)
305    }
306
307    fn pad_axis_to_len(&mut self, x: NodeId, len: usize) -> NodeId {
308        let shape = self.shape(x).clone();
309        let last = shape.rank() - 1;
310        let n = shape.dim(last).unwrap_static();
311        if n >= len {
312            return self.narrow_(x, last, 0, len);
313        }
314        let pad_len = len - n;
315        let mut pad_dims: Vec<usize> = shape.dims().iter().map(|d| d.unwrap_static()).collect();
316        pad_dims[last] = pad_len;
317        let zeros = self.zeros_tensor(&Shape::new(&pad_dims, shape.dtype()));
318        self.concat_(vec![x, zeros], last)
319    }
320
321    fn zeros_tensor(&mut self, shape: &Shape) -> NodeId {
322        let n = shape.num_elements().unwrap();
323        let bytes = vec![0u8; n * shape.dtype().size_bytes()];
324        self.add_node(Op::Constant { data: bytes }, vec![], shape.clone())
325    }
326
327    fn scalar_f32(&mut self, v: f32) -> NodeId {
328        self.add_node(
329            Op::Constant {
330                data: v.to_le_bytes().to_vec(),
331            },
332            vec![],
333            Shape::scalar(DType::F32),
334        )
335    }
336}