Skip to main content

rlx_flow/blocks/
bert_layer.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16use anyhow::Result;
17use rlx_ir::HirGraphExt;
18use rlx_ir::Shape;
19use rlx_ir::hir::HirMut;
20
21use super::BlockStage;
22use crate::context::FlowCtx;
23use crate::value::FlowValue;
24
25/// QKV weight layout for BERT-family encoders.
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum BertQkvStyle {
28    /// `attention.self.{query,key,value}` (BERT / MiniLM).
29    Bert,
30    /// `attention.attn.{q,k,v}` (mpnet-style).
31    Mpnet,
32}
33
34#[derive(Debug, Clone)]
35pub struct BertEncoderLayerSpec {
36    pub layer_prefix: String,
37    pub qkv_style: BertQkvStyle,
38    pub hidden_size: usize,
39    pub num_heads: usize,
40    pub head_dim: usize,
41    pub eps: f32,
42    pub attention_mask_input: String,
43}
44
45impl BertEncoderLayerSpec {
46    pub fn hf(
47        layer_prefix: impl Into<String>,
48        qkv_style: BertQkvStyle,
49        hidden_size: usize,
50        num_heads: usize,
51        eps: f32,
52    ) -> Self {
53        Self {
54            layer_prefix: layer_prefix.into(),
55            qkv_style,
56            hidden_size,
57            num_heads,
58            head_dim: hidden_size / num_heads,
59            eps,
60            attention_mask_input: "attention_mask".into(),
61        }
62    }
63}
64
65#[derive(Debug, Clone)]
66pub struct BertEncoderLayerStage {
67    pub spec: BertEncoderLayerSpec,
68}
69
70impl BertEncoderLayerStage {
71    pub fn new(spec: BertEncoderLayerSpec) -> Self {
72        Self { spec }
73    }
74}
75
76impl BlockStage for BertEncoderLayerStage {
77    fn emit(&self, ctx: &mut FlowCtx<'_>, input: FlowValue) -> Result<Option<FlowValue>> {
78        let spec = &self.spec;
79        let h = spec.hidden_size;
80        let nh = spec.num_heads;
81        let dh = spec.head_dim;
82        let lp = &spec.layer_prefix;
83
84        let (qkv_w, qkv_b) = load_fused_qkv(ctx, lp, h, spec.qkv_style)?;
85
86        let out_w = ctx.load_param(&format!("{lp}.attention.output.dense.weight"), true)?;
87        let out_b = ctx.load_param(&format!("{lp}.attention.output.dense.bias"), false)?;
88        let ln1_g = ctx.load_param(&format!("{lp}.attention.output.LayerNorm.weight"), false)?;
89        let ln1_b = ctx.load_param(&format!("{lp}.attention.output.LayerNorm.bias"), false)?;
90        let ln2_g = ctx.load_param(&format!("{lp}.output.LayerNorm.weight"), false)?;
91        let ln2_b = ctx.load_param(&format!("{lp}.output.LayerNorm.bias"), false)?;
92        let int_w = ctx.load_param(&format!("{lp}.intermediate.dense.weight"), true)?;
93        let int_b = ctx.load_param(&format!("{lp}.intermediate.dense.bias"), false)?;
94        let out2_w = ctx.load_param(&format!("{lp}.output.dense.weight"), true)?;
95        let out2_b = ctx.load_param(&format!("{lp}.output.dense.bias"), false)?;
96
97        let mask_id = ctx
98            .state
99            .inputs
100            .get(&spec.attention_mask_input)
101            .map(|(id, _)| *id)
102            .ok_or_else(|| {
103                anyhow::anyhow!(
104                    "BertEncoderLayer requires input `{}`",
105                    spec.attention_mask_input
106                )
107            })?;
108
109        let mut gb = HirMut::new(ctx.hir());
110        let skip = input.id;
111
112        let qkv_mm = gb.mm(skip, qkv_w);
113        let qkv = gb.add(qkv_mm, qkv_b);
114        let last_ax = gb.shape(qkv).rank() - 1;
115        let q = gb.narrow_(qkv, last_ax, 0, h);
116        let k = gb.narrow_(qkv, last_ax, h, h);
117        let v = gb.narrow_(qkv, last_ax, 2 * h, h);
118        let attn = gb.attention_(q, k, v, mask_id, nh, dh);
119
120        let attn_mm = gb.mm(attn, out_w);
121        let attn_out = gb.add(attn_mm, out_b);
122        let res1 = gb.add(attn_out, skip);
123        let normed1 = gb.ln(res1, ln1_g, ln1_b, spec.eps);
124
125        let int_mm = gb.mm(normed1, int_w);
126        let int_add = gb.add(int_mm, int_b);
127        let ffn_int = gb.gelu(int_add);
128        let out2_mm = gb.mm(ffn_int, out2_w);
129        let ffn_out = gb.add(out2_mm, out2_b);
130        let res2 = gb.add(ffn_out, normed1);
131        let out = gb.ln(res2, ln2_g, ln2_b, spec.eps);
132
133        Ok(Some(ctx.wrap(out, input.shape.clone())))
134    }
135}
136
137fn load_fused_qkv(
138    ctx: &mut FlowCtx<'_>,
139    layer_prefix: &str,
140    h: usize,
141    style: BertQkvStyle,
142) -> Result<(rlx_ir::HirNodeId, rlx_ir::HirNodeId)> {
143    let (wq_key, wk_key, wv_key, bq_key, bk_key, bv_key) = match style {
144        BertQkvStyle::Bert => (
145            format!("{layer_prefix}.attention.self.query.weight"),
146            format!("{layer_prefix}.attention.self.key.weight"),
147            format!("{layer_prefix}.attention.self.value.weight"),
148            format!("{layer_prefix}.attention.self.query.bias"),
149            format!("{layer_prefix}.attention.self.key.bias"),
150            format!("{layer_prefix}.attention.self.value.bias"),
151        ),
152        BertQkvStyle::Mpnet => (
153            format!("{layer_prefix}.attention.attn.q.weight"),
154            format!("{layer_prefix}.attention.attn.k.weight"),
155            format!("{layer_prefix}.attention.attn.v.weight"),
156            format!("{layer_prefix}.attention.attn.q.bias"),
157            format!("{layer_prefix}.attention.attn.k.bias"),
158            format!("{layer_prefix}.attention.attn.v.bias"),
159        ),
160    };
161
162    let wq_data = ctx.weights.take(&wq_key, true)?;
163    let wk_data = ctx.weights.take(&wk_key, true)?;
164    let wv_data = ctx.weights.take(&wv_key, true)?;
165    let (bq_data, _) = ctx.weights.take(&bq_key, false)?;
166    let (bk_data, _) = ctx.weights.take(&bk_key, false)?;
167    let (bv_data, _) = ctx.weights.take(&bv_key, false)?;
168
169    let w_name = format!("{layer_prefix}.attention.qkv.weight");
170    let b_name = format!("{layer_prefix}.attention.qkv.bias");
171
172    let (wq, _) = wq_data;
173    let (wk, _) = wk_data;
174    let (wv, _) = wv_data;
175
176    let mut fused_w = vec![0f32; h * 3 * h];
177    let mut fused_b = vec![0f32; 3 * h];
178    for row in 0..h {
179        fused_w[row * 3 * h..row * 3 * h + h].copy_from_slice(&wq[row * h..(row + 1) * h]);
180        fused_w[row * 3 * h + h..row * 3 * h + 2 * h].copy_from_slice(&wk[row * h..(row + 1) * h]);
181        fused_w[row * 3 * h + 2 * h..row * 3 * h + 3 * h]
182            .copy_from_slice(&wv[row * h..(row + 1) * h]);
183    }
184    fused_b[..h].copy_from_slice(&bq_data);
185    fused_b[h..2 * h].copy_from_slice(&bk_data);
186    fused_b[2 * h..].copy_from_slice(&bv_data);
187
188    let w_id = ctx
189        .hir()
190        .param(&w_name, Shape::new(&[h, 3 * h], rlx_ir::DType::F32));
191    let b_id = ctx
192        .hir()
193        .param(&b_name, Shape::new(&[3 * h], rlx_ir::DType::F32));
194    ctx.params.insert(w_name, fused_w);
195    ctx.params.insert(b_name, fused_b);
196    Ok((w_id, b_id))
197}