Skip to main content

rlx_fft/
runner.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Compiled inference session for learned FFT / IFFT.
17
18use crate::butterfly::{
19    build_butterfly_forward_graph, build_butterfly_inverse_graph, butterfly_forward_real_batch,
20    butterfly_inverse_complex_batch,
21};
22use crate::config::{FftLearnConfig, TransformDir};
23use crate::reference::{fft_real_batch, ifft_complex_batch, max_abs_error};
24use crate::twiddle::exact_twiddles;
25use crate::weights::WeightStore;
26use anyhow::{Result, bail};
27use rlx_runtime::{CompiledGraph, Device};
28
29pub struct FftLearnRunner {
30    cfg: FftLearnConfig,
31    direction: TransformDir,
32    twiddles: Vec<f32>,
33    compiled: Option<(Device, CompiledGraph)>,
34}
35
36impl FftLearnRunner {
37    pub fn new(cfg: FftLearnConfig) -> Result<Self> {
38        Self::new_dir(cfg, TransformDir::Forward)
39    }
40
41    pub fn new_ifft(cfg: FftLearnConfig) -> Result<Self> {
42        Self::new_dir(cfg, TransformDir::Inverse)
43    }
44
45    pub fn new_dir(cfg: FftLearnConfig, direction: TransformDir) -> Result<Self> {
46        cfg.validate()?;
47        Ok(Self {
48            twiddles: exact_twiddles(&cfg),
49            cfg,
50            direction,
51            compiled: None,
52        })
53    }
54
55    pub fn with_weights(cfg: FftLearnConfig, weights: &WeightStore) -> Result<Self> {
56        Self::with_weights_dir(cfg, weights, TransformDir::Forward)
57    }
58
59    pub fn with_weights_ifft(cfg: FftLearnConfig, weights: &WeightStore) -> Result<Self> {
60        Self::with_weights_dir(cfg, weights, TransformDir::Inverse)
61    }
62
63    pub fn with_weights_dir(
64        cfg: FftLearnConfig,
65        weights: &WeightStore,
66        direction: TransformDir,
67    ) -> Result<Self> {
68        let mut this = Self::new_dir(cfg, direction)?;
69        this.twiddles = weights.to_twiddles(this.cfg.n_fft)?;
70        Ok(this)
71    }
72
73    pub fn load_compiled(&mut self, device: Device) -> Result<()> {
74        let built = if self.direction.is_forward() {
75            build_butterfly_forward_graph(&self.cfg)?
76        } else {
77            build_butterfly_inverse_graph(&self.cfg)?
78        };
79        let store = WeightStore::from_twiddles(&self.twiddles, self.cfg.n_fft);
80        let mut compiled = crate::compile::try_compile_graph(device, built.graph)?;
81        store.apply_butterfly(&mut compiled, self.cfg.batch, self.cfg.n_fft);
82        self.compiled = Some((device, compiled));
83        Ok(())
84    }
85
86    pub fn forward_eager(&self, input: &[f32]) -> Result<Vec<f32>> {
87        if self.direction.is_forward() {
88            butterfly_forward_real_batch(input, &self.twiddles, self.cfg.batch, self.cfg.n_fft)
89        } else {
90            butterfly_inverse_complex_batch(input, &self.twiddles, self.cfg.batch, self.cfg.n_fft)
91        }
92    }
93
94    pub fn forward(&mut self, input: &[f32]) -> Result<Vec<f32>> {
95        if self.compiled.is_some() {
96            self.forward_compiled(input)
97        } else {
98            self.forward_eager(input)
99        }
100    }
101
102    fn forward_compiled(&mut self, input: &[f32]) -> Result<Vec<f32>> {
103        let expected = if self.direction.is_forward() {
104            self.cfg.batch * self.cfg.n_fft
105        } else {
106            self.cfg.batch * self.cfg.n_fft * 2
107        };
108        if input.len() != expected {
109            bail!("input len {} != expected {}", input.len(), expected);
110        }
111        let Some((_, ref mut exec)) = self.compiled else {
112            bail!("compiled session not loaded");
113        };
114        let input_name = if self.direction.is_forward() {
115            "signal"
116        } else {
117            "spectrum"
118        };
119        let outputs = exec.run(&[(input_name, input)]);
120        outputs
121            .into_iter()
122            .next()
123            .ok_or_else(|| anyhow::anyhow!("butterfly graph produced no outputs"))
124    }
125
126    pub fn compare_reference(&self, input: &[f32]) -> Result<(f32, f32)> {
127        let pred = self.forward_eager(input)?;
128        let target = if self.direction.is_forward() {
129            fft_real_batch(input, self.cfg.batch, self.cfg.n_fft)?
130        } else {
131            ifft_complex_batch(input, self.cfg.batch, self.cfg.n_fft)?
132        };
133        Ok((
134            crate::reference::mse(&pred, &target),
135            max_abs_error(&pred, &target),
136        ))
137    }
138
139    pub fn config(&self) -> &FftLearnConfig {
140        &self.cfg
141    }
142
143    pub fn direction(&self) -> TransformDir {
144        self.direction
145    }
146}