use std::path::Path;
use tokenizers::{EncodeInput, PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer};
use flodl::{Device, Result, Tensor, TensorError, Variable};
pub struct HfTokenizer {
inner: Tokenizer,
}
#[derive(Debug)]
pub struct EncodedBatch {
pub input_ids: Variable,
pub attention_mask: Variable,
pub token_type_ids: Variable,
pub position_ids: Variable,
pub sequence_ids: Variable,
}
impl HfTokenizer {
pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
let tok = Tokenizer::from_file(path.as_ref())
.map_err(|e| TensorError::new(&format!("tokenizer load: {e}")))?;
Ok(Self::from_inner(tok))
}
pub fn from_inner(mut inner: Tokenizer) -> Self {
if inner.get_padding().is_none() {
let pad_id = inner.token_to_id("[PAD]").unwrap_or(0);
inner.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::BatchLongest,
direction: PaddingDirection::Right,
pad_to_multiple_of: None,
pad_id,
pad_type_id: 0,
pad_token: "[PAD]".to_string(),
}));
}
Self { inner }
}
pub fn inner(&self) -> &Tokenizer {
&self.inner
}
pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
self.inner
.save(path.as_ref(), true)
.map_err(|e| TensorError::new(&format!(
"tokenizer save to {}: {e}",
path.as_ref().display(),
)))
}
pub fn encode(&self, texts: &[&str]) -> Result<EncodedBatch> {
self.encode_on_device(texts, Device::CPU)
}
pub fn encode_on_device(&self, texts: &[&str], device: Device) -> Result<EncodedBatch> {
if texts.is_empty() {
return Err(TensorError::new("tokenize: empty batch"));
}
let inputs: Vec<EncodeInput> = texts.iter().map(|s| (*s).into()).collect();
let encodings = self
.inner
.encode_batch(inputs, true)
.map_err(|e| TensorError::new(&format!("tokenize: {e}")))?;
let batch = encodings.len() as i64;
let seq = encodings[0].get_ids().len() as i64;
let cap = (batch * seq) as usize;
let mut input_ids = Vec::<i64>::with_capacity(cap);
let mut attention_mask = Vec::<i64>::with_capacity(cap);
let mut token_type_ids = Vec::<i64>::with_capacity(cap);
let mut sequence_ids = Vec::<i64>::with_capacity(cap);
for enc in &encodings {
debug_assert_eq!(enc.get_ids().len() as i64, seq);
input_ids.extend(enc.get_ids().iter().map(|&x| x as i64));
attention_mask.extend(enc.get_attention_mask().iter().map(|&x| x as i64));
token_type_ids.extend(enc.get_type_ids().iter().map(|&x| x as i64));
sequence_ids.extend(
enc.get_sequence_ids()
.iter()
.map(|opt| opt.map(|v| v as i64).unwrap_or(-1)),
);
}
let mut position_ids = Vec::<i64>::with_capacity(cap);
for _ in 0..batch {
position_ids.extend(0i64..seq);
}
let shape = [batch, seq];
Ok(EncodedBatch {
input_ids: Variable::new(Tensor::from_i64(&input_ids, &shape, device)?, false),
attention_mask: Variable::new(
Tensor::from_i64(&attention_mask, &shape, device)?,
false,
),
token_type_ids: Variable::new(
Tensor::from_i64(&token_type_ids, &shape, device)?,
false,
),
position_ids: Variable::new(Tensor::from_i64(&position_ids, &shape, device)?, false),
sequence_ids: Variable::new(Tensor::from_i64(&sequence_ids, &shape, device)?, false),
})
}
pub fn encode_pairs(&self, pairs: &[(&str, &str)]) -> Result<EncodedBatch> {
self.encode_pairs_on_device(pairs, Device::CPU)
}
pub fn encode_pairs_on_device(
&self,
pairs: &[(&str, &str)],
device: Device,
) -> Result<EncodedBatch> {
if pairs.is_empty() {
return Err(TensorError::new("tokenize pairs: empty batch"));
}
let inputs: Vec<EncodeInput> = pairs
.iter()
.map(|(a, b)| EncodeInput::Dual((*a).into(), (*b).into()))
.collect();
let encodings = self
.inner
.encode_batch(inputs, true)
.map_err(|e| TensorError::new(&format!("tokenize pairs: {e}")))?;
let batch = encodings.len() as i64;
let seq = encodings[0].get_ids().len() as i64;
let cap = (batch * seq) as usize;
let mut input_ids = Vec::<i64>::with_capacity(cap);
let mut attention_mask = Vec::<i64>::with_capacity(cap);
let mut token_type_ids = Vec::<i64>::with_capacity(cap);
let mut sequence_ids = Vec::<i64>::with_capacity(cap);
for enc in &encodings {
debug_assert_eq!(enc.get_ids().len() as i64, seq);
input_ids.extend(enc.get_ids().iter().map(|&x| x as i64));
attention_mask.extend(enc.get_attention_mask().iter().map(|&x| x as i64));
token_type_ids.extend(enc.get_type_ids().iter().map(|&x| x as i64));
sequence_ids.extend(
enc.get_sequence_ids()
.iter()
.map(|opt| opt.map(|v| v as i64).unwrap_or(-1)),
);
}
let mut position_ids = Vec::<i64>::with_capacity(cap);
for _ in 0..batch {
position_ids.extend(0i64..seq);
}
let shape = [batch, seq];
Ok(EncodedBatch {
input_ids: Variable::new(Tensor::from_i64(&input_ids, &shape, device)?, false),
attention_mask: Variable::new(
Tensor::from_i64(&attention_mask, &shape, device)?,
false,
),
token_type_ids: Variable::new(
Tensor::from_i64(&token_type_ids, &shape, device)?,
false,
),
position_ids: Variable::new(Tensor::from_i64(&position_ids, &shape, device)?, false),
sequence_ids: Variable::new(Tensor::from_i64(&sequence_ids, &shape, device)?, false),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_batch_errors() {
use tokenizers::models::bpe::BPE;
let bpe = BPE::default();
let tok = Tokenizer::new(bpe);
let hf = HfTokenizer::from_inner(tok);
let err = hf.encode(&[]).unwrap_err();
assert!(format!("{err}").contains("empty batch"));
}
#[test]
fn save_and_reload_round_trips() {
use std::collections::HashMap;
use std::process;
use tokenizers::models::wordlevel::WordLevel;
use tokenizers::pre_tokenizers::whitespace::Whitespace;
let mut vocab: HashMap<String, u32> = HashMap::new();
for (i, w) in ["[UNK]", "hello", "world", "rust"].iter().enumerate() {
vocab.insert((*w).to_string(), i as u32);
}
let wl = WordLevel::builder()
.vocab(vocab)
.unk_token("[UNK]".into())
.build()
.expect("WordLevel build");
let mut inner = Tokenizer::new(wl);
inner.with_pre_tokenizer(Some(Whitespace {}));
let hf = HfTokenizer::from_inner(inner);
let dir = std::env::temp_dir()
.join(format!("flodl_hf_tokenizer_save_{}", process::id()));
std::fs::create_dir_all(&dir).expect("create_dir_all");
let path = dir.join("tokenizer.json");
hf.save(&path).expect("save");
assert!(path.is_file(), "tokenizer.json was not written");
let reloaded = HfTokenizer::from_file(&path).expect("from_file");
let original_ids = hf.encode(&["hello world"]).unwrap();
let reloaded_ids = reloaded.encode(&["hello world"]).unwrap();
assert_eq!(
original_ids.input_ids.data().to_i64_vec().unwrap(),
reloaded_ids.input_ids.data().to_i64_vec().unwrap(),
"round-trip changed input_ids",
);
let _ = std::fs::remove_dir_all(&dir);
}
}