#![cfg(any(feature = "pooler", feature = "mlm"))]
use anyhow::{Context, Result, bail};
use rlx_core::config::BertConfig;
use rlx_core::weight_map::WeightMap;
#[cfg(feature = "pooler")]
pub struct PoolerHead {
weight_t: Vec<f32>, bias: Vec<f32>, hidden_size: usize,
}
#[cfg(feature = "pooler")]
impl PoolerHead {
pub fn load(cfg: &BertConfig, weights: &mut WeightMap) -> Result<Self> {
let h = cfg.hidden_size;
let prefix = if weights.has("bert.pooler.dense.weight") {
"bert."
} else {
""
};
let (w, w_shape) = weights
.take(&format!("{prefix}pooler.dense.weight"))
.with_context(|| format!("loading {prefix}pooler.dense.weight"))?;
let (b, _) = weights
.take(&format!("{prefix}pooler.dense.bias"))
.with_context(|| format!("loading {prefix}pooler.dense.bias"))?;
if w_shape != vec![h, h] {
bail!(
"rlx-clinicalbert: pooler.dense.weight has shape {w_shape:?}, expected [{h}, {h}]"
);
}
let weight_t = transpose(&w, w_shape[0], w_shape[1]);
Ok(Self {
weight_t,
bias: b,
hidden_size: h,
})
}
pub fn hidden_size(&self) -> usize {
self.hidden_size
}
pub fn apply(&self, hidden: &[f32], batch: usize, seq: usize) -> Result<Vec<f32>> {
let h = self.hidden_size;
if hidden.len() != batch * seq * h {
bail!(
"rlx-clinicalbert::PoolerHead: expected hidden of {} floats, got {}",
batch * seq * h,
hidden.len()
);
}
let mut cls = vec![0f32; batch * h];
for bi in 0..batch {
let src = bi * seq * h;
cls[bi * h..(bi + 1) * h].copy_from_slice(&hidden[src..src + h]);
}
let mut out = vec![0f32; batch * h];
rlx_cpu::blas::sgemm_bias_epilogue(
&cls,
&self.weight_t,
&self.bias,
&mut out,
batch,
h,
h,
|v| v.tanh(),
);
Ok(out)
}
}
#[cfg(feature = "mlm")]
pub struct MlmHead {
transform_w_t: Vec<f32>, transform_b: Vec<f32>, ln_w: Vec<f32>, ln_b: Vec<f32>, decoder_w_t: Vec<f32>, decoder_b: Vec<f32>, hidden_size: usize,
vocab_size: usize,
eps: f32,
}
#[cfg(feature = "mlm")]
impl MlmHead {
pub fn load(cfg: &BertConfig, weights: &mut WeightMap) -> Result<Self> {
let h = cfg.hidden_size;
let v = cfg.vocab_size;
let prefix = if weights.has("bert.embeddings.word_embeddings.weight") {
"bert."
} else {
""
};
let (tw, tw_shape) = weights
.take("cls.predictions.transform.dense.weight")
.context("loading cls.predictions.transform.dense.weight")?;
let (tb, _) = weights
.take("cls.predictions.transform.dense.bias")
.context("loading cls.predictions.transform.dense.bias")?;
let (lnw, _) = weights
.take("cls.predictions.transform.LayerNorm.weight")
.context("loading cls.predictions.transform.LayerNorm.weight")?;
let (lnb, _) = weights
.take("cls.predictions.transform.LayerNorm.bias")
.context("loading cls.predictions.transform.LayerNorm.bias")?;
let decoder_b = if weights.has("cls.predictions.bias") {
weights
.take("cls.predictions.bias")
.context("loading cls.predictions.bias")?
.0
} else if weights.has("cls.predictions.decoder.bias") {
weights
.take("cls.predictions.decoder.bias")
.context("loading cls.predictions.decoder.bias")?
.0
} else {
bail!("rlx-clinicalbert: MLM bias missing (cls.predictions.bias / .decoder.bias)");
};
if decoder_b.len() != v {
bail!(
"rlx-clinicalbert: MLM bias length {} != vocab_size {v}",
decoder_b.len()
);
}
let decoder_w_raw: Vec<f32>;
let decoder_w_shape: Vec<usize>;
if weights.has("cls.predictions.decoder.weight") {
let (w, s) = weights
.take("cls.predictions.decoder.weight")
.context("loading cls.predictions.decoder.weight")?;
decoder_w_raw = w;
decoder_w_shape = s;
} else {
let key = format!("{prefix}embeddings.word_embeddings.weight");
let (data, shape) = weights
.get(&key)
.ok_or_else(|| anyhow::anyhow!("rlx-clinicalbert: tied MLM decoder needs {key}"))?;
decoder_w_raw = data.to_vec();
decoder_w_shape = shape.to_vec();
}
if decoder_w_shape != vec![v, h] {
bail!(
"rlx-clinicalbert: MLM decoder weight has shape {decoder_w_shape:?}, expected [{v}, {h}]"
);
}
let transform_w_t = transpose(&tw, tw_shape[0], tw_shape[1]);
let decoder_w_t = transpose(&decoder_w_raw, decoder_w_shape[0], decoder_w_shape[1]);
Ok(Self {
transform_w_t,
transform_b: tb,
ln_w: lnw,
ln_b: lnb,
decoder_w_t,
decoder_b,
hidden_size: h,
vocab_size: v,
eps: cfg.layer_norm_eps as f32,
})
}
pub fn hidden_size(&self) -> usize {
self.hidden_size
}
pub fn vocab_size(&self) -> usize {
self.vocab_size
}
pub fn apply_into(
&self,
hidden: &[f32],
batch: usize,
seq: usize,
logits: &mut [f32],
) -> Result<()> {
let h = self.hidden_size;
let v = self.vocab_size;
let bs = batch * seq;
if hidden.len() != bs * h {
bail!(
"rlx-clinicalbert::MlmHead: expected hidden of {} floats, got {}",
bs * h,
hidden.len()
);
}
if logits.len() != bs * v {
bail!(
"rlx-clinicalbert::MlmHead: logits buffer must be {} floats, got {}",
bs * v,
logits.len()
);
}
let mut x = vec![0f32; bs * h];
rlx_cpu::blas::sgemm_bias_epilogue(
hidden,
&self.transform_w_t,
&self.transform_b,
&mut x,
bs,
h,
h,
gelu_exact,
);
for row in 0..bs {
let off = row * h;
let slice = &mut x[off..off + h];
let mean: f32 = slice.iter().copied().sum::<f32>() / (h as f32);
let mut var = 0f32;
for &val in slice.iter() {
let d = val - mean;
var += d * d;
}
var /= h as f32;
let inv_std = (var + self.eps).sqrt().recip();
for j in 0..h {
slice[j] = (slice[j] - mean) * inv_std * self.ln_w[j] + self.ln_b[j];
}
}
rlx_cpu::blas::sgemm_bias(&x, &self.decoder_w_t, &self.decoder_b, logits, bs, h, v);
Ok(())
}
pub fn apply(&self, hidden: &[f32], batch: usize, seq: usize) -> Result<Vec<f32>> {
let h = self.hidden_size;
let v = self.vocab_size;
let bs = batch * seq;
if hidden.len() != bs * h {
bail!(
"rlx-clinicalbert::MlmHead: expected hidden of {} floats, got {}",
bs * h,
hidden.len()
);
}
let mut x = vec![0f32; bs * h];
rlx_cpu::blas::sgemm_bias_epilogue(
hidden,
&self.transform_w_t,
&self.transform_b,
&mut x,
bs,
h,
h,
gelu_exact,
);
for row in 0..bs {
let off = row * h;
let slice = &mut x[off..off + h];
let mean: f32 = slice.iter().copied().sum::<f32>() / (h as f32);
let mut var = 0f32;
for &val in slice.iter() {
let d = val - mean;
var += d * d;
}
var /= h as f32;
let inv_std = (var + self.eps).sqrt().recip();
for j in 0..h {
slice[j] = (slice[j] - mean) * inv_std * self.ln_w[j] + self.ln_b[j];
}
}
let mut logits = vec![0f32; bs * v];
rlx_cpu::blas::sgemm_bias(
&x,
&self.decoder_w_t,
&self.decoder_b,
&mut logits,
bs,
h,
v,
);
Ok(logits)
}
pub fn allocate_logits_buffer(&self, batch: usize, seq: usize) -> Vec<f32> {
vec![0f32; batch * seq * self.vocab_size]
}
}
fn transpose(data: &[f32], rows: usize, cols: usize) -> Vec<f32> {
let mut out = vec![0f32; rows * cols];
for r in 0..rows {
for c in 0..cols {
out[c * rows + r] = data[r * cols + c];
}
}
out
}
#[allow(dead_code)]
fn gelu_exact(x: f32) -> f32 {
0.5 * x * (1.0 + erf_approx(x / std::f32::consts::SQRT_2))
}
#[allow(dead_code, clippy::excessive_precision)]
fn erf_approx(x: f32) -> f32 {
let sign = if x.is_sign_negative() { -1.0 } else { 1.0 };
let ax = x.abs();
let t = 1.0 / (1.0 + 0.3275911 * ax);
let y = 1.0
- ((((1.061405429_f32 * t - 1.453152027) * t + 1.421413741) * t - 0.284496736) * t
+ 0.254829592)
* t
* (-(ax * ax)).exp();
sign * y
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn erf_matches_libm_ish_values() {
assert!((erf_approx(0.0)).abs() < 1e-7);
assert!((erf_approx(1.0) - 0.842_700_8).abs() < 1e-6);
assert!((erf_approx(-1.0) + 0.842_700_8).abs() < 1e-6);
assert!((erf_approx(2.0) - 0.995_322_3).abs() < 1e-6);
}
}