Skip to main content

rlx_whisper/
fused.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Pre-fused encoder/decoder weights (QKV projection + tied logit embed), inspired by
17//! [fast-whisper-burn](https://github.com/AdrianEddy/fast-whisper-burn).
18
19use crate::config::WhisperConfig;
20use crate::weights::WhisperWeightPrefix;
21use anyhow::{Context, Result};
22use std::collections::HashMap;
23
24type WeightEntry = (String, Vec<f32>, Vec<usize>);
25type OptionalBias = Option<WeightEntry>;
26type OptionalBiasList = Vec<OptionalBias>;
27
28/// Fused encoder self-attention QKV per layer (`[d_model, 3*d_model]`).
29#[derive(Debug, Clone)]
30pub struct FusedEncoderWeights {
31    pub layer_qkv_w: Vec<WeightEntry>,
32    pub layer_qkv_b: OptionalBiasList,
33}
34
35impl FusedEncoderWeights {
36    pub fn from_checkpoint(
37        tensors: &HashMap<String, (Vec<f32>, Vec<usize>)>,
38        cfg: &WhisperConfig,
39        pfx: &WhisperWeightPrefix,
40    ) -> Result<Self> {
41        let d = cfg.d_model;
42        let mut layer_qkv_w = Vec::with_capacity(cfg.encoder_layers);
43        let mut layer_qkv_b = Vec::with_capacity(cfg.encoder_layers);
44
45        for i in 0..cfg.encoder_layers {
46            let qw = format_enc_layer(pfx, i, "self_attn.q_proj.weight");
47            let kw = format_enc_layer(pfx, i, "self_attn.k_proj.weight");
48            let vw = format_enc_layer(pfx, i, "self_attn.v_proj.weight");
49            let qb = format_enc_layer(pfx, i, "self_attn.q_proj.bias");
50            let kb = format_enc_layer(pfx, i, "self_attn.k_proj.bias");
51            let vb = format_enc_layer(pfx, i, "self_attn.v_proj.bias");
52
53            let (qw_d, qw_s) = tensors.get(&qw).with_context(|| qw.clone())?;
54            let (kw_d, kw_s) = tensors.get(&kw).with_context(|| kw.clone())?;
55            let (vw_d, vw_s) = tensors.get(&vw).with_context(|| vw.clone())?;
56            ensure_mat(qw_s, d, d)?;
57            ensure_mat(kw_s, d, d)?;
58            ensure_mat(vw_s, d, d)?;
59
60            let w_key = format!("fused.enc.{i}.self_attn.qkv.weight");
61            layer_qkv_w.push((
62                w_key,
63                concat_qkv_weights(qw_d, kw_d, vw_d, d),
64                vec![d, 3 * d],
65            ));
66
67            let b_key = format!("fused.enc.{i}.self_attn.qkv.bias");
68            let bias = match (tensors.get(&qb), tensors.get(&kb), tensors.get(&vb)) {
69                (Some((qb_d, _)), Some((kb_d, _)), Some((vb_d, _))) => {
70                    Some((b_key, concat_qkv_bias(qb_d, kb_d, vb_d, d), vec![3 * d]))
71                }
72                _ => None,
73            };
74            layer_qkv_b.push(bias);
75        }
76        Ok(Self {
77            layer_qkv_w,
78            layer_qkv_b,
79        })
80    }
81
82    pub fn merge_into_tensors(&self, tensors: &mut HashMap<String, (Vec<f32>, Vec<usize>)>) {
83        for (k, data, shape) in &self.layer_qkv_w {
84            tensors.insert(k.clone(), (data.clone(), shape.clone()));
85        }
86        for (k, data, shape) in self.layer_qkv_b.iter().flatten() {
87            tensors.insert(k.clone(), (data.clone(), shape.clone()));
88        }
89    }
90
91    pub fn qkv_w_key(&self, layer: usize) -> &str {
92        &self.layer_qkv_w[layer].0
93    }
94
95    pub fn qkv_b_key(&self, layer: usize) -> Option<&str> {
96        self.layer_qkv_b[layer].as_ref().map(|(k, _, _)| k.as_str())
97    }
98}
99
100/// Fused decoder self-attn QKV per layer (`[d_model, 3*d_model]`).
101#[derive(Debug, Clone)]
102pub struct FusedDecoderWeights {
103    pub layer_qkv_w: Vec<WeightEntry>,
104    pub layer_qkv_b: OptionalBiasList,
105}
106
107impl FusedDecoderWeights {
108    /// Build fused QKV `[d_model, 3*d_model]` per decoder layer.
109    pub fn from_checkpoint(
110        tensors: &HashMap<String, (Vec<f32>, Vec<usize>)>,
111        cfg: &WhisperConfig,
112        pfx: &WhisperWeightPrefix,
113    ) -> Result<Self> {
114        let d = cfg.d_model;
115        let mut layer_qkv_w = Vec::with_capacity(cfg.decoder_layers);
116        let mut layer_qkv_b = Vec::with_capacity(cfg.decoder_layers);
117
118        for i in 0..cfg.decoder_layers {
119            let qw = format_layer(pfx, i, "self_attn.q_proj.weight");
120            let kw = format_layer(pfx, i, "self_attn.k_proj.weight");
121            let vw = format_layer(pfx, i, "self_attn.v_proj.weight");
122            let qb = format_layer(pfx, i, "self_attn.q_proj.bias");
123            let kb = format_layer(pfx, i, "self_attn.k_proj.bias");
124            let vb = format_layer(pfx, i, "self_attn.v_proj.bias");
125
126            let (qw_d, qw_s) = tensors.get(&qw).with_context(|| qw.clone())?;
127            let (kw_d, kw_s) = tensors.get(&kw).with_context(|| kw.clone())?;
128            let (vw_d, vw_s) = tensors.get(&vw).with_context(|| vw.clone())?;
129            ensure_mat(qw_s, d, d)?;
130            ensure_mat(kw_s, d, d)?;
131            ensure_mat(vw_s, d, d)?;
132
133            let w_key = format!("fused.dec.{i}.self_attn.qkv.weight");
134            let w_data = concat_qkv_weights(qw_d, kw_d, vw_d, d);
135            layer_qkv_w.push((w_key, w_data, vec![d, 3 * d]));
136
137            let b_key = format!("fused.dec.{i}.self_attn.qkv.bias");
138            let bias = match (tensors.get(&qb), tensors.get(&kb), tensors.get(&vb)) {
139                (Some((qb_d, _)), Some((kb_d, _)), Some((vb_d, _))) => {
140                    let b_data = concat_qkv_bias(qb_d, kb_d, vb_d, d);
141                    Some((b_key, b_data, vec![3 * d]))
142                }
143                _ => None,
144            };
145            layer_qkv_b.push(bias);
146        }
147
148        Ok(Self {
149            layer_qkv_w,
150            layer_qkv_b,
151        })
152    }
153
154    pub fn merge_into_tensors(&self, tensors: &mut HashMap<String, (Vec<f32>, Vec<usize>)>) {
155        for (k, data, shape) in &self.layer_qkv_w {
156            tensors.insert(k.clone(), (data.clone(), shape.clone()));
157        }
158        for (k, data, shape) in self.layer_qkv_b.iter().flatten() {
159            tensors.insert(k.clone(), (data.clone(), shape.clone()));
160        }
161    }
162
163    pub fn merge_into_params(&self, params: &mut HashMap<String, Vec<f32>>) {
164        for (k, data, _) in &self.layer_qkv_w {
165            params.insert(k.clone(), data.clone());
166        }
167        for (k, data, _) in self.layer_qkv_b.iter().flatten() {
168            params.insert(k.clone(), data.clone());
169        }
170    }
171
172    pub fn qkv_w_key(&self, layer: usize) -> &str {
173        &self.layer_qkv_w[layer].0
174    }
175
176    pub fn qkv_b_key(&self, layer: usize) -> Option<&str> {
177        self.layer_qkv_b[layer].as_ref().map(|(k, _, _)| k.as_str())
178    }
179}
180
181fn format_layer(pfx: &WhisperWeightPrefix, i: usize, suffix: &str) -> String {
182    pfx.dec_layer(i, suffix)
183}
184
185fn format_enc_layer(pfx: &WhisperWeightPrefix, i: usize, suffix: &str) -> String {
186    pfx.enc_layer(i, suffix)
187}
188
189fn ensure_mat(shape: &[usize], rows: usize, cols: usize) -> Result<()> {
190    anyhow::ensure!(
191        shape == [rows, cols],
192        "expected mat [{rows}, {cols}], got {shape:?}"
193    );
194    Ok(())
195}
196
197/// HF `[out, in]` → fused `mm(x, w)` layout `[in, out]` per Q/K/V block (see `linear_fused_qkv`).
198fn concat_qkv_weights(q: &[f32], k: &[f32], v: &[f32], d: usize) -> Vec<f32> {
199    let mut out = vec![0f32; d * 3 * d];
200    for i in 0..d {
201        let base = i * 3 * d;
202        for j in 0..d {
203            out[base + j] = q[j * d + i];
204            out[base + d + j] = k[j * d + i];
205            out[base + 2 * d + j] = v[j * d + i];
206        }
207    }
208    out
209}
210
211fn concat_qkv_bias(q: &[f32], k: &[f32], v: &[f32], d: usize) -> Vec<f32> {
212    let mut out = vec![0f32; 3 * d];
213    out[..d].copy_from_slice(q);
214    out[d..2 * d].copy_from_slice(k);
215    out[2 * d..].copy_from_slice(v);
216    out
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    fn mm_row(x: &[f32], w: &[f32], in_f: usize, out_f: usize) -> Vec<f32> {
224        let mut y = vec![0f32; out_f];
225        for o in 0..out_f {
226            for i in 0..in_f {
227                y[o] += x[i] * w[i * out_f + o];
228            }
229        }
230        y
231    }
232
233    fn mm_row_fused_qkv(x: &[f32], fused: &[f32], d: usize, col_off: usize) -> Vec<f32> {
234        let mut y = vec![0f32; d];
235        for o in 0..d {
236            for i in 0..d {
237                y[o] += x[i] * fused[i * 3 * d + col_off + o];
238            }
239        }
240        y
241    }
242
243    fn hf_linear_transposed(w: &[f32], d: usize) -> Vec<f32> {
244        let mut out = vec![0f32; d * d];
245        for o in 0..d {
246            for i in 0..d {
247                out[i * d + o] = w[o * d + i];
248            }
249        }
250        out
251    }
252
253    #[test]
254    fn concat_qkv_matches_hf_linear_layout() {
255        let d = 16usize;
256        let q: Vec<f32> = (0..d * d).map(|i| (i as f32 * 0.013).sin()).collect();
257        let k: Vec<f32> = (0..d * d).map(|i| (i as f32 * 0.017).cos()).collect();
258        let v: Vec<f32> = (0..d * d).map(|i| (i as f32 * 0.019).sin()).collect();
259        let x: Vec<f32> = (0..d).map(|i| (i as f32 + 1.0) * 0.05).collect();
260
261        let fused = concat_qkv_weights(&q, &k, &v, d);
262        for (proj, col_off) in [(&q, 0usize), (&k, d), (&v, 2 * d)] {
263            let w = hf_linear_transposed(proj, d);
264            let expected = mm_row(&x, &w, d, d);
265            let got = mm_row_fused_qkv(&x, &fused, d, col_off);
266            let mx = expected
267                .iter()
268                .zip(&got)
269                .map(|(a, b)| (a - b).abs())
270                .fold(0f32, f32::max);
271            assert!(mx < 1e-6, "proj max_abs={mx}");
272        }
273    }
274}