use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::{Arc, OnceLock, RwLock};
pub mod download;
pub mod error;
pub mod hf;
pub mod registry;
pub use error::TokenizerError;
pub use hf::HfTokenizer;
pub use registry::Tokenizer;
fn registry() -> &'static RwLock<HashMap<Tokenizer, Arc<HfTokenizer>>> {
static REG: OnceLock<RwLock<HashMap<Tokenizer, Arc<HfTokenizer>>>> = OnceLock::new();
REG.get_or_init(|| RwLock::new(HashMap::new()))
}
pub fn xdg_root() -> Result<PathBuf, TokenizerError> {
Ok(crate::paths::data_dir().join("tokenizers"))
}
pub async fn ensure_loaded(family: Tokenizer) -> Result<(), TokenizerError> {
if registry()
.read()
.expect("tokenizer registry rwlock poisoned")
.contains_key(&family)
{
return Ok(());
}
let root = xdg_root()?;
let path: PathBuf =
tokio::task::spawn_blocking(move || download::ensure_on_disk(&root, family))
.await
.map_err(|e| TokenizerError::Io {
path: "<spawn_blocking>".to_string(),
source: std::io::Error::other(format!("spawn_blocking join failed: {e}")),
})??;
let parsed = tokio::task::spawn_blocking(move || HfTokenizer::from_path(&path, family))
.await
.map_err(|e| TokenizerError::Io {
path: "<spawn_blocking>".to_string(),
source: std::io::Error::other(format!("spawn_blocking join failed: {e}")),
})??;
registry()
.write()
.expect("tokenizer registry rwlock poisoned")
.insert(family, Arc::new(parsed));
Ok(())
}
pub fn count(text: &str, family: Tokenizer) -> Result<usize, TokenizerError> {
let map = registry()
.read()
.expect("tokenizer registry rwlock poisoned");
let tk = map.get(&family).ok_or(TokenizerError::NotLoaded(family))?;
tk.count(text)
}
#[cfg(test)]
pub(crate) fn _clear_registry_for_tests() {
registry()
.write()
.expect("tokenizer registry rwlock poisoned")
.clear();
}
#[cfg(test)]
pub(crate) fn _test_mutex() -> &'static std::sync::Mutex<()> {
static M: std::sync::OnceLock<std::sync::Mutex<()>> = std::sync::OnceLock::new();
M.get_or_init(|| std::sync::Mutex::new(()))
}
#[cfg(test)]
mod tests {
use super::*;
fn fixture() -> std::path::PathBuf {
let mut p = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
p.push("tests/fixtures/tokenizer/tiny.json");
p
}
#[test]
fn count_without_load_errors() {
let _guard = _test_mutex().lock().expect("test mutex");
_clear_registry_for_tests();
let err = count("abab", Tokenizer::Llama3).unwrap_err();
assert!(matches!(err, TokenizerError::NotLoaded(Tokenizer::Llama3)));
}
#[tokio::test]
async fn manual_insert_then_count_works() {
let _guard = _test_mutex().lock().expect("test mutex");
_clear_registry_for_tests();
let tk = HfTokenizer::from_path(&fixture(), Tokenizer::Cl100k).unwrap();
registry()
.write()
.unwrap()
.insert(Tokenizer::Cl100k, Arc::new(tk));
assert_eq!(count("abab", Tokenizer::Cl100k).unwrap(), 2);
}
}