Skip to main content

rlx_bert/
bert.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! BERT graph builder — constructs RLX IR from config + weights.
17
18use anyhow::Result;
19use rlx_core::config::BertConfig;
20use rlx_core::weight_map::WeightMap;
21use rlx_ir::*;
22use std::collections::HashMap;
23
24/// Build a BERT encoder IR graph from config and weights.
25///
26/// Returns the graph and a map of param_name → weight data.
27/// The graph expects inputs: `input_ids [B,S]`, `attention_mask [B,S]`, `token_type_ids [B,S]`.
28/// Output: `hidden_states [B, S, H]`.
29/// Build a BERT encoder IR graph.
30///
31/// `batch` and `seq` are the concrete dimensions for this compilation.
32/// The graph will be compiled for exactly these dimensions.
33/// Call again with different dims to recompile for a different size.
34pub fn build_bert_graph(
35    cfg: &BertConfig,
36    weights: &mut WeightMap,
37) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
38    build_bert_graph_sized(cfg, weights, 1, 1)
39}
40
41pub fn build_bert_graph_sized(
42    cfg: &BertConfig,
43    weights: &mut WeightMap,
44    batch: usize,
45    seq: usize,
46) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
47    rlx_core::flow_util::graph_from_built(crate::flow::build_bert_built(cfg, weights, batch, seq)?)
48}
49
50/// Load a parameter: register in graph + store weight data.
51#[allow(dead_code)]
52fn load_param(
53    g: &mut Graph,
54    params: &mut HashMap<String, Vec<f32>>,
55    weights: &mut WeightMap,
56    key: &str,
57    _expected_shape: &[usize],
58    transpose: bool,
59) -> Result<NodeId> {
60    let (data, shape) = if transpose {
61        weights.take_transposed(key)?
62    } else {
63        weights.take(key)?
64    };
65    let name = key.to_string();
66    let ir_shape = Shape::new(&shape, DType::F32);
67    let id = g.param(&name, ir_shape);
68    params.insert(name, data);
69    Ok(id)
70}
71
72/// Fuse Q/K/V weights into single [H, 3H] matrix (BERT-style keys).
73#[allow(dead_code)]
74fn load_fused_qkv(
75    g: &mut Graph,
76    params: &mut HashMap<String, Vec<f32>>,
77    weights: &mut WeightMap,
78    layer_prefix: &str,
79    h: usize,
80    _nh: usize,
81    _dh: usize,
82) -> Result<(NodeId, NodeId)> {
83    let (wq, _) =
84        weights.take_transposed(&format!("{layer_prefix}.attention.self.query.weight"))?;
85    let (wk, _) = weights.take_transposed(&format!("{layer_prefix}.attention.self.key.weight"))?;
86    let (wv, _) =
87        weights.take_transposed(&format!("{layer_prefix}.attention.self.value.weight"))?;
88
89    let bq = weights
90        .take(&format!("{layer_prefix}.attention.self.query.bias"))?
91        .0;
92    let bk = weights
93        .take(&format!("{layer_prefix}.attention.self.key.bias"))?
94        .0;
95    let bv = weights
96        .take(&format!("{layer_prefix}.attention.self.value.bias"))?
97        .0;
98
99    // Concatenate: [H, H] + [H, H] + [H, H] → [H, 3H]
100    let mut fused_w = vec![0f32; h * 3 * h];
101    let mut fused_b = vec![0f32; 3 * h];
102    for row in 0..h {
103        fused_w[row * 3 * h..row * 3 * h + h].copy_from_slice(&wq[row * h..(row + 1) * h]);
104        fused_w[row * 3 * h + h..row * 3 * h + 2 * h].copy_from_slice(&wk[row * h..(row + 1) * h]);
105        fused_w[row * 3 * h + 2 * h..row * 3 * h + 3 * h]
106            .copy_from_slice(&wv[row * h..(row + 1) * h]);
107    }
108    fused_b[..h].copy_from_slice(&bq);
109    fused_b[h..2 * h].copy_from_slice(&bk);
110    fused_b[2 * h..].copy_from_slice(&bv);
111
112    let w_name = format!("{layer_prefix}.attention.qkv.weight");
113    let b_name = format!("{layer_prefix}.attention.qkv.bias");
114    let w_id = g.param(&w_name, Shape::new(&[h, 3 * h], DType::F32));
115    let b_id = g.param(&b_name, Shape::new(&[3 * h], DType::F32));
116    params.insert(w_name, fused_w);
117    params.insert(b_name, fused_b);
118
119    Ok((w_id, b_id))
120}
121
122/// mpnet-style QKV fusion (different key names).
123#[allow(dead_code)]
124fn load_fused_qkv_mpnet(
125    g: &mut Graph,
126    params: &mut HashMap<String, Vec<f32>>,
127    weights: &mut WeightMap,
128    layer_prefix: &str,
129    h: usize,
130    nh: usize,
131    dh: usize,
132) -> Result<(NodeId, NodeId)> {
133    // Try mpnet keys
134    let q_key = format!("{layer_prefix}.attention.attn.q.weight");
135    if weights.has(&q_key) {
136        let (wq, _) = weights.take_transposed(&q_key)?;
137        let (wk, _) =
138            weights.take_transposed(&format!("{layer_prefix}.attention.attn.k.weight"))?;
139        let (wv, _) =
140            weights.take_transposed(&format!("{layer_prefix}.attention.attn.v.weight"))?;
141        let bq = weights
142            .take(&format!("{layer_prefix}.attention.attn.q.bias"))?
143            .0;
144        let bk = weights
145            .take(&format!("{layer_prefix}.attention.attn.k.bias"))?
146            .0;
147        let bv = weights
148            .take(&format!("{layer_prefix}.attention.attn.v.bias"))?
149            .0;
150
151        let mut fused_w = vec![0f32; h * 3 * h];
152        let mut fused_b = vec![0f32; 3 * h];
153        for row in 0..h {
154            fused_w[row * 3 * h..row * 3 * h + h].copy_from_slice(&wq[row * h..(row + 1) * h]);
155            fused_w[row * 3 * h + h..row * 3 * h + 2 * h]
156                .copy_from_slice(&wk[row * h..(row + 1) * h]);
157            fused_w[row * 3 * h + 2 * h..row * 3 * h + 3 * h]
158                .copy_from_slice(&wv[row * h..(row + 1) * h]);
159        }
160        fused_b[..h].copy_from_slice(&bq);
161        fused_b[h..2 * h].copy_from_slice(&bk);
162        fused_b[2 * h..].copy_from_slice(&bv);
163
164        let w_name = format!("{layer_prefix}.attention.qkv.weight");
165        let b_name = format!("{layer_prefix}.attention.qkv.bias");
166        let w_id = g.param(&w_name, Shape::new(&[h, 3 * h], DType::F32));
167        let b_id = g.param(&b_name, Shape::new(&[3 * h], DType::F32));
168        params.insert(w_name, fused_w);
169        params.insert(b_name, fused_b);
170        return Ok((w_id, b_id));
171    }
172
173    // Fallback: already-fused QKV
174    let fused_key = format!("{layer_prefix}.attention.self.qkv.weight");
175    if weights.has(&fused_key) {
176        let (data, _) = weights.take_transposed(&fused_key)?;
177        let bias = weights
178            .take(&format!("{layer_prefix}.attention.self.qkv.bias"))?
179            .0;
180        let w_name = format!("{layer_prefix}.attention.qkv.weight");
181        let b_name = format!("{layer_prefix}.attention.qkv.bias");
182        let w_id = g.param(&w_name, Shape::new(&[h, 3 * h], DType::F32));
183        let b_id = g.param(&b_name, Shape::new(&[3 * h], DType::F32));
184        params.insert(w_name, data);
185        params.insert(b_name, bias);
186        return Ok((w_id, b_id));
187    }
188
189    // Fallback to BERT style
190    load_fused_qkv(g, params, weights, layer_prefix, h, nh, dh)
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196
197    #[test]
198    fn build_tiny_bert_graph() {
199        // Create a minimal config
200        let cfg = BertConfig {
201            vocab_size: 100,
202            hidden_size: 64,
203            num_hidden_layers: 1,
204            num_attention_heads: 2,
205            intermediate_size: 256,
206            max_position_embeddings: 32,
207            type_vocab_size: 2,
208            layer_norm_eps: 1e-12,
209            hidden_act: "gelu".into(),
210        };
211
212        // Create fake weights
213        let h = cfg.hidden_size;
214        let int = cfg.intermediate_size;
215        let mut tensors = HashMap::new();
216        let add = |m: &mut HashMap<String, (Vec<f32>, Vec<usize>)>, k: &str, shape: Vec<usize>| {
217            let size: usize = shape.iter().product();
218            m.insert(k.to_string(), (vec![0.01f32; size], shape));
219        };
220
221        // Embeddings
222        add(
223            &mut tensors,
224            "embeddings.word_embeddings.weight",
225            vec![100, h],
226        );
227        add(
228            &mut tensors,
229            "embeddings.position_embeddings.weight",
230            vec![32, h],
231        );
232        add(
233            &mut tensors,
234            "embeddings.token_type_embeddings.weight",
235            vec![2, h],
236        );
237        add(&mut tensors, "embeddings.LayerNorm.weight", vec![h]);
238        add(&mut tensors, "embeddings.LayerNorm.bias", vec![h]);
239
240        // Layer 0 — attention
241        add(
242            &mut tensors,
243            "encoder.layer.0.attention.self.query.weight",
244            vec![h, h],
245        );
246        add(
247            &mut tensors,
248            "encoder.layer.0.attention.self.query.bias",
249            vec![h],
250        );
251        add(
252            &mut tensors,
253            "encoder.layer.0.attention.self.key.weight",
254            vec![h, h],
255        );
256        add(
257            &mut tensors,
258            "encoder.layer.0.attention.self.key.bias",
259            vec![h],
260        );
261        add(
262            &mut tensors,
263            "encoder.layer.0.attention.self.value.weight",
264            vec![h, h],
265        );
266        add(
267            &mut tensors,
268            "encoder.layer.0.attention.self.value.bias",
269            vec![h],
270        );
271        add(
272            &mut tensors,
273            "encoder.layer.0.attention.output.dense.weight",
274            vec![h, h],
275        );
276        add(
277            &mut tensors,
278            "encoder.layer.0.attention.output.dense.bias",
279            vec![h],
280        );
281        add(
282            &mut tensors,
283            "encoder.layer.0.attention.output.LayerNorm.weight",
284            vec![h],
285        );
286        add(
287            &mut tensors,
288            "encoder.layer.0.attention.output.LayerNorm.bias",
289            vec![h],
290        );
291
292        // Layer 0 — FFN
293        add(
294            &mut tensors,
295            "encoder.layer.0.intermediate.dense.weight",
296            vec![int, h],
297        );
298        add(
299            &mut tensors,
300            "encoder.layer.0.intermediate.dense.bias",
301            vec![int],
302        );
303        add(
304            &mut tensors,
305            "encoder.layer.0.output.dense.weight",
306            vec![h, int],
307        );
308        add(&mut tensors, "encoder.layer.0.output.dense.bias", vec![h]);
309        add(
310            &mut tensors,
311            "encoder.layer.0.output.LayerNorm.weight",
312            vec![h],
313        );
314        add(
315            &mut tensors,
316            "encoder.layer.0.output.LayerNorm.bias",
317            vec![h],
318        );
319
320        let mut wm = WeightMap::from_tensors(tensors);
321        let (graph, params) = build_bert_graph(&cfg, &mut wm).unwrap();
322
323        println!("{graph}");
324        println!("Nodes: {}, Params: {}", graph.len(), params.len());
325
326        // Verify graph is valid
327        let errors = rlx_ir::verify::verify(&graph);
328        assert!(errors.is_empty(), "verification errors: {errors:?}");
329
330        // Should have params for all weights
331        assert!(
332            params.len() >= 15,
333            "expected 15+ params, got {}",
334            params.len()
335        );
336
337        // Output should exist
338        assert!(!graph.outputs.is_empty());
339    }
340}