use std::path::Path;
pub struct Tokenizer {
inner: tokenizers::Tokenizer,
}
#[derive(Debug, thiserror::Error)]
pub enum TokenizerError {
#[error("failed to load tokenizer: {0}")]
Load(String),
#[error("encoding failed: {0}")]
Encode(String),
#[error("decoding failed: {0}")]
Decode(String),
}
impl Tokenizer {
pub fn from_file(path: impl AsRef<Path>) -> Result<Self, TokenizerError> {
let inner = tokenizers::Tokenizer::from_file(path.as_ref())
.map_err(|e| TokenizerError::Load(e.to_string()))?;
Ok(Self { inner })
}
pub fn from_json(json: &str) -> Result<Self, TokenizerError> {
let inner: tokenizers::Tokenizer =
json.parse()
.map_err(|e: Box<dyn std::error::Error + Send + Sync>| {
TokenizerError::Load(e.to_string())
})?;
Ok(Self { inner })
}
pub fn encode(&self, text: &str) -> Result<Vec<u32>, TokenizerError> {
let encoding = self
.inner
.encode(text, false)
.map_err(|e| TokenizerError::Encode(e.to_string()))?;
Ok(encoding.get_ids().to_vec())
}
pub fn encode_with_special(&self, text: &str) -> Result<Vec<u32>, TokenizerError> {
let encoding = self
.inner
.encode(text, true)
.map_err(|e| TokenizerError::Encode(e.to_string()))?;
Ok(encoding.get_ids().to_vec())
}
pub fn decode(&self, ids: &[u32]) -> Result<String, TokenizerError> {
self.inner
.decode(ids, true)
.map_err(|e| TokenizerError::Decode(e.to_string()))
}
pub fn decode_one(&self, id: u32) -> Result<String, TokenizerError> {
self.decode(&[id])
}
pub fn vocab_size(&self) -> usize {
self.inner.get_vocab_size(true)
}
pub fn token_to_id(&self, token: &str) -> Option<u32> {
self.inner.token_to_id(token)
}
pub fn bos_token_id(&self) -> Option<u32> {
self.token_to_id("<s>")
.or_else(|| self.token_to_id("<|begin_of_text|>"))
.or_else(|| self.token_to_id("<|startoftext|>"))
}
pub fn eos_token_id(&self) -> Option<u32> {
self.token_to_id("</s>")
.or_else(|| self.token_to_id("<|end_of_text|>"))
.or_else(|| self.token_to_id("<|endoftext|>"))
.or_else(|| self.token_to_id("<|im_end|>"))
}
pub fn stop_token_ids(&self) -> Vec<u32> {
let candidates = [
"</s>",
"<|end_of_text|>",
"<|endoftext|>",
"<|im_end|>",
"<|eot_id|>",
"<|end|>",
];
candidates
.iter()
.filter_map(|&token| self.token_to_id(token))
.collect()
}
pub fn is_stop_token(&self, token_id: u32) -> bool {
self.stop_token_ids().contains(&token_id)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn minimal_tokenizer_json() -> String {
r#"{
"version": "1.0",
"truncation": null,
"padding": null,
"added_tokens": [
{"id": 0, "content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false, "special": true},
{"id": 1, "content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false, "special": true},
{"id": 2, "content": "<unk>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false, "special": true}
],
"normalizer": null,
"pre_tokenizer": {"type": "ByteLevel", "add_prefix_space": false, "trim_offsets": true, "use_regex": true},
"post_processor": null,
"decoder": {"type": "ByteLevel", "add_prefix_space": true, "trim_offsets": true, "use_regex": true},
"model": {
"type": "BPE",
"dropout": null,
"unk_token": "<unk>",
"continuing_subword_prefix": null,
"end_of_word_suffix": null,
"fuse_unk": false,
"byte_fallback": false,
"ignore_merges": false,
"vocab": {
"<s>": 0, "</s>": 1, "<unk>": 2,
"h": 3, "e": 4, "l": 5, "o": 6,
"he": 7, "ll": 8, "lo": 9,
"hel": 10, "llo": 11
},
"merges": [
"h e", "l l", "l o", "he l", "ll o"
]
}
}"#
.to_string()
}
#[test]
fn load_from_json() {
let json = minimal_tokenizer_json();
let tok = Tokenizer::from_json(&json).unwrap();
assert!(tok.vocab_size() > 0);
}
#[test]
fn encode_decode_roundtrip() {
let json = minimal_tokenizer_json();
let tok = Tokenizer::from_json(&json).unwrap();
let ids = tok.encode("hello").unwrap();
assert!(!ids.is_empty());
let text = tok.decode(&ids).unwrap();
assert_eq!(text, "hello");
}
#[test]
fn special_tokens() {
let json = minimal_tokenizer_json();
let tok = Tokenizer::from_json(&json).unwrap();
assert_eq!(tok.bos_token_id(), Some(0));
assert_eq!(tok.eos_token_id(), Some(1));
}
#[test]
fn decode_single_token() {
let json = minimal_tokenizer_json();
let tok = Tokenizer::from_json(&json).unwrap();
let text = tok.decode_one(3).unwrap();
assert!(!text.is_empty());
}
}