Skip to main content

rlx_fft/
twiddle.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Twiddle-factor initialization for butterfly stages.
17
18use crate::config::{FftLearnConfig, TransformDir};
19use std::f32::consts::TAU;
20
21/// Which learned butterfly a twiddle belongs to.
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum TwiddleSet {
24    /// Shared forward twiddles (`twiddle.s*.b*.*`).
25    Shared,
26    /// Encoder / forward FFT twiddles (`encoder.twiddle.s*.b*.*`).
27    Encoder,
28    /// Decoder / inverse FFT twiddles (`decoder.twiddle.s*.b*.*`).
29    Decoder,
30}
31
32/// Flat twiddle buffer: for each stage `s` and butterfly `b`, store `(re, im)`.
33/// Length = `num_stages * (n_fft/2) * 2`.
34pub fn exact_twiddles(cfg: &FftLearnConfig) -> Vec<f32> {
35    let half = cfg.n_fft / 2;
36    let stages = cfg.num_stages();
37    let mut out = vec![0f32; stages * half * 2];
38    for s in 0..stages {
39        let stride = 1usize << s;
40        for b in 0..half {
41            let k = b % stride;
42            let m = (2 * stride) as f32;
43            let exp = -TAU * k as f32 / m;
44            let base = (s * half + b) * 2;
45            out[base] = exp.cos();
46            out[base + 1] = exp.sin();
47        }
48    }
49    out
50}
51
52pub fn exact_twiddles_dir(cfg: &FftLearnConfig, dir: TransformDir) -> Vec<f32> {
53    let _ = dir;
54    exact_twiddles(cfg)
55}
56
57pub fn twiddle_name(stage: usize, butterfly: usize, part: &str) -> String {
58    twiddle_name_set(TwiddleSet::Shared, stage, butterfly, part)
59}
60
61pub fn twiddle_name_set(set: TwiddleSet, stage: usize, butterfly: usize, part: &str) -> String {
62    let prefix = match set {
63        TwiddleSet::Shared => "twiddle",
64        TwiddleSet::Encoder => "encoder.twiddle",
65        TwiddleSet::Decoder => "decoder.twiddle",
66    };
67    format!("{prefix}.s{stage}.b{butterfly}.{part}")
68}
69
70pub fn twiddle_name_dir(dir: TransformDir, stage: usize, butterfly: usize, part: &str) -> String {
71    let _ = dir;
72    twiddle_name(stage, butterfly, part)
73}
74
75pub fn twiddle_index(stage: usize, butterfly: usize, half: usize, part: usize) -> usize {
76    (stage * half + butterfly) * 2 + part
77}
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82
83    #[test]
84    fn twiddle_magnitude_is_one() {
85        let cfg = FftLearnConfig::new(64, 1).unwrap();
86        let tw = exact_twiddles(&cfg);
87        for chunk in tw.chunks(2) {
88            let mag = (chunk[0] * chunk[0] + chunk[1] * chunk[1]).sqrt();
89            assert!((mag - 1.0).abs() < 1e-6, "mag={mag}");
90        }
91    }
92}