use tokenizers::tokenizer::Tokenizer as HfTokenizer;
use super::{
Encoding, Error, Result, TokenIdType,
traits::{Decoder, Encoder, Tokenizer},
};
pub struct HuggingFaceTokenizer {
tokenizer: HfTokenizer,
}
impl HuggingFaceTokenizer {
pub fn from_file(model_name: &str) -> Result<Self> {
let tokenizer = HfTokenizer::from_file(model_name)
.map_err(|err| Error::msg(format!("Error loading tokenizer: {}", err)))?;
Ok(HuggingFaceTokenizer { tokenizer })
}
pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self {
HuggingFaceTokenizer { tokenizer }
}
}
impl Encoder for HuggingFaceTokenizer {
fn encode(&self, input: &str) -> Result<Encoding> {
let encoding = self
.tokenizer
.encode(input, false)
.map_err(|err| Error::msg(format!("Error tokenizing input: {err}")))?;
Ok(Encoding::Hf(Box::new(encoding)))
}
fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
let hf_encodings = self
.tokenizer
.encode_batch(inputs.to_vec(), false)
.map_err(|err| Error::msg(format!("Error batch tokenizing input: {err}")))?;
let encodings = hf_encodings
.into_iter()
.map(|enc| Encoding::Hf(Box::new(enc)))
.collect();
Ok(encodings)
}
}
impl Decoder for HuggingFaceTokenizer {
fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
let text = self
.tokenizer
.decode(token_ids, skip_special_tokens)
.map_err(|err| Error::msg(format!("Error de-tokenizing input: {err}")))?;
Ok(text)
}
}
impl Tokenizer for HuggingFaceTokenizer {}
impl From<HfTokenizer> for HuggingFaceTokenizer {
fn from(tokenizer: HfTokenizer) -> Self {
HuggingFaceTokenizer { tokenizer }
}
}