Skip to main content

rlx_fft/
variants.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! FFT implementation variants for ablation (Tiers A/B/C + baselines).
17
18use crate::butterfly::{build_butterfly_forward_graph, butterfly_forward_real_batch};
19use crate::config::FftLearnConfig;
20use crate::domain::train_domain_twiddles;
21use crate::fused::{build_fused_spectral_graph, fused_spectral_eager, unit_mask};
22use crate::q8::Q8Twiddles;
23use crate::reference::{fft_real_batch, max_abs_error};
24use crate::stockham::{build_stockham_forward_graph, stockham_forward_real_batch};
25use crate::train::random_batch;
26use crate::twiddle::exact_twiddles;
27use crate::unitary::{UnitaryWeights, train_unitary_quick};
28use crate::welch::{
29    WelchParams, compile_welch_rlx_fft, welch_butterfly, welch_rlx_op_fft, welch_rustfft,
30};
31use anyhow::{Result, bail};
32use rand::prelude::*;
33use rlx_runtime::{CompiledGraph, Device};
34use std::time::Instant;
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
37pub enum FftVariantId {
38    Rustfft,
39    RlxOpFft,
40    RlxOpIfft,
41    ButterflyEager,
42    ButterflyCompiled,
43    StockhamEager,
44    StockhamCompiled,
45    FusedSpectralEager,
46    FusedSpectralCompiled,
47    ButterflyQ8,
48    ButterflyUnitary,
49    DomainTwiddle,
50    WelchRustfft,
51    WelchRlxOpFft,
52    WelchButterflyEager,
53    WelchButterflyCompiled,
54}
55
56impl FftVariantId {
57    pub fn all() -> &'static [Self] {
58        Self::all_with_compiled(true)
59    }
60
61    pub fn all_eager() -> &'static [Self] {
62        Self::all_with_compiled(false)
63    }
64
65    fn all_with_compiled(compiled: bool) -> &'static [Self] {
66        if compiled {
67            &[
68                Self::Rustfft,
69                Self::RlxOpFft,
70                Self::RlxOpIfft,
71                Self::ButterflyEager,
72                Self::ButterflyCompiled,
73                Self::StockhamEager,
74                Self::StockhamCompiled,
75                Self::FusedSpectralEager,
76                Self::FusedSpectralCompiled,
77                Self::ButterflyQ8,
78                Self::ButterflyUnitary,
79                Self::DomainTwiddle,
80            ]
81        } else {
82            &[
83                Self::Rustfft,
84                Self::RlxOpFft,
85                Self::RlxOpIfft,
86                Self::ButterflyEager,
87                Self::StockhamEager,
88                Self::FusedSpectralEager,
89                Self::ButterflyQ8,
90                Self::ButterflyUnitary,
91                Self::DomainTwiddle,
92            ]
93        }
94    }
95
96    pub fn welch_variants(with_compiled: bool) -> &'static [Self] {
97        if with_compiled {
98            &[
99                Self::WelchRustfft,
100                Self::WelchRlxOpFft,
101                Self::WelchButterflyEager,
102                Self::WelchButterflyCompiled,
103            ]
104        } else {
105            &[
106                Self::WelchRustfft,
107                Self::WelchRlxOpFft,
108                Self::WelchButterflyEager,
109            ]
110        }
111    }
112
113    pub fn is_welch(self) -> bool {
114        matches!(
115            self,
116            Self::WelchRustfft
117                | Self::WelchRlxOpFft
118                | Self::WelchButterflyEager
119                | Self::WelchButterflyCompiled
120        )
121    }
122
123    pub fn tier(self) -> &'static str {
124        match self {
125            Self::Rustfft
126            | Self::RlxOpFft
127            | Self::RlxOpIfft
128            | Self::ButterflyEager
129            | Self::WelchRustfft
130            | Self::WelchRlxOpFft
131            | Self::WelchButterflyEager => "baseline",
132            Self::StockhamEager
133            | Self::StockhamCompiled
134            | Self::FusedSpectralEager
135            | Self::FusedSpectralCompiled => "A",
136            Self::ButterflyCompiled | Self::ButterflyQ8 | Self::WelchButterflyCompiled => "B",
137            Self::ButterflyUnitary | Self::DomainTwiddle => "C",
138        }
139    }
140
141    pub fn label(self) -> &'static str {
142        match self {
143            Self::Rustfft => "rustfft",
144            Self::RlxOpFft => "rlx_op_fft",
145            Self::RlxOpIfft => "rlx_op_ifft",
146            Self::ButterflyEager => "butterfly_eager",
147            Self::ButterflyCompiled => "butterfly_compiled",
148            Self::StockhamEager => "stockham_eager",
149            Self::StockhamCompiled => "stockham_compiled",
150            Self::FusedSpectralEager => "fused_spectral_eager",
151            Self::FusedSpectralCompiled => "fused_spectral_compiled",
152            Self::ButterflyQ8 => "butterfly_q8",
153            Self::ButterflyUnitary => "butterfly_unitary",
154            Self::DomainTwiddle => "domain_twiddle",
155            Self::WelchRustfft => "welch_rustfft",
156            Self::WelchRlxOpFft => "welch_rlx_op_fft",
157            Self::WelchButterflyEager => "welch_butterfly_eager",
158            Self::WelchButterflyCompiled => "welch_butterfly_compiled",
159        }
160    }
161
162    pub fn supports_inverse(self) -> bool {
163        matches!(self, Self::Rustfft | Self::RlxOpIfft | Self::ButterflyEager)
164    }
165
166    pub fn needs_training(self) -> bool {
167        matches!(self, Self::ButterflyUnitary | Self::DomainTwiddle)
168    }
169}
170
171pub struct VariantState {
172    pub twiddles: Vec<f32>,
173    pub q8: Option<Q8Twiddles>,
174    pub unitary: Option<UnitaryWeights>,
175    pub mask: Vec<f32>,
176    compiled_butterfly: Option<CompiledGraph>,
177    compiled_stockham: Option<CompiledGraph>,
178    compiled_fused: Option<CompiledGraph>,
179    compiled_rlx: Option<CompiledGraph>,
180    compiled_rlx_inv: Option<CompiledGraph>,
181    compiled_welch_rlx: Option<CompiledGraph>,
182    compiled_welch_butterfly: Option<CompiledGraph>,
183    spectrum_block: Vec<f32>,
184    inverse_spectrum: Vec<f32>,
185}
186
187impl VariantState {
188    pub fn new(cfg: &FftLearnConfig) -> Self {
189        Self {
190            twiddles: exact_twiddles(cfg),
191            q8: None,
192            unitary: None,
193            mask: unit_mask(cfg.n_fft),
194            compiled_butterfly: None,
195            compiled_stockham: None,
196            compiled_fused: None,
197            compiled_rlx: None,
198            compiled_rlx_inv: None,
199            compiled_welch_rlx: None,
200            compiled_welch_butterfly: None,
201            spectrum_block: Vec::new(),
202            inverse_spectrum: Vec::new(),
203        }
204    }
205
206    pub fn set_inverse_input_block(&mut self, block: Vec<f32>) {
207        self.spectrum_block = block;
208    }
209
210    pub fn set_inverse_spectrum(&mut self, spectrum: Vec<f32>) {
211        self.inverse_spectrum = spectrum;
212    }
213
214    pub fn prepare(
215        &mut self,
216        variant: FftVariantId,
217        cfg: &FftLearnConfig,
218        device: Device,
219        train_steps: usize,
220        seed: u64,
221    ) -> Result<()> {
222        match variant {
223            FftVariantId::ButterflyQ8 => {
224                self.q8 = Some(Q8Twiddles::from_f32(&self.twiddles));
225            }
226            FftVariantId::ButterflyUnitary => {
227                if cfg.n_fft <= 64 && cfg.batch <= 8 && train_steps > 0 {
228                    let (w, _) = train_unitary_quick(cfg, train_steps.min(25), 1e-3, seed)?;
229                    self.unitary = Some(w);
230                } else {
231                    self.unitary = Some(UnitaryWeights::exact_init(cfg));
232                }
233            }
234            FftVariantId::DomainTwiddle => {
235                let (tw, _) = train_domain_twiddles(cfg, train_steps, 5e-4, seed)?;
236                self.twiddles = tw;
237            }
238            FftVariantId::ButterflyCompiled if self.compiled_butterfly.is_none() => {
239                self.compiled_butterfly = Some(compile_butterfly(cfg, device, &self.twiddles)?);
240            }
241            FftVariantId::StockhamCompiled if self.compiled_stockham.is_none() => {
242                self.compiled_stockham = Some(compile_stockham(cfg, device, &self.twiddles)?);
243            }
244            FftVariantId::FusedSpectralCompiled if self.compiled_fused.is_none() => {
245                self.compiled_fused = Some(compile_fused(cfg, device, &self.mask)?);
246            }
247            FftVariantId::RlxOpFft if self.compiled_rlx.is_none() => {
248                self.compiled_rlx = Some(crate::rlx_fft::compile_rlx_fft(
249                    cfg,
250                    crate::config::TransformDir::Forward,
251                    device,
252                )?);
253            }
254            FftVariantId::RlxOpIfft if self.compiled_rlx_inv.is_none() => {
255                self.compiled_rlx_inv = Some(crate::rlx_fft::compile_rlx_fft(
256                    cfg,
257                    crate::config::TransformDir::Inverse,
258                    device,
259                )?);
260            }
261            FftVariantId::WelchRlxOpFft if self.compiled_welch_rlx.is_none() => {
262                let params = WelchParams::for_n_fft(cfg.n_fft);
263                self.compiled_welch_rlx = Some(compile_welch_rlx_fft(cfg.batch, params, device)?);
264            }
265            FftVariantId::WelchButterflyCompiled if self.compiled_welch_butterfly.is_none() => {
266                let params = WelchParams::for_n_fft(cfg.n_fft);
267                let welch_cfg = FftLearnConfig::new(cfg.n_fft, cfg.batch * params.n_segments)?;
268                self.compiled_welch_butterfly =
269                    Some(compile_butterfly(&welch_cfg, device, &self.twiddles)?);
270            }
271            _ => {}
272        }
273        Ok(())
274    }
275
276    pub fn forward(
277        &mut self,
278        variant: FftVariantId,
279        signal: &[f32],
280        cfg: &FftLearnConfig,
281    ) -> Result<Vec<f32>> {
282        let n = cfg.n_fft;
283        let batch = cfg.batch;
284        match variant {
285            FftVariantId::Rustfft => fft_real_batch(signal, batch, n),
286            FftVariantId::RlxOpFft => {
287                let exec = self.compiled_rlx.as_mut().expect("rlx compiled");
288                Ok(crate::rlx_fft::rlx_fft_forward(exec, signal, batch, n))
289            }
290            FftVariantId::ButterflyEager | FftVariantId::DomainTwiddle => {
291                butterfly_forward_real_batch(signal, &self.twiddles, batch, n)
292            }
293            FftVariantId::ButterflyCompiled => {
294                let exec = self
295                    .compiled_butterfly
296                    .as_mut()
297                    .expect("butterfly compiled");
298                Ok(exec.run(&[("signal", signal)]).remove(0))
299            }
300            FftVariantId::StockhamEager => {
301                stockham_forward_real_batch(signal, &self.twiddles, batch, n)
302            }
303            FftVariantId::StockhamCompiled => {
304                let exec = self.compiled_stockham.as_mut().expect("stockham compiled");
305                Ok(exec.run(&[("signal", signal)]).remove(0))
306            }
307            FftVariantId::FusedSpectralEager => {
308                fused_spectral_eager(signal, &self.twiddles, &self.mask, batch, n)
309            }
310            FftVariantId::FusedSpectralCompiled => {
311                let exec = self.compiled_fused.as_mut().expect("fused compiled");
312                Ok(exec.run(&[("signal", signal)]).remove(0))
313            }
314            FftVariantId::ButterflyQ8 => self
315                .q8
316                .as_ref()
317                .expect("q8")
318                .forward_real_batch(signal, batch, n),
319            FftVariantId::ButterflyUnitary => self
320                .unitary
321                .as_ref()
322                .expect("unitary")
323                .forward_real_batch(signal, batch, n),
324            FftVariantId::RlxOpIfft => bail!("rlx_op_ifft is inverse-only; call inverse()"),
325            FftVariantId::WelchRustfft
326            | FftVariantId::WelchRlxOpFft
327            | FftVariantId::WelchButterflyEager
328            | FftVariantId::WelchButterflyCompiled => {
329                bail!("{} is welch-only; call welch()", variant.label())
330            }
331        }
332    }
333
334    pub fn welch(
335        &mut self,
336        variant: FftVariantId,
337        signal: &[f32],
338        cfg: &FftLearnConfig,
339    ) -> Result<Vec<f32>> {
340        let params = WelchParams::for_n_fft(cfg.n_fft);
341        let batch = cfg.batch;
342        match variant {
343            FftVariantId::WelchRustfft => welch_rustfft(signal, batch, params),
344            FftVariantId::WelchRlxOpFft => {
345                let exec = self
346                    .compiled_welch_rlx
347                    .as_mut()
348                    .expect("welch rlx compiled");
349                welch_rlx_op_fft(exec, signal, batch, params)
350            }
351            FftVariantId::WelchButterflyEager => {
352                welch_butterfly(signal, &self.twiddles, batch, params)
353            }
354            FftVariantId::WelchButterflyCompiled => {
355                let window = crate::welch::hann_window(params.n_fft);
356                let segs = crate::welch::welch_windowed_segments(signal, batch, params, &window)?;
357                let exec = self
358                    .compiled_welch_butterfly
359                    .as_mut()
360                    .expect("welch butterfly compiled");
361                let spec = exec.run(&[("signal", &segs)]).remove(0);
362                Ok(crate::welch::average_welch_psd(
363                    &spec,
364                    batch,
365                    params.n_segments,
366                    params.n_fft,
367                ))
368            }
369            other => bail!("variant {} has no welch path", other.label()),
370        }
371    }
372
373    pub fn inverse(&mut self, variant: FftVariantId, cfg: &FftLearnConfig) -> Result<Vec<f32>> {
374        let n = cfg.n_fft;
375        let batch = cfg.batch;
376        match variant {
377            FftVariantId::Rustfft => {
378                crate::reference::ifft_complex_batch(&self.inverse_spectrum, batch, n)
379            }
380            FftVariantId::RlxOpIfft => {
381                let exec = self.compiled_rlx_inv.as_mut().expect("rlx inv compiled");
382                Ok(crate::rlx_fft::rlx_fft_inverse_block(
383                    exec,
384                    &self.spectrum_block,
385                    batch,
386                    n,
387                ))
388            }
389            FftVariantId::ButterflyEager => crate::butterfly::butterfly_inverse_complex_batch(
390                &self.inverse_spectrum,
391                &self.twiddles,
392                batch,
393                n,
394            ),
395            other => bail!("variant {} has no inverse path", other.label()),
396        }
397    }
398}
399
400fn compile_butterfly(
401    cfg: &FftLearnConfig,
402    device: Device,
403    twiddles: &[f32],
404) -> Result<CompiledGraph> {
405    use crate::weights::WeightStore;
406    let built = build_butterfly_forward_graph(cfg)?;
407    let store = WeightStore::from_twiddles(twiddles, cfg.n_fft);
408    let mut compiled = crate::compile::try_compile_graph(device, built.graph)?;
409    store.apply_butterfly(&mut compiled, cfg.batch, cfg.n_fft);
410    Ok(compiled)
411}
412
413fn compile_stockham(
414    cfg: &FftLearnConfig,
415    device: Device,
416    twiddles: &[f32],
417) -> Result<CompiledGraph> {
418    use crate::weights::WeightStore;
419    let (graph, _names) = build_stockham_forward_graph(cfg)?;
420    let store = WeightStore::from_twiddles(twiddles, cfg.n_fft);
421    let mut compiled = crate::compile::try_compile_graph(device, graph)?;
422    store.apply_butterfly(&mut compiled, cfg.batch, cfg.n_fft);
423    Ok(compiled)
424}
425
426fn compile_fused(cfg: &FftLearnConfig, device: Device, mask: &[f32]) -> Result<CompiledGraph> {
427    let (graph, names) = build_fused_spectral_graph(cfg)?;
428    let mut compiled = crate::compile::try_compile_graph(device, graph)?;
429    for (i, name) in names.iter().enumerate() {
430        compiled.set_param(name, &[mask[i]]);
431    }
432    Ok(compiled)
433}
434
435pub fn variants_for_direction(with_compiled: bool, forward: bool) -> Vec<FftVariantId> {
436    let all = if with_compiled {
437        FftVariantId::all()
438    } else {
439        FftVariantId::all_eager()
440    };
441    all.iter()
442        .copied()
443        .filter(|v| {
444            if v.is_welch() {
445                return false;
446            }
447            if forward {
448                !matches!(v, FftVariantId::RlxOpIfft)
449            } else {
450                v.supports_inverse()
451            }
452        })
453        .collect()
454}
455
456pub fn bench_variant_ms(
457    state: &mut VariantState,
458    variant: FftVariantId,
459    cfg: &FftLearnConfig,
460    signal: &[f32],
461    iters: usize,
462) -> Result<f64> {
463    let _ = state.forward(variant, signal, cfg)?;
464    let t0 = Instant::now();
465    for _ in 0..iters {
466        state.forward(variant, signal, cfg)?;
467    }
468    Ok(t0.elapsed().as_secs_f64() * 1000.0 / iters as f64)
469}
470
471pub fn variants_for_welch(with_compiled: bool) -> Vec<FftVariantId> {
472    FftVariantId::welch_variants(with_compiled).to_vec()
473}
474
475pub fn bench_variant_ms_welch(
476    state: &mut VariantState,
477    variant: FftVariantId,
478    cfg: &FftLearnConfig,
479    signal: &[f32],
480    iters: usize,
481) -> Result<f64> {
482    let _ = state.welch(variant, signal, cfg)?;
483    let t0 = Instant::now();
484    for _ in 0..iters {
485        state.welch(variant, signal, cfg)?;
486    }
487    Ok(t0.elapsed().as_secs_f64() * 1000.0 / iters as f64)
488}
489
490pub fn variant_welch_error(
491    state: &mut VariantState,
492    variant: FftVariantId,
493    cfg: &FftLearnConfig,
494    signal: &[f32],
495) -> Result<f32> {
496    if matches!(variant, FftVariantId::WelchRustfft) {
497        return Ok(0.0);
498    }
499    let pred = state.welch(variant, signal, cfg)?;
500    let params = WelchParams::for_n_fft(cfg.n_fft);
501    let target = welch_rustfft(signal, cfg.batch, params)?;
502    Ok(max_abs_error(&pred, &target))
503}
504
505pub fn fixed_ablation_welch_signal(seed: u64, batch: usize, n_fft: usize) -> Vec<f32> {
506    let params = WelchParams::for_n_fft(n_fft);
507    let frame = params.frame_len();
508    let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
509    random_batch(&mut rng, batch, frame)
510}
511
512pub fn bench_variant_ms_inverse(
513    state: &mut VariantState,
514    variant: FftVariantId,
515    cfg: &FftLearnConfig,
516    iters: usize,
517) -> Result<f64> {
518    let _ = state.inverse(variant, cfg)?;
519    let t0 = Instant::now();
520    for _ in 0..iters {
521        state.inverse(variant, cfg)?;
522    }
523    Ok(t0.elapsed().as_secs_f64() * 1000.0 / iters as f64)
524}
525
526pub fn variant_spectrum_error(
527    state: &mut VariantState,
528    variant: FftVariantId,
529    cfg: &FftLearnConfig,
530    signal: &[f32],
531) -> Result<f32> {
532    if matches!(
533        variant,
534        FftVariantId::FusedSpectralEager | FftVariantId::FusedSpectralCompiled
535    ) {
536        return Ok(0.0);
537    }
538    let pred = state.forward(variant, signal, cfg)?;
539    let target = fft_real_batch(signal, cfg.batch, cfg.n_fft)?;
540    Ok(max_abs_error(&pred, &target))
541}
542
543pub fn variant_inverse_error(
544    state: &mut VariantState,
545    variant: FftVariantId,
546    cfg: &FftLearnConfig,
547) -> Result<f32> {
548    let pred = state.inverse(variant, cfg)?;
549    let target =
550        crate::reference::ifft_complex_batch(&state.inverse_spectrum, cfg.batch, cfg.n_fft)?;
551    Ok(max_abs_error(&pred, &target))
552}
553
554pub fn fixed_ablation_signal(seed: u64, batch: usize, n_fft: usize) -> Vec<f32> {
555    let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
556    random_batch(&mut rng, batch, n_fft)
557}
558
559pub fn fixed_ablation_spectrum(seed: u64, batch: usize, n_fft: usize) -> Vec<f32> {
560    let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
561    crate::train::random_complex_batch(&mut rng, batch, n_fft)
562}
563
564pub fn ensure_variant_ready(variant: FftVariantId, device: Device) -> Result<()> {
565    if matches!(
566        variant,
567        FftVariantId::ButterflyCompiled
568            | FftVariantId::StockhamCompiled
569            | FftVariantId::FusedSpectralCompiled
570            | FftVariantId::RlxOpFft
571            | FftVariantId::RlxOpIfft
572            | FftVariantId::WelchRlxOpFft
573            | FftVariantId::WelchButterflyCompiled
574    ) && device == Device::Cpu
575    {
576        return Ok(());
577    }
578    if matches!(variant, FftVariantId::Rustfft) {
579        return Ok(());
580    }
581    Ok(())
582}
583
584pub fn skip_on_device(variant: FftVariantId, device: Device) -> bool {
585    let _ = (variant, device);
586    false
587}
588
589pub fn validate_variant_output(
590    variant: FftVariantId,
591    pred_len: usize,
592    cfg: &FftLearnConfig,
593) -> Result<()> {
594    let expected = cfg.batch * cfg.n_fft * 2;
595    if pred_len != expected {
596        bail!(
597            "variant {} output len {pred_len} != expected {expected}",
598            variant.label()
599        );
600    }
601    Ok(())
602}