1use anyhow::Result;
19use rlx_flow::{BertQkvStyle, BuiltModel, CompileProfile, ModelFlow};
20use rlx_ir::{DType, Shape};
21
22use rlx_core::config::BertConfig;
23use rlx_core::flow_util::WeightMapSource;
24use rlx_core::weight_map::WeightMap;
25
26#[derive(Debug, Clone)]
27pub struct BertFlow<'a> {
28 cfg: &'a BertConfig,
29 batch: usize,
30 seq: usize,
31 profile: CompileProfile,
32}
33
34impl<'a> BertFlow<'a> {
35 pub fn new(cfg: &'a BertConfig, batch: usize, seq: usize) -> Self {
36 Self {
37 cfg,
38 batch,
39 seq,
40 profile: CompileProfile::encoder(),
41 }
42 }
43
44 pub fn with_profile(mut self, profile: CompileProfile) -> Self {
45 self.profile = profile;
46 self
47 }
48
49 pub fn build(self, weights: &mut WeightMap) -> Result<BuiltModel> {
50 let prefix = if weights.has("bert.embeddings.word_embeddings.weight") {
51 "bert."
52 } else {
53 ""
54 };
55 let qkv_style = if weights.has(&format!(
56 "{prefix}encoder.layer.0.attention.self.query.weight"
57 )) {
58 BertQkvStyle::Bert
59 } else {
60 BertQkvStyle::Mpnet
61 };
62
63 let h = self.cfg.hidden_size;
64 let f = DType::F32;
65 let eps = self.cfg.layer_norm_eps as f32;
66
67 let flow = ModelFlow::new("bert")
68 .with_profile(self.profile)
69 .input("input_ids", Shape::new(&[self.batch, self.seq], DType::F32))
70 .input("attention_mask", Shape::new(&[self.batch, self.seq], f))
71 .input(
72 "token_type_ids",
73 Shape::new(&[self.batch, self.seq], DType::F32),
74 )
75 .input(
76 "position_ids",
77 Shape::new(&[self.batch, self.seq], DType::F32),
78 )
79 .embed(format!("{prefix}embeddings.word_embeddings.weight"))
80 .gather_add(
81 "position_ids",
82 format!("{prefix}embeddings.position_embeddings.weight"),
83 )
84 .gather_add(
85 "token_type_ids",
86 format!("{prefix}embeddings.token_type_embeddings.weight"),
87 )
88 .layer_norm(
89 format!("{prefix}embeddings.LayerNorm.weight"),
90 format!("{prefix}embeddings.LayerNorm.bias"),
91 eps,
92 )
93 .repeat_bert_layers(
94 self.cfg.num_hidden_layers,
95 prefix.trim_end_matches('.'),
96 qkv_style,
97 h,
98 self.cfg.num_attention_heads,
99 eps,
100 )
101 .output("hidden_states");
102
103 flow.build(&mut WeightMapSource(weights))
104 }
105}
106
107pub fn build_bert_built(
108 cfg: &BertConfig,
109 weights: &mut WeightMap,
110 batch: usize,
111 seq: usize,
112) -> Result<BuiltModel> {
113 BertFlow::new(cfg, batch, seq).build(weights)
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119 use rlx_core::config::BertConfig;
120 use std::collections::HashMap;
121
122 #[test]
123 fn bert_flow_builds() {
124 let cfg = BertConfig {
125 vocab_size: 32,
126 hidden_size: 16,
127 num_hidden_layers: 1,
128 num_attention_heads: 4,
129 intermediate_size: 32,
130 max_position_embeddings: 32,
131 type_vocab_size: 2,
132 layer_norm_eps: 1e-12,
133 hidden_act: "gelu".into(),
134 };
135 let h = cfg.hidden_size;
136 let int_dim = cfg.intermediate_size;
137 let mut t: HashMap<String, (Vec<f32>, Vec<usize>)> = HashMap::new();
138 let z = |n: usize| vec![0.0f32; n];
139 t.insert(
140 "embeddings.word_embeddings.weight".into(),
141 (z(cfg.vocab_size * h), vec![cfg.vocab_size, h]),
142 );
143 t.insert(
144 "embeddings.position_embeddings.weight".into(),
145 (
146 z(cfg.max_position_embeddings * h),
147 vec![cfg.max_position_embeddings, h],
148 ),
149 );
150 t.insert(
151 "embeddings.token_type_embeddings.weight".into(),
152 (z(cfg.type_vocab_size * h), vec![cfg.type_vocab_size, h]),
153 );
154 t.insert("embeddings.LayerNorm.weight".into(), (z(h), vec![h]));
155 t.insert("embeddings.LayerNorm.bias".into(), (z(h), vec![h]));
156 let lp = "encoder.layer.0";
157 t.insert(
158 format!("{lp}.attention.self.query.weight"),
159 (z(h * h), vec![h, h]),
160 );
161 t.insert(format!("{lp}.attention.self.query.bias"), (z(h), vec![h]));
162 t.insert(
163 format!("{lp}.attention.self.key.weight"),
164 (z(h * h), vec![h, h]),
165 );
166 t.insert(format!("{lp}.attention.self.key.bias"), (z(h), vec![h]));
167 t.insert(
168 format!("{lp}.attention.self.value.weight"),
169 (z(h * h), vec![h, h]),
170 );
171 t.insert(format!("{lp}.attention.self.value.bias"), (z(h), vec![h]));
172 t.insert(
173 format!("{lp}.attention.output.dense.weight"),
174 (z(h * h), vec![h, h]),
175 );
176 t.insert(format!("{lp}.attention.output.dense.bias"), (z(h), vec![h]));
177 t.insert(
178 format!("{lp}.attention.output.LayerNorm.weight"),
179 (z(h), vec![h]),
180 );
181 t.insert(
182 format!("{lp}.attention.output.LayerNorm.bias"),
183 (z(h), vec![h]),
184 );
185 t.insert(
186 format!("{lp}.intermediate.dense.weight"),
187 (z(int_dim * h), vec![int_dim, h]),
188 );
189 t.insert(
190 format!("{lp}.intermediate.dense.bias"),
191 (z(int_dim), vec![int_dim]),
192 );
193 t.insert(
194 format!("{lp}.output.dense.weight"),
195 (z(int_dim * h), vec![h, int_dim]),
196 );
197 t.insert(format!("{lp}.output.dense.bias"), (z(h), vec![h]));
198 t.insert(format!("{lp}.output.LayerNorm.weight"), (z(h), vec![h]));
199 t.insert(format!("{lp}.output.LayerNorm.bias"), (z(h), vec![h]));
200 let mut wm = WeightMap::from_tensors(t);
201 let built = BertFlow::new(&cfg, 1, 4).build(&mut wm).unwrap();
202 assert!(built.into_hir().unwrap().len() > 10);
203 }
204}