Skip to main content

rlx_flow/blocks/
bert_layer.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4use anyhow::Result;
5use rlx_ir::HirGraphExt;
6use rlx_ir::Shape;
7use rlx_ir::hir::HirMut;
8
9use super::BlockStage;
10use crate::context::FlowCtx;
11use crate::value::FlowValue;
12
13/// QKV weight layout for BERT-family encoders.
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum BertQkvStyle {
16    /// `attention.self.{query,key,value}` (BERT / MiniLM).
17    Bert,
18    /// `attention.attn.{q,k,v}` (mpnet-style).
19    Mpnet,
20}
21
22#[derive(Debug, Clone)]
23pub struct BertEncoderLayerSpec {
24    pub layer_prefix: String,
25    pub qkv_style: BertQkvStyle,
26    pub hidden_size: usize,
27    pub num_heads: usize,
28    pub head_dim: usize,
29    pub eps: f32,
30    pub attention_mask_input: String,
31}
32
33impl BertEncoderLayerSpec {
34    pub fn hf(
35        layer_prefix: impl Into<String>,
36        qkv_style: BertQkvStyle,
37        hidden_size: usize,
38        num_heads: usize,
39        eps: f32,
40    ) -> Self {
41        Self {
42            layer_prefix: layer_prefix.into(),
43            qkv_style,
44            hidden_size,
45            num_heads,
46            head_dim: hidden_size / num_heads,
47            eps,
48            attention_mask_input: "attention_mask".into(),
49        }
50    }
51}
52
53#[derive(Debug, Clone)]
54pub struct BertEncoderLayerStage {
55    pub spec: BertEncoderLayerSpec,
56}
57
58impl BertEncoderLayerStage {
59    pub fn new(spec: BertEncoderLayerSpec) -> Self {
60        Self { spec }
61    }
62}
63
64impl BlockStage for BertEncoderLayerStage {
65    fn emit(&self, ctx: &mut FlowCtx<'_>, input: FlowValue) -> Result<Option<FlowValue>> {
66        let spec = &self.spec;
67        let h = spec.hidden_size;
68        let nh = spec.num_heads;
69        let dh = spec.head_dim;
70        let lp = &spec.layer_prefix;
71
72        let (qkv_w, qkv_b) = load_fused_qkv(ctx, lp, h, spec.qkv_style)?;
73
74        let out_w = ctx.load_param(&format!("{lp}.attention.output.dense.weight"), true)?;
75        let out_b = ctx.load_param(&format!("{lp}.attention.output.dense.bias"), false)?;
76        let ln1_g = ctx.load_param(&format!("{lp}.attention.output.LayerNorm.weight"), false)?;
77        let ln1_b = ctx.load_param(&format!("{lp}.attention.output.LayerNorm.bias"), false)?;
78        let ln2_g = ctx.load_param(&format!("{lp}.output.LayerNorm.weight"), false)?;
79        let ln2_b = ctx.load_param(&format!("{lp}.output.LayerNorm.bias"), false)?;
80        let int_w = ctx.load_param(&format!("{lp}.intermediate.dense.weight"), true)?;
81        let int_b = ctx.load_param(&format!("{lp}.intermediate.dense.bias"), false)?;
82        let out2_w = ctx.load_param(&format!("{lp}.output.dense.weight"), true)?;
83        let out2_b = ctx.load_param(&format!("{lp}.output.dense.bias"), false)?;
84
85        let mask_id = ctx
86            .state
87            .inputs
88            .get(&spec.attention_mask_input)
89            .map(|(id, _)| *id)
90            .ok_or_else(|| {
91                anyhow::anyhow!(
92                    "BertEncoderLayer requires input `{}`",
93                    spec.attention_mask_input
94                )
95            })?;
96
97        let mut gb = HirMut::new(ctx.hir());
98        let skip = input.id;
99
100        let qkv_mm = gb.mm(skip, qkv_w);
101        let qkv = gb.add(qkv_mm, qkv_b);
102        let last_ax = gb.shape(qkv).rank() - 1;
103        let q = gb.narrow_(qkv, last_ax, 0, h);
104        let k = gb.narrow_(qkv, last_ax, h, h);
105        let v = gb.narrow_(qkv, last_ax, 2 * h, h);
106        let attn = gb.attention_(q, k, v, mask_id, nh, dh);
107
108        let attn_mm = gb.mm(attn, out_w);
109        let attn_out = gb.add(attn_mm, out_b);
110        let res1 = gb.add(attn_out, skip);
111        let normed1 = gb.ln(res1, ln1_g, ln1_b, spec.eps);
112
113        let int_mm = gb.mm(normed1, int_w);
114        let int_add = gb.add(int_mm, int_b);
115        let ffn_int = gb.gelu(int_add);
116        let out2_mm = gb.mm(ffn_int, out2_w);
117        let ffn_out = gb.add(out2_mm, out2_b);
118        let res2 = gb.add(ffn_out, normed1);
119        let out = gb.ln(res2, ln2_g, ln2_b, spec.eps);
120
121        Ok(Some(ctx.wrap(out, input.shape.clone())))
122    }
123}
124
125fn load_fused_qkv(
126    ctx: &mut FlowCtx<'_>,
127    layer_prefix: &str,
128    h: usize,
129    style: BertQkvStyle,
130) -> Result<(rlx_ir::HirNodeId, rlx_ir::HirNodeId)> {
131    let (wq_key, wk_key, wv_key, bq_key, bk_key, bv_key) = match style {
132        BertQkvStyle::Bert => (
133            format!("{layer_prefix}.attention.self.query.weight"),
134            format!("{layer_prefix}.attention.self.key.weight"),
135            format!("{layer_prefix}.attention.self.value.weight"),
136            format!("{layer_prefix}.attention.self.query.bias"),
137            format!("{layer_prefix}.attention.self.key.bias"),
138            format!("{layer_prefix}.attention.self.value.bias"),
139        ),
140        BertQkvStyle::Mpnet => (
141            format!("{layer_prefix}.attention.attn.q.weight"),
142            format!("{layer_prefix}.attention.attn.k.weight"),
143            format!("{layer_prefix}.attention.attn.v.weight"),
144            format!("{layer_prefix}.attention.attn.q.bias"),
145            format!("{layer_prefix}.attention.attn.k.bias"),
146            format!("{layer_prefix}.attention.attn.v.bias"),
147        ),
148    };
149
150    let wq_data = ctx.weights.take(&wq_key, true)?;
151    let wk_data = ctx.weights.take(&wk_key, true)?;
152    let wv_data = ctx.weights.take(&wv_key, true)?;
153    let (bq_data, _) = ctx.weights.take(&bq_key, false)?;
154    let (bk_data, _) = ctx.weights.take(&bk_key, false)?;
155    let (bv_data, _) = ctx.weights.take(&bv_key, false)?;
156
157    let w_name = format!("{layer_prefix}.attention.qkv.weight");
158    let b_name = format!("{layer_prefix}.attention.qkv.bias");
159
160    let (wq, _) = wq_data;
161    let (wk, _) = wk_data;
162    let (wv, _) = wv_data;
163
164    let mut fused_w = vec![0f32; h * 3 * h];
165    let mut fused_b = vec![0f32; 3 * h];
166    for row in 0..h {
167        fused_w[row * 3 * h..row * 3 * h + h].copy_from_slice(&wq[row * h..(row + 1) * h]);
168        fused_w[row * 3 * h + h..row * 3 * h + 2 * h].copy_from_slice(&wk[row * h..(row + 1) * h]);
169        fused_w[row * 3 * h + 2 * h..row * 3 * h + 3 * h]
170            .copy_from_slice(&wv[row * h..(row + 1) * h]);
171    }
172    fused_b[..h].copy_from_slice(&bq_data);
173    fused_b[h..2 * h].copy_from_slice(&bk_data);
174    fused_b[2 * h..].copy_from_slice(&bv_data);
175
176    let w_id = ctx
177        .hir()
178        .param(&w_name, Shape::new(&[h, 3 * h], rlx_ir::DType::F32));
179    let b_id = ctx
180        .hir()
181        .param(&b_name, Shape::new(&[3 * h], rlx_ir::DType::F32));
182    ctx.params.insert(w_name, fused_w);
183    ctx.params.insert(b_name, fused_b);
184    Ok((w_id, b_id))
185}