Skip to main content

rlx_embed/
bert.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! RLX-compiled BERT encoder for text embeddings.
17
18use std::path::Path;
19
20use anyhow::Result;
21use rlx_runtime::{CompiledGraph, Device, Precision, PrecisionPolicy, Session};
22
23use rlx_bert::flow::build_bert_built;
24use rlx_core::config::BertConfig;
25use rlx_core::flow_bridge::compile_options_from_profile;
26use rlx_core::flow_util::{compile_built, graph_from_built};
27use rlx_core::weight_map::WeightMap;
28use rlx_ir::logical_kernel::KernelDispatchConfig;
29
30/// RLX-compiled BERT model ready for inference.
31pub struct RlxBertModel {
32    compiled: CompiledGraph,
33    config: BertConfig,
34    weights_path: String,
35    compiled_bs: (usize, usize),
36    device: Device,
37    precision: Precision,
38    policy: Option<PrecisionPolicy>,
39}
40
41impl RlxBertModel {
42    pub fn load_sized(
43        config_path: &Path,
44        weights_path: &str,
45        batch: usize,
46        seq: usize,
47    ) -> Result<Self> {
48        Self::load_sized_on(config_path, weights_path, batch, seq, Device::Cpu)
49    }
50
51    pub fn load_sized_on(
52        config_path: &Path,
53        weights_path: &str,
54        batch: usize,
55        seq: usize,
56        device: Device,
57    ) -> Result<Self> {
58        Self::load_sized_with_policy(
59            config_path,
60            weights_path,
61            batch,
62            seq,
63            device,
64            Precision::F32,
65            None,
66        )
67    }
68
69    pub fn load_sized_with_policy(
70        config_path: &Path,
71        weights_path: &str,
72        batch: usize,
73        seq: usize,
74        device: Device,
75        precision: Precision,
76        policy: Option<PrecisionPolicy>,
77    ) -> Result<Self> {
78        let config = BertConfig::from_file(config_path)?;
79        let compiled = Self::compile_flow(
80            &config,
81            weights_path,
82            batch,
83            seq,
84            device,
85            precision,
86            &policy,
87        )?;
88        Ok(Self {
89            compiled,
90            config,
91            weights_path: weights_path.to_string(),
92            compiled_bs: (batch, seq),
93            device,
94            precision,
95            policy,
96        })
97    }
98
99    pub fn load(config_path: &Path, weights_path: &str) -> Result<Self> {
100        Self::load_sized(config_path, weights_path, 1, 1)
101    }
102
103    pub fn recompile(&mut self, batch: usize, seq: usize) -> Result<()> {
104        if self.compiled_bs == (batch, seq) {
105            return Ok(());
106        }
107        self.compiled = Self::compile_flow(
108            &self.config,
109            &self.weights_path,
110            batch,
111            seq,
112            self.device,
113            self.precision,
114            &self.policy,
115        )?;
116        self.compiled_bs = (batch, seq);
117        Ok(())
118    }
119
120    fn compile_flow(
121        config: &BertConfig,
122        weights_path: &str,
123        batch: usize,
124        seq: usize,
125        device: Device,
126        precision: Precision,
127        policy: &Option<PrecisionPolicy>,
128    ) -> Result<CompiledGraph> {
129        let mut wm = WeightMap::from_file(weights_path)?;
130        let built = build_bert_built(config, &mut wm, batch, seq)?;
131        if device == Device::Cpu && precision == Precision::F32 && policy.is_none() {
132            return compile_built(built, device);
133        }
134        let profile = built.profile().clone();
135        let (graph, params) = graph_from_built(built)?;
136        let mut opts =
137            compile_options_from_profile(&profile, device, KernelDispatchConfig::default());
138        opts.precision = precision;
139        opts.policy = policy.clone();
140        let mut compiled = Session::new(device).compile_with(graph, &opts);
141        for (name, data) in params {
142            compiled.set_param(&name, &data);
143        }
144        Ok(compiled)
145    }
146
147    pub fn forward(
148        &mut self,
149        input_ids: &[f32],
150        attention_mask: &[f32],
151        token_type_ids: &[f32],
152        position_ids: &[f32],
153    ) -> Vec<f32> {
154        let batch = self.compiled_bs.0;
155        let seq = self.compiled_bs.1;
156        let _ = self.recompile(batch, seq);
157        let outputs = self.compiled.run(&[
158            ("input_ids", input_ids),
159            ("attention_mask", attention_mask),
160            ("token_type_ids", token_type_ids),
161            ("position_ids", position_ids),
162        ]);
163        outputs.into_iter().next().unwrap_or_default()
164    }
165
166    pub fn hidden_size(&self) -> usize {
167        self.config.hidden_size
168    }
169}