1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
use std::collections::HashMap;
use std::sync::{Mutex, OnceLock};
use tiktoken_rs::CoreBPE;
static BPE: OnceLock<CoreBPE> = OnceLock::new();
fn get_bpe() -> &'static CoreBPE {
BPE.get_or_init(|| tiktoken_rs::o200k_base().expect("failed to load o200k_base tokenizer"))
}
const TOKEN_CACHE_MAX: usize = 256;
static TOKEN_CACHE: Mutex<Option<HashMap<u64, usize>>> = Mutex::new(None);
fn hash_text(text: &str) -> u64 {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
text.len().hash(&mut hasher);
if text.len() <= 512 {
text.hash(&mut hasher);
} else {
text[..256].hash(&mut hasher);
text[text.len() - 256..].hash(&mut hasher);
}
hasher.finish()
}
pub fn count_tokens(text: &str) -> usize {
if text.is_empty() {
return 0;
}
let key = hash_text(text);
if let Ok(guard) = TOKEN_CACHE.lock() {
if let Some(ref map) = *guard {
if let Some(&cached) = map.get(&key) {
return cached;
}
}
}
let count = get_bpe().encode_with_special_tokens(text).len();
if let Ok(mut guard) = TOKEN_CACHE.lock() {
let map = guard.get_or_insert_with(HashMap::new);
if map.len() >= TOKEN_CACHE_MAX {
map.clear();
}
map.insert(key, count);
}
count
}
pub fn encode_tokens(text: &str) -> Vec<u32> {
if text.is_empty() {
return Vec::new();
}
get_bpe().encode_with_special_tokens(text)
}