use crate::config::ModelAsset;
use crate::error::TtsError;
use std::path::Path;
pub struct TextTokenizer {
inner: tokenizers::Tokenizer,
}
impl TextTokenizer {
pub fn from_file(path: impl AsRef<Path>) -> Result<Self, TtsError> {
let inner = tokenizers::Tokenizer::from_file(path.as_ref())
.map_err(|e| TtsError::TokenizerError(format!("Failed to load tokenizer: {}", e)))?;
Ok(Self { inner })
}
pub fn from_bytes(bytes: impl AsRef<[u8]>) -> Result<Self, TtsError> {
let inner = tokenizers::Tokenizer::from_bytes(bytes)
.map_err(|e| TtsError::TokenizerError(format!("Failed to load tokenizer: {}", e)))?;
Ok(Self { inner })
}
pub fn from_asset(asset: &ModelAsset) -> Result<Self, TtsError> {
if let Some(path) = asset.as_path() {
return Self::from_file(path);
}
let bytes = asset.read_bytes()?;
Self::from_bytes(bytes.as_ref())
}
pub fn from_model_dir(dir: impl AsRef<Path>) -> Result<Self, TtsError> {
let dir = dir.as_ref();
let tokenizer_json = dir.join("tokenizer.json");
if tokenizer_json.exists() {
return Self::from_file(tokenizer_json);
}
let vocab_path = dir.join("vocab.json");
let merges_path = dir.join("merges.txt");
if vocab_path.exists() && merges_path.exists() {
let inner = tokenizers::Tokenizer::from_file(&vocab_path).map_err(|e| {
TtsError::TokenizerError(format!("Failed to load tokenizer from vocab.json: {}", e))
})?;
Ok(Self { inner })
} else {
Err(TtsError::TokenizerError(format!(
"No tokenizer files found in {}",
dir.display()
)))
}
}
pub fn encode(&self, text: &str) -> Result<Vec<u32>, TtsError> {
let encoding = self
.inner
.encode(text, false)
.map_err(|e| TtsError::TokenizerError(format!("Encoding failed: {}", e)))?;
Ok(encoding.get_ids().to_vec())
}
pub fn decode(&self, ids: &[u32]) -> Result<String, TtsError> {
self.inner
.decode(ids, true)
.map_err(|e| TtsError::TokenizerError(format!("Decoding failed: {}", e)))
}
pub fn vocab_size(&self) -> usize {
self.inner.get_vocab_size(true)
}
pub fn token_to_id(&self, token: &str) -> Option<u32> {
self.inner.token_to_id(token)
}
}