Skip to main content

rlx_fft/
bench.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Benchmark learned butterfly FFT vs `rustfft` and native RLX `Op::Fft`.
17
18use crate::butterfly::{butterfly_forward_real_batch, butterfly_inverse_complex_batch};
19use crate::config::{FftLearnConfig, TransformDir};
20use crate::device::resolve_train_device;
21use crate::reference::{fft_real_batch, ifft_complex_batch, max_abs_error};
22use crate::runner::FftLearnRunner;
23use crate::train::{random_batch, random_complex_batch};
24use crate::twiddle::exact_twiddles;
25use crate::weights::{EncDecWeights, WeightStore, load_safetensors};
26use anyhow::{Result, ensure};
27use rand::prelude::*;
28use rlx_runtime::Device;
29use std::path::Path;
30use std::time::Instant;
31
32#[derive(Debug, Clone)]
33pub struct BenchReport {
34    pub direction: TransformDir,
35    pub n_fft: usize,
36    pub batch: usize,
37    pub iters: usize,
38    pub device: Device,
39    /// `exact twiddles` or `learned checkpoint`.
40    pub butterfly_weights: String,
41    pub rustfft_ms: f64,
42    pub rlx_fft_ms: f64,
43    pub butterfly_eager_ms: f64,
44    pub butterfly_compiled_ms: f64,
45    pub rlx_fft_err: f32,
46    pub butterfly_eager_err: f32,
47    pub butterfly_compiled_err: f32,
48}
49
50pub fn bench_all_dir(
51    n_fft: usize,
52    batch: usize,
53    iters: usize,
54    dir: TransformDir,
55    device: Device,
56    with_butterfly_compiled: bool,
57    weights_path: Option<&Path>,
58) -> Result<BenchReport> {
59    ensure!(iters >= 1);
60    let cfg = FftLearnConfig::new(n_fft, batch)?;
61    let (twiddles, butterfly_weights) = resolve_butterfly_weights(&cfg, dir, weights_path)?;
62    let mut rng = rand::rngs::StdRng::seed_from_u64(1);
63
64    let (signal, spectrum_interleaved, rlx_input, rlx_input_name) = if dir.is_forward() {
65        let signal = random_batch(&mut rng, batch, n_fft);
66        (signal, Vec::new(), None, "")
67    } else {
68        let spectrum = random_complex_batch(&mut rng, batch, n_fft);
69        let block = crate::rlx_fft::interleaved_to_block(&spectrum, batch, n_fft);
70        (Vec::new(), spectrum, Some(block), "spectrum")
71    };
72
73    let rustfft_ms = time_iters(iters, || {
74        if dir.is_forward() {
75            let _ = fft_real_batch(&signal, batch, n_fft)?;
76        } else {
77            let _ = ifft_complex_batch(&spectrum_interleaved, batch, n_fft)?;
78        }
79        Ok(())
80    })?;
81
82    eprintln!("[bench] compiling native RLX Op::Fft on {device:?}…");
83    let mut rlx_exec = crate::rlx_fft::compile_rlx_fft(&cfg, dir, device)?;
84    let rlx_fft_ms = time_iters(iters, || {
85        if dir.is_forward() {
86            rlx_exec.run(&[("signal", &signal)]);
87        } else {
88            let block = rlx_input.as_ref().expect("ifft block");
89            rlx_exec.run(&[(rlx_input_name, block)]);
90        }
91        Ok(())
92    })?;
93
94    let target = if dir.is_forward() {
95        fft_real_batch(&signal, batch, n_fft)?
96    } else {
97        ifft_complex_batch(&spectrum_interleaved, batch, n_fft)?
98    };
99
100    let rlx_out = if dir.is_forward() {
101        rlx_exec.run(&[("signal", &signal)])
102    } else {
103        rlx_exec.run(&[(rlx_input_name, rlx_input.as_ref().unwrap())])
104    };
105    let rlx_pred = crate::reference::block_to_interleaved(&rlx_out[0], batch, n_fft);
106    let rlx_fft_err = max_abs_error(&rlx_pred, &target);
107
108    let butterfly_eager_ms = time_iters(iters, || {
109        if dir.is_forward() {
110            let _ = butterfly_forward_real_batch(&signal, &twiddles, batch, n_fft)?;
111        } else {
112            let _ =
113                butterfly_inverse_complex_batch(&spectrum_interleaved, &twiddles, batch, n_fft)?;
114        }
115        Ok(())
116    })?;
117
118    let compiled_input = if dir.is_forward() {
119        signal.clone()
120    } else {
121        spectrum_interleaved.clone()
122    };
123
124    let eager_pred = if dir.is_forward() {
125        butterfly_forward_real_batch(&signal, &twiddles, batch, n_fft)?
126    } else {
127        butterfly_inverse_complex_batch(&spectrum_interleaved, &twiddles, batch, n_fft)?
128    };
129    let butterfly_eager_err = max_abs_error(&eager_pred, &target);
130
131    let (butterfly_compiled_ms, butterfly_compiled_err) = if with_butterfly_compiled {
132        eprintln!("[bench] compiling learned butterfly graph on {device:?}…");
133        match bench_butterfly_compiled(
134            &cfg,
135            dir,
136            device,
137            &compiled_input,
138            &target,
139            iters,
140            &twiddles,
141        ) {
142            Ok(v) => v,
143            Err(e) => {
144                eprintln!("[bench] butterfly compiled skipped: {e:#}");
145                (f64::NAN, f32::NAN)
146            }
147        }
148    } else {
149        (f64::NAN, f32::NAN)
150    };
151
152    Ok(BenchReport {
153        direction: dir,
154        n_fft,
155        batch,
156        iters,
157        device,
158        butterfly_weights,
159        rustfft_ms,
160        rlx_fft_ms,
161        butterfly_eager_ms,
162        butterfly_compiled_ms,
163        rlx_fft_err,
164        butterfly_eager_err,
165        butterfly_compiled_err,
166    })
167}
168
169pub fn bench_all(
170    n_fft: usize,
171    batch: usize,
172    iters: usize,
173    dir: TransformDir,
174    device_name: &str,
175    with_butterfly_compiled: bool,
176    weights_path: Option<&Path>,
177) -> Result<BenchReport> {
178    let device = resolve_train_device(Some(device_name))?;
179    bench_all_dir(
180        n_fft,
181        batch,
182        iters,
183        dir,
184        device,
185        with_butterfly_compiled,
186        weights_path,
187    )
188}
189
190/// Legacy API: rustfft vs butterfly eager only.
191pub fn bench_reference_vs_learned_dir(
192    n_fft: usize,
193    batch: usize,
194    iters: usize,
195    dir: TransformDir,
196) -> Result<(f64, f64, f32)> {
197    let report = bench_all_dir(n_fft, batch, iters, dir, Device::Cpu, false, None)?;
198    Ok((
199        report.rustfft_ms,
200        report.butterfly_eager_ms,
201        report.butterfly_eager_err,
202    ))
203}
204
205pub fn bench_reference_vs_learned(
206    n_fft: usize,
207    batch: usize,
208    iters: usize,
209) -> Result<(f64, f64, f32)> {
210    bench_reference_vs_learned_dir(n_fft, batch, iters, TransformDir::Forward)
211}
212
213fn bench_butterfly_compiled(
214    cfg: &FftLearnConfig,
215    dir: TransformDir,
216    device: Device,
217    input: &[f32],
218    target: &[f32],
219    iters: usize,
220    twiddles: &[f32],
221) -> Result<(f64, f32)> {
222    let store = WeightStore::from_twiddles(twiddles, cfg.n_fft);
223    let mut runner = FftLearnRunner::with_weights_dir(cfg.clone(), &store, dir)?;
224    runner.load_compiled(device)?;
225    let _ = runner.forward(input)?;
226    let ms = time_iters(iters, || {
227        let _ = runner.forward(input)?;
228        Ok(())
229    })?;
230    let pred = runner.forward(input)?;
231    Ok((ms, max_abs_error(&pred, target)))
232}
233
234fn resolve_butterfly_weights(
235    cfg: &FftLearnConfig,
236    dir: TransformDir,
237    weights_path: Option<&Path>,
238) -> Result<(Vec<f32>, String)> {
239    let Some(path) = weights_path else {
240        return Ok((exact_twiddles(cfg), "exact twiddles".to_string()));
241    };
242
243    let store = load_safetensors(path)?;
244    if let Ok(tw) = store.to_twiddles(cfg.n_fft) {
245        return Ok((tw, format!("learned ({})", path.display())));
246    }
247
248    let encdec = EncDecWeights::from_merged(&store, cfg.n_fft)?;
249    let tw = if dir.is_forward() {
250        encdec.encoder_twiddles(cfg.n_fft)?
251    } else {
252        encdec.decoder_twiddles(cfg.n_fft)?
253    };
254    Ok((
255        tw,
256        format!(
257            "learned encdec {} ({})",
258            if dir.is_forward() {
259                "encoder"
260            } else {
261                "decoder"
262            },
263            path.display()
264        ),
265    ))
266}
267
268fn time_iters(iters: usize, mut f: impl FnMut() -> Result<()>) -> Result<f64> {
269    let t0 = Instant::now();
270    for _ in 0..iters {
271        f()?;
272    }
273    Ok(t0.elapsed().as_secs_f64() * 1000.0 / iters as f64)
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279    use crate::config::FftLearnConfig;
280
281    #[test]
282    fn rlx_fft_graph_builds() {
283        use crate::rlx_fft::{build_rlx_fft_forward_graph, build_rlx_fft_inverse_graph};
284        let cfg = FftLearnConfig::new(64, 2).unwrap();
285        let _ = build_rlx_fft_forward_graph(&cfg);
286        let _ = build_rlx_fft_inverse_graph(&cfg);
287    }
288
289    #[test]
290    #[ignore = "slow compile; run with `cargo test -p rlx-fft bench_cpu_forward_smoke -- --ignored`"]
291    fn bench_cpu_forward_smoke() {
292        let report = bench_all_dir(64, 4, 3, TransformDir::Forward, Device::Cpu, false, None)
293            .expect("bench");
294        assert!(report.rustfft_ms >= 0.0);
295        assert!(report.rlx_fft_ms >= 0.0);
296        assert!(report.rlx_fft_err < 1e-3);
297    }
298}