1use crate::butterfly::{
19 build_butterfly_forward_graph, build_butterfly_inverse_graph, butterfly_forward_real_batch,
20 butterfly_inverse_complex_batch,
21};
22use crate::config::{FftLearnConfig, TransformDir};
23use crate::reference::{fft_real_batch, ifft_complex_batch, max_abs_error};
24use crate::twiddle::exact_twiddles;
25use crate::weights::WeightStore;
26use anyhow::{Result, bail};
27use rlx_runtime::{CompiledGraph, Device};
28
29pub struct FftLearnRunner {
30 cfg: FftLearnConfig,
31 direction: TransformDir,
32 twiddles: Vec<f32>,
33 compiled: Option<(Device, CompiledGraph)>,
34}
35
36impl FftLearnRunner {
37 pub fn new(cfg: FftLearnConfig) -> Result<Self> {
38 Self::new_dir(cfg, TransformDir::Forward)
39 }
40
41 pub fn new_ifft(cfg: FftLearnConfig) -> Result<Self> {
42 Self::new_dir(cfg, TransformDir::Inverse)
43 }
44
45 pub fn new_dir(cfg: FftLearnConfig, direction: TransformDir) -> Result<Self> {
46 cfg.validate()?;
47 Ok(Self {
48 twiddles: exact_twiddles(&cfg),
49 cfg,
50 direction,
51 compiled: None,
52 })
53 }
54
55 pub fn with_weights(cfg: FftLearnConfig, weights: &WeightStore) -> Result<Self> {
56 Self::with_weights_dir(cfg, weights, TransformDir::Forward)
57 }
58
59 pub fn with_weights_ifft(cfg: FftLearnConfig, weights: &WeightStore) -> Result<Self> {
60 Self::with_weights_dir(cfg, weights, TransformDir::Inverse)
61 }
62
63 pub fn with_weights_dir(
64 cfg: FftLearnConfig,
65 weights: &WeightStore,
66 direction: TransformDir,
67 ) -> Result<Self> {
68 let mut this = Self::new_dir(cfg, direction)?;
69 this.twiddles = weights.to_twiddles(this.cfg.n_fft)?;
70 Ok(this)
71 }
72
73 pub fn load_compiled(&mut self, device: Device) -> Result<()> {
74 let built = if self.direction.is_forward() {
75 build_butterfly_forward_graph(&self.cfg)?
76 } else {
77 build_butterfly_inverse_graph(&self.cfg)?
78 };
79 let store = WeightStore::from_twiddles(&self.twiddles, self.cfg.n_fft);
80 let mut compiled = crate::compile::try_compile_graph(device, built.graph)?;
81 store.apply_butterfly(&mut compiled, self.cfg.batch, self.cfg.n_fft);
82 self.compiled = Some((device, compiled));
83 Ok(())
84 }
85
86 pub fn forward_eager(&self, input: &[f32]) -> Result<Vec<f32>> {
87 if self.direction.is_forward() {
88 butterfly_forward_real_batch(input, &self.twiddles, self.cfg.batch, self.cfg.n_fft)
89 } else {
90 butterfly_inverse_complex_batch(input, &self.twiddles, self.cfg.batch, self.cfg.n_fft)
91 }
92 }
93
94 pub fn forward(&mut self, input: &[f32]) -> Result<Vec<f32>> {
95 if self.compiled.is_some() {
96 self.forward_compiled(input)
97 } else {
98 self.forward_eager(input)
99 }
100 }
101
102 fn forward_compiled(&mut self, input: &[f32]) -> Result<Vec<f32>> {
103 let expected = if self.direction.is_forward() {
104 self.cfg.batch * self.cfg.n_fft
105 } else {
106 self.cfg.batch * self.cfg.n_fft * 2
107 };
108 if input.len() != expected {
109 bail!("input len {} != expected {}", input.len(), expected);
110 }
111 let Some((_, ref mut exec)) = self.compiled else {
112 bail!("compiled session not loaded");
113 };
114 let input_name = if self.direction.is_forward() {
115 "signal"
116 } else {
117 "spectrum"
118 };
119 let outputs = exec.run(&[(input_name, input)]);
120 outputs
121 .into_iter()
122 .next()
123 .ok_or_else(|| anyhow::anyhow!("butterfly graph produced no outputs"))
124 }
125
126 pub fn compare_reference(&self, input: &[f32]) -> Result<(f32, f32)> {
127 let pred = self.forward_eager(input)?;
128 let target = if self.direction.is_forward() {
129 fft_real_batch(input, self.cfg.batch, self.cfg.n_fft)?
130 } else {
131 ifft_complex_batch(input, self.cfg.batch, self.cfg.n_fft)?
132 };
133 Ok((
134 crate::reference::mse(&pred, &target),
135 max_abs_error(&pred, &target),
136 ))
137 }
138
139 pub fn config(&self) -> &FftLearnConfig {
140 &self.cfg
141 }
142
143 pub fn direction(&self) -> TransformDir {
144 self.direction
145 }
146}