use std::path::Path;
use tokenizers::{EncodeInput, PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer};
use flodl::{Device, Result, Tensor, TensorError, Variable};
pub struct HfTokenizer {
inner: Tokenizer,
}
#[derive(Debug)]
pub struct EncodedBatch {
pub input_ids: Variable,
pub attention_mask: Variable,
pub token_type_ids: Variable,
pub position_ids: Variable,
pub sequence_ids: Variable,
}
impl HfTokenizer {
pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
let tok = Tokenizer::from_file(path.as_ref())
.map_err(|e| TensorError::new(&format!("tokenizer load: {e}")))?;
Ok(Self::from_inner(tok))
}
pub fn from_inner(mut inner: Tokenizer) -> Self {
if inner.get_padding().is_none() {
let pad_id = inner.token_to_id("[PAD]").unwrap_or(0);
inner.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::BatchLongest,
direction: PaddingDirection::Right,
pad_to_multiple_of: None,
pad_id,
pad_type_id: 0,
pad_token: "[PAD]".to_string(),
}));
}
Self { inner }
}
pub fn inner(&self) -> &Tokenizer {
&self.inner
}
pub fn encode(&self, texts: &[&str]) -> Result<EncodedBatch> {
self.encode_on_device(texts, Device::CPU)
}
pub fn encode_on_device(&self, texts: &[&str], device: Device) -> Result<EncodedBatch> {
if texts.is_empty() {
return Err(TensorError::new("tokenize: empty batch"));
}
let inputs: Vec<EncodeInput> = texts.iter().map(|s| (*s).into()).collect();
let encodings = self
.inner
.encode_batch(inputs, true)
.map_err(|e| TensorError::new(&format!("tokenize: {e}")))?;
let batch = encodings.len() as i64;
let seq = encodings[0].get_ids().len() as i64;
let cap = (batch * seq) as usize;
let mut input_ids = Vec::<i64>::with_capacity(cap);
let mut attention_mask = Vec::<i64>::with_capacity(cap);
let mut token_type_ids = Vec::<i64>::with_capacity(cap);
let mut sequence_ids = Vec::<i64>::with_capacity(cap);
for enc in &encodings {
debug_assert_eq!(enc.get_ids().len() as i64, seq);
input_ids.extend(enc.get_ids().iter().map(|&x| x as i64));
attention_mask.extend(enc.get_attention_mask().iter().map(|&x| x as i64));
token_type_ids.extend(enc.get_type_ids().iter().map(|&x| x as i64));
sequence_ids.extend(
enc.get_sequence_ids()
.iter()
.map(|opt| opt.map(|v| v as i64).unwrap_or(-1)),
);
}
let mut position_ids = Vec::<i64>::with_capacity(cap);
for _ in 0..batch {
position_ids.extend(0i64..seq);
}
let shape = [batch, seq];
Ok(EncodedBatch {
input_ids: Variable::new(Tensor::from_i64(&input_ids, &shape, device)?, false),
attention_mask: Variable::new(
Tensor::from_i64(&attention_mask, &shape, device)?,
false,
),
token_type_ids: Variable::new(
Tensor::from_i64(&token_type_ids, &shape, device)?,
false,
),
position_ids: Variable::new(Tensor::from_i64(&position_ids, &shape, device)?, false),
sequence_ids: Variable::new(Tensor::from_i64(&sequence_ids, &shape, device)?, false),
})
}
pub fn encode_pairs(&self, pairs: &[(&str, &str)]) -> Result<EncodedBatch> {
self.encode_pairs_on_device(pairs, Device::CPU)
}
pub fn encode_pairs_on_device(
&self,
pairs: &[(&str, &str)],
device: Device,
) -> Result<EncodedBatch> {
if pairs.is_empty() {
return Err(TensorError::new("tokenize pairs: empty batch"));
}
let inputs: Vec<EncodeInput> = pairs
.iter()
.map(|(a, b)| EncodeInput::Dual((*a).into(), (*b).into()))
.collect();
let encodings = self
.inner
.encode_batch(inputs, true)
.map_err(|e| TensorError::new(&format!("tokenize pairs: {e}")))?;
let batch = encodings.len() as i64;
let seq = encodings[0].get_ids().len() as i64;
let cap = (batch * seq) as usize;
let mut input_ids = Vec::<i64>::with_capacity(cap);
let mut attention_mask = Vec::<i64>::with_capacity(cap);
let mut token_type_ids = Vec::<i64>::with_capacity(cap);
let mut sequence_ids = Vec::<i64>::with_capacity(cap);
for enc in &encodings {
debug_assert_eq!(enc.get_ids().len() as i64, seq);
input_ids.extend(enc.get_ids().iter().map(|&x| x as i64));
attention_mask.extend(enc.get_attention_mask().iter().map(|&x| x as i64));
token_type_ids.extend(enc.get_type_ids().iter().map(|&x| x as i64));
sequence_ids.extend(
enc.get_sequence_ids()
.iter()
.map(|opt| opt.map(|v| v as i64).unwrap_or(-1)),
);
}
let mut position_ids = Vec::<i64>::with_capacity(cap);
for _ in 0..batch {
position_ids.extend(0i64..seq);
}
let shape = [batch, seq];
Ok(EncodedBatch {
input_ids: Variable::new(Tensor::from_i64(&input_ids, &shape, device)?, false),
attention_mask: Variable::new(
Tensor::from_i64(&attention_mask, &shape, device)?,
false,
),
token_type_ids: Variable::new(
Tensor::from_i64(&token_type_ids, &shape, device)?,
false,
),
position_ids: Variable::new(Tensor::from_i64(&position_ids, &shape, device)?, false),
sequence_ids: Variable::new(Tensor::from_i64(&sequence_ids, &shape, device)?, false),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_batch_errors() {
use tokenizers::models::bpe::BPE;
let bpe = BPE::default();
let tok = Tokenizer::new(bpe);
let hf = HfTokenizer::from_inner(tok);
let err = hf.encode(&[]).unwrap_err();
assert!(format!("{err}").contains("empty batch"));
}
}