use crate::error::{RuntimeError, RuntimeResult};
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
pub struct TokenizerBridge {
tokenizer: tokenizers::Tokenizer,
cached_vocab: std::sync::OnceLock<Vec<(u32, Vec<u8>)>>,
}
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
impl TokenizerBridge {
pub fn from_file(path: &str) -> RuntimeResult<Self> {
let tokenizer =
tokenizers::Tokenizer::from_file(path).map_err(|e| RuntimeError::TokenizerError {
message: format!("failed to load tokenizer from {path}: {e}"),
})?;
Ok(Self {
tokenizer,
cached_vocab: std::sync::OnceLock::new(),
})
}
pub fn from_bytes(json: &[u8]) -> RuntimeResult<Self> {
let tokenizer =
tokenizers::Tokenizer::from_bytes(json).map_err(|e| RuntimeError::TokenizerError {
message: format!("failed to parse tokenizer JSON: {e}"),
})?;
Ok(Self {
tokenizer,
cached_vocab: std::sync::OnceLock::new(),
})
}
pub fn encode(&self, text: &str) -> RuntimeResult<Vec<u32>> {
let encoding =
self.tokenizer
.encode(text, false)
.map_err(|e| RuntimeError::TokenizerError {
message: format!("encoding failed: {e}"),
})?;
Ok(encoding.get_ids().to_vec())
}
pub fn decode(&self, tokens: &[u32]) -> RuntimeResult<String> {
self.tokenizer
.decode(tokens, true)
.map_err(|e| RuntimeError::TokenizerError {
message: format!("decoding failed: {e}"),
})
}
pub fn vocab_size(&self) -> usize {
self.tokenizer.get_vocab_size(true)
}
pub fn bos_token_id(&self) -> Option<u32> {
self.tokenizer
.token_to_id("<s>")
.or_else(|| self.tokenizer.token_to_id("<|begin_of_text|>"))
}
pub fn eos_token_id(&self) -> Option<u32> {
self.tokenizer
.token_to_id("</s>")
.or_else(|| self.tokenizer.token_to_id("<|end_of_text|>"))
.or_else(|| self.tokenizer.token_to_id("<|endoftext|>"))
}
pub fn id_to_token(&self, id: u32) -> Option<String> {
self.tokenizer.id_to_token(id)
}
pub fn token_to_bytes(&self, id: u32) -> Option<Vec<u8>> {
self.tokenizer
.decode(&[id], false)
.ok()
.map(|s| s.into_bytes())
}
pub fn vocab_bytes(&self) -> Vec<(u32, Vec<u8>)> {
let vocab = self.tokenizer.get_vocab(true);
let mut result: Vec<(u32, Vec<u8>)> = vocab
.into_values()
.filter_map(|id| self.token_to_bytes(id).map(|bytes| (id, bytes)))
.collect();
result.sort_unstable_by_key(|&(id, _)| id);
result
}
pub fn vocab_bytes_cached(&self) -> &[(u32, Vec<u8>)] {
self.cached_vocab.get_or_init(|| self.vocab_bytes())
}
}
#[cfg(not(any(feature = "tokenizer-onig", feature = "tokenizer-wasm")))]
pub struct TokenizerBridge;
#[cfg(not(any(feature = "tokenizer-onig", feature = "tokenizer-wasm")))]
impl TokenizerBridge {
pub fn from_file(_path: &str) -> RuntimeResult<Self> {
Err(RuntimeError::TokenizerNotAvailable)
}
pub fn from_bytes(_json: &[u8]) -> RuntimeResult<Self> {
Err(RuntimeError::TokenizerNotAvailable)
}
pub fn encode(&self, _text: &str) -> RuntimeResult<Vec<u32>> {
Err(RuntimeError::TokenizerNotAvailable)
}
pub fn decode(&self, _tokens: &[u32]) -> RuntimeResult<String> {
Err(RuntimeError::TokenizerNotAvailable)
}
pub fn vocab_size(&self) -> usize {
0
}
pub fn bos_token_id(&self) -> Option<u32> {
None
}
pub fn eos_token_id(&self) -> Option<u32> {
None
}
pub fn id_to_token(&self, _id: u32) -> Option<String> {
None
}
pub fn token_to_bytes(&self, _id: u32) -> Option<Vec<u8>> {
None
}
pub fn vocab_bytes(&self) -> Vec<(u32, Vec<u8>)> {
Vec::new()
}
pub fn vocab_bytes_cached(&self) -> &[(u32, Vec<u8>)] {
&[]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_from_file_nonexistent_errors() {
let result = TokenizerBridge::from_file("/nonexistent/path/tokenizer_test.json");
assert!(result.is_err(), "missing tokenizer file should error");
}
#[cfg(not(any(feature = "tokenizer-onig", feature = "tokenizer-wasm")))]
#[test]
fn test_stub_from_file_returns_not_available() {
let result = TokenizerBridge::from_file("/any/path.json");
assert!(
matches!(result, Err(RuntimeError::TokenizerNotAvailable)),
"stub should return TokenizerNotAvailable, got {result:?}"
);
}
#[cfg(not(any(feature = "tokenizer-onig", feature = "tokenizer-wasm")))]
#[test]
fn test_stub_from_bytes_returns_not_available() {
let result = TokenizerBridge::from_bytes(b"{}");
assert!(
matches!(result, Err(RuntimeError::TokenizerNotAvailable)),
"stub should return TokenizerNotAvailable, got {result:?}"
);
}
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
#[test]
fn test_from_bytes_invalid_json_errors() {
let result = TokenizerBridge::from_bytes(b"not valid json {{{{");
assert!(
result.is_err(),
"invalid tokenizer JSON should return an error"
);
}
#[cfg(not(any(feature = "tokenizer-onig", feature = "tokenizer-wasm")))]
#[test]
fn test_stub_vocab_size_is_zero() {
let r = TokenizerBridge::from_bytes(b"{}");
assert!(r.is_err());
}
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
const MINIMAL_TOKENIZER_JSON: &str = r#"{
"version": "1.0",
"truncation": null,
"padding": null,
"added_tokens": [
{"id": 0, "special": true, "content": "<unk>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false},
{"id": 1, "special": true, "content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false},
{"id": 2, "special": true, "content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false}
],
"normalizer": null,
"pre_tokenizer": null,
"post_processor": null,
"decoder": null,
"model": {
"type": "BPE",
"dropout": null,
"unk_token": "<unk>",
"continuing_subword_prefix": null,
"end_of_word_suffix": null,
"fuse_unk": false,
"byte_fallback": false,
"vocab": {
"<unk>": 0,
"<s>": 1,
"</s>": 2,
"h": 3,
"e": 4,
"l": 5,
"o": 6,
" ": 7,
"w": 8,
"r": 9,
"d": 10,
"a": 11,
"b": 12
},
"merges": []
}
}"#;
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
#[test]
fn test_from_bytes_valid_json_succeeds() {
let bridge = TokenizerBridge::from_bytes(MINIMAL_TOKENIZER_JSON.as_bytes())
.expect("test: valid tokenizer JSON should parse");
assert!(
bridge.vocab_size() > 0,
"vocab_size should be positive after loading"
);
}
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
#[test]
fn test_bos_token_id_found() {
let bridge = TokenizerBridge::from_bytes(MINIMAL_TOKENIZER_JSON.as_bytes())
.expect("test: valid tokenizer JSON should parse");
assert_eq!(
bridge.bos_token_id(),
Some(1),
"BOS token <s> should have id=1"
);
}
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
#[test]
fn test_eos_token_id_found() {
let bridge = TokenizerBridge::from_bytes(MINIMAL_TOKENIZER_JSON.as_bytes())
.expect("test: valid tokenizer JSON should parse");
assert_eq!(
bridge.eos_token_id(),
Some(2),
"EOS token </s> should have id=2"
);
}
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
#[test]
fn test_encode_produces_tokens() {
let bridge = TokenizerBridge::from_bytes(MINIMAL_TOKENIZER_JSON.as_bytes())
.expect("test: valid tokenizer JSON should parse");
let tokens = bridge.encode("hello").expect("test: encode should succeed");
assert!(
!tokens.is_empty(),
"encoding 'hello' should produce at least one token"
);
}
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
#[test]
fn test_decode_empty_slice_returns_empty_string() {
let bridge = TokenizerBridge::from_bytes(MINIMAL_TOKENIZER_JSON.as_bytes())
.expect("test: valid tokenizer JSON should parse");
let decoded = bridge
.decode(&[])
.expect("test: decoding empty slice should succeed");
assert_eq!(
decoded, "",
"decoding empty token list should return empty string"
);
}
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
#[test]
fn test_encode_decode_roundtrip() {
let bridge = TokenizerBridge::from_bytes(MINIMAL_TOKENIZER_JSON.as_bytes())
.expect("test: valid tokenizer JSON should parse");
let tokens = bridge.encode("hello").expect("test: encode should succeed");
let decoded = bridge.decode(&tokens).expect("test: decode should succeed");
assert!(
!decoded.is_empty() || tokens.is_empty(),
"decoded output consistency"
);
}
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
#[test]
fn test_vocab_size_matches_json() {
let bridge = TokenizerBridge::from_bytes(MINIMAL_TOKENIZER_JSON.as_bytes())
.expect("test: valid tokenizer JSON should parse");
assert_eq!(
bridge.vocab_size(),
13,
"vocab_size should match the number of defined tokens"
);
}
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
#[test]
fn test_token_to_bytes_special_token() {
let bridge = TokenizerBridge::from_bytes(MINIMAL_TOKENIZER_JSON.as_bytes())
.expect("test: valid tokenizer JSON should parse");
let _bytes = bridge.token_to_bytes(0); }
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
#[test]
fn test_vocab_bytes_is_sorted() {
let bridge = TokenizerBridge::from_bytes(MINIMAL_TOKENIZER_JSON.as_bytes())
.expect("test: valid tokenizer JSON should parse");
let pairs = bridge.vocab_bytes();
for window in pairs.windows(2) {
assert!(
window[0].0 <= window[1].0,
"vocab_bytes should be sorted by token id, got {} > {}",
window[0].0,
window[1].0
);
}
}
#[cfg(any(feature = "tokenizer-onig", feature = "tokenizer-wasm"))]
#[test]
fn test_from_bytes_invalid_json_structure_errors() {
let result = TokenizerBridge::from_bytes(b"{\"not\": \"a tokenizer\"}");
assert!(result.is_err(), "non-tokenizer JSON should return an error");
}
}