1use anyhow::Result;
19use rlx_core::config::BertConfig;
20use rlx_core::weight_map::WeightMap;
21use rlx_ir::*;
22use std::collections::HashMap;
23
24pub fn build_bert_graph(
35 cfg: &BertConfig,
36 weights: &mut WeightMap,
37) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
38 build_bert_graph_sized(cfg, weights, 1, 1)
39}
40
41pub fn build_bert_graph_sized(
42 cfg: &BertConfig,
43 weights: &mut WeightMap,
44 batch: usize,
45 seq: usize,
46) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
47 rlx_core::flow_util::graph_from_built(crate::flow::build_bert_built(cfg, weights, batch, seq)?)
48}
49
50#[allow(dead_code)]
52fn load_param(
53 g: &mut Graph,
54 params: &mut HashMap<String, Vec<f32>>,
55 weights: &mut WeightMap,
56 key: &str,
57 _expected_shape: &[usize],
58 transpose: bool,
59) -> Result<NodeId> {
60 let (data, shape) = if transpose {
61 weights.take_transposed(key)?
62 } else {
63 weights.take(key)?
64 };
65 let name = key.to_string();
66 let ir_shape = Shape::new(&shape, DType::F32);
67 let id = g.param(&name, ir_shape);
68 params.insert(name, data);
69 Ok(id)
70}
71
72#[allow(dead_code)]
74fn load_fused_qkv(
75 g: &mut Graph,
76 params: &mut HashMap<String, Vec<f32>>,
77 weights: &mut WeightMap,
78 layer_prefix: &str,
79 h: usize,
80 _nh: usize,
81 _dh: usize,
82) -> Result<(NodeId, NodeId)> {
83 let (wq, _) =
84 weights.take_transposed(&format!("{layer_prefix}.attention.self.query.weight"))?;
85 let (wk, _) = weights.take_transposed(&format!("{layer_prefix}.attention.self.key.weight"))?;
86 let (wv, _) =
87 weights.take_transposed(&format!("{layer_prefix}.attention.self.value.weight"))?;
88
89 let bq = weights
90 .take(&format!("{layer_prefix}.attention.self.query.bias"))?
91 .0;
92 let bk = weights
93 .take(&format!("{layer_prefix}.attention.self.key.bias"))?
94 .0;
95 let bv = weights
96 .take(&format!("{layer_prefix}.attention.self.value.bias"))?
97 .0;
98
99 let mut fused_w = vec![0f32; h * 3 * h];
101 let mut fused_b = vec![0f32; 3 * h];
102 for row in 0..h {
103 fused_w[row * 3 * h..row * 3 * h + h].copy_from_slice(&wq[row * h..(row + 1) * h]);
104 fused_w[row * 3 * h + h..row * 3 * h + 2 * h].copy_from_slice(&wk[row * h..(row + 1) * h]);
105 fused_w[row * 3 * h + 2 * h..row * 3 * h + 3 * h]
106 .copy_from_slice(&wv[row * h..(row + 1) * h]);
107 }
108 fused_b[..h].copy_from_slice(&bq);
109 fused_b[h..2 * h].copy_from_slice(&bk);
110 fused_b[2 * h..].copy_from_slice(&bv);
111
112 let w_name = format!("{layer_prefix}.attention.qkv.weight");
113 let b_name = format!("{layer_prefix}.attention.qkv.bias");
114 let w_id = g.param(&w_name, Shape::new(&[h, 3 * h], DType::F32));
115 let b_id = g.param(&b_name, Shape::new(&[3 * h], DType::F32));
116 params.insert(w_name, fused_w);
117 params.insert(b_name, fused_b);
118
119 Ok((w_id, b_id))
120}
121
122#[allow(dead_code)]
124fn load_fused_qkv_mpnet(
125 g: &mut Graph,
126 params: &mut HashMap<String, Vec<f32>>,
127 weights: &mut WeightMap,
128 layer_prefix: &str,
129 h: usize,
130 nh: usize,
131 dh: usize,
132) -> Result<(NodeId, NodeId)> {
133 let q_key = format!("{layer_prefix}.attention.attn.q.weight");
135 if weights.has(&q_key) {
136 let (wq, _) = weights.take_transposed(&q_key)?;
137 let (wk, _) =
138 weights.take_transposed(&format!("{layer_prefix}.attention.attn.k.weight"))?;
139 let (wv, _) =
140 weights.take_transposed(&format!("{layer_prefix}.attention.attn.v.weight"))?;
141 let bq = weights
142 .take(&format!("{layer_prefix}.attention.attn.q.bias"))?
143 .0;
144 let bk = weights
145 .take(&format!("{layer_prefix}.attention.attn.k.bias"))?
146 .0;
147 let bv = weights
148 .take(&format!("{layer_prefix}.attention.attn.v.bias"))?
149 .0;
150
151 let mut fused_w = vec![0f32; h * 3 * h];
152 let mut fused_b = vec![0f32; 3 * h];
153 for row in 0..h {
154 fused_w[row * 3 * h..row * 3 * h + h].copy_from_slice(&wq[row * h..(row + 1) * h]);
155 fused_w[row * 3 * h + h..row * 3 * h + 2 * h]
156 .copy_from_slice(&wk[row * h..(row + 1) * h]);
157 fused_w[row * 3 * h + 2 * h..row * 3 * h + 3 * h]
158 .copy_from_slice(&wv[row * h..(row + 1) * h]);
159 }
160 fused_b[..h].copy_from_slice(&bq);
161 fused_b[h..2 * h].copy_from_slice(&bk);
162 fused_b[2 * h..].copy_from_slice(&bv);
163
164 let w_name = format!("{layer_prefix}.attention.qkv.weight");
165 let b_name = format!("{layer_prefix}.attention.qkv.bias");
166 let w_id = g.param(&w_name, Shape::new(&[h, 3 * h], DType::F32));
167 let b_id = g.param(&b_name, Shape::new(&[3 * h], DType::F32));
168 params.insert(w_name, fused_w);
169 params.insert(b_name, fused_b);
170 return Ok((w_id, b_id));
171 }
172
173 let fused_key = format!("{layer_prefix}.attention.self.qkv.weight");
175 if weights.has(&fused_key) {
176 let (data, _) = weights.take_transposed(&fused_key)?;
177 let bias = weights
178 .take(&format!("{layer_prefix}.attention.self.qkv.bias"))?
179 .0;
180 let w_name = format!("{layer_prefix}.attention.qkv.weight");
181 let b_name = format!("{layer_prefix}.attention.qkv.bias");
182 let w_id = g.param(&w_name, Shape::new(&[h, 3 * h], DType::F32));
183 let b_id = g.param(&b_name, Shape::new(&[3 * h], DType::F32));
184 params.insert(w_name, data);
185 params.insert(b_name, bias);
186 return Ok((w_id, b_id));
187 }
188
189 load_fused_qkv(g, params, weights, layer_prefix, h, nh, dh)
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196
197 #[test]
198 fn build_tiny_bert_graph() {
199 let cfg = BertConfig {
201 vocab_size: 100,
202 hidden_size: 64,
203 num_hidden_layers: 1,
204 num_attention_heads: 2,
205 intermediate_size: 256,
206 max_position_embeddings: 32,
207 type_vocab_size: 2,
208 layer_norm_eps: 1e-12,
209 hidden_act: "gelu".into(),
210 };
211
212 let h = cfg.hidden_size;
214 let int = cfg.intermediate_size;
215 let mut tensors = HashMap::new();
216 let add = |m: &mut HashMap<String, (Vec<f32>, Vec<usize>)>, k: &str, shape: Vec<usize>| {
217 let size: usize = shape.iter().product();
218 m.insert(k.to_string(), (vec![0.01f32; size], shape));
219 };
220
221 add(
223 &mut tensors,
224 "embeddings.word_embeddings.weight",
225 vec![100, h],
226 );
227 add(
228 &mut tensors,
229 "embeddings.position_embeddings.weight",
230 vec![32, h],
231 );
232 add(
233 &mut tensors,
234 "embeddings.token_type_embeddings.weight",
235 vec![2, h],
236 );
237 add(&mut tensors, "embeddings.LayerNorm.weight", vec![h]);
238 add(&mut tensors, "embeddings.LayerNorm.bias", vec![h]);
239
240 add(
242 &mut tensors,
243 "encoder.layer.0.attention.self.query.weight",
244 vec![h, h],
245 );
246 add(
247 &mut tensors,
248 "encoder.layer.0.attention.self.query.bias",
249 vec![h],
250 );
251 add(
252 &mut tensors,
253 "encoder.layer.0.attention.self.key.weight",
254 vec![h, h],
255 );
256 add(
257 &mut tensors,
258 "encoder.layer.0.attention.self.key.bias",
259 vec![h],
260 );
261 add(
262 &mut tensors,
263 "encoder.layer.0.attention.self.value.weight",
264 vec![h, h],
265 );
266 add(
267 &mut tensors,
268 "encoder.layer.0.attention.self.value.bias",
269 vec![h],
270 );
271 add(
272 &mut tensors,
273 "encoder.layer.0.attention.output.dense.weight",
274 vec![h, h],
275 );
276 add(
277 &mut tensors,
278 "encoder.layer.0.attention.output.dense.bias",
279 vec![h],
280 );
281 add(
282 &mut tensors,
283 "encoder.layer.0.attention.output.LayerNorm.weight",
284 vec![h],
285 );
286 add(
287 &mut tensors,
288 "encoder.layer.0.attention.output.LayerNorm.bias",
289 vec![h],
290 );
291
292 add(
294 &mut tensors,
295 "encoder.layer.0.intermediate.dense.weight",
296 vec![int, h],
297 );
298 add(
299 &mut tensors,
300 "encoder.layer.0.intermediate.dense.bias",
301 vec![int],
302 );
303 add(
304 &mut tensors,
305 "encoder.layer.0.output.dense.weight",
306 vec![h, int],
307 );
308 add(&mut tensors, "encoder.layer.0.output.dense.bias", vec![h]);
309 add(
310 &mut tensors,
311 "encoder.layer.0.output.LayerNorm.weight",
312 vec![h],
313 );
314 add(
315 &mut tensors,
316 "encoder.layer.0.output.LayerNorm.bias",
317 vec![h],
318 );
319
320 let mut wm = WeightMap::from_tensors(tensors);
321 let (graph, params) = build_bert_graph(&cfg, &mut wm).unwrap();
322
323 println!("{graph}");
324 println!("Nodes: {}, Params: {}", graph.len(), params.len());
325
326 let errors = rlx_ir::verify::verify(&graph);
328 assert!(errors.is_empty(), "verification errors: {errors:?}");
329
330 assert!(
332 params.len() >= 15,
333 "expected 15+ params, got {}",
334 params.len()
335 );
336
337 assert!(!graph.outputs.is_empty());
339 }
340}