rlx-clinicalbert
ClinicalBERT encoder runner (Huang / Bio_ClinicalBERT / Bio_Discharge_Summary_BERT) on top of [rlx-bert].
Loads HF safetensors with the bert. prefix, presets each variant's (vocab_size, hidden, layers, heads), and exposes a high-level [ClinicalBertRunner] that compiles the encoder + optional heads to a single backend (CPU, Metal, MLX, CUDA, ROCm, WGPU, Vulkan).
Quick start
use ;
use Device;
let mut runner = builder
.weights
.config_path
.device
.batch.max_seq
.pooling
.with_pooler
.mlm_mode // pick CPU vs IR-folded per (device, batch)
.build?;
let hidden = runner.forward?;
let pooled = runner.pooler_output?;
let logits = runner.mlm_logits?;
Heads
Two optional heads behind the pooler and mlm features (or heads for both):
- Pooler —
tanh(W·h_cls + b). Loadsbert.pooler.*. - MLM —
dense(H→H) + GeLU + LN + tied decoder(H→V) + bias. Loadscls.predictions.*; the decoder weight is tied to the input embedding matrix when no explicitcls.predictions.decoder.weightis present (the common HF layout).
MLM head placement
The MLM head runs either as a CPU post-process or as part of the compiled encoder graph. Pick via .mlm_mode(MlmExecMode::Auto | Cpu | InGraph) — Auto is the default and chooses per (device, batch):
| Device | Batch | Pick |
|---|---|---|
| CPU, Metal, MLX, WGPU, Vulkan, ROCm | any | InGraph |
| CUDA | ≤ 8 | InGraph |
| CUDA | > 8 | Cpu |
See MlmExecMode for measured numbers (RTX 4090 + Intel x86, Bio_ClinicalBERT @ seq=32). Override only if you've measured your own (seq_len, vocab, host BLAS) shape — crossovers shift with all three.
Features
heads=pooler+mlmtokenizer— WordPiece viatokenizersprepare,hf-download— convert / download HF snapshots intomodel.safetensors+tokenizer.jsonblas-mkl— linkrlx-cpuagainst Intel MKL instead of OpenBLAS on Linux / Windows (macOS stays on Accelerate)- backend:
metal,mlx,cuda,rocm,gpu,vulkan, plus convenience bundles (all-backends,apple-silicon,nvidia-gpu,amd-gpu,portable-gpu)
Examples
# Parity vs HF reference
# IR-fold parity (encoder-graph MLM vs CPU post-process head)
# Latency / throughput sweep over batch sizes