use anyhow::Result;
use std::path::Path;
use tokenizers::{Encoding, Tokenizer as HFTokenizer};
pub struct Tokenizer {
tokenizer: HFTokenizer,
}
impl Tokenizer {
pub fn load(path: &Path) -> Result<Self> {
let tokenizer =
HFTokenizer::from_file(path).map_err(|e| anyhow::anyhow!("加载分词器失败: {}", e))?;
Ok(Self { tokenizer })
}
pub fn encode(&self, text: &str, add_special_tokens: bool) -> Result<Encoding> {
self.tokenizer
.encode(text, add_special_tokens)
.map_err(|e| anyhow::anyhow!("编码失败: {}", e))
}
pub fn encode_batch(
&self,
texts: Vec<&str>,
add_special_tokens: bool,
) -> Result<Vec<Encoding>> {
self.tokenizer
.encode_batch(texts, add_special_tokens)
.map_err(|e| anyhow::anyhow!("批量编码失败: {}", e))
}
pub fn decode(&self, ids: &[u32], skip_special_tokens: bool) -> Result<String> {
self.tokenizer
.decode(ids, skip_special_tokens)
.map_err(|e| anyhow::anyhow!("解码失败: {}", e))
}
pub fn vocab_size(&self) -> usize {
self.tokenizer.get_vocab_size(false)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[ignore] fn test_encode() {
let tokenizer = Tokenizer::load(Path::new("assets/tokenizer.json")).unwrap();
let encoding = tokenizer.encode("Hello, world!", true).unwrap();
assert!(!encoding.get_ids().is_empty());
}
#[test]
#[ignore]
fn test_decode() {
let tokenizer = Tokenizer::load(Path::new("assets/tokenizer.json")).unwrap();
let encoding = tokenizer.encode("Hello, world!", true).unwrap();
let decoded = tokenizer.decode(encoding.get_ids(), true).unwrap();
assert!(decoded.contains("Hello"));
}
}