1use std::path::{Path, PathBuf};
19
20use anyhow::{Context, Result, bail};
21use rlx_core::config::BertConfig;
22use rlx_core::flow_util::compile_built;
23use rlx_core::validate_standard_device;
24use rlx_core::weight_map::WeightMap;
25use rlx_runtime::{CompiledGraph, Device};
26
27use crate::builder::build_clinicalbert_built;
28#[cfg(feature = "mlm")]
29use crate::builder::build_clinicalbert_with_mlm_built;
30use crate::config::{ClinicalBertConfig, ClinicalBertVariant, validate_hf_config};
31#[cfg(feature = "mlm")]
32use crate::heads::MlmHead;
33#[cfg(feature = "pooler")]
34use crate::heads::PoolerHead;
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub enum Pooling {
39 Cls,
41 Mean,
43 None,
45}
46
47impl Pooling {
48 pub fn from_str_opt(s: &str) -> Option<Self> {
49 match s.to_ascii_lowercase().as_str() {
50 "cls" => Some(Pooling::Cls),
51 "mean" | "avg" | "average" => Some(Pooling::Mean),
52 "none" | "raw" => Some(Pooling::None),
53 _ => None,
54 }
55 }
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
73pub enum MlmExecMode {
74 Cpu,
77 InGraph,
80 #[default]
83 Auto,
84}
85
86impl MlmExecMode {
87 pub fn resolve(self, device: Device, batch: usize) -> MlmExecMode {
90 match self {
91 MlmExecMode::Cpu | MlmExecMode::InGraph => self,
92 MlmExecMode::Auto => match device {
93 Device::Cuda if batch > 8 => MlmExecMode::Cpu,
94 _ => MlmExecMode::InGraph,
95 },
96 }
97 }
98
99 pub fn from_str_opt(s: &str) -> Option<Self> {
100 match s.to_ascii_lowercase().as_str() {
101 "cpu" | "post" | "host" => Some(MlmExecMode::Cpu),
102 "ingraph" | "in-graph" | "in_graph" | "graph" | "fold" | "folded" => {
103 Some(MlmExecMode::InGraph)
104 }
105 "auto" | "default" => Some(MlmExecMode::Auto),
106 _ => None,
107 }
108 }
109}
110
111pub struct ClinicalBertRunner {
117 config: ClinicalBertConfig,
118 weights_path: PathBuf,
119 compiled: CompiledGraph,
120 compiled_bs: (usize, usize),
121 device: Device,
122 pooling: Pooling,
123 #[cfg(feature = "pooler")]
124 pooler_head: Option<PoolerHead>,
125 #[cfg(feature = "mlm")]
126 mlm_head: Option<MlmHead>,
127 #[cfg(feature = "mlm")]
132 mlm_in_graph: bool,
133 #[cfg(feature = "mlm")]
135 cached_mlm_logits: Option<Vec<f32>>,
136}
137
138impl ClinicalBertRunner {
139 pub fn builder() -> ClinicalBertRunnerBuilder {
140 ClinicalBertRunnerBuilder::default()
141 }
142
143 pub fn config(&self) -> &ClinicalBertConfig {
144 &self.config
145 }
146
147 pub fn hidden_size(&self) -> usize {
148 self.config.bert.hidden_size
149 }
150
151 pub fn device(&self) -> Device {
152 self.device
153 }
154
155 pub fn pooling(&self) -> Pooling {
156 self.pooling
157 }
158
159 pub fn compiled_shape(&self) -> (usize, usize) {
160 self.compiled_bs
161 }
162
163 #[cfg(feature = "pooler")]
166 pub fn has_pooler(&self) -> bool {
167 self.pooler_head.is_some()
168 }
169
170 #[cfg(feature = "mlm")]
173 pub fn has_mlm(&self) -> bool {
174 self.mlm_head.is_some() || self.mlm_in_graph
175 }
176
177 #[cfg(feature = "mlm")]
179 pub fn mlm_in_graph(&self) -> bool {
180 self.mlm_in_graph
181 }
182
183 #[cfg(feature = "mlm")]
186 pub fn mlm_mode(&self) -> Option<MlmExecMode> {
187 if self.mlm_in_graph {
188 Some(MlmExecMode::InGraph)
189 } else if self.mlm_head.is_some() {
190 Some(MlmExecMode::Cpu)
191 } else {
192 None
193 }
194 }
195
196 #[cfg(feature = "pooler")]
202 pub fn pooler_output(&self, hidden: &[f32]) -> Result<Vec<f32>> {
203 let head = self.pooler_head.as_ref().ok_or_else(|| {
204 anyhow::anyhow!(
205 "rlx-clinicalbert: pooler not enabled — call .with_pooler() on the builder"
206 )
207 })?;
208 let (b, s) = self.compiled_bs;
209 head.apply(hidden, b, s)
210 }
211
212 #[cfg(feature = "mlm")]
217 pub fn mlm_logits(&self, hidden: &[f32]) -> Result<Vec<f32>> {
218 if self.mlm_in_graph {
219 return self.cached_mlm_logits.clone().ok_or_else(|| {
220 anyhow::anyhow!(
221 "rlx-clinicalbert: call forward() first to populate the in-graph MLM logits"
222 )
223 });
224 }
225 let head = self
226 .mlm_head
227 .as_ref()
228 .ok_or_else(|| anyhow::anyhow!("rlx-clinicalbert: MLM head not enabled — call .with_mlm() or .with_mlm_in_graph() on the builder"))?;
229 let (b, s) = self.compiled_bs;
230 head.apply(hidden, b, s)
231 }
232
233 #[cfg(feature = "mlm")]
237 pub fn mlm_logits_into(&self, hidden: &[f32], logits: &mut [f32]) -> Result<()> {
238 if self.mlm_in_graph {
239 let src = self.cached_mlm_logits.as_ref().ok_or_else(|| {
240 anyhow::anyhow!(
241 "rlx-clinicalbert: call forward() first to populate the in-graph MLM logits"
242 )
243 })?;
244 if logits.len() != src.len() {
245 bail!(
246 "rlx-clinicalbert: mlm_logits_into expected buffer of {} floats, got {}",
247 src.len(),
248 logits.len()
249 );
250 }
251 logits.copy_from_slice(src);
252 return Ok(());
253 }
254 let head = self
255 .mlm_head
256 .as_ref()
257 .ok_or_else(|| anyhow::anyhow!("rlx-clinicalbert: MLM head not enabled — call .with_mlm() or .with_mlm_in_graph() on the builder"))?;
258 let (b, s) = self.compiled_bs;
259 head.apply_into(hidden, b, s, logits)
260 }
261
262 #[cfg(feature = "mlm")]
264 pub fn allocate_mlm_logits(&self) -> Result<Vec<f32>> {
265 if self.mlm_in_graph {
266 let (b, s) = self.compiled_bs;
267 return Ok(vec![0f32; b * s * self.config.bert.vocab_size]);
268 }
269 let head = self
270 .mlm_head
271 .as_ref()
272 .ok_or_else(|| anyhow::anyhow!("rlx-clinicalbert: MLM head not enabled"))?;
273 let (b, s) = self.compiled_bs;
274 Ok(head.allocate_logits_buffer(b, s))
275 }
276
277 pub fn recompile(&mut self, batch: usize, seq: usize) -> Result<()> {
279 if self.compiled_bs == (batch, seq) {
280 return Ok(());
281 }
282 let mut wm = if self.weights_path.is_dir() {
283 WeightMap::from_resolved_path(&self.weights_path)
284 } else {
285 WeightMap::from_file(self.weights_path.to_str().ok_or_else(|| {
286 anyhow::anyhow!(
287 "rlx-clinicalbert: non-UTF8 weights path {:?}",
288 self.weights_path
289 )
290 })?)
291 }?;
292 let built = build_clinicalbert_built(&self.config.bert, &mut wm, batch, seq)?;
293 self.compiled = compile_built(built, self.device)?;
294 self.compiled_bs = (batch, seq);
295 Ok(())
296 }
297
298 pub fn forward(
304 &mut self,
305 input_ids: &[f32],
306 attention_mask: &[f32],
307 token_type_ids: &[f32],
308 position_ids: &[f32],
309 ) -> Result<Vec<f32>> {
310 let (b, s) = self.compiled_bs;
311 let expected = b * s;
312 if input_ids.len() != expected
313 || attention_mask.len() != expected
314 || token_type_ids.len() != expected
315 || position_ids.len() != expected
316 {
317 bail!(
318 "rlx-clinicalbert: forward expects each input of length {expected} \
319 (batch={b}, seq={s}); got {}, {}, {}, {}",
320 input_ids.len(),
321 attention_mask.len(),
322 token_type_ids.len(),
323 position_ids.len()
324 );
325 }
326 let outputs = self.compiled.run(&[
327 ("input_ids", input_ids),
328 ("attention_mask", attention_mask),
329 ("token_type_ids", token_type_ids),
330 ("position_ids", position_ids),
331 ]);
332 if std::env::var("RLX_CLINICALBERT_DEBUG").is_ok() {
333 let sizes: Vec<usize> = outputs.iter().map(|o| o.len()).collect();
334 eprintln!("[rlx-clinicalbert] forward outputs: {sizes:?}");
335 }
336 #[cfg(feature = "mlm")]
340 if self.mlm_in_graph {
341 if outputs.len() >= 2 {
342 self.cached_mlm_logits = Some(outputs[1].clone());
343 } else {
344 bail!(
345 "rlx-clinicalbert: with_mlm_in_graph but compiled graph returned {} outputs",
346 outputs.len()
347 );
348 }
349 }
350 outputs
351 .into_iter()
352 .next()
353 .ok_or_else(|| anyhow::anyhow!("rlx-clinicalbert: compiled graph returned no outputs"))
354 }
355
356 pub fn embed(
361 &mut self,
362 input_ids: &[f32],
363 attention_mask: &[f32],
364 token_type_ids: &[f32],
365 position_ids: &[f32],
366 ) -> Result<Vec<f32>> {
367 let hidden = self.forward(input_ids, attention_mask, token_type_ids, position_ids)?;
368 let (b, s) = self.compiled_bs;
369 let h = self.hidden_size();
370 Ok(match self.pooling {
371 Pooling::None => hidden,
372 Pooling::Cls => pool_cls(&hidden, b, s, h),
373 Pooling::Mean => pool_mean(&hidden, attention_mask, b, s, h),
374 })
375 }
376}
377
378#[derive(Debug, Clone, Default)]
379pub struct ClinicalBertRunnerBuilder {
380 weights: Option<PathBuf>,
381 config: Option<ClinicalBertConfig>,
382 config_path: Option<PathBuf>,
383 variant: Option<ClinicalBertVariant>,
384 device: Option<Device>,
385 batch: Option<usize>,
386 seq: Option<usize>,
387 pooling: Option<Pooling>,
388 #[cfg(feature = "pooler")]
389 enable_pooler: bool,
390 #[cfg(feature = "mlm")]
391 enable_mlm: bool,
392 #[cfg(feature = "mlm")]
393 enable_mlm_in_graph: bool,
394 #[cfg(feature = "mlm")]
395 mlm_mode: Option<MlmExecMode>,
396}
397
398impl ClinicalBertRunnerBuilder {
399 pub fn weights(mut self, path: impl Into<PathBuf>) -> Self {
400 self.weights = Some(path.into());
401 self
402 }
403
404 pub fn config(mut self, cfg: BertConfig) -> Self {
405 self.config = Some(ClinicalBertConfig::new(cfg));
406 self
407 }
408
409 pub fn config_path(mut self, path: impl Into<PathBuf>) -> Self {
410 self.config_path = Some(path.into());
411 self
412 }
413
414 pub fn variant(mut self, v: ClinicalBertVariant) -> Self {
415 self.variant = Some(v);
416 self
417 }
418
419 pub fn device(mut self, d: Device) -> Self {
420 self.device = Some(d);
421 self
422 }
423
424 pub fn batch(mut self, b: usize) -> Self {
425 self.batch = Some(b);
426 self
427 }
428
429 pub fn max_seq(mut self, s: usize) -> Self {
430 self.seq = Some(s);
431 self
432 }
433
434 pub fn pooling(mut self, p: Pooling) -> Self {
435 self.pooling = Some(p);
436 self
437 }
438
439 #[cfg(feature = "pooler")]
442 pub fn with_pooler(mut self) -> Self {
443 self.enable_pooler = true;
444 self
445 }
446
447 #[cfg(feature = "mlm")]
450 pub fn with_mlm(mut self) -> Self {
451 self.enable_mlm = true;
452 self
453 }
454
455 #[cfg(feature = "mlm")]
459 pub fn with_mlm_in_graph(mut self) -> Self {
460 self.enable_mlm_in_graph = true;
461 self
462 }
463
464 #[cfg(feature = "mlm")]
470 pub fn mlm_mode(mut self, mode: MlmExecMode) -> Self {
471 self.mlm_mode = Some(mode);
472 self.enable_mlm = false;
473 self.enable_mlm_in_graph = false;
474 self
475 }
476
477 pub fn build(self) -> Result<ClinicalBertRunner> {
478 let weights = self
479 .weights
480 .clone()
481 .ok_or_else(|| anyhow::anyhow!("rlx-clinicalbert: weights path required"))?;
482 let device = self.device.unwrap_or(Device::Cpu);
483 validate_standard_device("clinicalbert", device)?;
484
485 let mut config = if let Some(cfg) = self.config {
486 cfg
487 } else if let Some(variant) = self.variant {
488 ClinicalBertConfig::new(variant.preset()).with_variant(variant)
489 } else {
490 let cfg_path = self
491 .config_path
492 .clone()
493 .unwrap_or_else(|| ClinicalBertConfig::config_json_path(&weights));
494 if cfg_path.is_file() {
495 validate_hf_config(cfg_path.parent().unwrap_or(Path::new(".")))?;
496 ClinicalBertConfig::from_file(&cfg_path)?
497 } else {
498 bail!(
499 "rlx-clinicalbert: no config supplied — call `.config(..)`, \
500 `.config_path(..)`, or `.variant(..)`, or place `config.json` next \
501 to {weights:?}"
502 );
503 }
504 };
505
506 if config.variant.is_none() {
507 config.variant = self.variant;
508 }
509
510 let batch = self.batch.unwrap_or(1);
511 let seq = self
512 .seq
513 .unwrap_or_else(|| config.bert.max_position_embeddings.min(512));
514
515 let weights_str = weights.to_str().ok_or_else(|| {
516 anyhow::anyhow!("rlx-clinicalbert: non-UTF8 weights path {weights:?}")
517 })?;
518 let mut wm = if weights.is_dir() {
519 WeightMap::from_resolved_path(&weights)
520 } else {
521 WeightMap::from_file(weights_str)
522 }
523 .with_context(|| format!("rlx-clinicalbert: loading {weights_str}"))?;
524
525 #[cfg(feature = "mlm")]
528 if self.enable_mlm && self.enable_mlm_in_graph {
529 bail!("rlx-clinicalbert: .with_mlm() and .with_mlm_in_graph() are mutually exclusive");
530 }
531
532 #[cfg(feature = "mlm")]
539 let resolved_mlm: Option<MlmExecMode> = match self.mlm_mode {
540 Some(MlmExecMode::Auto) => Some(MlmExecMode::Auto.resolve(device, batch)),
541 Some(m) => Some(m),
542 None => {
543 if self.enable_mlm {
544 Some(MlmExecMode::Cpu)
545 } else if self.enable_mlm_in_graph {
546 Some(MlmExecMode::InGraph)
547 } else {
548 None
549 }
550 }
551 };
552
553 #[cfg(feature = "mlm")]
556 let mlm_head: Option<MlmHead> = if resolved_mlm == Some(MlmExecMode::Cpu) {
557 Some(MlmHead::load(&config.bert, &mut wm)?)
558 } else {
559 None
560 };
561 #[cfg(feature = "pooler")]
562 let pooler_head: Option<PoolerHead> = if self.enable_pooler {
563 Some(PoolerHead::load(&config.bert, &mut wm)?)
564 } else {
565 None
566 };
567
568 #[cfg(feature = "mlm")]
573 let built = if resolved_mlm == Some(MlmExecMode::InGraph) {
574 build_clinicalbert_with_mlm_built(&config.bert, &mut wm, batch, seq)?
575 } else {
576 build_clinicalbert_built(&config.bert, &mut wm, batch, seq)?
577 };
578 #[cfg(not(feature = "mlm"))]
579 let built = build_clinicalbert_built(&config.bert, &mut wm, batch, seq)?;
580 let compiled = compile_built(built, device)?;
581
582 Ok(ClinicalBertRunner {
583 config,
584 weights_path: weights,
585 compiled,
586 compiled_bs: (batch, seq),
587 device,
588 pooling: self.pooling.unwrap_or(Pooling::Cls),
589 #[cfg(feature = "pooler")]
590 pooler_head,
591 #[cfg(feature = "mlm")]
592 mlm_head,
593 #[cfg(feature = "mlm")]
594 mlm_in_graph: resolved_mlm == Some(MlmExecMode::InGraph),
595 #[cfg(feature = "mlm")]
596 cached_mlm_logits: None,
597 })
598 }
599}
600
601fn pool_cls(hidden: &[f32], batch: usize, seq: usize, h: usize) -> Vec<f32> {
602 let mut out = vec![0f32; batch * h];
603 for bi in 0..batch {
604 let src = bi * seq * h;
605 out[bi * h..(bi + 1) * h].copy_from_slice(&hidden[src..src + h]);
606 }
607 out
608}
609
610fn pool_mean(
611 hidden: &[f32],
612 attention_mask: &[f32],
613 batch: usize,
614 seq: usize,
615 h: usize,
616) -> Vec<f32> {
617 let mut out = vec![0f32; batch * h];
618 for bi in 0..batch {
619 let mut count = 0.0f32;
620 for si in 0..seq {
621 let m = attention_mask[bi * seq + si];
622 if m > 0.0 {
623 count += 1.0;
624 let off = (bi * seq + si) * h;
625 let dst = bi * h;
626 for j in 0..h {
627 out[dst + j] += hidden[off + j];
628 }
629 }
630 }
631 let inv = 1.0 / count.max(1.0);
632 for j in 0..h {
633 out[bi * h + j] *= inv;
634 }
635 }
636 out
637}