use crate::tokenizer::base::Tokenizer;
use tokenizers::tokenizer::Tokenizer as HFTokenizerEngine;
#[derive(Debug, Clone)]
pub struct HFTokenizer {
engine: HFTokenizerEngine,
}
impl HFTokenizer {
pub fn new(engine: HFTokenizerEngine) -> Self {
Self { engine }
}
pub fn from_pretrained(model_name: &str) -> Self {
let engine: HFTokenizerEngine =
HFTokenizerEngine::from_pretrained(model_name, None).expect("Failed to load tokenizer");
Self { engine }
}
}
impl Tokenizer for HFTokenizer {
fn encode(&self, text: &str) -> Vec<usize> {
self.engine
.encode(text, false)
.expect("Failed to encode text")
.get_ids()
.iter()
.copied()
.map(|id| id as usize)
.collect::<Vec<usize>>()
}
fn decode(&self, ids: &[usize]) -> String {
let u32_ids: Vec<_> = ids.iter().map(|&id| id as u32).collect();
self.engine
.decode(&u32_ids, true)
.expect("Failed to decode ids")
}
}
#[cfg(test)]
#[cfg(feature = "tokenizers")]
mod tests {
use super::*;
#[test]
fn test_hf_tokenizer() {
let tokenizer = HFTokenizer::from_pretrained("bert-base-cased");
let text = "Hello, world!".to_string();
let encoded = tokenizer.encode(&text);
let decoded = tokenizer.decode(&encoded);
assert_eq!(text, decoded);
}
}