use anyhow::{Context, Result, bail};
use rlx_bert::flow::BertFlow;
use rlx_core::config::BertConfig;
use rlx_core::flow_util::graph_from_built;
use rlx_core::weight_map::WeightMap;
use rlx_flow::BuiltModel;
use rlx_ir::op::{Activation, BinaryOp};
use rlx_ir::{DType, Graph, Shape};
use std::collections::HashMap;
pub fn build_clinicalbert_built(
cfg: &BertConfig,
weights: &mut WeightMap,
batch: usize,
seq: usize,
) -> Result<BuiltModel> {
BertFlow::new(cfg, batch, seq).build(weights)
}
pub fn build_clinicalbert_graph(
cfg: &BertConfig,
weights: &mut WeightMap,
batch: usize,
seq: usize,
) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
graph_from_built(build_clinicalbert_built(cfg, weights, batch, seq)?)
}
pub fn build_clinicalbert_with_mlm_built(
cfg: &BertConfig,
weights: &mut WeightMap,
batch: usize,
seq: usize,
) -> Result<BuiltModel> {
let prefix = if weights.has("bert.embeddings.word_embeddings.weight") {
"bert."
} else {
""
};
let tied_decoder_w: Option<(Vec<f32>, Vec<usize>)> =
if weights.has("cls.predictions.decoder.weight") {
None
} else {
let key = format!("{prefix}embeddings.word_embeddings.weight");
let (data, shape) = weights
.get(&key)
.ok_or_else(|| anyhow::anyhow!("rlx-clinicalbert: tied MLM decoder needs {key}"))?;
Some((data.to_vec(), shape.to_vec()))
};
let built_encoder = BertFlow::new(cfg, batch, seq).build(weights)?;
let profile = built_encoder.profile().clone();
let (mut graph, mut params) = graph_from_built(built_encoder)?;
let hidden_id = *graph.outputs.last().ok_or_else(|| {
anyhow::anyhow!("build_clinicalbert_with_mlm_built: encoder has no outputs")
})?;
let h = cfg.hidden_size;
let v = cfg.vocab_size;
let eps = cfg.layer_norm_eps as f32;
let f = DType::F32;
let (transform_w, transform_w_shape) =
weights
.take("cls.predictions.transform.dense.weight")
.context("loading cls.predictions.transform.dense.weight")?;
let transform_b = weights
.take("cls.predictions.transform.dense.bias")
.context("loading cls.predictions.transform.dense.bias")?
.0;
let ln_w = weights
.take("cls.predictions.transform.LayerNorm.weight")
.context("loading cls.predictions.transform.LayerNorm.weight")?
.0;
let ln_b = weights
.take("cls.predictions.transform.LayerNorm.bias")
.context("loading cls.predictions.transform.LayerNorm.bias")?
.0;
let decoder_b = if weights.has("cls.predictions.bias") {
weights
.take("cls.predictions.bias")
.context("loading cls.predictions.bias")?
.0
} else if weights.has("cls.predictions.decoder.bias") {
weights
.take("cls.predictions.decoder.bias")
.context("loading cls.predictions.decoder.bias")?
.0
} else {
bail!("rlx-clinicalbert: MLM bias missing (cls.predictions.bias / .decoder.bias)");
};
if decoder_b.len() != v {
bail!(
"rlx-clinicalbert: MLM bias length {} != vocab_size {v}",
decoder_b.len()
);
}
let (decoder_w_raw, decoder_w_shape) = if weights.has("cls.predictions.decoder.weight") {
weights
.take("cls.predictions.decoder.weight")
.context("loading cls.predictions.decoder.weight")?
} else {
tied_decoder_w
.ok_or_else(|| anyhow::anyhow!("rlx-clinicalbert: tied decoder clone missing"))?
};
let transform_w_t = transpose(&transform_w, transform_w_shape[0], transform_w_shape[1]);
let decoder_w_t = transpose(&decoder_w_raw, decoder_w_shape[0], decoder_w_shape[1]);
let transform_w_id = graph.param("mlm.transform.weight_t", Shape::new(&[h, h], f));
let transform_b_id = graph.param("mlm.transform.bias", Shape::new(&[h], f));
let ln_w_id = graph.param("mlm.transform.LayerNorm.weight", Shape::new(&[h], f));
let ln_b_id = graph.param("mlm.transform.LayerNorm.bias", Shape::new(&[h], f));
let decoder_w_id = graph.param("mlm.decoder.weight_t", Shape::new(&[h, v], f));
let decoder_b_id = graph.param("mlm.decoder.bias", Shape::new(&[v], f));
params.insert("mlm.transform.weight_t".into(), transform_w_t);
params.insert("mlm.transform.bias".into(), transform_b);
params.insert("mlm.transform.LayerNorm.weight".into(), ln_w);
params.insert("mlm.transform.LayerNorm.bias".into(), ln_b);
params.insert("mlm.decoder.weight_t".into(), decoder_w_t);
params.insert("mlm.decoder.bias".into(), decoder_b);
let bsh = Shape::new(&[batch, seq, h], f);
let bsv = Shape::new(&[batch, seq, v], f);
let mm1 = graph.matmul(hidden_id, transform_w_id, bsh.clone());
let mm1_bias = graph.binary(BinaryOp::Add, mm1, transform_b_id, bsh.clone());
let gelu = graph.activation(Activation::Gelu, mm1_bias, bsh.clone());
let normalized = graph.layer_norm(gelu, ln_w_id, ln_b_id, -1, eps, bsh.clone());
let mm2 = graph.matmul(normalized, decoder_w_id, bsv.clone());
let mlm_logits = graph.binary(BinaryOp::Add, mm2, decoder_b_id, bsv);
graph.set_outputs(vec![hidden_id, mlm_logits]);
if std::env::var("RLX_CLINICALBERT_DEBUG").is_ok() {
eprintln!(
"[rlx-clinicalbert::builder] graph.outputs = {:?}",
graph
.outputs
.iter()
.map(|&id| {
let shape = &graph.node(id).shape;
(id, shape.dims().to_vec(), shape.num_elements())
})
.collect::<Vec<_>>()
);
}
let mut built = BuiltModel::from_graph(graph, params)?;
built.profile = profile;
Ok(built)
}
fn transpose(data: &[f32], rows: usize, cols: usize) -> Vec<f32> {
let mut out = vec![0f32; rows * cols];
for r in 0..rows {
for c in 0..cols {
out[c * rows + r] = data[r * cols + c];
}
}
out
}