use std::path::Path;
use tokenizers::{AddedToken, PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
pub struct TokenizedBatch {
pub input_ids: Vec<Vec<u32>>,
pub attention_mask: Vec<Vec<u32>>,
pub token_type_ids: Vec<Vec<u32>>,
pub seq_len: usize,
}
pub struct BertTokenizer {
inner: Tokenizer,
}
impl BertTokenizer {
pub fn from_dir(dir: &Path, max_length: usize) -> anyhow::Result<Self> {
let tokenizer_json = std::fs::read(dir.join("tokenizer.json"))?;
let config_json = std::fs::read(dir.join("config.json"))?;
let special_tokens_map = std::fs::read(dir.join("special_tokens_map.json"))?;
let tokenizer_config = std::fs::read(dir.join("tokenizer_config.json"))?;
Self::from_bytes(
&tokenizer_json,
&config_json,
&special_tokens_map,
&tokenizer_config,
max_length,
)
}
pub fn from_bytes(
tokenizer_json: &[u8],
config_json: &[u8],
special_tokens_map_json: &[u8],
tokenizer_config_json: &[u8],
max_length: usize,
) -> anyhow::Result<Self> {
let mut tokenizer = Tokenizer::from_bytes(tokenizer_json)
.map_err(|e| anyhow::anyhow!("failed to load tokenizer: {e}"))?;
let config: serde_json::Value = serde_json::from_slice(config_json)?;
let tokenizer_config: serde_json::Value = serde_json::from_slice(tokenizer_config_json)?;
let special_tokens_map: serde_json::Value =
serde_json::from_slice(special_tokens_map_json)?;
let model_max_length = tokenizer_config
.get("model_max_length")
.and_then(|v| v.as_f64())
.map(|v| v.min(1e9) as usize)
.unwrap_or(512);
let effective_max_length = max_length.min(model_max_length);
let pad_token = tokenizer_config
.get("pad_token")
.and_then(|v| v.as_str())
.unwrap_or("[PAD]")
.to_string();
let pad_token_id = config
.get("pad_token_id")
.and_then(|v| v.as_u64())
.unwrap_or(0) as u32;
tokenizer.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::BatchLongest,
pad_token: pad_token.clone(),
pad_id: pad_token_id,
..PaddingParams::default()
}));
tokenizer
.with_truncation(Some(TruncationParams {
max_length: effective_max_length,
..TruncationParams::default()
}))
.map_err(|e| anyhow::anyhow!("failed to set truncation: {e}"))?;
let mut special_tokens = Vec::new();
if let Some(map) = special_tokens_map.as_object() {
for (_key, value) in map {
match value {
serde_json::Value::String(s) => {
special_tokens.push(AddedToken::from(s.clone(), true));
}
serde_json::Value::Object(obj) => {
if let Some(content) = obj.get("content").and_then(|v| v.as_str()) {
special_tokens.push(AddedToken::from(content.to_string(), true));
}
}
serde_json::Value::Array(arr) => {
for item in arr {
match item {
serde_json::Value::String(s) => {
special_tokens.push(AddedToken::from(s.clone(), true));
}
serde_json::Value::Object(obj) => {
if let Some(content) =
obj.get("content").and_then(|v| v.as_str())
{
special_tokens
.push(AddedToken::from(content.to_string(), true));
}
}
_ => {}
}
}
}
_ => {}
}
}
}
if !special_tokens.is_empty() {
tokenizer.add_special_tokens(&special_tokens);
}
Ok(Self { inner: tokenizer })
}
pub fn encode_batch(&self, texts: &[&str]) -> anyhow::Result<TokenizedBatch> {
let encodings = self
.inner
.encode_batch(texts.to_vec(), true)
.map_err(|e| anyhow::anyhow!("tokenization failed: {e}"))?;
let seq_len = encodings
.first()
.ok_or_else(|| anyhow::anyhow!("empty batch"))?
.len();
let mut input_ids = Vec::with_capacity(texts.len());
let mut attention_mask = Vec::with_capacity(texts.len());
let mut token_type_ids = Vec::with_capacity(texts.len());
for enc in &encodings {
input_ids.push(enc.get_ids().to_vec());
attention_mask.push(enc.get_attention_mask().to_vec());
token_type_ids.push(enc.get_type_ids().to_vec());
}
Ok(TokenizedBatch {
input_ids,
attention_mask,
token_type_ids,
seq_len,
})
}
}