Skip to main content

rlx_bert/
flow.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Tier-0 BERT encoder flow — native [`ModelFlow`] assembly.
17
18use anyhow::Result;
19use rlx_flow::{BertQkvStyle, BuiltModel, CompileProfile, ModelFlow};
20use rlx_ir::{DType, Shape};
21
22use rlx_core::config::BertConfig;
23use rlx_core::flow_util::WeightMapSource;
24use rlx_core::weight_map::WeightMap;
25
26#[derive(Debug, Clone)]
27pub struct BertFlow<'a> {
28    cfg: &'a BertConfig,
29    batch: usize,
30    seq: usize,
31    profile: CompileProfile,
32}
33
34impl<'a> BertFlow<'a> {
35    pub fn new(cfg: &'a BertConfig, batch: usize, seq: usize) -> Self {
36        Self {
37            cfg,
38            batch,
39            seq,
40            profile: CompileProfile::encoder(),
41        }
42    }
43
44    pub fn with_profile(mut self, profile: CompileProfile) -> Self {
45        self.profile = profile;
46        self
47    }
48
49    pub fn build(self, weights: &mut WeightMap) -> Result<BuiltModel> {
50        let prefix = if weights.has("bert.embeddings.word_embeddings.weight") {
51            "bert."
52        } else {
53            ""
54        };
55        let qkv_style = if weights.has(&format!(
56            "{prefix}encoder.layer.0.attention.self.query.weight"
57        )) {
58            BertQkvStyle::Bert
59        } else {
60            BertQkvStyle::Mpnet
61        };
62
63        let h = self.cfg.hidden_size;
64        let f = DType::F32;
65        let eps = self.cfg.layer_norm_eps as f32;
66
67        let flow = ModelFlow::new("bert")
68            .with_profile(self.profile)
69            .input("input_ids", Shape::new(&[self.batch, self.seq], DType::F32))
70            .input("attention_mask", Shape::new(&[self.batch, self.seq], f))
71            .input(
72                "token_type_ids",
73                Shape::new(&[self.batch, self.seq], DType::F32),
74            )
75            .input(
76                "position_ids",
77                Shape::new(&[self.batch, self.seq], DType::F32),
78            )
79            .embed(format!("{prefix}embeddings.word_embeddings.weight"))
80            .gather_add(
81                "position_ids",
82                format!("{prefix}embeddings.position_embeddings.weight"),
83            )
84            .gather_add(
85                "token_type_ids",
86                format!("{prefix}embeddings.token_type_embeddings.weight"),
87            )
88            .layer_norm(
89                format!("{prefix}embeddings.LayerNorm.weight"),
90                format!("{prefix}embeddings.LayerNorm.bias"),
91                eps,
92            )
93            .repeat_bert_layers(
94                self.cfg.num_hidden_layers,
95                prefix.trim_end_matches('.'),
96                qkv_style,
97                h,
98                self.cfg.num_attention_heads,
99                eps,
100            )
101            .output("hidden_states");
102
103        flow.build(&mut WeightMapSource(weights))
104    }
105}
106
107pub fn build_bert_built(
108    cfg: &BertConfig,
109    weights: &mut WeightMap,
110    batch: usize,
111    seq: usize,
112) -> Result<BuiltModel> {
113    BertFlow::new(cfg, batch, seq).build(weights)
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119    use rlx_core::config::BertConfig;
120    use std::collections::HashMap;
121
122    #[test]
123    fn bert_flow_builds() {
124        let cfg = BertConfig {
125            vocab_size: 32,
126            hidden_size: 16,
127            num_hidden_layers: 1,
128            num_attention_heads: 4,
129            intermediate_size: 32,
130            max_position_embeddings: 32,
131            type_vocab_size: 2,
132            layer_norm_eps: 1e-12,
133            hidden_act: "gelu".into(),
134        };
135        let h = cfg.hidden_size;
136        let int_dim = cfg.intermediate_size;
137        let mut t: HashMap<String, (Vec<f32>, Vec<usize>)> = HashMap::new();
138        let z = |n: usize| vec![0.0f32; n];
139        t.insert(
140            "embeddings.word_embeddings.weight".into(),
141            (z(cfg.vocab_size * h), vec![cfg.vocab_size, h]),
142        );
143        t.insert(
144            "embeddings.position_embeddings.weight".into(),
145            (
146                z(cfg.max_position_embeddings * h),
147                vec![cfg.max_position_embeddings, h],
148            ),
149        );
150        t.insert(
151            "embeddings.token_type_embeddings.weight".into(),
152            (z(cfg.type_vocab_size * h), vec![cfg.type_vocab_size, h]),
153        );
154        t.insert("embeddings.LayerNorm.weight".into(), (z(h), vec![h]));
155        t.insert("embeddings.LayerNorm.bias".into(), (z(h), vec![h]));
156        let lp = "encoder.layer.0";
157        t.insert(
158            format!("{lp}.attention.self.query.weight"),
159            (z(h * h), vec![h, h]),
160        );
161        t.insert(format!("{lp}.attention.self.query.bias"), (z(h), vec![h]));
162        t.insert(
163            format!("{lp}.attention.self.key.weight"),
164            (z(h * h), vec![h, h]),
165        );
166        t.insert(format!("{lp}.attention.self.key.bias"), (z(h), vec![h]));
167        t.insert(
168            format!("{lp}.attention.self.value.weight"),
169            (z(h * h), vec![h, h]),
170        );
171        t.insert(format!("{lp}.attention.self.value.bias"), (z(h), vec![h]));
172        t.insert(
173            format!("{lp}.attention.output.dense.weight"),
174            (z(h * h), vec![h, h]),
175        );
176        t.insert(format!("{lp}.attention.output.dense.bias"), (z(h), vec![h]));
177        t.insert(
178            format!("{lp}.attention.output.LayerNorm.weight"),
179            (z(h), vec![h]),
180        );
181        t.insert(
182            format!("{lp}.attention.output.LayerNorm.bias"),
183            (z(h), vec![h]),
184        );
185        t.insert(
186            format!("{lp}.intermediate.dense.weight"),
187            (z(int_dim * h), vec![int_dim, h]),
188        );
189        t.insert(
190            format!("{lp}.intermediate.dense.bias"),
191            (z(int_dim), vec![int_dim]),
192        );
193        t.insert(
194            format!("{lp}.output.dense.weight"),
195            (z(int_dim * h), vec![h, int_dim]),
196        );
197        t.insert(format!("{lp}.output.dense.bias"), (z(h), vec![h]));
198        t.insert(format!("{lp}.output.LayerNorm.weight"), (z(h), vec![h]));
199        t.insert(format!("{lp}.output.LayerNorm.bias"), (z(h), vec![h]));
200        let mut wm = WeightMap::from_tensors(t);
201        let built = BertFlow::new(&cfg, 1, 4).build(&mut wm).unwrap();
202        assert!(built.into_hir().unwrap().len() > 10);
203    }
204}