rlx_voxtral_tts/backbone/
layer.rs1use crate::config::TextConfig;
19use crate::math::{rms_norm, silu};
20use anyhow::{Context, Result, ensure};
21use ndarray::{Array2, ArrayView2};
22use std::collections::HashMap;
23
24pub struct DecoderLayer {
25 wq: Array2<f32>,
26 wk: Array2<f32>,
27 wv: Array2<f32>,
28 wo: Array2<f32>,
29 lora: Option<crate::lora::LayerLora>,
30 lora_scale: f32,
31 attn_norm: Array1Like,
32 ffn_norm: Array1Like,
33 w1: Array2<f32>,
34 w2: Array2<f32>,
35 w3: Array2<f32>,
36 n_heads: usize,
37 n_kv_heads: usize,
38 head_dim: usize,
39}
40
41type Array1Like = ndarray::Array1<f32>;
42
43pub struct LayerKv {
44 pub k: Array2<f32>,
45 pub v: Array2<f32>,
46}
47
48impl DecoderLayer {
49 pub fn load(
50 map: &HashMap<String, (Vec<f32>, Vec<usize>)>,
51 prefix: &str,
52 cfg: &TextConfig,
53 ) -> Result<Self> {
54 let tp = |s: &str| format!("{prefix}.{s}");
55 Ok(Self {
56 wq: take2d(map, &tp("attention.wq.weight"))?,
57 wk: take2d(map, &tp("attention.wk.weight"))?,
58 wv: take2d(map, &tp("attention.wv.weight"))?,
59 wo: take2d(map, &tp("attention.wo.weight"))?,
60 lora: None,
61 lora_scale: 1.0,
62 attn_norm: take1d(map, &tp("attention_norm.weight"))?,
63 ffn_norm: take1d(map, &tp("ffn_norm.weight"))?,
64 w1: take2d(map, &tp("feed_forward.w1.weight"))?,
65 w2: take2d(map, &tp("feed_forward.w2.weight"))?,
66 w3: take2d(map, &tp("feed_forward.w3.weight"))?,
67 n_heads: cfg.num_attention_heads,
68 n_kv_heads: cfg.num_key_value_heads,
69 head_dim: cfg.head_dim,
70 })
71 }
72
73 pub fn set_lora(&mut self, lora: crate::lora::LayerLora, scale: f32) {
74 self.lora = Some(lora);
75 self.lora_scale = scale;
76 }
77
78 pub fn forward(
79 &self,
80 x: ArrayView2<f32>,
81 cos: &[f32],
82 sin: &[f32],
83 start_pos: usize,
84 kv: &mut LayerKv,
85 ) -> Result<Array2<f32>> {
86 let eps = 1e-5f32;
87 let h = rms_norm(x, self.attn_norm.view(), eps);
88 let scale = self.lora_scale;
89 let lora = self.lora.as_ref();
90 let mut q =
91 crate::lora::apply_lora_linear(&h, &self.wq, lora.and_then(|l| l.wq.as_ref()), scale);
92 let mut k =
93 crate::lora::apply_lora_linear(&h, &self.wk, lora.and_then(|l| l.wk.as_ref()), scale);
94 let v =
95 crate::lora::apply_lora_linear(&h, &self.wv, lora.and_then(|l| l.wv.as_ref()), scale);
96 super::rope::apply_rope_qk(
97 &mut q,
98 &mut k,
99 cos,
100 sin,
101 start_pos,
102 self.n_heads,
103 self.n_kv_heads,
104 self.head_dim,
105 );
106
107 if kv.k.nrows() == 0 {
108 kv.k = k;
109 kv.v = v;
110 } else {
111 let (t_new, d) = k.dim();
112 let (t_old, _) = kv.k.dim();
113 let mut k_cat = Array2::<f32>::zeros((t_old + t_new, d));
114 let mut v_cat = Array2::<f32>::zeros((t_old + t_new, d));
115 for i in 0..t_old {
116 for j in 0..d {
117 k_cat[[i, j]] = kv.k[[i, j]];
118 v_cat[[i, j]] = kv.v[[i, j]];
119 }
120 }
121 for i in 0..t_new {
122 for j in 0..d {
123 k_cat[[t_old + i, j]] = k[[i, j]];
124 v_cat[[t_old + i, j]] = v[[i, j]];
125 }
126 }
127 kv.k = k_cat;
128 kv.v = v_cat;
129 }
130
131 let attn = gqa_attention(
132 q.view(),
133 kv.k.view(),
134 kv.v.view(),
135 self.n_heads,
136 self.n_kv_heads,
137 self.head_dim,
138 );
139 let attn_out = crate::lora::apply_lora_linear(
140 &attn,
141 &self.wo,
142 lora.and_then(|l| l.wo.as_ref()),
143 scale,
144 );
145 let mut out = x.to_owned() + attn_out;
146 let h2 = rms_norm(out.view(), self.ffn_norm.view(), eps);
147 let w1 =
148 crate::lora::apply_lora_linear(&h2, &self.w1, lora.and_then(|l| l.w1.as_ref()), scale);
149 let w3 =
150 crate::lora::apply_lora_linear(&h2, &self.w3, lora.and_then(|l| l.w3.as_ref()), scale);
151 let swiglu = &silu(w1.view()) * w3;
152 let ff = crate::lora::apply_lora_linear(
153 &swiglu,
154 &self.w2,
155 lora.and_then(|l| l.w2.as_ref()),
156 scale,
157 );
158 out = out + ff;
159 Ok(out)
160 }
161}
162
163fn gqa_attention(
164 q: ArrayView2<f32>,
165 k: ArrayView2<f32>,
166 v: ArrayView2<f32>,
167 n_heads: usize,
168 n_kv_heads: usize,
169 head_dim: usize,
170) -> Array2<f32> {
171 let (t_q, _) = q.dim();
172 let t_k = k.dim().0;
173 let repeats = n_heads / n_kv_heads;
174 let mut out = Array2::<f32>::zeros((t_q, n_heads * head_dim));
175 for qi in 0..t_q {
176 for hi in 0..n_heads {
177 let kv_h = hi / repeats;
178 let mut max_w = f32::NEG_INFINITY;
179 let mut weights = vec![0f32; t_k];
180 for ki in 0..t_k {
181 if ki > qi + (t_k - t_q) {
182 continue;
183 }
184 let mut dot = 0f32;
185 for di in 0..head_dim {
186 dot += q[[qi, hi * head_dim + di]] * k[[ki, kv_h * head_dim + di]];
187 }
188 dot /= (head_dim as f32).sqrt();
189 weights[ki] = dot;
190 max_w = max_w.max(dot);
191 }
192 let mut sum = 0f32;
193 for w in weights.iter_mut() {
194 *w = (*w - max_w).exp();
195 sum += *w;
196 }
197 for w in weights.iter_mut() {
198 *w /= sum.max(1e-12);
199 }
200 for di in 0..head_dim {
201 let mut acc = 0f32;
202 for ki in 0..t_k {
203 acc += weights[ki] * v[[ki, kv_h * head_dim + di]];
204 }
205 out[[qi, hi * head_dim + di]] = acc;
206 }
207 }
208 }
209 out
210}
211
212fn take2d(map: &HashMap<String, (Vec<f32>, Vec<usize>)>, key: &str) -> Result<Array2<f32>> {
213 let (data, shape) = map.get(key).with_context(|| format!("missing {key}"))?;
214 ensure!(shape.len() == 2);
215 Array2::from_shape_vec((shape[0], shape[1]), data.clone()).with_context(|| key.to_string())
216}
217
218fn take1d(map: &HashMap<String, (Vec<f32>, Vec<usize>)>, key: &str) -> Result<Array1Like> {
219 let (data, shape) = map.get(key).with_context(|| format!("missing {key}"))?;
220 ensure!(shape.len() == 1);
221 Array1Like::from_shape_vec(shape[0], data.clone()).with_context(|| key.to_string())
222}