use std::collections::HashMap;
#[cfg(feature = "native")]
use std::sync::Arc;
#[cfg(feature = "native")]
use parking_lot::RwLock;
use tokenizers::Tokenizer;
use log::debug;
use crate::Result;
use crate::error::Error;
pub struct HfTokenizer {
tokenizer: Tokenizer,
}
#[derive(Debug, Clone)]
pub enum TokenizerSource {
#[cfg(not(target_arch = "wasm32"))]
HuggingFace(String),
#[cfg(not(target_arch = "wasm32"))]
LocalFile(String),
IndexDirectory(String),
}
impl TokenizerSource {
#[cfg(not(target_arch = "wasm32"))]
pub fn parse(path: &str) -> Self {
if let Some(relative) = path.strip_prefix("index://") {
TokenizerSource::IndexDirectory(relative.to_string())
} else if path.starts_with('/') {
TokenizerSource::LocalFile(path.to_string())
} else {
TokenizerSource::HuggingFace(path.to_string())
}
}
#[cfg(target_arch = "wasm32")]
pub fn parse(path: &str) -> Self {
if let Some(relative) = path.strip_prefix("index://") {
TokenizerSource::IndexDirectory(relative.to_string())
} else {
TokenizerSource::IndexDirectory(path.to_string())
}
}
}
impl HfTokenizer {
#[cfg(not(target_arch = "wasm32"))]
pub fn load(name_or_path: &str) -> Result<Self> {
let tokenizer = if name_or_path.contains('/') && !name_or_path.starts_with('/') {
Tokenizer::from_pretrained(name_or_path, None).map_err(|e| {
Error::Tokenizer(format!(
"Failed to load tokenizer '{}': {}",
name_or_path, e
))
})?
} else {
Tokenizer::from_file(name_or_path).map_err(|e| {
Error::Tokenizer(format!(
"Failed to load tokenizer from '{}': {}",
name_or_path, e
))
})?
};
Ok(Self { tokenizer })
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
let tokenizer = Tokenizer::from_bytes(bytes).map_err(|e| {
Error::Tokenizer(format!("Failed to parse tokenizer from bytes: {}", e))
})?;
Ok(Self { tokenizer })
}
#[cfg(not(target_arch = "wasm32"))]
pub fn from_source(source: &TokenizerSource) -> Result<Self> {
match source {
TokenizerSource::HuggingFace(name) => {
let tokenizer = Tokenizer::from_pretrained(name, None).map_err(|e| {
Error::Tokenizer(format!("Failed to load tokenizer '{}': {}", name, e))
})?;
Ok(Self { tokenizer })
}
TokenizerSource::LocalFile(path) => {
let tokenizer = Tokenizer::from_file(path).map_err(|e| {
Error::Tokenizer(format!("Failed to load tokenizer from '{}': {}", path, e))
})?;
Ok(Self { tokenizer })
}
TokenizerSource::IndexDirectory(_) => {
Err(Error::Tokenizer(
"IndexDirectory source requires using from_bytes with Directory read"
.to_string(),
))
}
}
}
pub fn tokenize(&self, text: &str) -> Result<Vec<(u32, u32)>> {
let encoding = self
.tokenizer
.encode(text, false)
.map_err(|e| Error::Tokenizer(format!("Tokenization failed: {}", e)))?;
let mut counts: HashMap<u32, u32> = HashMap::new();
for &id in encoding.get_ids() {
*counts.entry(id).or_insert(0) += 1;
}
let result: Vec<(u32, u32)> = counts.into_iter().collect();
debug!(
"Tokenized query: text={:?} tokens={:?} token_ids={:?} unique_count={}",
text,
encoding.get_tokens(),
encoding.get_ids(),
result.len()
);
Ok(result)
}
pub fn tokenize_unique(&self, text: &str) -> Result<Vec<u32>> {
let encoding = self
.tokenizer
.encode(text, false)
.map_err(|e| Error::Tokenizer(format!("Tokenization failed: {}", e)))?;
let mut ids: Vec<u32> = encoding.get_ids().to_vec();
ids.sort_unstable();
ids.dedup();
debug!(
"Tokenized query (unique): text={:?} tokens={:?} token_ids={:?} unique_count={}",
text,
encoding.get_tokens(),
encoding.get_ids(),
ids.len()
);
Ok(ids)
}
}
#[cfg(feature = "native")]
pub struct TokenizerCache {
cache: RwLock<HashMap<String, Arc<HfTokenizer>>>,
}
#[cfg(feature = "native")]
impl Default for TokenizerCache {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "native")]
impl TokenizerCache {
pub fn new() -> Self {
Self {
cache: RwLock::new(HashMap::new()),
}
}
pub fn get_or_load(&self, name_or_path: &str) -> Result<Arc<HfTokenizer>> {
{
let cache = self.cache.read();
if let Some(tokenizer) = cache.get(name_or_path) {
return Ok(Arc::clone(tokenizer));
}
}
let tokenizer = Arc::new(HfTokenizer::load(name_or_path)?);
{
let mut cache = self.cache.write();
cache.insert(name_or_path.to_string(), Arc::clone(&tokenizer));
}
Ok(tokenizer)
}
pub fn clear(&self) {
let mut cache = self.cache.write();
cache.clear();
}
}
#[cfg(feature = "native")]
static TOKENIZER_CACHE: std::sync::OnceLock<TokenizerCache> = std::sync::OnceLock::new();
#[cfg(feature = "native")]
pub fn tokenizer_cache() -> &'static TokenizerCache {
TOKENIZER_CACHE.get_or_init(TokenizerCache::new)
}
#[cfg(test)]
#[cfg(feature = "native")]
mod tests {
use super::*;
#[test]
#[ignore]
fn test_load_tokenizer_from_hub() {
let tokenizer = HfTokenizer::load("bert-base-uncased").unwrap();
let tokens = tokenizer.tokenize("hello world").unwrap();
assert!(!tokens.is_empty());
}
#[test]
#[ignore]
fn test_tokenize_unique() {
let tokenizer = HfTokenizer::load("bert-base-uncased").unwrap();
let ids = tokenizer.tokenize_unique("the quick brown fox").unwrap();
let mut sorted = ids.clone();
sorted.sort_unstable();
sorted.dedup();
assert_eq!(ids.len(), sorted.len());
}
#[test]
fn test_tokenizer_cache() {
let cache = TokenizerCache::new();
assert!(cache.cache.read().is_empty());
}
}