1use crate::config::WhisperConfig;
20use crate::weights::WhisperWeightPrefix;
21use anyhow::{Context, Result};
22use std::collections::HashMap;
23
24type WeightEntry = (String, Vec<f32>, Vec<usize>);
25type OptionalBias = Option<WeightEntry>;
26type OptionalBiasList = Vec<OptionalBias>;
27
28#[derive(Debug, Clone)]
30pub struct FusedEncoderWeights {
31 pub layer_qkv_w: Vec<WeightEntry>,
32 pub layer_qkv_b: OptionalBiasList,
33}
34
35impl FusedEncoderWeights {
36 pub fn from_checkpoint(
37 tensors: &HashMap<String, (Vec<f32>, Vec<usize>)>,
38 cfg: &WhisperConfig,
39 pfx: &WhisperWeightPrefix,
40 ) -> Result<Self> {
41 let d = cfg.d_model;
42 let mut layer_qkv_w = Vec::with_capacity(cfg.encoder_layers);
43 let mut layer_qkv_b = Vec::with_capacity(cfg.encoder_layers);
44
45 for i in 0..cfg.encoder_layers {
46 let qw = format_enc_layer(pfx, i, "self_attn.q_proj.weight");
47 let kw = format_enc_layer(pfx, i, "self_attn.k_proj.weight");
48 let vw = format_enc_layer(pfx, i, "self_attn.v_proj.weight");
49 let qb = format_enc_layer(pfx, i, "self_attn.q_proj.bias");
50 let kb = format_enc_layer(pfx, i, "self_attn.k_proj.bias");
51 let vb = format_enc_layer(pfx, i, "self_attn.v_proj.bias");
52
53 let (qw_d, qw_s) = tensors.get(&qw).with_context(|| qw.clone())?;
54 let (kw_d, kw_s) = tensors.get(&kw).with_context(|| kw.clone())?;
55 let (vw_d, vw_s) = tensors.get(&vw).with_context(|| vw.clone())?;
56 ensure_mat(qw_s, d, d)?;
57 ensure_mat(kw_s, d, d)?;
58 ensure_mat(vw_s, d, d)?;
59
60 let w_key = format!("fused.enc.{i}.self_attn.qkv.weight");
61 layer_qkv_w.push((
62 w_key,
63 concat_qkv_weights(qw_d, kw_d, vw_d, d),
64 vec![d, 3 * d],
65 ));
66
67 let b_key = format!("fused.enc.{i}.self_attn.qkv.bias");
68 let bias = match (tensors.get(&qb), tensors.get(&kb), tensors.get(&vb)) {
69 (Some((qb_d, _)), Some((kb_d, _)), Some((vb_d, _))) => {
70 Some((b_key, concat_qkv_bias(qb_d, kb_d, vb_d, d), vec![3 * d]))
71 }
72 _ => None,
73 };
74 layer_qkv_b.push(bias);
75 }
76 Ok(Self {
77 layer_qkv_w,
78 layer_qkv_b,
79 })
80 }
81
82 pub fn merge_into_tensors(&self, tensors: &mut HashMap<String, (Vec<f32>, Vec<usize>)>) {
83 for (k, data, shape) in &self.layer_qkv_w {
84 tensors.insert(k.clone(), (data.clone(), shape.clone()));
85 }
86 for (k, data, shape) in self.layer_qkv_b.iter().flatten() {
87 tensors.insert(k.clone(), (data.clone(), shape.clone()));
88 }
89 }
90
91 pub fn qkv_w_key(&self, layer: usize) -> &str {
92 &self.layer_qkv_w[layer].0
93 }
94
95 pub fn qkv_b_key(&self, layer: usize) -> Option<&str> {
96 self.layer_qkv_b[layer].as_ref().map(|(k, _, _)| k.as_str())
97 }
98}
99
100#[derive(Debug, Clone)]
102pub struct FusedDecoderWeights {
103 pub layer_qkv_w: Vec<WeightEntry>,
104 pub layer_qkv_b: OptionalBiasList,
105}
106
107impl FusedDecoderWeights {
108 pub fn from_checkpoint(
110 tensors: &HashMap<String, (Vec<f32>, Vec<usize>)>,
111 cfg: &WhisperConfig,
112 pfx: &WhisperWeightPrefix,
113 ) -> Result<Self> {
114 let d = cfg.d_model;
115 let mut layer_qkv_w = Vec::with_capacity(cfg.decoder_layers);
116 let mut layer_qkv_b = Vec::with_capacity(cfg.decoder_layers);
117
118 for i in 0..cfg.decoder_layers {
119 let qw = format_layer(pfx, i, "self_attn.q_proj.weight");
120 let kw = format_layer(pfx, i, "self_attn.k_proj.weight");
121 let vw = format_layer(pfx, i, "self_attn.v_proj.weight");
122 let qb = format_layer(pfx, i, "self_attn.q_proj.bias");
123 let kb = format_layer(pfx, i, "self_attn.k_proj.bias");
124 let vb = format_layer(pfx, i, "self_attn.v_proj.bias");
125
126 let (qw_d, qw_s) = tensors.get(&qw).with_context(|| qw.clone())?;
127 let (kw_d, kw_s) = tensors.get(&kw).with_context(|| kw.clone())?;
128 let (vw_d, vw_s) = tensors.get(&vw).with_context(|| vw.clone())?;
129 ensure_mat(qw_s, d, d)?;
130 ensure_mat(kw_s, d, d)?;
131 ensure_mat(vw_s, d, d)?;
132
133 let w_key = format!("fused.dec.{i}.self_attn.qkv.weight");
134 let w_data = concat_qkv_weights(qw_d, kw_d, vw_d, d);
135 layer_qkv_w.push((w_key, w_data, vec![d, 3 * d]));
136
137 let b_key = format!("fused.dec.{i}.self_attn.qkv.bias");
138 let bias = match (tensors.get(&qb), tensors.get(&kb), tensors.get(&vb)) {
139 (Some((qb_d, _)), Some((kb_d, _)), Some((vb_d, _))) => {
140 let b_data = concat_qkv_bias(qb_d, kb_d, vb_d, d);
141 Some((b_key, b_data, vec![3 * d]))
142 }
143 _ => None,
144 };
145 layer_qkv_b.push(bias);
146 }
147
148 Ok(Self {
149 layer_qkv_w,
150 layer_qkv_b,
151 })
152 }
153
154 pub fn merge_into_tensors(&self, tensors: &mut HashMap<String, (Vec<f32>, Vec<usize>)>) {
155 for (k, data, shape) in &self.layer_qkv_w {
156 tensors.insert(k.clone(), (data.clone(), shape.clone()));
157 }
158 for (k, data, shape) in self.layer_qkv_b.iter().flatten() {
159 tensors.insert(k.clone(), (data.clone(), shape.clone()));
160 }
161 }
162
163 pub fn merge_into_params(&self, params: &mut HashMap<String, Vec<f32>>) {
164 for (k, data, _) in &self.layer_qkv_w {
165 params.insert(k.clone(), data.clone());
166 }
167 for (k, data, _) in self.layer_qkv_b.iter().flatten() {
168 params.insert(k.clone(), data.clone());
169 }
170 }
171
172 pub fn qkv_w_key(&self, layer: usize) -> &str {
173 &self.layer_qkv_w[layer].0
174 }
175
176 pub fn qkv_b_key(&self, layer: usize) -> Option<&str> {
177 self.layer_qkv_b[layer].as_ref().map(|(k, _, _)| k.as_str())
178 }
179}
180
181fn format_layer(pfx: &WhisperWeightPrefix, i: usize, suffix: &str) -> String {
182 pfx.dec_layer(i, suffix)
183}
184
185fn format_enc_layer(pfx: &WhisperWeightPrefix, i: usize, suffix: &str) -> String {
186 pfx.enc_layer(i, suffix)
187}
188
189fn ensure_mat(shape: &[usize], rows: usize, cols: usize) -> Result<()> {
190 anyhow::ensure!(
191 shape == [rows, cols],
192 "expected mat [{rows}, {cols}], got {shape:?}"
193 );
194 Ok(())
195}
196
197fn concat_qkv_weights(q: &[f32], k: &[f32], v: &[f32], d: usize) -> Vec<f32> {
199 let mut out = vec![0f32; d * 3 * d];
200 for i in 0..d {
201 let base = i * 3 * d;
202 for j in 0..d {
203 out[base + j] = q[j * d + i];
204 out[base + d + j] = k[j * d + i];
205 out[base + 2 * d + j] = v[j * d + i];
206 }
207 }
208 out
209}
210
211fn concat_qkv_bias(q: &[f32], k: &[f32], v: &[f32], d: usize) -> Vec<f32> {
212 let mut out = vec![0f32; 3 * d];
213 out[..d].copy_from_slice(q);
214 out[d..2 * d].copy_from_slice(k);
215 out[2 * d..].copy_from_slice(v);
216 out
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222
223 fn mm_row(x: &[f32], w: &[f32], in_f: usize, out_f: usize) -> Vec<f32> {
224 let mut y = vec![0f32; out_f];
225 for o in 0..out_f {
226 for i in 0..in_f {
227 y[o] += x[i] * w[i * out_f + o];
228 }
229 }
230 y
231 }
232
233 fn mm_row_fused_qkv(x: &[f32], fused: &[f32], d: usize, col_off: usize) -> Vec<f32> {
234 let mut y = vec![0f32; d];
235 for o in 0..d {
236 for i in 0..d {
237 y[o] += x[i] * fused[i * 3 * d + col_off + o];
238 }
239 }
240 y
241 }
242
243 fn hf_linear_transposed(w: &[f32], d: usize) -> Vec<f32> {
244 let mut out = vec![0f32; d * d];
245 for o in 0..d {
246 for i in 0..d {
247 out[i * d + o] = w[o * d + i];
248 }
249 }
250 out
251 }
252
253 #[test]
254 fn concat_qkv_matches_hf_linear_layout() {
255 let d = 16usize;
256 let q: Vec<f32> = (0..d * d).map(|i| (i as f32 * 0.013).sin()).collect();
257 let k: Vec<f32> = (0..d * d).map(|i| (i as f32 * 0.017).cos()).collect();
258 let v: Vec<f32> = (0..d * d).map(|i| (i as f32 * 0.019).sin()).collect();
259 let x: Vec<f32> = (0..d).map(|i| (i as f32 + 1.0) * 0.05).collect();
260
261 let fused = concat_qkv_weights(&q, &k, &v, d);
262 for (proj, col_off) in [(&q, 0usize), (&k, d), (&v, 2 * d)] {
263 let w = hf_linear_transposed(proj, d);
264 let expected = mm_row(&x, &w, d, d);
265 let got = mm_row_fused_qkv(&x, &fused, d, col_off);
266 let mx = expected
267 .iter()
268 .zip(&got)
269 .map(|(a, b)| (a - b).abs())
270 .fold(0f32, f32::max);
271 assert!(mx < 1e-6, "proj max_abs={mx}");
272 }
273 }
274}