1use crate::config::TransformDir;
19use crate::twiddle::{TwiddleSet, twiddle_index, twiddle_name_set};
20use anyhow::{Context, Result, ensure};
21use safetensors::SafeTensors;
22use std::collections::HashMap;
23use std::path::Path;
24
25#[derive(Debug, Clone, Default)]
26pub struct WeightStore(pub HashMap<String, Vec<f32>>);
27
28impl WeightStore {
29 pub fn from_twiddles(twiddles: &[f32], n_fft: usize) -> Self {
30 Self::from_twiddles_dir(twiddles, n_fft, TransformDir::Forward)
31 }
32
33 pub fn from_twiddles_dir(twiddles: &[f32], n_fft: usize, dir: TransformDir) -> Self {
34 let _ = dir;
35 Self::from_twiddles_set(twiddles, n_fft, TwiddleSet::Shared)
36 }
37
38 pub fn from_twiddles_set(twiddles: &[f32], n_fft: usize, set: TwiddleSet) -> Self {
39 let half = n_fft / 2;
40 let stages = n_fft.trailing_zeros() as usize;
41 let mut store = Self::default();
42 for s in 0..stages {
43 for b in 0..half {
44 let base = twiddle_index(s, b, half, 0);
45 store
46 .0
47 .insert(twiddle_name_set(set, s, b, "re"), vec![twiddles[base]]);
48 store
49 .0
50 .insert(twiddle_name_set(set, s, b, "im"), vec![twiddles[base + 1]]);
51 }
52 }
53 store
54 }
55
56 pub fn to_twiddles(&self, n_fft: usize) -> Result<Vec<f32>> {
57 self.to_twiddles_dir(n_fft, TransformDir::Forward)
58 }
59
60 pub fn to_twiddles_dir(&self, n_fft: usize, dir: TransformDir) -> Result<Vec<f32>> {
61 let _ = dir;
62 self.to_twiddles_set(n_fft, TwiddleSet::Shared)
63 }
64
65 pub fn to_twiddles_set(&self, n_fft: usize, set: TwiddleSet) -> Result<Vec<f32>> {
66 let half = n_fft / 2;
67 let stages = n_fft.trailing_zeros() as usize;
68 let mut out = vec![0f32; stages * half * 2];
69 for s in 0..stages {
70 for b in 0..half {
71 let base = twiddle_index(s, b, half, 0);
72 let re_name = twiddle_name_set(set, s, b, "re");
73 let im_name = twiddle_name_set(set, s, b, "im");
74 out[base] = *self
75 .0
76 .get(&re_name)
77 .with_context(|| format!("missing twiddle param {re_name}"))?
78 .first()
79 .context("empty twiddle re")?;
80 out[base + 1] = *self
81 .0
82 .get(&im_name)
83 .with_context(|| format!("missing twiddle param {im_name}"))?
84 .first()
85 .context("empty twiddle im")?;
86 }
87 }
88 Ok(out)
89 }
90
91 pub fn merge(&self, other: &Self) -> Self {
92 let mut out = self.clone();
93 for (k, v) in &other.0 {
94 out.0.insert(k.clone(), v.clone());
95 }
96 out
97 }
98
99 pub fn apply(&self, exec: &mut rlx_runtime::CompiledGraph) {
100 for (name, data) in &self.0 {
101 exec.set_param(name, data);
102 }
103 }
104
105 pub fn apply_butterfly(
107 &self,
108 exec: &mut rlx_runtime::CompiledGraph,
109 _batch: usize,
110 _n_fft: usize,
111 ) {
112 self.apply(exec);
113 }
114
115 pub fn apply_butterfly_for_gates(
117 &self,
118 exec: &mut rlx_runtime::CompiledGraph,
119 n_fft: usize,
120 gates: &[i8],
121 ) {
122 use crate::pruned::{gate_count, gate_index};
123 use crate::ternary_gates::GateMode;
124 let half = n_fft / 2;
125 let stages = n_fft.trailing_zeros() as usize;
126 if gates.len() < gate_count(n_fft) {
127 return;
128 }
129 for s in 0..stages {
130 for b in 0..half {
131 let gi = gate_index(s, b, half);
132 if GateMode::from_i8(gates[gi]) == GateMode::Skip {
133 continue;
134 }
135 let re_name = twiddle_name_set(TwiddleSet::Shared, s, b, "re");
136 let im_name = twiddle_name_set(TwiddleSet::Shared, s, b, "im");
137 if let Some(v) = self.0.get(&re_name) {
138 exec.set_param(&re_name, v);
139 }
140 if let Some(v) = self.0.get(&im_name) {
141 exec.set_param(&im_name, v);
142 }
143 }
144 }
145 }
146}
147
148#[derive(Debug, Clone, Default)]
150pub struct EncDecWeights {
151 pub encoder: WeightStore,
152 pub decoder: WeightStore,
153}
154
155impl EncDecWeights {
156 pub fn from_twiddles(encoder: &[f32], decoder: &[f32], n_fft: usize) -> Self {
157 Self {
158 encoder: WeightStore::from_twiddles_set(encoder, n_fft, TwiddleSet::Encoder),
159 decoder: WeightStore::from_twiddles_set(decoder, n_fft, TwiddleSet::Decoder),
160 }
161 }
162
163 pub fn merged(&self) -> WeightStore {
164 self.encoder.merge(&self.decoder)
165 }
166
167 pub fn encoder_twiddles(&self, n_fft: usize) -> Result<Vec<f32>> {
168 self.encoder.to_twiddles_set(n_fft, TwiddleSet::Encoder)
169 }
170
171 pub fn decoder_twiddles(&self, n_fft: usize) -> Result<Vec<f32>> {
172 self.decoder.to_twiddles_set(n_fft, TwiddleSet::Decoder)
173 }
174
175 pub fn from_merged(store: &WeightStore, n_fft: usize) -> Result<Self> {
176 Ok(Self {
177 encoder: {
178 let tw = store.to_twiddles_set(n_fft, TwiddleSet::Encoder)?;
179 WeightStore::from_twiddles_set(&tw, n_fft, TwiddleSet::Encoder)
180 },
181 decoder: {
182 let tw = store.to_twiddles_set(n_fft, TwiddleSet::Decoder)?;
183 WeightStore::from_twiddles_set(&tw, n_fft, TwiddleSet::Decoder)
184 },
185 })
186 }
187}
188
189pub fn export_safetensors(path: &Path, weights: &WeightStore) -> Result<()> {
190 ensure!(!weights.0.is_empty(), "no weights to export");
191 let mut storages: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
192 for (name, data) in &weights.0 {
193 let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
194 storages.push((name.clone(), bytes, vec![data.len()]));
195 }
196 let mut views: HashMap<String, safetensors::tensor::TensorView> = HashMap::new();
197 for (name, bytes, shape) in &storages {
198 views.insert(
199 name.clone(),
200 safetensors::tensor::TensorView::new(safetensors::Dtype::F32, shape.clone(), bytes)
201 .context("tensor view")?,
202 );
203 }
204 if let Some(parent) = path.parent() {
205 std::fs::create_dir_all(parent)?;
206 }
207 safetensors::serialize_to_file(&views, None, path)
208 .with_context(|| format!("write {}", path.display()))?;
209 Ok(())
210}
211
212pub fn load_safetensors(path: &Path) -> Result<WeightStore> {
213 let bytes = std::fs::read(path)?;
214 let st = SafeTensors::deserialize(&bytes)?;
215 let mut store = WeightStore::default();
216 for name in st.names() {
217 let view = st.tensor(name)?;
218 ensure!(
219 view.dtype() == safetensors::Dtype::F32,
220 "expected f32 weights in {path:?}, got {:?} for {name}",
221 view.dtype()
222 );
223 let data: Vec<f32> = view
224 .data()
225 .chunks_exact(4)
226 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
227 .collect();
228 store.0.insert(name.to_string(), data);
229 }
230 Ok(store)
231}