use crate::core::Result;
use base64ct::{Base64, Encoding};
use bstr::ByteSlice;
use serde_json::{Map, Value};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::path::Path;
use std::{env, fs};
use uuid::Uuid;
const TIKTOKEN_CACHE_DIR: &str = "TIKTOKEN_CACHE_DIR";
const DATA_GYM_CACHE_DIR: &str = "DATA_GYM_CACHE_DIR";
const DATA_GYM_TMP_DIR: &str = "data-gym-cache";
fn read_file_remote(blob_path: &str) -> Result<String> {
let res = reqwest::blocking::get(blob_path)?.text()?;
Ok(res)
}
fn read_file_cached(blob_path: &str) -> Result<String> {
let cache_dir = get_cache_dir();
if cache_dir.is_empty() {
return read_file_remote(blob_path);
}
let cache_filename = generate_cache_filename(blob_path);
let cache_path = Path::new(&cache_dir).join(&cache_filename);
if cache_path.exists() {
let res = fs::read_to_string(&cache_path)?;
return Ok(res);
}
let contents = read_file_remote(blob_path)?;
let tmp_file = cache_filename + "." + Uuid::new_v4().to_string().as_str() + ".tmp";
let tmp_cache_path = Path::new(&cache_dir).join(tmp_file);
fs::create_dir_all(&cache_dir)?;
fs::write(&tmp_cache_path, &contents)?;
fs::rename(&tmp_cache_path, &cache_path)?;
Ok(contents)
}
pub fn load_tiktoken_bpe(tiktoken_bpe_file: &str) -> HashMap<Vec<u8>, usize> {
let contents = read_file_cached(tiktoken_bpe_file).unwrap_or_default();
contents
.lines()
.map(|line| line.split_once(' '))
.filter(|item| item.is_some())
.map(|item| {
let (b64, num) = item.unwrap();
let key = Base64::decode_vec(b64).unwrap();
let val: usize = num.parse().unwrap();
(key, val)
})
.collect()
}
pub fn data_gym_to_mergeable_bpe_ranks(
vocab_bpe_file: &str,
encoder_json_file: &str,
) -> HashMap<Vec<u8>, usize> {
let mut rank_to_intbyte: Vec<u8> = vec![];
rank_to_intbyte.extend(0x21..=0x7E);
rank_to_intbyte.extend(0xA1..0xAD);
rank_to_intbyte.extend(0xAE..=0xFF);
let mut data_gym_byte_to_byte: HashMap<char, u8> = rank_to_intbyte
.iter()
.map(|&b| (char::from(b), b))
.collect();
let mut n = 0u32;
for b in 0..=255 {
if !rank_to_intbyte.contains(&b) {
rank_to_intbyte.push(b);
data_gym_byte_to_byte.insert(char::from_u32(256 + n).unwrap(), b);
n += 1;
}
}
assert_eq!(rank_to_intbyte.len(), 256);
let mut bpe_ranks: HashMap<Vec<u8>, usize> = rank_to_intbyte
.into_iter()
.enumerate()
.map(|(i, b)| (vec![b], i))
.collect();
let vocab_bpe_contents = read_file_cached(vocab_bpe_file).unwrap_or_default();
let bpe_merges: Vec<(&str, &str)> = vocab_bpe_contents
.lines()
.skip(1)
.map(|line| line.split_once(' '))
.filter(|item| item.is_some())
.flatten()
.collect();
let mut n = bpe_ranks.len();
for (first, second) in bpe_merges {
let mut key = decode_data_gym(first, &data_gym_byte_to_byte);
key.extend(decode_data_gym(second, &data_gym_byte_to_byte));
bpe_ranks.insert(key, n);
n += 1
}
let content = read_file_cached(encoder_json_file).unwrap_or("{}".to_string());
let encoder_json: Value =
serde_json::from_str(&content).unwrap_or(Value::Object(Map::default()));
let mut encoder_json_loaded: HashMap<Vec<u8>, usize> = encoder_json
.as_object()
.unwrap()
.iter()
.map(|(key, val)| {
(
decode_data_gym(key, &data_gym_byte_to_byte),
val.as_u64().unwrap() as usize,
)
})
.collect();
encoder_json_loaded.remove(b"<|endoftext|>".as_bytes());
encoder_json_loaded.remove(b"<|endoftext|>".as_bytes());
bpe_ranks
}
fn decode_data_gym(value: &str, dict: &HashMap<char, u8>) -> Vec<u8> {
value
.chars()
.map(|c| dict.get(&c).copied().unwrap())
.collect()
}
fn get_cache_dir() -> String {
env::var(TIKTOKEN_CACHE_DIR)
.or(env::var(DATA_GYM_CACHE_DIR))
.unwrap_or(
env::temp_dir()
.join(DATA_GYM_TMP_DIR)
.to_string_lossy()
.to_string(),
)
}
fn generate_cache_filename(blob_path: &str) -> String {
let cache_key = Sha256::digest(blob_path);
let hash_items: Vec<String> = cache_key.iter().map(|k| format!("{:02X?}", k)).collect();
hash_items.join("")
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
#[test]
fn test_get_cache_dir() {
env::set_var(TIKTOKEN_CACHE_DIR, "/tiktoken/cache/dir/");
let res = get_cache_dir();
assert_eq!(res, "/tiktoken/cache/dir/");
env::remove_var(TIKTOKEN_CACHE_DIR);
env::set_var(DATA_GYM_CACHE_DIR, "/data/gym/cache/dir/");
let res = get_cache_dir();
assert_eq!(res, "/data/gym/cache/dir/");
env::remove_var(DATA_GYM_CACHE_DIR);
let res = get_cache_dir();
assert!(res.ends_with(DATA_GYM_TMP_DIR));
}
#[test]
fn test_generate_cache_path() {
let expected = "26B9C229141B3D34DCAC6D3728F94F1E40ABB67EF4A84CA1351ABC0A20E6B701";
let res = generate_cache_filename(
"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken",
);
assert_eq!(&res, expected);
}
}