1use crate::butterfly::{butterfly_forward_real_batch, butterfly_inverse_complex_batch};
19use crate::config::{FftLearnConfig, TransformDir};
20use crate::device::resolve_train_device;
21use crate::reference::{fft_real_batch, ifft_complex_batch, max_abs_error};
22use crate::runner::FftLearnRunner;
23use crate::train::{random_batch, random_complex_batch};
24use crate::twiddle::exact_twiddles;
25use crate::weights::{EncDecWeights, WeightStore, load_safetensors};
26use anyhow::{Result, ensure};
27use rand::prelude::*;
28use rlx_runtime::Device;
29use std::path::Path;
30use std::time::Instant;
31
32#[derive(Debug, Clone)]
33pub struct BenchReport {
34 pub direction: TransformDir,
35 pub n_fft: usize,
36 pub batch: usize,
37 pub iters: usize,
38 pub device: Device,
39 pub butterfly_weights: String,
41 pub rustfft_ms: f64,
42 pub rlx_fft_ms: f64,
43 pub butterfly_eager_ms: f64,
44 pub butterfly_compiled_ms: f64,
45 pub rlx_fft_err: f32,
46 pub butterfly_eager_err: f32,
47 pub butterfly_compiled_err: f32,
48}
49
50pub fn bench_all_dir(
51 n_fft: usize,
52 batch: usize,
53 iters: usize,
54 dir: TransformDir,
55 device: Device,
56 with_butterfly_compiled: bool,
57 weights_path: Option<&Path>,
58) -> Result<BenchReport> {
59 ensure!(iters >= 1);
60 let cfg = FftLearnConfig::new(n_fft, batch)?;
61 let (twiddles, butterfly_weights) = resolve_butterfly_weights(&cfg, dir, weights_path)?;
62 let mut rng = rand::rngs::StdRng::seed_from_u64(1);
63
64 let (signal, spectrum_interleaved, rlx_input, rlx_input_name) = if dir.is_forward() {
65 let signal = random_batch(&mut rng, batch, n_fft);
66 (signal, Vec::new(), None, "")
67 } else {
68 let spectrum = random_complex_batch(&mut rng, batch, n_fft);
69 let block = crate::rlx_fft::interleaved_to_block(&spectrum, batch, n_fft);
70 (Vec::new(), spectrum, Some(block), "spectrum")
71 };
72
73 let rustfft_ms = time_iters(iters, || {
74 if dir.is_forward() {
75 let _ = fft_real_batch(&signal, batch, n_fft)?;
76 } else {
77 let _ = ifft_complex_batch(&spectrum_interleaved, batch, n_fft)?;
78 }
79 Ok(())
80 })?;
81
82 eprintln!("[bench] compiling native RLX Op::Fft on {device:?}…");
83 let mut rlx_exec = crate::rlx_fft::compile_rlx_fft(&cfg, dir, device)?;
84 let rlx_fft_ms = time_iters(iters, || {
85 if dir.is_forward() {
86 rlx_exec.run(&[("signal", &signal)]);
87 } else {
88 let block = rlx_input.as_ref().expect("ifft block");
89 rlx_exec.run(&[(rlx_input_name, block)]);
90 }
91 Ok(())
92 })?;
93
94 let target = if dir.is_forward() {
95 fft_real_batch(&signal, batch, n_fft)?
96 } else {
97 ifft_complex_batch(&spectrum_interleaved, batch, n_fft)?
98 };
99
100 let rlx_out = if dir.is_forward() {
101 rlx_exec.run(&[("signal", &signal)])
102 } else {
103 rlx_exec.run(&[(rlx_input_name, rlx_input.as_ref().unwrap())])
104 };
105 let rlx_pred = crate::reference::block_to_interleaved(&rlx_out[0], batch, n_fft);
106 let rlx_fft_err = max_abs_error(&rlx_pred, &target);
107
108 let butterfly_eager_ms = time_iters(iters, || {
109 if dir.is_forward() {
110 let _ = butterfly_forward_real_batch(&signal, &twiddles, batch, n_fft)?;
111 } else {
112 let _ =
113 butterfly_inverse_complex_batch(&spectrum_interleaved, &twiddles, batch, n_fft)?;
114 }
115 Ok(())
116 })?;
117
118 let compiled_input = if dir.is_forward() {
119 signal.clone()
120 } else {
121 spectrum_interleaved.clone()
122 };
123
124 let eager_pred = if dir.is_forward() {
125 butterfly_forward_real_batch(&signal, &twiddles, batch, n_fft)?
126 } else {
127 butterfly_inverse_complex_batch(&spectrum_interleaved, &twiddles, batch, n_fft)?
128 };
129 let butterfly_eager_err = max_abs_error(&eager_pred, &target);
130
131 let (butterfly_compiled_ms, butterfly_compiled_err) = if with_butterfly_compiled {
132 eprintln!("[bench] compiling learned butterfly graph on {device:?}…");
133 match bench_butterfly_compiled(
134 &cfg,
135 dir,
136 device,
137 &compiled_input,
138 &target,
139 iters,
140 &twiddles,
141 ) {
142 Ok(v) => v,
143 Err(e) => {
144 eprintln!("[bench] butterfly compiled skipped: {e:#}");
145 (f64::NAN, f32::NAN)
146 }
147 }
148 } else {
149 (f64::NAN, f32::NAN)
150 };
151
152 Ok(BenchReport {
153 direction: dir,
154 n_fft,
155 batch,
156 iters,
157 device,
158 butterfly_weights,
159 rustfft_ms,
160 rlx_fft_ms,
161 butterfly_eager_ms,
162 butterfly_compiled_ms,
163 rlx_fft_err,
164 butterfly_eager_err,
165 butterfly_compiled_err,
166 })
167}
168
169pub fn bench_all(
170 n_fft: usize,
171 batch: usize,
172 iters: usize,
173 dir: TransformDir,
174 device_name: &str,
175 with_butterfly_compiled: bool,
176 weights_path: Option<&Path>,
177) -> Result<BenchReport> {
178 let device = resolve_train_device(Some(device_name))?;
179 bench_all_dir(
180 n_fft,
181 batch,
182 iters,
183 dir,
184 device,
185 with_butterfly_compiled,
186 weights_path,
187 )
188}
189
190pub fn bench_reference_vs_learned_dir(
192 n_fft: usize,
193 batch: usize,
194 iters: usize,
195 dir: TransformDir,
196) -> Result<(f64, f64, f32)> {
197 let report = bench_all_dir(n_fft, batch, iters, dir, Device::Cpu, false, None)?;
198 Ok((
199 report.rustfft_ms,
200 report.butterfly_eager_ms,
201 report.butterfly_eager_err,
202 ))
203}
204
205pub fn bench_reference_vs_learned(
206 n_fft: usize,
207 batch: usize,
208 iters: usize,
209) -> Result<(f64, f64, f32)> {
210 bench_reference_vs_learned_dir(n_fft, batch, iters, TransformDir::Forward)
211}
212
213fn bench_butterfly_compiled(
214 cfg: &FftLearnConfig,
215 dir: TransformDir,
216 device: Device,
217 input: &[f32],
218 target: &[f32],
219 iters: usize,
220 twiddles: &[f32],
221) -> Result<(f64, f32)> {
222 let store = WeightStore::from_twiddles(twiddles, cfg.n_fft);
223 let mut runner = FftLearnRunner::with_weights_dir(cfg.clone(), &store, dir)?;
224 runner.load_compiled(device)?;
225 let _ = runner.forward(input)?;
226 let ms = time_iters(iters, || {
227 let _ = runner.forward(input)?;
228 Ok(())
229 })?;
230 let pred = runner.forward(input)?;
231 Ok((ms, max_abs_error(&pred, target)))
232}
233
234fn resolve_butterfly_weights(
235 cfg: &FftLearnConfig,
236 dir: TransformDir,
237 weights_path: Option<&Path>,
238) -> Result<(Vec<f32>, String)> {
239 let Some(path) = weights_path else {
240 return Ok((exact_twiddles(cfg), "exact twiddles".to_string()));
241 };
242
243 let store = load_safetensors(path)?;
244 if let Ok(tw) = store.to_twiddles(cfg.n_fft) {
245 return Ok((tw, format!("learned ({})", path.display())));
246 }
247
248 let encdec = EncDecWeights::from_merged(&store, cfg.n_fft)?;
249 let tw = if dir.is_forward() {
250 encdec.encoder_twiddles(cfg.n_fft)?
251 } else {
252 encdec.decoder_twiddles(cfg.n_fft)?
253 };
254 Ok((
255 tw,
256 format!(
257 "learned encdec {} ({})",
258 if dir.is_forward() {
259 "encoder"
260 } else {
261 "decoder"
262 },
263 path.display()
264 ),
265 ))
266}
267
268fn time_iters(iters: usize, mut f: impl FnMut() -> Result<()>) -> Result<f64> {
269 let t0 = Instant::now();
270 for _ in 0..iters {
271 f()?;
272 }
273 Ok(t0.elapsed().as_secs_f64() * 1000.0 / iters as f64)
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279 use crate::config::FftLearnConfig;
280
281 #[test]
282 fn rlx_fft_graph_builds() {
283 use crate::rlx_fft::{build_rlx_fft_forward_graph, build_rlx_fft_inverse_graph};
284 let cfg = FftLearnConfig::new(64, 2).unwrap();
285 let _ = build_rlx_fft_forward_graph(&cfg);
286 let _ = build_rlx_fft_inverse_graph(&cfg);
287 }
288
289 #[test]
290 #[ignore = "slow compile; run with `cargo test -p rlx-fft bench_cpu_forward_smoke -- --ignored`"]
291 fn bench_cpu_forward_smoke() {
292 let report = bench_all_dir(64, 4, 3, TransformDir::Forward, Device::Cpu, false, None)
293 .expect("bench");
294 assert!(report.rustfft_ms >= 0.0);
295 assert!(report.rlx_fft_ms >= 0.0);
296 assert!(report.rlx_fft_err < 1e-3);
297 }
298}