1use std::{
2 collections::HashMap,
3 fs::File,
4 io::{BufReader, BufWriter, Read as _, Write as _},
5 path::{Path, PathBuf},
6 sync::OnceLock,
7};
8
9use base64::{prelude::BASE64_STANDARD, Engine as _};
10
11use crate::tiktoken::{CoreBPE, Rank};
12use sha1::Sha1;
13use sha2::{Digest as _, Sha256};
14
15#[derive(Debug, thiserror::Error)]
16pub enum LoadError {
17 #[error("the env var TIKTOKEN_ENCODINGS_BASE is not set, or invalid")]
18 InvalidEncodingBaseDirEnvVar,
19
20 #[error("unknown encoding name: {0}")]
21 UnknownEncodingName(String),
22
23 #[error("invalid tiktoken vocab file: {0}")]
24 InvalidTiktokenVocabFile(#[source] std::io::Error),
25
26 #[error("failed to create CoreBPE: {0}")]
27 CoreBPECreationFailed(#[source] Box<dyn std::error::Error + Send + Sync>),
28
29 #[error("error downloading or loading vocab file: {0}")]
30 DownloadOrLoadVocabFile(
31 #[source]
32 #[from]
33 RemoteVocabFileError,
34 ),
35}
36
37#[derive(Debug, thiserror::Error)]
38pub enum RemoteVocabFileError {
39 #[error("failed to download or load vocab file")]
40 FailedToDownloadOrLoadVocabFile(#[source] Box<dyn std::error::Error + Send + Sync>),
41
42 #[error("an underlying IO error occurred while {0}: {1}")]
43 IOError(String, #[source] std::io::Error),
44
45 #[error("hash mismatch for remote file {file_url}")]
46 HashMismatch {
47 file_url: String,
48 expected_hash: String,
49 computed_hash: String,
50 },
51}
52
53const TIKTOKEN_ENCODINGS_BASE_VAR: &str = "TIKTOKEN_ENCODINGS_BASE";
54const DEFAULT_TIKTOKEN_BASE_URL: &str = "https://openaipublic.blob.core.windows.net/encodings/";
55
56static TIKTOKEN_BASE_URL_OVERRIDE: OnceLock<String> = OnceLock::new();
57
58pub fn set_tiktoken_base_url(base_url: impl Into<String>) {
59 let mut base = base_url.into();
60 if !base.ends_with('/') {
61 base.push('/');
62 }
63 let _ = TIKTOKEN_BASE_URL_OVERRIDE.set(base);
64}
65
66fn tiktoken_base_url() -> &'static str {
67 TIKTOKEN_BASE_URL_OVERRIDE
68 .get()
69 .map(|s| s.as_str())
70 .unwrap_or(DEFAULT_TIKTOKEN_BASE_URL)
71}
72
73#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74pub enum Encoding {
75 O200kBase,
76 O200kHarmony,
77 Cl100kBase,
78}
79
80impl Encoding {
81 pub fn all() -> &'static [Self] {
82 &[Self::O200kBase, Self::O200kHarmony, Self::Cl100kBase]
83 }
84
85 pub fn from_name(name: impl AsRef<str>) -> Option<Self> {
86 let name_str = name.as_ref();
87 for encoding in Self::all() {
88 if encoding.name() == name_str {
89 return Some(*encoding);
90 }
91 }
92 None
93 }
94
95 #[cfg(not(target_arch = "wasm32"))]
96 pub fn load_from_name(name: impl AsRef<str>) -> Result<CoreBPE, LoadError> {
97 let name = name.as_ref();
98 Self::from_name(name)
99 .ok_or_else(|| LoadError::UnknownEncodingName(name.to_string()))?
100 .load()
101 }
102
103 #[cfg(target_arch = "wasm32")]
104 pub async fn load_from_name(name: impl AsRef<str>) -> Result<CoreBPE, LoadError> {
105 let name = name.as_ref();
106 Self::from_name(name)
107 .ok_or_else(|| LoadError::UnknownEncodingName(name.to_string()))?
108 .load()
109 .await
110 }
111
112 pub fn name(&self) -> &'static str {
113 match self {
114 Self::O200kBase => "o200k_base",
115 Self::O200kHarmony => "o200k_harmony",
116 Self::Cl100kBase => "cl100k_base",
117 }
118 }
119
120 #[cfg(not(target_arch = "wasm32"))]
121 pub fn load(&self) -> Result<CoreBPE, LoadError> {
122 let (vocab_file_path, check_hash) =
123 if let Ok(base_dir) = std::env::var(TIKTOKEN_ENCODINGS_BASE_VAR) {
124 (PathBuf::from(base_dir).join(self.vocab_file_name()), true)
125 } else {
126 let url = self.public_vocab_file_url();
127 (
128 download_or_find_cached_file(&url, Some(self.expected_hash()))
129 .map_err(LoadError::DownloadOrLoadVocabFile)?,
130 false,
131 )
132 };
133
134 match self {
135 Self::O200kHarmony => {
136 let mut specials: Vec<(String, Rank)> = self
137 .special_tokens()
138 .iter()
139 .map(|(s, r)| ((*s).to_string(), *r))
140 .collect();
141 specials.extend((200014..=201088).map(|id| (format!("<|reserved_{id}|>"), id)));
142 load_encoding_from_file(
143 vocab_file_path,
144 check_hash.then(|| self.expected_hash()),
145 specials,
146 &self.pattern(),
147 )
148 }
149 Self::O200kBase => {
150 let mut specials: Vec<(String, Rank)> = self
151 .special_tokens()
152 .iter()
153 .map(|(s, r)| ((*s).to_string(), *r))
154 .collect();
155 specials.extend((199998..=201088).map(|id| (format!("<|reserved_{id}|>"), id)));
156 load_encoding_from_file(
157 vocab_file_path,
158 check_hash.then(|| self.expected_hash()),
159 specials,
160 &self.pattern(),
161 )
162 }
163 _ => load_encoding_from_file(
164 vocab_file_path,
165 check_hash.then(|| self.expected_hash()),
166 self.special_tokens().iter().cloned(),
167 &self.pattern(),
168 ),
169 }
170 }
171
172 #[cfg(target_arch = "wasm32")]
173 pub async fn load(&self) -> Result<CoreBPE, LoadError> {
174 let url = self.public_vocab_file_url();
175 let vocab_bytes = download_or_find_cached_file_bytes(&url, Some(self.expected_hash()))
176 .await
177 .map_err(LoadError::DownloadOrLoadVocabFile)?;
178
179 match self {
180 Self::O200kHarmony => {
181 let mut specials: Vec<(String, Rank)> = self
182 .special_tokens()
183 .iter()
184 .map(|(s, r)| ((*s).to_string(), *r))
185 .collect();
186 specials.extend((200014..=201088).map(|id| (format!("<|reserved_{id}|>"), id)));
187 load_encoding_from_bytes(&vocab_bytes, None, specials, &self.pattern())
188 }
189 Self::O200kBase => {
190 let mut specials: Vec<(String, Rank)> = self
191 .special_tokens()
192 .iter()
193 .map(|(s, r)| ((*s).to_string(), *r))
194 .collect();
195 specials.extend((199998..=201088).map(|id| (format!("<|reserved_{id}|>"), id)));
196 load_encoding_from_bytes(&vocab_bytes, None, specials, &self.pattern())
197 }
198 _ => load_encoding_from_bytes(
199 &vocab_bytes,
200 None,
201 self.special_tokens().iter().cloned(),
202 &self.pattern(),
203 ),
204 }
205 }
206
207 fn public_vocab_file_url(&self) -> String {
208 let base = tiktoken_base_url();
209 match self {
210 Self::O200kBase => format!("{base}o200k_base.tiktoken"),
211 Self::O200kHarmony => format!("{base}o200k_base.tiktoken"),
212 Self::Cl100kBase => format!("{base}cl100k_base.tiktoken"),
213 }
214 }
215
216 fn vocab_file_name(&self) -> &'static str {
217 match self {
218 Self::O200kBase => "o200k_base.tiktoken",
219 Self::O200kHarmony => "o200k_base.tiktoken",
220 Self::Cl100kBase => "cl100k_base.tiktoken",
221 }
222 }
223
224 fn expected_hash(&self) -> &'static str {
225 match self {
226 Self::O200kBase => "446a9538cb6c348e3516120d7c08b09f57c36495e2acfffe59a5bf8b0cfb1a2d",
227 Self::O200kHarmony => {
228 "446a9538cb6c348e3516120d7c08b09f57c36495e2acfffe59a5bf8b0cfb1a2d"
229 }
230 Self::Cl100kBase => "223921b76ee99bde995b7ff738513eef100fb51d18c93597a113bcffe865b2a7",
231 }
232 }
233
234 fn special_tokens(&self) -> &'static [(&'static str, Rank)] {
235 match self {
236 Self::O200kBase => &[],
237 Self::O200kHarmony => &[
238 ("<|startoftext|>", 199998),
239 ("<|endoftext|>", 199999),
240 ("<|reserved_200000|>", 200000),
241 ("<|reserved_200001|>", 200001),
242 ("<|return|>", 200002),
243 ("<|constrain|>", 200003),
244 ("<|reserved_200004|>", 200004),
245 ("<|channel|>", 200005),
246 ("<|start|>", 200006),
247 ("<|end|>", 200007),
248 ("<|message|>", 200008),
249 ("<|reserved_200009|>", 200009),
250 ("<|reserved_200010|>", 200010),
251 ("<|reserved_200011|>", 200011),
252 ("<|call|>", 200012),
253 ("<|reserved_200013|>", 200013),
254 ],
255 Self::Cl100kBase => &[
256 ("<|endoftext|>", 100257),
257 ("<|fim_prefix|>", 100258),
258 ("<|fim_middle|>", 100259),
259 ("<|fim_suffix|>", 100260),
260 ("<|endofprompt|>", 100276),
261 ],
262 }
263 }
264
265 fn pattern(&self) -> String {
266 match self {
267 Self::O200kBase => {
268 [
269 "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
270 "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
271 "\\p{N}{1,3}",
272 " ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*",
273 "\\s*[\\r\\n]+",
274 "\\s+(?!\\S)",
275 "\\s+",
276 ].join("|")
277 }
278 Self::O200kHarmony => {
279 [
280 "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
281 "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
282 "\\p{N}{1,3}",
283 " ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*",
284 "\\s*[\\r\\n]+",
285 "\\s+(?!\\S)",
286 "\\s+",
287 ].join("|")
288 }
289 Self::Cl100kBase => {
290 "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+".to_string()
291 }
292 }
293 }
294}
295
296fn load_tiktoken_vocab<R>(
297 mut reader: R,
298 expected_hash: Option<&str>,
299) -> std::result::Result<HashMap<Vec<u8>, Rank>, std::io::Error>
300where
301 R: std::io::BufRead,
302{
303 let mut hasher = expected_hash.map(|_| Sha256::new());
304 let mut bpe_ranks = HashMap::default();
305 let mut lin_no = 0;
306 let mut line_buffer = String::new();
307 while reader.read_line(&mut line_buffer)? > 0 {
308 lin_no += 1;
309 if let Some(hasher) = hasher.as_mut() {
310 hasher.update(line_buffer.as_bytes());
311 }
312 let line = line_buffer.trim_end();
313 let (token, rank) = line.split_once(' ').ok_or_else(|| {
314 std::io::Error::new(
315 std::io::ErrorKind::InvalidData,
316 format!("expected token and rank, could not split on ' ' at line {lin_no}"),
317 )
318 })?;
319 let bytes = BASE64_STANDARD.decode(token).map_err(|e| {
320 std::io::Error::new(
321 std::io::ErrorKind::InvalidData,
322 format!("failed to decode base64 token at line {lin_no}: {e}",),
323 )
324 })?;
325 let rank = rank.parse().map_err(|e| {
326 std::io::Error::new(
327 std::io::ErrorKind::InvalidData,
328 format!("failed to parse rank at line {lin_no}: {e}"),
329 )
330 })?;
331 bpe_ranks.insert(bytes, rank);
332 line_buffer.clear();
333 }
334 if let Some(hasher) = hasher {
335 let expected_hash = expected_hash.unwrap();
336 let computed_hash = format!("{:x}", hasher.finalize());
337 if computed_hash != expected_hash {
338 return Err(std::io::Error::new(
339 std::io::ErrorKind::InvalidData,
340 format!("hash mismatch: computed={computed_hash}, expected={expected_hash}"),
341 ));
342 }
343 }
344 Ok(bpe_ranks)
345}
346
347pub fn load_tiktoken_vocab_file<P>(
348 path: P,
349 expected_hash: Option<&str>,
350) -> std::result::Result<HashMap<Vec<u8>, Rank>, std::io::Error>
351where
352 P: AsRef<Path>,
353{
354 let file = std::fs::File::open(path)?;
355 let reader = std::io::BufReader::new(file);
356 load_tiktoken_vocab(reader, expected_hash)
357}
358
359pub fn load_encoding_from_file<P, S, TS>(
360 file_path: P,
361 expected_hash: Option<&str>,
362 special_tokens: S,
363 pattern: &str,
364) -> Result<CoreBPE, LoadError>
365where
366 P: AsRef<Path>,
367 S: IntoIterator<Item = (TS, Rank)>,
368 TS: Into<String>,
369{
370 let encoder = load_tiktoken_vocab_file(file_path, expected_hash)
371 .map_err(LoadError::InvalidTiktokenVocabFile)?;
372 CoreBPE::new(
373 encoder,
374 special_tokens.into_iter().map(|(k, v)| (k.into(), v)),
375 pattern,
376 )
377 .map_err(LoadError::CoreBPECreationFailed)
378}
379
380#[cfg(target_arch = "wasm32")]
381pub fn load_encoding_from_bytes<S, TS>(
382 vocab_bytes: &[u8],
383 expected_hash: Option<&str>,
384 special_tokens: S,
385 pattern: &str,
386) -> Result<CoreBPE, LoadError>
387where
388 S: IntoIterator<Item = (TS, Rank)>,
389 TS: Into<String>,
390{
391 let reader = std::io::BufReader::new(vocab_bytes);
392 let encoder = load_tiktoken_vocab(reader, expected_hash)
393 .map_err(LoadError::InvalidTiktokenVocabFile)?;
394 CoreBPE::new(
395 encoder,
396 special_tokens.into_iter().map(|(k, v)| (k.into(), v)),
397 pattern,
398 )
399 .map_err(LoadError::CoreBPECreationFailed)
400}
401
402#[cfg(not(target_arch = "wasm32"))]
403fn download_or_find_cached_file(
404 url: &str,
405 expected_hash: Option<&str>,
406) -> Result<PathBuf, RemoteVocabFileError> {
407 let cache_dir = resolve_cache_dir()?;
408 let cache_path = resolve_cache_path(&cache_dir, url);
409 if cache_path.exists() {
410 if verify_file_hash(&cache_path, expected_hash)? {
411 return Ok(cache_path);
412 }
413 let _ = std::fs::remove_file(&cache_path);
414 }
415 let hash = load_remote_file(url, &cache_path)?;
416 if let Some(expected_hash) = expected_hash {
417 if hash != expected_hash {
418 let _ = std::fs::remove_file(&cache_path);
419 return Err(RemoteVocabFileError::HashMismatch {
420 file_url: url.to_string(),
421 expected_hash: expected_hash.to_string(),
422 computed_hash: hash,
423 });
424 }
425 }
426 Ok(cache_path)
427}
428
429#[cfg(target_arch = "wasm32")]
430async fn download_or_find_cached_file_bytes(
431 url: &str,
432 expected_hash: Option<&str>,
433) -> Result<Vec<u8>, RemoteVocabFileError> {
434 let bytes = load_remote_file_bytes(url).await?;
435 if let Some(expected_hash) = expected_hash {
436 let computed_hash = format!("{:x}", Sha256::digest(&bytes));
437 if computed_hash != expected_hash {
438 return Err(RemoteVocabFileError::HashMismatch {
439 file_url: url.to_string(),
440 expected_hash: expected_hash.to_string(),
441 computed_hash,
442 });
443 }
444 }
445 Ok(bytes)
446}
447
448fn resolve_cache_dir() -> Result<PathBuf, RemoteVocabFileError> {
449 let cache_dir_override = std::env::var("TIKTOKEN_RS_CACHE_DIR").ok();
450 if let Some(cache_dir_override) = cache_dir_override {
451 Ok(PathBuf::from(cache_dir_override))
452 } else {
453 let cache_dir = std::env::temp_dir().join("tiktoken-rs-cache");
454 std::fs::create_dir_all(&cache_dir).map_err(|e| {
455 RemoteVocabFileError::IOError(format!("creating cache dir {cache_dir:?}"), e)
456 })?;
457 Ok(cache_dir)
458 }
459}
460
461fn resolve_cache_path(cache_dir: &Path, url: &str) -> PathBuf {
462 let mut hasher = Sha1::new();
463 hasher.update(url.as_bytes());
464 let cache_key = format!("{:x}", hasher.finalize());
465 cache_dir.join(cache_key)
466}
467
468fn verify_file_hash(
469 file_path: &Path,
470 expected_hash: Option<&str>,
471) -> Result<bool, RemoteVocabFileError> {
472 let Some(expected_hash) = expected_hash else {
473 return Ok(true);
474 };
475 let file = File::open(file_path)
476 .map_err(|e| RemoteVocabFileError::IOError(format!("opening file {file_path:?}"), e))?;
477 let mut reader = BufReader::new(file);
478 let mut hasher = Sha256::new();
479 std::io::copy(&mut reader, &mut hasher).map_err(|e| {
480 RemoteVocabFileError::IOError(format!("copying file {file_path:?} contents to hasher"), e)
481 })?;
482 let computed_hash = format!("{:x}", hasher.finalize());
483 Ok(computed_hash == expected_hash)
484}
485
486#[cfg(not(target_arch = "wasm32"))]
487fn load_remote_file(url: &str, destination: &Path) -> Result<String, RemoteVocabFileError> {
488 let client = reqwest::blocking::Client::new();
489 let mut response = client
490 .get(url)
491 .send()
492 .and_then(|r| r.error_for_status())
493 .map_err(|e| RemoteVocabFileError::FailedToDownloadOrLoadVocabFile(Box::new(e)))?;
494
495 let file = File::create(destination)
496 .map_err(|e| RemoteVocabFileError::IOError(format!("creating file {destination:?}"), e))?;
497 let mut dest = BufWriter::new(file);
498 let mut hasher = Sha256::new();
499 let mut buffer = [0u8; 8192];
500 loop {
501 let bytes_read = response.read(&mut buffer).map_err(|e| {
502 RemoteVocabFileError::IOError(format!("reading from response {url}"), e)
503 })?;
504 if bytes_read == 0 {
505 break;
506 }
507 dest.write_all(&buffer[..bytes_read]).map_err(|e| {
508 RemoteVocabFileError::IOError(format!("writing to file {destination:?}"), e)
509 })?;
510 hasher.update(&buffer[..bytes_read]);
511 }
512 Ok(format!("{:x}", hasher.finalize()))
513}
514
515#[cfg(target_arch = "wasm32")]
516async fn load_remote_file_bytes(url: &str) -> Result<Vec<u8>, RemoteVocabFileError> {
517 use reqwest::Client;
518
519 let client = Client::new();
520 let response = client
521 .get(url)
522 .send()
523 .await
524 .and_then(|r| r.error_for_status())
525 .map_err(|e| RemoteVocabFileError::FailedToDownloadOrLoadVocabFile(Box::new(e)))?;
526 let bytes = response
527 .bytes()
528 .await
529 .map_err(|e| RemoteVocabFileError::FailedToDownloadOrLoadVocabFile(Box::new(e)))?;
530 Ok(bytes.to_vec())
531}