Skip to main content

rlx_fft/
weights.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Named twiddle parameters for training and compiled inference.
17
18use crate::config::TransformDir;
19use crate::twiddle::{TwiddleSet, twiddle_index, twiddle_name_set};
20use anyhow::{Context, Result, ensure};
21use safetensors::SafeTensors;
22use std::collections::HashMap;
23use std::path::Path;
24
25#[derive(Debug, Clone, Default)]
26pub struct WeightStore(pub HashMap<String, Vec<f32>>);
27
28impl WeightStore {
29    pub fn from_twiddles(twiddles: &[f32], n_fft: usize) -> Self {
30        Self::from_twiddles_dir(twiddles, n_fft, TransformDir::Forward)
31    }
32
33    pub fn from_twiddles_dir(twiddles: &[f32], n_fft: usize, dir: TransformDir) -> Self {
34        let _ = dir;
35        Self::from_twiddles_set(twiddles, n_fft, TwiddleSet::Shared)
36    }
37
38    pub fn from_twiddles_set(twiddles: &[f32], n_fft: usize, set: TwiddleSet) -> Self {
39        let half = n_fft / 2;
40        let stages = n_fft.trailing_zeros() as usize;
41        let mut store = Self::default();
42        for s in 0..stages {
43            for b in 0..half {
44                let base = twiddle_index(s, b, half, 0);
45                store
46                    .0
47                    .insert(twiddle_name_set(set, s, b, "re"), vec![twiddles[base]]);
48                store
49                    .0
50                    .insert(twiddle_name_set(set, s, b, "im"), vec![twiddles[base + 1]]);
51            }
52        }
53        store
54    }
55
56    pub fn to_twiddles(&self, n_fft: usize) -> Result<Vec<f32>> {
57        self.to_twiddles_dir(n_fft, TransformDir::Forward)
58    }
59
60    pub fn to_twiddles_dir(&self, n_fft: usize, dir: TransformDir) -> Result<Vec<f32>> {
61        let _ = dir;
62        self.to_twiddles_set(n_fft, TwiddleSet::Shared)
63    }
64
65    pub fn to_twiddles_set(&self, n_fft: usize, set: TwiddleSet) -> Result<Vec<f32>> {
66        let half = n_fft / 2;
67        let stages = n_fft.trailing_zeros() as usize;
68        let mut out = vec![0f32; stages * half * 2];
69        for s in 0..stages {
70            for b in 0..half {
71                let base = twiddle_index(s, b, half, 0);
72                let re_name = twiddle_name_set(set, s, b, "re");
73                let im_name = twiddle_name_set(set, s, b, "im");
74                out[base] = *self
75                    .0
76                    .get(&re_name)
77                    .with_context(|| format!("missing twiddle param {re_name}"))?
78                    .first()
79                    .context("empty twiddle re")?;
80                out[base + 1] = *self
81                    .0
82                    .get(&im_name)
83                    .with_context(|| format!("missing twiddle param {im_name}"))?
84                    .first()
85                    .context("empty twiddle im")?;
86            }
87        }
88        Ok(out)
89    }
90
91    pub fn merge(&self, other: &Self) -> Self {
92        let mut out = self.clone();
93        for (k, v) in &other.0 {
94            out.0.insert(k.clone(), v.clone());
95        }
96        out
97    }
98
99    pub fn apply(&self, exec: &mut rlx_runtime::CompiledGraph) {
100        for (name, data) in &self.0 {
101            exec.set_param(name, data);
102        }
103    }
104
105    /// Bind twiddles for a compiled butterfly graph.
106    pub fn apply_butterfly(
107        &self,
108        exec: &mut rlx_runtime::CompiledGraph,
109        _batch: usize,
110        _n_fft: usize,
111    ) {
112        self.apply(exec);
113    }
114
115    /// Bind twiddles only for non-skip ternary butterflies present in a pruned graph.
116    pub fn apply_butterfly_for_gates(
117        &self,
118        exec: &mut rlx_runtime::CompiledGraph,
119        n_fft: usize,
120        gates: &[i8],
121    ) {
122        use crate::pruned::{gate_count, gate_index};
123        use crate::ternary_gates::GateMode;
124        let half = n_fft / 2;
125        let stages = n_fft.trailing_zeros() as usize;
126        if gates.len() < gate_count(n_fft) {
127            return;
128        }
129        for s in 0..stages {
130            for b in 0..half {
131                let gi = gate_index(s, b, half);
132                if GateMode::from_i8(gates[gi]) == GateMode::Skip {
133                    continue;
134                }
135                let re_name = twiddle_name_set(TwiddleSet::Shared, s, b, "re");
136                let im_name = twiddle_name_set(TwiddleSet::Shared, s, b, "im");
137                if let Some(v) = self.0.get(&re_name) {
138                    exec.set_param(&re_name, v);
139                }
140                if let Some(v) = self.0.get(&im_name) {
141                    exec.set_param(&im_name, v);
142                }
143            }
144        }
145    }
146}
147
148/// Encoder + decoder twiddle checkpoints.
149#[derive(Debug, Clone, Default)]
150pub struct EncDecWeights {
151    pub encoder: WeightStore,
152    pub decoder: WeightStore,
153}
154
155impl EncDecWeights {
156    pub fn from_twiddles(encoder: &[f32], decoder: &[f32], n_fft: usize) -> Self {
157        Self {
158            encoder: WeightStore::from_twiddles_set(encoder, n_fft, TwiddleSet::Encoder),
159            decoder: WeightStore::from_twiddles_set(decoder, n_fft, TwiddleSet::Decoder),
160        }
161    }
162
163    pub fn merged(&self) -> WeightStore {
164        self.encoder.merge(&self.decoder)
165    }
166
167    pub fn encoder_twiddles(&self, n_fft: usize) -> Result<Vec<f32>> {
168        self.encoder.to_twiddles_set(n_fft, TwiddleSet::Encoder)
169    }
170
171    pub fn decoder_twiddles(&self, n_fft: usize) -> Result<Vec<f32>> {
172        self.decoder.to_twiddles_set(n_fft, TwiddleSet::Decoder)
173    }
174
175    pub fn from_merged(store: &WeightStore, n_fft: usize) -> Result<Self> {
176        Ok(Self {
177            encoder: {
178                let tw = store.to_twiddles_set(n_fft, TwiddleSet::Encoder)?;
179                WeightStore::from_twiddles_set(&tw, n_fft, TwiddleSet::Encoder)
180            },
181            decoder: {
182                let tw = store.to_twiddles_set(n_fft, TwiddleSet::Decoder)?;
183                WeightStore::from_twiddles_set(&tw, n_fft, TwiddleSet::Decoder)
184            },
185        })
186    }
187}
188
189pub fn export_safetensors(path: &Path, weights: &WeightStore) -> Result<()> {
190    ensure!(!weights.0.is_empty(), "no weights to export");
191    let mut storages: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
192    for (name, data) in &weights.0 {
193        let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
194        storages.push((name.clone(), bytes, vec![data.len()]));
195    }
196    let mut views: HashMap<String, safetensors::tensor::TensorView> = HashMap::new();
197    for (name, bytes, shape) in &storages {
198        views.insert(
199            name.clone(),
200            safetensors::tensor::TensorView::new(safetensors::Dtype::F32, shape.clone(), bytes)
201                .context("tensor view")?,
202        );
203    }
204    if let Some(parent) = path.parent() {
205        std::fs::create_dir_all(parent)?;
206    }
207    safetensors::serialize_to_file(&views, None, path)
208        .with_context(|| format!("write {}", path.display()))?;
209    Ok(())
210}
211
212pub fn load_safetensors(path: &Path) -> Result<WeightStore> {
213    let bytes = std::fs::read(path)?;
214    let st = SafeTensors::deserialize(&bytes)?;
215    let mut store = WeightStore::default();
216    for name in st.names() {
217        let view = st.tensor(name)?;
218        ensure!(
219            view.dtype() == safetensors::Dtype::F32,
220            "expected f32 weights in {path:?}, got {:?} for {name}",
221            view.dtype()
222        );
223        let data: Vec<f32> = view
224            .data()
225            .chunks_exact(4)
226            .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
227            .collect();
228        store.0.insert(name.to_string(), data);
229    }
230    Ok(store)
231}