Skip to main content

rlx_voxtral_tts/backbone/
layer.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Ministral decoder layer (GQA + RoPE + SwiGLU).
17
18use crate::config::TextConfig;
19use crate::math::{rms_norm, silu};
20use anyhow::{Context, Result, ensure};
21use ndarray::{Array2, ArrayView2};
22use std::collections::HashMap;
23
24pub struct DecoderLayer {
25    wq: Array2<f32>,
26    wk: Array2<f32>,
27    wv: Array2<f32>,
28    wo: Array2<f32>,
29    lora: Option<crate::lora::LayerLora>,
30    lora_scale: f32,
31    attn_norm: Array1Like,
32    ffn_norm: Array1Like,
33    w1: Array2<f32>,
34    w2: Array2<f32>,
35    w3: Array2<f32>,
36    n_heads: usize,
37    n_kv_heads: usize,
38    head_dim: usize,
39}
40
41type Array1Like = ndarray::Array1<f32>;
42
43pub struct LayerKv {
44    pub k: Array2<f32>,
45    pub v: Array2<f32>,
46}
47
48impl DecoderLayer {
49    pub fn load(
50        map: &HashMap<String, (Vec<f32>, Vec<usize>)>,
51        prefix: &str,
52        cfg: &TextConfig,
53    ) -> Result<Self> {
54        let tp = |s: &str| format!("{prefix}.{s}");
55        Ok(Self {
56            wq: take2d(map, &tp("attention.wq.weight"))?,
57            wk: take2d(map, &tp("attention.wk.weight"))?,
58            wv: take2d(map, &tp("attention.wv.weight"))?,
59            wo: take2d(map, &tp("attention.wo.weight"))?,
60            lora: None,
61            lora_scale: 1.0,
62            attn_norm: take1d(map, &tp("attention_norm.weight"))?,
63            ffn_norm: take1d(map, &tp("ffn_norm.weight"))?,
64            w1: take2d(map, &tp("feed_forward.w1.weight"))?,
65            w2: take2d(map, &tp("feed_forward.w2.weight"))?,
66            w3: take2d(map, &tp("feed_forward.w3.weight"))?,
67            n_heads: cfg.num_attention_heads,
68            n_kv_heads: cfg.num_key_value_heads,
69            head_dim: cfg.head_dim,
70        })
71    }
72
73    pub fn set_lora(&mut self, lora: crate::lora::LayerLora, scale: f32) {
74        self.lora = Some(lora);
75        self.lora_scale = scale;
76    }
77
78    pub fn forward(
79        &self,
80        x: ArrayView2<f32>,
81        cos: &[f32],
82        sin: &[f32],
83        start_pos: usize,
84        kv: &mut LayerKv,
85    ) -> Result<Array2<f32>> {
86        let eps = 1e-5f32;
87        let h = rms_norm(x, self.attn_norm.view(), eps);
88        let scale = self.lora_scale;
89        let lora = self.lora.as_ref();
90        let mut q =
91            crate::lora::apply_lora_linear(&h, &self.wq, lora.and_then(|l| l.wq.as_ref()), scale);
92        let mut k =
93            crate::lora::apply_lora_linear(&h, &self.wk, lora.and_then(|l| l.wk.as_ref()), scale);
94        let v =
95            crate::lora::apply_lora_linear(&h, &self.wv, lora.and_then(|l| l.wv.as_ref()), scale);
96        super::rope::apply_rope_qk(
97            &mut q,
98            &mut k,
99            cos,
100            sin,
101            start_pos,
102            self.n_heads,
103            self.n_kv_heads,
104            self.head_dim,
105        );
106
107        if kv.k.nrows() == 0 {
108            kv.k = k;
109            kv.v = v;
110        } else {
111            let (t_new, d) = k.dim();
112            let (t_old, _) = kv.k.dim();
113            let mut k_cat = Array2::<f32>::zeros((t_old + t_new, d));
114            let mut v_cat = Array2::<f32>::zeros((t_old + t_new, d));
115            for i in 0..t_old {
116                for j in 0..d {
117                    k_cat[[i, j]] = kv.k[[i, j]];
118                    v_cat[[i, j]] = kv.v[[i, j]];
119                }
120            }
121            for i in 0..t_new {
122                for j in 0..d {
123                    k_cat[[t_old + i, j]] = k[[i, j]];
124                    v_cat[[t_old + i, j]] = v[[i, j]];
125                }
126            }
127            kv.k = k_cat;
128            kv.v = v_cat;
129        }
130
131        let attn = gqa_attention(
132            q.view(),
133            kv.k.view(),
134            kv.v.view(),
135            self.n_heads,
136            self.n_kv_heads,
137            self.head_dim,
138        );
139        let attn_out = crate::lora::apply_lora_linear(
140            &attn,
141            &self.wo,
142            lora.and_then(|l| l.wo.as_ref()),
143            scale,
144        );
145        let mut out = x.to_owned() + attn_out;
146        let h2 = rms_norm(out.view(), self.ffn_norm.view(), eps);
147        let w1 =
148            crate::lora::apply_lora_linear(&h2, &self.w1, lora.and_then(|l| l.w1.as_ref()), scale);
149        let w3 =
150            crate::lora::apply_lora_linear(&h2, &self.w3, lora.and_then(|l| l.w3.as_ref()), scale);
151        let swiglu = &silu(w1.view()) * w3;
152        let ff = crate::lora::apply_lora_linear(
153            &swiglu,
154            &self.w2,
155            lora.and_then(|l| l.w2.as_ref()),
156            scale,
157        );
158        out = out + ff;
159        Ok(out)
160    }
161}
162
163fn gqa_attention(
164    q: ArrayView2<f32>,
165    k: ArrayView2<f32>,
166    v: ArrayView2<f32>,
167    n_heads: usize,
168    n_kv_heads: usize,
169    head_dim: usize,
170) -> Array2<f32> {
171    let (t_q, _) = q.dim();
172    let t_k = k.dim().0;
173    let repeats = n_heads / n_kv_heads;
174    let mut out = Array2::<f32>::zeros((t_q, n_heads * head_dim));
175    for qi in 0..t_q {
176        for hi in 0..n_heads {
177            let kv_h = hi / repeats;
178            let mut max_w = f32::NEG_INFINITY;
179            let mut weights = vec![0f32; t_k];
180            for ki in 0..t_k {
181                if ki > qi + (t_k - t_q) {
182                    continue;
183                }
184                let mut dot = 0f32;
185                for di in 0..head_dim {
186                    dot += q[[qi, hi * head_dim + di]] * k[[ki, kv_h * head_dim + di]];
187                }
188                dot /= (head_dim as f32).sqrt();
189                weights[ki] = dot;
190                max_w = max_w.max(dot);
191            }
192            let mut sum = 0f32;
193            for w in weights.iter_mut() {
194                *w = (*w - max_w).exp();
195                sum += *w;
196            }
197            for w in weights.iter_mut() {
198                *w /= sum.max(1e-12);
199            }
200            for di in 0..head_dim {
201                let mut acc = 0f32;
202                for ki in 0..t_k {
203                    acc += weights[ki] * v[[ki, kv_h * head_dim + di]];
204                }
205                out[[qi, hi * head_dim + di]] = acc;
206            }
207        }
208    }
209    out
210}
211
212fn take2d(map: &HashMap<String, (Vec<f32>, Vec<usize>)>, key: &str) -> Result<Array2<f32>> {
213    let (data, shape) = map.get(key).with_context(|| format!("missing {key}"))?;
214    ensure!(shape.len() == 2);
215    Array2::from_shape_vec((shape[0], shape[1]), data.clone()).with_context(|| key.to_string())
216}
217
218fn take1d(map: &HashMap<String, (Vec<f32>, Vec<usize>)>, key: &str) -> Result<Array1Like> {
219    let (data, shape) = map.get(key).with_context(|| format!("missing {key}"))?;
220    ensure!(shape.len() == 1);
221    Array1Like::from_shape_vec(shape[0], data.clone()).with_context(|| key.to_string())
222}