use super::error::{Result, TokenizerError};
const GPT2_VOCAB_SIZE: u32 = 50256;
pub use aprender::text::bpe::{
bytes_to_unicode, load_from_files as load_hf_from_files, load_from_json as load_hf_from_json,
BpeConfig as HfBpeConfig, BpeTokenizer as HfBpeTokenizer, MergeRule, Qwen2BpeTokenizer,
};
#[derive(Debug, Clone)]
pub struct HfTokenizer {
inner: HfBpeTokenizer,
pad_id: u32,
eos_id: Option<u32>,
bos_id: Option<u32>,
}
impl HfTokenizer {
#[must_use]
pub fn gpt2() -> Self {
Self {
inner: HfBpeTokenizer::gpt2_base(),
pad_id: GPT2_VOCAB_SIZE,
eos_id: Some(GPT2_VOCAB_SIZE),
bos_id: None,
}
}
#[must_use]
pub fn qwen2() -> Self {
Self {
inner: HfBpeTokenizer::new(HfBpeConfig::qwen2()),
pad_id: Qwen2BpeTokenizer::ENDOFTEXT_ID,
eos_id: Some(Qwen2BpeTokenizer::IM_END_ID),
bos_id: Some(Qwen2BpeTokenizer::IM_START_ID),
}
}
pub fn from_file(path: impl AsRef<std::path::Path>) -> Result<Self> {
let json = std::fs::read_to_string(path.as_ref())?;
Self::from_json(&json)
}
pub fn from_json(json: &str) -> Result<Self> {
let inner = load_hf_from_json(json).map_err(|e| {
TokenizerError::Serialization(format!("Failed to parse tokenizer JSON: {e}"))
})?;
let pad_id =
inner.token_to_id("<pad>").or_else(|| inner.token_to_id("<|endoftext|>")).unwrap_or(0);
let eos_id = inner
.token_to_id("</s>")
.or_else(|| inner.token_to_id("<|im_end|>"))
.or_else(|| inner.token_to_id("<|endoftext|>"));
let bos_id = inner.token_to_id("<s>").or_else(|| inner.token_to_id("<|im_start|>"));
Ok(Self { inner, pad_id, eos_id, bos_id })
}
#[must_use]
pub fn vocab_size(&self) -> usize {
self.inner.vocab_size()
}
#[must_use]
pub fn encode(&self, text: &str) -> Vec<u32> {
self.inner.encode(text)
}
#[must_use]
pub fn encode_with_special(&self, text: &str) -> Vec<u32> {
let mut tokens = Vec::new();
if let Some(bos) = self.bos_id {
tokens.push(bos);
}
tokens.extend(self.inner.encode(text));
if let Some(eos) = self.eos_id {
tokens.push(eos);
}
tokens
}
#[must_use]
pub fn decode(&self, ids: &[u32]) -> String {
self.inner.decode(ids)
}
#[must_use]
pub fn pad_id(&self) -> u32 {
self.pad_id
}
#[must_use]
pub fn eos_id(&self) -> Option<u32> {
self.eos_id
}
#[must_use]
pub fn bos_id(&self) -> Option<u32> {
self.bos_id
}
#[must_use]
pub fn batch_encode(&self, texts: &[&str], max_len: usize) -> Vec<Vec<u32>> {
let mut encoded: Vec<Vec<u32>> = texts
.iter()
.map(|text| {
let mut tokens = self.encode_with_special(text);
tokens.truncate(max_len);
tokens
})
.collect();
let batch_max = encoded.iter().map(Vec::len).max().unwrap_or(0);
let pad_to = batch_max.min(max_len);
for tokens in &mut encoded {
while tokens.len() < pad_to {
tokens.push(self.pad_id);
}
}
encoded
}
pub fn create_batches(
&self,
pairs: &[(&str, &str)],
max_len: usize,
batch_size: usize,
) -> Vec<crate::train::Batch> {
use crate::Tensor;
pairs
.chunks(batch_size)
.map(|chunk| {
let inputs: Vec<&str> = chunk.iter().map(|(i, _)| *i).collect();
let targets: Vec<&str> = chunk.iter().map(|(_, t)| *t).collect();
let input_tokens = self.batch_encode(&inputs, max_len);
let target_tokens = self.batch_encode(&targets, max_len);
let input_data: Vec<f32> =
input_tokens.into_iter().flatten().map(|t| t as f32).collect();
let target_data: Vec<f32> =
target_tokens.into_iter().flatten().map(|t| t as f32).collect();
crate::train::Batch::new(
Tensor::from_vec(input_data, false),
Tensor::from_vec(target_data, false),
)
})
.collect()
}
pub fn create_causal_batches(
&self,
texts: &[&str],
max_len: usize,
batch_size: usize,
) -> Vec<crate::train::Batch> {
use crate::Tensor;
texts
.chunks(batch_size)
.map(|chunk| {
let encoded = self.batch_encode(chunk, max_len);
let mut input_data: Vec<f32> = Vec::new();
let mut target_data: Vec<f32> = Vec::new();
for tokens in &encoded {
if tokens.len() > 1 {
input_data.extend(tokens[..tokens.len() - 1].iter().map(|&t| t as f32));
target_data.extend(tokens[1..].iter().map(|&t| t as f32));
}
}
crate::train::Batch::new(
Tensor::from_vec(input_data, false),
Tensor::from_vec(target_data, false),
)
})
.collect()
}
}
impl Default for HfTokenizer {
fn default() -> Self {
Self::gpt2()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hf_tokenizer_gpt2() {
let tokenizer = HfTokenizer::gpt2();
assert!(tokenizer.vocab_size() > 0);
assert_eq!(tokenizer.pad_id(), GPT2_VOCAB_SIZE);
}
#[test]
fn test_hf_tokenizer_qwen2() {
let tokenizer = HfTokenizer::qwen2();
assert_eq!(tokenizer.eos_id(), Some(Qwen2BpeTokenizer::IM_END_ID));
}
#[test]
fn test_hf_tokenizer_encode() {
let tokenizer = HfTokenizer::gpt2();
let tokens = tokenizer.encode("Hello");
assert!(!tokens.is_empty());
}
#[test]
fn test_hf_tokenizer_encode_with_special() {
let tokenizer = HfTokenizer::gpt2();
let tokens = tokenizer.encode_with_special("Hi");
assert!(tokens.last() == tokenizer.eos_id().as_ref());
}
#[test]
fn test_hf_tokenizer_batch_encode() {
let tokenizer = HfTokenizer::gpt2();
let texts = vec!["Hello", "Hi there"];
let encoded = tokenizer.batch_encode(&texts, 32);
assert_eq!(encoded.len(), 2);
assert_eq!(encoded[0].len(), encoded[1].len());
}
#[test]
fn test_hf_tokenizer_create_batches() {
let tokenizer = HfTokenizer::gpt2();
let pairs = vec![("Hello", "World"), ("How are", "you")];
let batches = tokenizer.create_batches(&pairs, 16, 2);
assert_eq!(batches.len(), 1);
assert!(!batches[0].inputs.is_empty());
}
#[test]
fn test_hf_tokenizer_create_causal_batches() {
let tokenizer = HfTokenizer::gpt2();
let texts = vec!["Hello world", "Test text"];
let batches = tokenizer.create_causal_batches(&texts, 16, 2);
assert_eq!(batches.len(), 1);
assert_eq!(batches[0].inputs.len(), batches[0].targets.len());
}
#[test]
fn test_hf_tokenizer_from_json() {
let json = r#"{
"model": {
"vocab": {"hello": 0, "world": 1, "<|endoftext|>": 2},
"merges": []
},
"added_tokens": []
}"#;
let result = HfTokenizer::from_json(json);
assert!(result.is_ok());
assert_eq!(result.expect("operation should succeed").vocab_size(), 3);
}
#[test]
fn test_hf_tokenizer_from_json_invalid() {
let result = HfTokenizer::from_json("invalid json");
assert!(result.is_err());
}
}