harmony_protocol/tiktoken_ext/
public_encodings.rs

1use std::{
2    collections::HashMap,
3    fs::File,
4    io::{BufReader, BufWriter, Read as _, Write as _},
5    path::{Path, PathBuf},
6    sync::OnceLock,
7};
8
9use base64::{prelude::BASE64_STANDARD, Engine as _};
10
11use crate::tiktoken::{CoreBPE, Rank};
12use sha1::Sha1;
13use sha2::{Digest as _, Sha256};
14
15#[derive(Debug, thiserror::Error)]
16pub enum LoadError {
17    #[error("the env var TIKTOKEN_ENCODINGS_BASE is not set, or invalid")]
18    InvalidEncodingBaseDirEnvVar,
19
20    #[error("unknown encoding name: {0}")]
21    UnknownEncodingName(String),
22
23    #[error("invalid tiktoken vocab file: {0}")]
24    InvalidTiktokenVocabFile(#[source] std::io::Error),
25
26    #[error("failed to create CoreBPE: {0}")]
27    CoreBPECreationFailed(#[source] Box<dyn std::error::Error + Send + Sync>),
28
29    #[error("error downloading or loading vocab file: {0}")]
30    DownloadOrLoadVocabFile(
31        #[source]
32        #[from]
33        RemoteVocabFileError,
34    ),
35}
36
37#[derive(Debug, thiserror::Error)]
38pub enum RemoteVocabFileError {
39    #[error("failed to download or load vocab file")]
40    FailedToDownloadOrLoadVocabFile(#[source] Box<dyn std::error::Error + Send + Sync>),
41
42    #[error("an underlying IO error occurred while {0}: {1}")]
43    IOError(String, #[source] std::io::Error),
44
45    #[error("hash mismatch for remote file {file_url}")]
46    HashMismatch {
47        file_url: String,
48        expected_hash: String,
49        computed_hash: String,
50    },
51}
52
53const TIKTOKEN_ENCODINGS_BASE_VAR: &str = "TIKTOKEN_ENCODINGS_BASE";
54const DEFAULT_TIKTOKEN_BASE_URL: &str = "https://openaipublic.blob.core.windows.net/encodings/";
55
56static TIKTOKEN_BASE_URL_OVERRIDE: OnceLock<String> = OnceLock::new();
57
58pub fn set_tiktoken_base_url(base_url: impl Into<String>) {
59    let mut base = base_url.into();
60    if !base.ends_with('/') {
61        base.push('/');
62    }
63    let _ = TIKTOKEN_BASE_URL_OVERRIDE.set(base);
64}
65
66fn tiktoken_base_url() -> &'static str {
67    TIKTOKEN_BASE_URL_OVERRIDE
68        .get()
69        .map(|s| s.as_str())
70        .unwrap_or(DEFAULT_TIKTOKEN_BASE_URL)
71}
72
73#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74pub enum Encoding {
75    O200kBase,
76    O200kHarmony,
77    Cl100kBase,
78}
79
80impl Encoding {
81    pub fn all() -> &'static [Self] {
82        &[Self::O200kBase, Self::O200kHarmony, Self::Cl100kBase]
83    }
84
85    pub fn from_name(name: impl AsRef<str>) -> Option<Self> {
86        let name_str = name.as_ref();
87        for encoding in Self::all() {
88            if encoding.name() == name_str {
89                return Some(*encoding);
90            }
91        }
92        None
93    }
94
95    #[cfg(not(target_arch = "wasm32"))]
96    pub fn load_from_name(name: impl AsRef<str>) -> Result<CoreBPE, LoadError> {
97        let name = name.as_ref();
98        Self::from_name(name)
99            .ok_or_else(|| LoadError::UnknownEncodingName(name.to_string()))?
100            .load()
101    }
102
103    #[cfg(target_arch = "wasm32")]
104    pub async fn load_from_name(name: impl AsRef<str>) -> Result<CoreBPE, LoadError> {
105        let name = name.as_ref();
106        Self::from_name(name)
107            .ok_or_else(|| LoadError::UnknownEncodingName(name.to_string()))?
108            .load()
109            .await
110    }
111
112    pub fn name(&self) -> &'static str {
113        match self {
114            Self::O200kBase => "o200k_base",
115            Self::O200kHarmony => "o200k_harmony",
116            Self::Cl100kBase => "cl100k_base",
117        }
118    }
119
120    #[cfg(not(target_arch = "wasm32"))]
121    pub fn load(&self) -> Result<CoreBPE, LoadError> {
122        let (vocab_file_path, check_hash) =
123            if let Ok(base_dir) = std::env::var(TIKTOKEN_ENCODINGS_BASE_VAR) {
124                (PathBuf::from(base_dir).join(self.vocab_file_name()), true)
125            } else {
126                let url = self.public_vocab_file_url();
127                (
128                    download_or_find_cached_file(&url, Some(self.expected_hash()))
129                        .map_err(LoadError::DownloadOrLoadVocabFile)?,
130                    false,
131                )
132            };
133
134        match self {
135            Self::O200kHarmony => {
136                let mut specials: Vec<(String, Rank)> = self
137                    .special_tokens()
138                    .iter()
139                    .map(|(s, r)| ((*s).to_string(), *r))
140                    .collect();
141                specials.extend((200014..=201088).map(|id| (format!("<|reserved_{id}|>"), id)));
142                load_encoding_from_file(
143                    vocab_file_path,
144                    check_hash.then(|| self.expected_hash()),
145                    specials,
146                    &self.pattern(),
147                )
148            }
149            Self::O200kBase => {
150                let mut specials: Vec<(String, Rank)> = self
151                    .special_tokens()
152                    .iter()
153                    .map(|(s, r)| ((*s).to_string(), *r))
154                    .collect();
155                specials.extend((199998..=201088).map(|id| (format!("<|reserved_{id}|>"), id)));
156                load_encoding_from_file(
157                    vocab_file_path,
158                    check_hash.then(|| self.expected_hash()),
159                    specials,
160                    &self.pattern(),
161                )
162            }
163            _ => load_encoding_from_file(
164                vocab_file_path,
165                check_hash.then(|| self.expected_hash()),
166                self.special_tokens().iter().cloned(),
167                &self.pattern(),
168            ),
169        }
170    }
171
172    #[cfg(target_arch = "wasm32")]
173    pub async fn load(&self) -> Result<CoreBPE, LoadError> {
174        let url = self.public_vocab_file_url();
175        let vocab_bytes = download_or_find_cached_file_bytes(&url, Some(self.expected_hash()))
176            .await
177            .map_err(LoadError::DownloadOrLoadVocabFile)?;
178
179        match self {
180            Self::O200kHarmony => {
181                let mut specials: Vec<(String, Rank)> = self
182                    .special_tokens()
183                    .iter()
184                    .map(|(s, r)| ((*s).to_string(), *r))
185                    .collect();
186                specials.extend((200014..=201088).map(|id| (format!("<|reserved_{id}|>"), id)));
187                load_encoding_from_bytes(&vocab_bytes, None, specials, &self.pattern())
188            }
189            Self::O200kBase => {
190                let mut specials: Vec<(String, Rank)> = self
191                    .special_tokens()
192                    .iter()
193                    .map(|(s, r)| ((*s).to_string(), *r))
194                    .collect();
195                specials.extend((199998..=201088).map(|id| (format!("<|reserved_{id}|>"), id)));
196                load_encoding_from_bytes(&vocab_bytes, None, specials, &self.pattern())
197            }
198            _ => load_encoding_from_bytes(
199                &vocab_bytes,
200                None,
201                self.special_tokens().iter().cloned(),
202                &self.pattern(),
203            ),
204        }
205    }
206
207    fn public_vocab_file_url(&self) -> String {
208        let base = tiktoken_base_url();
209        match self {
210            Self::O200kBase => format!("{base}o200k_base.tiktoken"),
211            Self::O200kHarmony => format!("{base}o200k_base.tiktoken"),
212            Self::Cl100kBase => format!("{base}cl100k_base.tiktoken"),
213        }
214    }
215
216    fn vocab_file_name(&self) -> &'static str {
217        match self {
218            Self::O200kBase => "o200k_base.tiktoken",
219            Self::O200kHarmony => "o200k_base.tiktoken",
220            Self::Cl100kBase => "cl100k_base.tiktoken",
221        }
222    }
223
224    fn expected_hash(&self) -> &'static str {
225        match self {
226            Self::O200kBase => "446a9538cb6c348e3516120d7c08b09f57c36495e2acfffe59a5bf8b0cfb1a2d",
227            Self::O200kHarmony => {
228                "446a9538cb6c348e3516120d7c08b09f57c36495e2acfffe59a5bf8b0cfb1a2d"
229            }
230            Self::Cl100kBase => "223921b76ee99bde995b7ff738513eef100fb51d18c93597a113bcffe865b2a7",
231        }
232    }
233
234    fn special_tokens(&self) -> &'static [(&'static str, Rank)] {
235        match self {
236            Self::O200kBase => &[],
237            Self::O200kHarmony => &[
238                ("<|startoftext|>", 199998),
239                ("<|endoftext|>", 199999),
240                ("<|reserved_200000|>", 200000),
241                ("<|reserved_200001|>", 200001),
242                ("<|return|>", 200002),
243                ("<|constrain|>", 200003),
244                ("<|reserved_200004|>", 200004),
245                ("<|channel|>", 200005),
246                ("<|start|>", 200006),
247                ("<|end|>", 200007),
248                ("<|message|>", 200008),
249                ("<|reserved_200009|>", 200009),
250                ("<|reserved_200010|>", 200010),
251                ("<|reserved_200011|>", 200011),
252                ("<|call|>", 200012),
253                ("<|reserved_200013|>", 200013),
254            ],
255            Self::Cl100kBase => &[
256                ("<|endoftext|>", 100257),
257                ("<|fim_prefix|>", 100258),
258                ("<|fim_middle|>", 100259),
259                ("<|fim_suffix|>", 100260),
260                ("<|endofprompt|>", 100276),
261            ],
262        }
263    }
264
265    fn pattern(&self) -> String {
266        match self {
267            Self::O200kBase => {
268                [
269                    "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
270                    "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
271                    "\\p{N}{1,3}",
272                    " ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*",
273                    "\\s*[\\r\\n]+",
274                    "\\s+(?!\\S)",
275                    "\\s+",
276                ].join("|")
277            }
278            Self::O200kHarmony => {
279                [
280                    "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
281                    "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
282                    "\\p{N}{1,3}",
283                    " ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*",
284                    "\\s*[\\r\\n]+",
285                    "\\s+(?!\\S)",
286                    "\\s+",
287                ].join("|")
288            }
289            Self::Cl100kBase => {
290                "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+".to_string()
291            }
292        }
293    }
294}
295
296fn load_tiktoken_vocab<R>(
297    mut reader: R,
298    expected_hash: Option<&str>,
299) -> std::result::Result<HashMap<Vec<u8>, Rank>, std::io::Error>
300where
301    R: std::io::BufRead,
302{
303    let mut hasher = expected_hash.map(|_| Sha256::new());
304    let mut bpe_ranks = HashMap::default();
305    let mut lin_no = 0;
306    let mut line_buffer = String::new();
307    while reader.read_line(&mut line_buffer)? > 0 {
308        lin_no += 1;
309        if let Some(hasher) = hasher.as_mut() {
310            hasher.update(line_buffer.as_bytes());
311        }
312        let line = line_buffer.trim_end();
313        let (token, rank) = line.split_once(' ').ok_or_else(|| {
314            std::io::Error::new(
315                std::io::ErrorKind::InvalidData,
316                format!("expected token and rank, could not split on ' ' at line {lin_no}"),
317            )
318        })?;
319        let bytes = BASE64_STANDARD.decode(token).map_err(|e| {
320            std::io::Error::new(
321                std::io::ErrorKind::InvalidData,
322                format!("failed to decode base64 token at line {lin_no}: {e}",),
323            )
324        })?;
325        let rank = rank.parse().map_err(|e| {
326            std::io::Error::new(
327                std::io::ErrorKind::InvalidData,
328                format!("failed to parse rank at line {lin_no}: {e}"),
329            )
330        })?;
331        bpe_ranks.insert(bytes, rank);
332        line_buffer.clear();
333    }
334    if let Some(hasher) = hasher {
335        let expected_hash = expected_hash.unwrap();
336        let computed_hash = format!("{:x}", hasher.finalize());
337        if computed_hash != expected_hash {
338            return Err(std::io::Error::new(
339                std::io::ErrorKind::InvalidData,
340                format!("hash mismatch: computed={computed_hash}, expected={expected_hash}"),
341            ));
342        }
343    }
344    Ok(bpe_ranks)
345}
346
347pub fn load_tiktoken_vocab_file<P>(
348    path: P,
349    expected_hash: Option<&str>,
350) -> std::result::Result<HashMap<Vec<u8>, Rank>, std::io::Error>
351where
352    P: AsRef<Path>,
353{
354    let file = std::fs::File::open(path)?;
355    let reader = std::io::BufReader::new(file);
356    load_tiktoken_vocab(reader, expected_hash)
357}
358
359pub fn load_encoding_from_file<P, S, TS>(
360    file_path: P,
361    expected_hash: Option<&str>,
362    special_tokens: S,
363    pattern: &str,
364) -> Result<CoreBPE, LoadError>
365where
366    P: AsRef<Path>,
367    S: IntoIterator<Item = (TS, Rank)>,
368    TS: Into<String>,
369{
370    let encoder = load_tiktoken_vocab_file(file_path, expected_hash)
371        .map_err(LoadError::InvalidTiktokenVocabFile)?;
372    CoreBPE::new(
373        encoder,
374        special_tokens.into_iter().map(|(k, v)| (k.into(), v)),
375        pattern,
376    )
377    .map_err(LoadError::CoreBPECreationFailed)
378}
379
380#[cfg(target_arch = "wasm32")]
381pub fn load_encoding_from_bytes<S, TS>(
382    vocab_bytes: &[u8],
383    expected_hash: Option<&str>,
384    special_tokens: S,
385    pattern: &str,
386) -> Result<CoreBPE, LoadError>
387where
388    S: IntoIterator<Item = (TS, Rank)>,
389    TS: Into<String>,
390{
391    let reader = std::io::BufReader::new(vocab_bytes);
392    let encoder = load_tiktoken_vocab(reader, expected_hash)
393        .map_err(LoadError::InvalidTiktokenVocabFile)?;
394    CoreBPE::new(
395        encoder,
396        special_tokens.into_iter().map(|(k, v)| (k.into(), v)),
397        pattern,
398    )
399    .map_err(LoadError::CoreBPECreationFailed)
400}
401
402#[cfg(not(target_arch = "wasm32"))]
403fn download_or_find_cached_file(
404    url: &str,
405    expected_hash: Option<&str>,
406) -> Result<PathBuf, RemoteVocabFileError> {
407    let cache_dir = resolve_cache_dir()?;
408    let cache_path = resolve_cache_path(&cache_dir, url);
409    if cache_path.exists() {
410        if verify_file_hash(&cache_path, expected_hash)? {
411            return Ok(cache_path);
412        }
413        let _ = std::fs::remove_file(&cache_path);
414    }
415    let hash = load_remote_file(url, &cache_path)?;
416    if let Some(expected_hash) = expected_hash {
417        if hash != expected_hash {
418            let _ = std::fs::remove_file(&cache_path);
419            return Err(RemoteVocabFileError::HashMismatch {
420                file_url: url.to_string(),
421                expected_hash: expected_hash.to_string(),
422                computed_hash: hash,
423            });
424        }
425    }
426    Ok(cache_path)
427}
428
429#[cfg(target_arch = "wasm32")]
430async fn download_or_find_cached_file_bytes(
431    url: &str,
432    expected_hash: Option<&str>,
433) -> Result<Vec<u8>, RemoteVocabFileError> {
434    let bytes = load_remote_file_bytes(url).await?;
435    if let Some(expected_hash) = expected_hash {
436        let computed_hash = format!("{:x}", Sha256::digest(&bytes));
437        if computed_hash != expected_hash {
438            return Err(RemoteVocabFileError::HashMismatch {
439                file_url: url.to_string(),
440                expected_hash: expected_hash.to_string(),
441                computed_hash,
442            });
443        }
444    }
445    Ok(bytes)
446}
447
448fn resolve_cache_dir() -> Result<PathBuf, RemoteVocabFileError> {
449    let cache_dir_override = std::env::var("TIKTOKEN_RS_CACHE_DIR").ok();
450    if let Some(cache_dir_override) = cache_dir_override {
451        Ok(PathBuf::from(cache_dir_override))
452    } else {
453        let cache_dir = std::env::temp_dir().join("tiktoken-rs-cache");
454        std::fs::create_dir_all(&cache_dir).map_err(|e| {
455            RemoteVocabFileError::IOError(format!("creating cache dir {cache_dir:?}"), e)
456        })?;
457        Ok(cache_dir)
458    }
459}
460
461fn resolve_cache_path(cache_dir: &Path, url: &str) -> PathBuf {
462    let mut hasher = Sha1::new();
463    hasher.update(url.as_bytes());
464    let cache_key = format!("{:x}", hasher.finalize());
465    cache_dir.join(cache_key)
466}
467
468fn verify_file_hash(
469    file_path: &Path,
470    expected_hash: Option<&str>,
471) -> Result<bool, RemoteVocabFileError> {
472    let Some(expected_hash) = expected_hash else {
473        return Ok(true);
474    };
475    let file = File::open(file_path)
476        .map_err(|e| RemoteVocabFileError::IOError(format!("opening file {file_path:?}"), e))?;
477    let mut reader = BufReader::new(file);
478    let mut hasher = Sha256::new();
479    std::io::copy(&mut reader, &mut hasher).map_err(|e| {
480        RemoteVocabFileError::IOError(format!("copying file {file_path:?} contents to hasher"), e)
481    })?;
482    let computed_hash = format!("{:x}", hasher.finalize());
483    Ok(computed_hash == expected_hash)
484}
485
486#[cfg(not(target_arch = "wasm32"))]
487fn load_remote_file(url: &str, destination: &Path) -> Result<String, RemoteVocabFileError> {
488    let client = reqwest::blocking::Client::new();
489    let mut response = client
490        .get(url)
491        .send()
492        .and_then(|r| r.error_for_status())
493        .map_err(|e| RemoteVocabFileError::FailedToDownloadOrLoadVocabFile(Box::new(e)))?;
494
495    let file = File::create(destination)
496        .map_err(|e| RemoteVocabFileError::IOError(format!("creating file {destination:?}"), e))?;
497    let mut dest = BufWriter::new(file);
498    let mut hasher = Sha256::new();
499    let mut buffer = [0u8; 8192];
500    loop {
501        let bytes_read = response.read(&mut buffer).map_err(|e| {
502            RemoteVocabFileError::IOError(format!("reading from response {url}"), e)
503        })?;
504        if bytes_read == 0 {
505            break;
506        }
507        dest.write_all(&buffer[..bytes_read]).map_err(|e| {
508            RemoteVocabFileError::IOError(format!("writing to file {destination:?}"), e)
509        })?;
510        hasher.update(&buffer[..bytes_read]);
511    }
512    Ok(format!("{:x}", hasher.finalize()))
513}
514
515#[cfg(target_arch = "wasm32")]
516async fn load_remote_file_bytes(url: &str) -> Result<Vec<u8>, RemoteVocabFileError> {
517    use reqwest::Client;
518
519    let client = Client::new();
520    let response = client
521        .get(url)
522        .send()
523        .await
524        .and_then(|r| r.error_for_status())
525        .map_err(|e| RemoteVocabFileError::FailedToDownloadOrLoadVocabFile(Box::new(e)))?;
526    let bytes = response
527        .bytes()
528        .await
529        .map_err(|e| RemoteVocabFileError::FailedToDownloadOrLoadVocabFile(Box::new(e)))?;
530    Ok(bytes.to_vec())
531}