1use crate::butterfly::{build_butterfly_forward_graph, butterfly_forward_real_batch};
19use crate::config::FftLearnConfig;
20use crate::domain::train_domain_twiddles;
21use crate::fused::{build_fused_spectral_graph, fused_spectral_eager, unit_mask};
22use crate::q8::Q8Twiddles;
23use crate::reference::{fft_real_batch, max_abs_error};
24use crate::stockham::{build_stockham_forward_graph, stockham_forward_real_batch};
25use crate::train::random_batch;
26use crate::twiddle::exact_twiddles;
27use crate::unitary::{UnitaryWeights, train_unitary_quick};
28use crate::welch::{
29 WelchParams, compile_welch_rlx_fft, welch_butterfly, welch_rlx_op_fft, welch_rustfft,
30};
31use anyhow::{Result, bail};
32use rand::prelude::*;
33use rlx_runtime::{CompiledGraph, Device};
34use std::time::Instant;
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
37pub enum FftVariantId {
38 Rustfft,
39 RlxOpFft,
40 RlxOpIfft,
41 ButterflyEager,
42 ButterflyCompiled,
43 StockhamEager,
44 StockhamCompiled,
45 FusedSpectralEager,
46 FusedSpectralCompiled,
47 ButterflyQ8,
48 ButterflyUnitary,
49 DomainTwiddle,
50 WelchRustfft,
51 WelchRlxOpFft,
52 WelchButterflyEager,
53 WelchButterflyCompiled,
54}
55
56impl FftVariantId {
57 pub fn all() -> &'static [Self] {
58 Self::all_with_compiled(true)
59 }
60
61 pub fn all_eager() -> &'static [Self] {
62 Self::all_with_compiled(false)
63 }
64
65 fn all_with_compiled(compiled: bool) -> &'static [Self] {
66 if compiled {
67 &[
68 Self::Rustfft,
69 Self::RlxOpFft,
70 Self::RlxOpIfft,
71 Self::ButterflyEager,
72 Self::ButterflyCompiled,
73 Self::StockhamEager,
74 Self::StockhamCompiled,
75 Self::FusedSpectralEager,
76 Self::FusedSpectralCompiled,
77 Self::ButterflyQ8,
78 Self::ButterflyUnitary,
79 Self::DomainTwiddle,
80 ]
81 } else {
82 &[
83 Self::Rustfft,
84 Self::RlxOpFft,
85 Self::RlxOpIfft,
86 Self::ButterflyEager,
87 Self::StockhamEager,
88 Self::FusedSpectralEager,
89 Self::ButterflyQ8,
90 Self::ButterflyUnitary,
91 Self::DomainTwiddle,
92 ]
93 }
94 }
95
96 pub fn welch_variants(with_compiled: bool) -> &'static [Self] {
97 if with_compiled {
98 &[
99 Self::WelchRustfft,
100 Self::WelchRlxOpFft,
101 Self::WelchButterflyEager,
102 Self::WelchButterflyCompiled,
103 ]
104 } else {
105 &[
106 Self::WelchRustfft,
107 Self::WelchRlxOpFft,
108 Self::WelchButterflyEager,
109 ]
110 }
111 }
112
113 pub fn is_welch(self) -> bool {
114 matches!(
115 self,
116 Self::WelchRustfft
117 | Self::WelchRlxOpFft
118 | Self::WelchButterflyEager
119 | Self::WelchButterflyCompiled
120 )
121 }
122
123 pub fn tier(self) -> &'static str {
124 match self {
125 Self::Rustfft
126 | Self::RlxOpFft
127 | Self::RlxOpIfft
128 | Self::ButterflyEager
129 | Self::WelchRustfft
130 | Self::WelchRlxOpFft
131 | Self::WelchButterflyEager => "baseline",
132 Self::StockhamEager
133 | Self::StockhamCompiled
134 | Self::FusedSpectralEager
135 | Self::FusedSpectralCompiled => "A",
136 Self::ButterflyCompiled | Self::ButterflyQ8 | Self::WelchButterflyCompiled => "B",
137 Self::ButterflyUnitary | Self::DomainTwiddle => "C",
138 }
139 }
140
141 pub fn label(self) -> &'static str {
142 match self {
143 Self::Rustfft => "rustfft",
144 Self::RlxOpFft => "rlx_op_fft",
145 Self::RlxOpIfft => "rlx_op_ifft",
146 Self::ButterflyEager => "butterfly_eager",
147 Self::ButterflyCompiled => "butterfly_compiled",
148 Self::StockhamEager => "stockham_eager",
149 Self::StockhamCompiled => "stockham_compiled",
150 Self::FusedSpectralEager => "fused_spectral_eager",
151 Self::FusedSpectralCompiled => "fused_spectral_compiled",
152 Self::ButterflyQ8 => "butterfly_q8",
153 Self::ButterflyUnitary => "butterfly_unitary",
154 Self::DomainTwiddle => "domain_twiddle",
155 Self::WelchRustfft => "welch_rustfft",
156 Self::WelchRlxOpFft => "welch_rlx_op_fft",
157 Self::WelchButterflyEager => "welch_butterfly_eager",
158 Self::WelchButterflyCompiled => "welch_butterfly_compiled",
159 }
160 }
161
162 pub fn supports_inverse(self) -> bool {
163 matches!(self, Self::Rustfft | Self::RlxOpIfft | Self::ButterflyEager)
164 }
165
166 pub fn needs_training(self) -> bool {
167 matches!(self, Self::ButterflyUnitary | Self::DomainTwiddle)
168 }
169}
170
171pub struct VariantState {
172 pub twiddles: Vec<f32>,
173 pub q8: Option<Q8Twiddles>,
174 pub unitary: Option<UnitaryWeights>,
175 pub mask: Vec<f32>,
176 compiled_butterfly: Option<CompiledGraph>,
177 compiled_stockham: Option<CompiledGraph>,
178 compiled_fused: Option<CompiledGraph>,
179 compiled_rlx: Option<CompiledGraph>,
180 compiled_rlx_inv: Option<CompiledGraph>,
181 compiled_welch_rlx: Option<CompiledGraph>,
182 compiled_welch_butterfly: Option<CompiledGraph>,
183 spectrum_block: Vec<f32>,
184 inverse_spectrum: Vec<f32>,
185}
186
187impl VariantState {
188 pub fn new(cfg: &FftLearnConfig) -> Self {
189 Self {
190 twiddles: exact_twiddles(cfg),
191 q8: None,
192 unitary: None,
193 mask: unit_mask(cfg.n_fft),
194 compiled_butterfly: None,
195 compiled_stockham: None,
196 compiled_fused: None,
197 compiled_rlx: None,
198 compiled_rlx_inv: None,
199 compiled_welch_rlx: None,
200 compiled_welch_butterfly: None,
201 spectrum_block: Vec::new(),
202 inverse_spectrum: Vec::new(),
203 }
204 }
205
206 pub fn set_inverse_input_block(&mut self, block: Vec<f32>) {
207 self.spectrum_block = block;
208 }
209
210 pub fn set_inverse_spectrum(&mut self, spectrum: Vec<f32>) {
211 self.inverse_spectrum = spectrum;
212 }
213
214 pub fn prepare(
215 &mut self,
216 variant: FftVariantId,
217 cfg: &FftLearnConfig,
218 device: Device,
219 train_steps: usize,
220 seed: u64,
221 ) -> Result<()> {
222 match variant {
223 FftVariantId::ButterflyQ8 => {
224 self.q8 = Some(Q8Twiddles::from_f32(&self.twiddles));
225 }
226 FftVariantId::ButterflyUnitary => {
227 if cfg.n_fft <= 64 && cfg.batch <= 8 && train_steps > 0 {
228 let (w, _) = train_unitary_quick(cfg, train_steps.min(25), 1e-3, seed)?;
229 self.unitary = Some(w);
230 } else {
231 self.unitary = Some(UnitaryWeights::exact_init(cfg));
232 }
233 }
234 FftVariantId::DomainTwiddle => {
235 let (tw, _) = train_domain_twiddles(cfg, train_steps, 5e-4, seed)?;
236 self.twiddles = tw;
237 }
238 FftVariantId::ButterflyCompiled if self.compiled_butterfly.is_none() => {
239 self.compiled_butterfly = Some(compile_butterfly(cfg, device, &self.twiddles)?);
240 }
241 FftVariantId::StockhamCompiled if self.compiled_stockham.is_none() => {
242 self.compiled_stockham = Some(compile_stockham(cfg, device, &self.twiddles)?);
243 }
244 FftVariantId::FusedSpectralCompiled if self.compiled_fused.is_none() => {
245 self.compiled_fused = Some(compile_fused(cfg, device, &self.mask)?);
246 }
247 FftVariantId::RlxOpFft if self.compiled_rlx.is_none() => {
248 self.compiled_rlx = Some(crate::rlx_fft::compile_rlx_fft(
249 cfg,
250 crate::config::TransformDir::Forward,
251 device,
252 )?);
253 }
254 FftVariantId::RlxOpIfft if self.compiled_rlx_inv.is_none() => {
255 self.compiled_rlx_inv = Some(crate::rlx_fft::compile_rlx_fft(
256 cfg,
257 crate::config::TransformDir::Inverse,
258 device,
259 )?);
260 }
261 FftVariantId::WelchRlxOpFft if self.compiled_welch_rlx.is_none() => {
262 let params = WelchParams::for_n_fft(cfg.n_fft);
263 self.compiled_welch_rlx = Some(compile_welch_rlx_fft(cfg.batch, params, device)?);
264 }
265 FftVariantId::WelchButterflyCompiled if self.compiled_welch_butterfly.is_none() => {
266 let params = WelchParams::for_n_fft(cfg.n_fft);
267 let welch_cfg = FftLearnConfig::new(cfg.n_fft, cfg.batch * params.n_segments)?;
268 self.compiled_welch_butterfly =
269 Some(compile_butterfly(&welch_cfg, device, &self.twiddles)?);
270 }
271 _ => {}
272 }
273 Ok(())
274 }
275
276 pub fn forward(
277 &mut self,
278 variant: FftVariantId,
279 signal: &[f32],
280 cfg: &FftLearnConfig,
281 ) -> Result<Vec<f32>> {
282 let n = cfg.n_fft;
283 let batch = cfg.batch;
284 match variant {
285 FftVariantId::Rustfft => fft_real_batch(signal, batch, n),
286 FftVariantId::RlxOpFft => {
287 let exec = self.compiled_rlx.as_mut().expect("rlx compiled");
288 Ok(crate::rlx_fft::rlx_fft_forward(exec, signal, batch, n))
289 }
290 FftVariantId::ButterflyEager | FftVariantId::DomainTwiddle => {
291 butterfly_forward_real_batch(signal, &self.twiddles, batch, n)
292 }
293 FftVariantId::ButterflyCompiled => {
294 let exec = self
295 .compiled_butterfly
296 .as_mut()
297 .expect("butterfly compiled");
298 Ok(exec.run(&[("signal", signal)]).remove(0))
299 }
300 FftVariantId::StockhamEager => {
301 stockham_forward_real_batch(signal, &self.twiddles, batch, n)
302 }
303 FftVariantId::StockhamCompiled => {
304 let exec = self.compiled_stockham.as_mut().expect("stockham compiled");
305 Ok(exec.run(&[("signal", signal)]).remove(0))
306 }
307 FftVariantId::FusedSpectralEager => {
308 fused_spectral_eager(signal, &self.twiddles, &self.mask, batch, n)
309 }
310 FftVariantId::FusedSpectralCompiled => {
311 let exec = self.compiled_fused.as_mut().expect("fused compiled");
312 Ok(exec.run(&[("signal", signal)]).remove(0))
313 }
314 FftVariantId::ButterflyQ8 => self
315 .q8
316 .as_ref()
317 .expect("q8")
318 .forward_real_batch(signal, batch, n),
319 FftVariantId::ButterflyUnitary => self
320 .unitary
321 .as_ref()
322 .expect("unitary")
323 .forward_real_batch(signal, batch, n),
324 FftVariantId::RlxOpIfft => bail!("rlx_op_ifft is inverse-only; call inverse()"),
325 FftVariantId::WelchRustfft
326 | FftVariantId::WelchRlxOpFft
327 | FftVariantId::WelchButterflyEager
328 | FftVariantId::WelchButterflyCompiled => {
329 bail!("{} is welch-only; call welch()", variant.label())
330 }
331 }
332 }
333
334 pub fn welch(
335 &mut self,
336 variant: FftVariantId,
337 signal: &[f32],
338 cfg: &FftLearnConfig,
339 ) -> Result<Vec<f32>> {
340 let params = WelchParams::for_n_fft(cfg.n_fft);
341 let batch = cfg.batch;
342 match variant {
343 FftVariantId::WelchRustfft => welch_rustfft(signal, batch, params),
344 FftVariantId::WelchRlxOpFft => {
345 let exec = self
346 .compiled_welch_rlx
347 .as_mut()
348 .expect("welch rlx compiled");
349 welch_rlx_op_fft(exec, signal, batch, params)
350 }
351 FftVariantId::WelchButterflyEager => {
352 welch_butterfly(signal, &self.twiddles, batch, params)
353 }
354 FftVariantId::WelchButterflyCompiled => {
355 let window = crate::welch::hann_window(params.n_fft);
356 let segs = crate::welch::welch_windowed_segments(signal, batch, params, &window)?;
357 let exec = self
358 .compiled_welch_butterfly
359 .as_mut()
360 .expect("welch butterfly compiled");
361 let spec = exec.run(&[("signal", &segs)]).remove(0);
362 Ok(crate::welch::average_welch_psd(
363 &spec,
364 batch,
365 params.n_segments,
366 params.n_fft,
367 ))
368 }
369 other => bail!("variant {} has no welch path", other.label()),
370 }
371 }
372
373 pub fn inverse(&mut self, variant: FftVariantId, cfg: &FftLearnConfig) -> Result<Vec<f32>> {
374 let n = cfg.n_fft;
375 let batch = cfg.batch;
376 match variant {
377 FftVariantId::Rustfft => {
378 crate::reference::ifft_complex_batch(&self.inverse_spectrum, batch, n)
379 }
380 FftVariantId::RlxOpIfft => {
381 let exec = self.compiled_rlx_inv.as_mut().expect("rlx inv compiled");
382 Ok(crate::rlx_fft::rlx_fft_inverse_block(
383 exec,
384 &self.spectrum_block,
385 batch,
386 n,
387 ))
388 }
389 FftVariantId::ButterflyEager => crate::butterfly::butterfly_inverse_complex_batch(
390 &self.inverse_spectrum,
391 &self.twiddles,
392 batch,
393 n,
394 ),
395 other => bail!("variant {} has no inverse path", other.label()),
396 }
397 }
398}
399
400fn compile_butterfly(
401 cfg: &FftLearnConfig,
402 device: Device,
403 twiddles: &[f32],
404) -> Result<CompiledGraph> {
405 use crate::weights::WeightStore;
406 let built = build_butterfly_forward_graph(cfg)?;
407 let store = WeightStore::from_twiddles(twiddles, cfg.n_fft);
408 let mut compiled = crate::compile::try_compile_graph(device, built.graph)?;
409 store.apply_butterfly(&mut compiled, cfg.batch, cfg.n_fft);
410 Ok(compiled)
411}
412
413fn compile_stockham(
414 cfg: &FftLearnConfig,
415 device: Device,
416 twiddles: &[f32],
417) -> Result<CompiledGraph> {
418 use crate::weights::WeightStore;
419 let (graph, _names) = build_stockham_forward_graph(cfg)?;
420 let store = WeightStore::from_twiddles(twiddles, cfg.n_fft);
421 let mut compiled = crate::compile::try_compile_graph(device, graph)?;
422 store.apply_butterfly(&mut compiled, cfg.batch, cfg.n_fft);
423 Ok(compiled)
424}
425
426fn compile_fused(cfg: &FftLearnConfig, device: Device, mask: &[f32]) -> Result<CompiledGraph> {
427 let (graph, names) = build_fused_spectral_graph(cfg)?;
428 let mut compiled = crate::compile::try_compile_graph(device, graph)?;
429 for (i, name) in names.iter().enumerate() {
430 compiled.set_param(name, &[mask[i]]);
431 }
432 Ok(compiled)
433}
434
435pub fn variants_for_direction(with_compiled: bool, forward: bool) -> Vec<FftVariantId> {
436 let all = if with_compiled {
437 FftVariantId::all()
438 } else {
439 FftVariantId::all_eager()
440 };
441 all.iter()
442 .copied()
443 .filter(|v| {
444 if v.is_welch() {
445 return false;
446 }
447 if forward {
448 !matches!(v, FftVariantId::RlxOpIfft)
449 } else {
450 v.supports_inverse()
451 }
452 })
453 .collect()
454}
455
456pub fn bench_variant_ms(
457 state: &mut VariantState,
458 variant: FftVariantId,
459 cfg: &FftLearnConfig,
460 signal: &[f32],
461 iters: usize,
462) -> Result<f64> {
463 let _ = state.forward(variant, signal, cfg)?;
464 let t0 = Instant::now();
465 for _ in 0..iters {
466 state.forward(variant, signal, cfg)?;
467 }
468 Ok(t0.elapsed().as_secs_f64() * 1000.0 / iters as f64)
469}
470
471pub fn variants_for_welch(with_compiled: bool) -> Vec<FftVariantId> {
472 FftVariantId::welch_variants(with_compiled).to_vec()
473}
474
475pub fn bench_variant_ms_welch(
476 state: &mut VariantState,
477 variant: FftVariantId,
478 cfg: &FftLearnConfig,
479 signal: &[f32],
480 iters: usize,
481) -> Result<f64> {
482 let _ = state.welch(variant, signal, cfg)?;
483 let t0 = Instant::now();
484 for _ in 0..iters {
485 state.welch(variant, signal, cfg)?;
486 }
487 Ok(t0.elapsed().as_secs_f64() * 1000.0 / iters as f64)
488}
489
490pub fn variant_welch_error(
491 state: &mut VariantState,
492 variant: FftVariantId,
493 cfg: &FftLearnConfig,
494 signal: &[f32],
495) -> Result<f32> {
496 if matches!(variant, FftVariantId::WelchRustfft) {
497 return Ok(0.0);
498 }
499 let pred = state.welch(variant, signal, cfg)?;
500 let params = WelchParams::for_n_fft(cfg.n_fft);
501 let target = welch_rustfft(signal, cfg.batch, params)?;
502 Ok(max_abs_error(&pred, &target))
503}
504
505pub fn fixed_ablation_welch_signal(seed: u64, batch: usize, n_fft: usize) -> Vec<f32> {
506 let params = WelchParams::for_n_fft(n_fft);
507 let frame = params.frame_len();
508 let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
509 random_batch(&mut rng, batch, frame)
510}
511
512pub fn bench_variant_ms_inverse(
513 state: &mut VariantState,
514 variant: FftVariantId,
515 cfg: &FftLearnConfig,
516 iters: usize,
517) -> Result<f64> {
518 let _ = state.inverse(variant, cfg)?;
519 let t0 = Instant::now();
520 for _ in 0..iters {
521 state.inverse(variant, cfg)?;
522 }
523 Ok(t0.elapsed().as_secs_f64() * 1000.0 / iters as f64)
524}
525
526pub fn variant_spectrum_error(
527 state: &mut VariantState,
528 variant: FftVariantId,
529 cfg: &FftLearnConfig,
530 signal: &[f32],
531) -> Result<f32> {
532 if matches!(
533 variant,
534 FftVariantId::FusedSpectralEager | FftVariantId::FusedSpectralCompiled
535 ) {
536 return Ok(0.0);
537 }
538 let pred = state.forward(variant, signal, cfg)?;
539 let target = fft_real_batch(signal, cfg.batch, cfg.n_fft)?;
540 Ok(max_abs_error(&pred, &target))
541}
542
543pub fn variant_inverse_error(
544 state: &mut VariantState,
545 variant: FftVariantId,
546 cfg: &FftLearnConfig,
547) -> Result<f32> {
548 let pred = state.inverse(variant, cfg)?;
549 let target =
550 crate::reference::ifft_complex_batch(&state.inverse_spectrum, cfg.batch, cfg.n_fft)?;
551 Ok(max_abs_error(&pred, &target))
552}
553
554pub fn fixed_ablation_signal(seed: u64, batch: usize, n_fft: usize) -> Vec<f32> {
555 let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
556 random_batch(&mut rng, batch, n_fft)
557}
558
559pub fn fixed_ablation_spectrum(seed: u64, batch: usize, n_fft: usize) -> Vec<f32> {
560 let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
561 crate::train::random_complex_batch(&mut rng, batch, n_fft)
562}
563
564pub fn ensure_variant_ready(variant: FftVariantId, device: Device) -> Result<()> {
565 if matches!(
566 variant,
567 FftVariantId::ButterflyCompiled
568 | FftVariantId::StockhamCompiled
569 | FftVariantId::FusedSpectralCompiled
570 | FftVariantId::RlxOpFft
571 | FftVariantId::RlxOpIfft
572 | FftVariantId::WelchRlxOpFft
573 | FftVariantId::WelchButterflyCompiled
574 ) && device == Device::Cpu
575 {
576 return Ok(());
577 }
578 if matches!(variant, FftVariantId::Rustfft) {
579 return Ok(());
580 }
581 Ok(())
582}
583
584pub fn skip_on_device(variant: FftVariantId, device: Device) -> bool {
585 let _ = (variant, device);
586 false
587}
588
589pub fn validate_variant_output(
590 variant: FftVariantId,
591 pred_len: usize,
592 cfg: &FftLearnConfig,
593) -> Result<()> {
594 let expected = cfg.batch * cfg.n_fft * 2;
595 if pred_len != expected {
596 bail!(
597 "variant {} output len {pred_len} != expected {expected}",
598 variant.label()
599 );
600 }
601 Ok(())
602}