use crate::butterfly::{
build_butterfly_forward_graph, build_butterfly_inverse_graph, butterfly_forward_real_batch,
butterfly_inverse_complex_batch,
};
use crate::config::{FftLearnConfig, TransformDir};
use crate::reference::{fft_real_batch, ifft_complex_batch, max_abs_error};
use crate::twiddle::exact_twiddles;
use crate::weights::WeightStore;
use anyhow::{Result, bail};
use rlx_runtime::{CompiledGraph, Device};
pub struct FftLearnRunner {
cfg: FftLearnConfig,
direction: TransformDir,
twiddles: Vec<f32>,
compiled: Option<(Device, CompiledGraph)>,
}
impl FftLearnRunner {
pub fn new(cfg: FftLearnConfig) -> Result<Self> {
Self::new_dir(cfg, TransformDir::Forward)
}
pub fn new_ifft(cfg: FftLearnConfig) -> Result<Self> {
Self::new_dir(cfg, TransformDir::Inverse)
}
pub fn new_dir(cfg: FftLearnConfig, direction: TransformDir) -> Result<Self> {
cfg.validate()?;
Ok(Self {
twiddles: exact_twiddles(&cfg),
cfg,
direction,
compiled: None,
})
}
pub fn with_weights(cfg: FftLearnConfig, weights: &WeightStore) -> Result<Self> {
Self::with_weights_dir(cfg, weights, TransformDir::Forward)
}
pub fn with_weights_ifft(cfg: FftLearnConfig, weights: &WeightStore) -> Result<Self> {
Self::with_weights_dir(cfg, weights, TransformDir::Inverse)
}
pub fn with_weights_dir(
cfg: FftLearnConfig,
weights: &WeightStore,
direction: TransformDir,
) -> Result<Self> {
let mut this = Self::new_dir(cfg, direction)?;
this.twiddles = weights.to_twiddles(this.cfg.n_fft)?;
Ok(this)
}
pub fn load_compiled(&mut self, device: Device) -> Result<()> {
let built = if self.direction.is_forward() {
build_butterfly_forward_graph(&self.cfg)?
} else {
build_butterfly_inverse_graph(&self.cfg)?
};
let store = WeightStore::from_twiddles(&self.twiddles, self.cfg.n_fft);
let mut compiled = crate::compile::try_compile_graph(device, built.graph)?;
store.apply_butterfly(&mut compiled, self.cfg.batch, self.cfg.n_fft);
self.compiled = Some((device, compiled));
Ok(())
}
pub fn forward_eager(&self, input: &[f32]) -> Result<Vec<f32>> {
if self.direction.is_forward() {
butterfly_forward_real_batch(input, &self.twiddles, self.cfg.batch, self.cfg.n_fft)
} else {
butterfly_inverse_complex_batch(input, &self.twiddles, self.cfg.batch, self.cfg.n_fft)
}
}
pub fn forward(&mut self, input: &[f32]) -> Result<Vec<f32>> {
if self.compiled.is_some() {
self.forward_compiled(input)
} else {
self.forward_eager(input)
}
}
fn forward_compiled(&mut self, input: &[f32]) -> Result<Vec<f32>> {
let expected = if self.direction.is_forward() {
self.cfg.batch * self.cfg.n_fft
} else {
self.cfg.batch * self.cfg.n_fft * 2
};
if input.len() != expected {
bail!("input len {} != expected {}", input.len(), expected);
}
let Some((_, ref mut exec)) = self.compiled else {
bail!("compiled session not loaded");
};
let input_name = if self.direction.is_forward() {
"signal"
} else {
"spectrum"
};
let outputs = exec.run(&[(input_name, input)]);
outputs
.into_iter()
.next()
.ok_or_else(|| anyhow::anyhow!("butterfly graph produced no outputs"))
}
pub fn compare_reference(&self, input: &[f32]) -> Result<(f32, f32)> {
let pred = self.forward_eager(input)?;
let target = if self.direction.is_forward() {
fft_real_batch(input, self.cfg.batch, self.cfg.n_fft)?
} else {
ifft_complex_batch(input, self.cfg.batch, self.cfg.n_fft)?
};
Ok((
crate::reference::mse(&pred, &target),
max_abs_error(&pred, &target),
))
}
pub fn config(&self) -> &FftLearnConfig {
&self.cfg
}
pub fn direction(&self) -> TransformDir {
self.direction
}
}