1use anyhow::Result;
5use rlx_ir::HirGraphExt;
6use rlx_ir::Shape;
7use rlx_ir::hir::HirMut;
8
9use super::BlockStage;
10use crate::context::FlowCtx;
11use crate::value::FlowValue;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum BertQkvStyle {
16 Bert,
18 Mpnet,
20}
21
22#[derive(Debug, Clone)]
23pub struct BertEncoderLayerSpec {
24 pub layer_prefix: String,
25 pub qkv_style: BertQkvStyle,
26 pub hidden_size: usize,
27 pub num_heads: usize,
28 pub head_dim: usize,
29 pub eps: f32,
30 pub attention_mask_input: String,
31}
32
33impl BertEncoderLayerSpec {
34 pub fn hf(
35 layer_prefix: impl Into<String>,
36 qkv_style: BertQkvStyle,
37 hidden_size: usize,
38 num_heads: usize,
39 eps: f32,
40 ) -> Self {
41 Self {
42 layer_prefix: layer_prefix.into(),
43 qkv_style,
44 hidden_size,
45 num_heads,
46 head_dim: hidden_size / num_heads,
47 eps,
48 attention_mask_input: "attention_mask".into(),
49 }
50 }
51}
52
53#[derive(Debug, Clone)]
54pub struct BertEncoderLayerStage {
55 pub spec: BertEncoderLayerSpec,
56}
57
58impl BertEncoderLayerStage {
59 pub fn new(spec: BertEncoderLayerSpec) -> Self {
60 Self { spec }
61 }
62}
63
64impl BlockStage for BertEncoderLayerStage {
65 fn emit(&self, ctx: &mut FlowCtx<'_>, input: FlowValue) -> Result<Option<FlowValue>> {
66 let spec = &self.spec;
67 let h = spec.hidden_size;
68 let nh = spec.num_heads;
69 let dh = spec.head_dim;
70 let lp = &spec.layer_prefix;
71
72 let (qkv_w, qkv_b) = load_fused_qkv(ctx, lp, h, spec.qkv_style)?;
73
74 let out_w = ctx.load_param(&format!("{lp}.attention.output.dense.weight"), true)?;
75 let out_b = ctx.load_param(&format!("{lp}.attention.output.dense.bias"), false)?;
76 let ln1_g = ctx.load_param(&format!("{lp}.attention.output.LayerNorm.weight"), false)?;
77 let ln1_b = ctx.load_param(&format!("{lp}.attention.output.LayerNorm.bias"), false)?;
78 let ln2_g = ctx.load_param(&format!("{lp}.output.LayerNorm.weight"), false)?;
79 let ln2_b = ctx.load_param(&format!("{lp}.output.LayerNorm.bias"), false)?;
80 let int_w = ctx.load_param(&format!("{lp}.intermediate.dense.weight"), true)?;
81 let int_b = ctx.load_param(&format!("{lp}.intermediate.dense.bias"), false)?;
82 let out2_w = ctx.load_param(&format!("{lp}.output.dense.weight"), true)?;
83 let out2_b = ctx.load_param(&format!("{lp}.output.dense.bias"), false)?;
84
85 let mask_id = ctx
86 .state
87 .inputs
88 .get(&spec.attention_mask_input)
89 .map(|(id, _)| *id)
90 .ok_or_else(|| {
91 anyhow::anyhow!(
92 "BertEncoderLayer requires input `{}`",
93 spec.attention_mask_input
94 )
95 })?;
96
97 let mut gb = HirMut::new(ctx.hir());
98 let skip = input.id;
99
100 let qkv_mm = gb.mm(skip, qkv_w);
101 let qkv = gb.add(qkv_mm, qkv_b);
102 let last_ax = gb.shape(qkv).rank() - 1;
103 let q = gb.narrow_(qkv, last_ax, 0, h);
104 let k = gb.narrow_(qkv, last_ax, h, h);
105 let v = gb.narrow_(qkv, last_ax, 2 * h, h);
106 let attn = gb.attention_(q, k, v, mask_id, nh, dh);
107
108 let attn_mm = gb.mm(attn, out_w);
109 let attn_out = gb.add(attn_mm, out_b);
110 let res1 = gb.add(attn_out, skip);
111 let normed1 = gb.ln(res1, ln1_g, ln1_b, spec.eps);
112
113 let int_mm = gb.mm(normed1, int_w);
114 let int_add = gb.add(int_mm, int_b);
115 let ffn_int = gb.gelu(int_add);
116 let out2_mm = gb.mm(ffn_int, out2_w);
117 let ffn_out = gb.add(out2_mm, out2_b);
118 let res2 = gb.add(ffn_out, normed1);
119 let out = gb.ln(res2, ln2_g, ln2_b, spec.eps);
120
121 Ok(Some(ctx.wrap(out, input.shape.clone())))
122 }
123}
124
125fn load_fused_qkv(
126 ctx: &mut FlowCtx<'_>,
127 layer_prefix: &str,
128 h: usize,
129 style: BertQkvStyle,
130) -> Result<(rlx_ir::HirNodeId, rlx_ir::HirNodeId)> {
131 let (wq_key, wk_key, wv_key, bq_key, bk_key, bv_key) = match style {
132 BertQkvStyle::Bert => (
133 format!("{layer_prefix}.attention.self.query.weight"),
134 format!("{layer_prefix}.attention.self.key.weight"),
135 format!("{layer_prefix}.attention.self.value.weight"),
136 format!("{layer_prefix}.attention.self.query.bias"),
137 format!("{layer_prefix}.attention.self.key.bias"),
138 format!("{layer_prefix}.attention.self.value.bias"),
139 ),
140 BertQkvStyle::Mpnet => (
141 format!("{layer_prefix}.attention.attn.q.weight"),
142 format!("{layer_prefix}.attention.attn.k.weight"),
143 format!("{layer_prefix}.attention.attn.v.weight"),
144 format!("{layer_prefix}.attention.attn.q.bias"),
145 format!("{layer_prefix}.attention.attn.k.bias"),
146 format!("{layer_prefix}.attention.attn.v.bias"),
147 ),
148 };
149
150 let wq_data = ctx.weights.take(&wq_key, true)?;
151 let wk_data = ctx.weights.take(&wk_key, true)?;
152 let wv_data = ctx.weights.take(&wv_key, true)?;
153 let (bq_data, _) = ctx.weights.take(&bq_key, false)?;
154 let (bk_data, _) = ctx.weights.take(&bk_key, false)?;
155 let (bv_data, _) = ctx.weights.take(&bv_key, false)?;
156
157 let w_name = format!("{layer_prefix}.attention.qkv.weight");
158 let b_name = format!("{layer_prefix}.attention.qkv.bias");
159
160 let (wq, _) = wq_data;
161 let (wk, _) = wk_data;
162 let (wv, _) = wv_data;
163
164 let mut fused_w = vec![0f32; h * 3 * h];
165 let mut fused_b = vec![0f32; 3 * h];
166 for row in 0..h {
167 fused_w[row * 3 * h..row * 3 * h + h].copy_from_slice(&wq[row * h..(row + 1) * h]);
168 fused_w[row * 3 * h + h..row * 3 * h + 2 * h].copy_from_slice(&wk[row * h..(row + 1) * h]);
169 fused_w[row * 3 * h + 2 * h..row * 3 * h + 3 * h]
170 .copy_from_slice(&wv[row * h..(row + 1) * h]);
171 }
172 fused_b[..h].copy_from_slice(&bq_data);
173 fused_b[h..2 * h].copy_from_slice(&bk_data);
174 fused_b[2 * h..].copy_from_slice(&bv_data);
175
176 let w_id = ctx
177 .hir()
178 .param(&w_name, Shape::new(&[h, 3 * h], rlx_ir::DType::F32));
179 let b_id = ctx
180 .hir()
181 .param(&b_name, Shape::new(&[3 * h], rlx_ir::DType::F32));
182 ctx.params.insert(w_name, fused_w);
183 ctx.params.insert(b_name, fused_b);
184 Ok((w_id, b_id))
185}