1use std::path::Path;
19
20use anyhow::Result;
21use rlx_runtime::{CompiledGraph, Device, Precision, PrecisionPolicy, Session};
22
23use rlx_bert::flow::build_bert_built;
24use rlx_core::config::BertConfig;
25use rlx_core::flow_bridge::compile_options_from_profile;
26use rlx_core::flow_util::{compile_built, graph_from_built};
27use rlx_core::weight_map::WeightMap;
28use rlx_ir::logical_kernel::KernelDispatchConfig;
29
30pub struct RlxBertModel {
32 compiled: CompiledGraph,
33 config: BertConfig,
34 weights_path: String,
35 compiled_bs: (usize, usize),
36 device: Device,
37 precision: Precision,
38 policy: Option<PrecisionPolicy>,
39}
40
41impl RlxBertModel {
42 pub fn load_sized(
43 config_path: &Path,
44 weights_path: &str,
45 batch: usize,
46 seq: usize,
47 ) -> Result<Self> {
48 Self::load_sized_on(config_path, weights_path, batch, seq, Device::Cpu)
49 }
50
51 pub fn load_sized_on(
52 config_path: &Path,
53 weights_path: &str,
54 batch: usize,
55 seq: usize,
56 device: Device,
57 ) -> Result<Self> {
58 Self::load_sized_with_policy(
59 config_path,
60 weights_path,
61 batch,
62 seq,
63 device,
64 Precision::F32,
65 None,
66 )
67 }
68
69 pub fn load_sized_with_policy(
70 config_path: &Path,
71 weights_path: &str,
72 batch: usize,
73 seq: usize,
74 device: Device,
75 precision: Precision,
76 policy: Option<PrecisionPolicy>,
77 ) -> Result<Self> {
78 let config = BertConfig::from_file(config_path)?;
79 let compiled = Self::compile_flow(
80 &config,
81 weights_path,
82 batch,
83 seq,
84 device,
85 precision,
86 &policy,
87 )?;
88 Ok(Self {
89 compiled,
90 config,
91 weights_path: weights_path.to_string(),
92 compiled_bs: (batch, seq),
93 device,
94 precision,
95 policy,
96 })
97 }
98
99 pub fn load(config_path: &Path, weights_path: &str) -> Result<Self> {
100 Self::load_sized(config_path, weights_path, 1, 1)
101 }
102
103 pub fn recompile(&mut self, batch: usize, seq: usize) -> Result<()> {
104 if self.compiled_bs == (batch, seq) {
105 return Ok(());
106 }
107 self.compiled = Self::compile_flow(
108 &self.config,
109 &self.weights_path,
110 batch,
111 seq,
112 self.device,
113 self.precision,
114 &self.policy,
115 )?;
116 self.compiled_bs = (batch, seq);
117 Ok(())
118 }
119
120 fn compile_flow(
121 config: &BertConfig,
122 weights_path: &str,
123 batch: usize,
124 seq: usize,
125 device: Device,
126 precision: Precision,
127 policy: &Option<PrecisionPolicy>,
128 ) -> Result<CompiledGraph> {
129 let mut wm = WeightMap::from_file(weights_path)?;
130 let built = build_bert_built(config, &mut wm, batch, seq)?;
131 if device == Device::Cpu && precision == Precision::F32 && policy.is_none() {
132 return compile_built(built, device);
133 }
134 let profile = built.profile().clone();
135 let (graph, params) = graph_from_built(built)?;
136 let mut opts =
137 compile_options_from_profile(&profile, device, KernelDispatchConfig::default());
138 opts.precision = precision;
139 opts.policy = policy.clone();
140 let mut compiled = Session::new(device).compile_with(graph, &opts);
141 for (name, data) in params {
142 compiled.set_param(&name, &data);
143 }
144 Ok(compiled)
145 }
146
147 pub fn forward(
148 &mut self,
149 input_ids: &[f32],
150 attention_mask: &[f32],
151 token_type_ids: &[f32],
152 position_ids: &[f32],
153 ) -> Vec<f32> {
154 let batch = self.compiled_bs.0;
155 let seq = self.compiled_bs.1;
156 let _ = self.recompile(batch, seq);
157 let outputs = self.compiled.run(&[
158 ("input_ids", input_ids),
159 ("attention_mask", attention_mask),
160 ("token_type_ids", token_type_ids),
161 ("position_ids", position_ids),
162 ]);
163 outputs.into_iter().next().unwrap_or_default()
164 }
165
166 pub fn hidden_size(&self) -> usize {
167 self.config.hidden_size
168 }
169}