use anyhow::{Context, Result};
use std::path::Path;
use tokenizers::Tokenizer;
pub struct ClinicalBertTokenizer {
inner: Tokenizer,
}
pub struct EncodedBatch {
pub input_ids: Vec<f32>,
pub attention_mask: Vec<f32>,
pub token_type_ids: Vec<f32>,
pub position_ids: Vec<f32>,
pub batch: usize,
pub seq: usize,
}
impl ClinicalBertTokenizer {
pub fn from_file(path: &Path) -> Result<Self> {
let inner = Tokenizer::from_file(path)
.map_err(|e| anyhow::anyhow!("rlx-clinicalbert: tokenizer.from_file: {e}"))?;
Ok(Self { inner })
}
pub fn from_dir_or_sibling(path: &Path) -> Result<Self> {
let dir = if path.is_dir() {
path.to_path_buf()
} else {
path.parent()
.map(Path::to_path_buf)
.unwrap_or_else(|| std::path::PathBuf::from("."))
};
#[cfg(feature = "prepare")]
if !dir.join("tokenizer.json").is_file() {
crate::prepare::prepare_clinicalbert_dir(&dir)?;
}
let tok = dir.join("tokenizer.json");
Self::from_file(&tok).with_context(|| format!("loading {tok:?}"))
}
pub fn encode_batch(&self, texts: &[&str], seq: usize) -> Result<EncodedBatch> {
let inputs: Vec<tokenizers::EncodeInput> = texts.iter().map(|t| (*t).into()).collect();
self.encode_inputs(inputs, seq)
}
pub fn encode_pairs_batch(&self, pairs: &[(&str, &str)], seq: usize) -> Result<EncodedBatch> {
let inputs: Vec<tokenizers::EncodeInput> = pairs
.iter()
.map(|(a, b)| (a.to_string(), b.to_string()).into())
.collect();
self.encode_inputs(inputs, seq)
}
fn encode_inputs(
&self,
inputs: Vec<tokenizers::EncodeInput>,
seq: usize,
) -> Result<EncodedBatch> {
let encodings = self
.inner
.encode_batch(inputs, true)
.map_err(|e| anyhow::anyhow!("rlx-clinicalbert: encode_batch: {e}"))?;
let batch = encodings.len();
let mut input_ids = vec![0f32; batch * seq];
let mut attention_mask = vec![0f32; batch * seq];
let mut token_type_ids = vec![0f32; batch * seq];
let mut position_ids = vec![0f32; batch * seq];
for (bi, enc) in encodings.iter().enumerate() {
let ids = enc.get_ids();
let mask = enc.get_attention_mask();
let types = enc.get_type_ids();
let take = ids.len().min(seq);
for si in 0..take {
input_ids[bi * seq + si] = ids[si] as f32;
attention_mask[bi * seq + si] = mask[si] as f32;
token_type_ids[bi * seq + si] = types[si] as f32;
}
for si in 0..seq {
position_ids[bi * seq + si] = si as f32;
}
}
Ok(EncodedBatch {
input_ids,
attention_mask,
token_type_ids,
position_ids,
batch,
seq,
})
}
}