use std::{
fs::File,
io::{
BufRead,
BufReader,
BufWriter,
Write,
},
path::Path,
};
use base64::{
Engine,
prelude::BASE64_STANDARD,
};
use crate::{
TokenType,
WCError,
WCResult,
prelude::*,
spanners::TextSpanningConfig,
vocab::{
SpanMapVocab,
UnifiedTokenVocab,
vocab_types::SpanTokenMap,
},
};
pub fn load_base64_unified_vocab_path<T: TokenType>(
path: impl AsRef<Path>,
spanning: TextSpanningConfig<T>,
) -> WCResult<UnifiedTokenVocab<T>> {
let mut reader = BufReader::new(File::open(path)?);
read_base64_unified_vocab(&mut reader, spanning)
}
pub fn read_base64_unified_vocab<T: TokenType>(
reader: &mut dyn BufRead,
spanning: TextSpanningConfig<T>,
) -> WCResult<UnifiedTokenVocab<T>> {
UnifiedTokenVocab::from_span_vocab(spanning, read_base64_span_map(reader)?.into())
}
pub fn load_base64_span_vocab_path<T, P>(path: P) -> WCResult<SpanMapVocab<T>>
where
T: TokenType,
P: AsRef<Path>,
{
Ok(load_base64_span_map_path(path)?.into())
}
pub fn load_base64_span_map_path<T, P>(path: P) -> WCResult<SpanTokenMap<T>>
where
T: TokenType,
P: AsRef<Path>,
{
let mut reader = BufReader::new(File::open(path)?);
read_base64_span_map(&mut reader)
}
pub fn read_base64_span_map<T>(reader: &mut dyn BufRead) -> WCResult<SpanTokenMap<T>>
where
T: TokenType,
{
let mut vocab = SpanTokenMap::default();
let stream = reader.lines();
for line in stream {
let line = line?;
let s: &str = line.as_ref();
let parts = s.splitn(2, ' ').collect::<Vec<&str>>();
assert_eq!(parts.len(), 2);
let span = BASE64_STANDARD
.decode(parts[0])
.map_err(|e| WCError::Parse(e.to_string()))?;
let id: u64 = parts[1]
.parse()
.map_err(|e: core::num::ParseIntError| WCError::Parse(e.to_string()))?;
let token = T::from_u64(id).ok_or(WCError::TokenOutOfRange)?;
vocab.insert(span, token);
}
Ok(vocab)
}
pub fn save_base64_span_map_path<T: TokenType, P: AsRef<Path>>(
span_map: &SpanTokenMap<T>,
path: P,
) -> WCResult<()> {
let mut writer = BufWriter::new(File::create(path)?);
write_base64_span_map(span_map, &mut writer)
}
pub fn write_base64_span_map<T>(
span_map: &SpanTokenMap<T>,
writer: &mut dyn Write,
) -> WCResult<()>
where
T: TokenType,
{
let mut items: Vec<(T, &Vec<u8>)> = span_map
.iter()
.map(|(chunk, &token)| (token, chunk))
.collect();
items.sort_by_key(|(t, _)| *t);
for (token, chunk) in items {
writeln!(
writer,
"{} {}",
BASE64_STANDARD.encode(chunk),
token.to_u64().unwrap()
)?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_save_load_tiktoken() {
type T = u32;
let mut span_map: SpanTokenMap<T> = Default::default();
span_map.insert("apple".as_bytes().to_vec(), 300);
span_map.insert("banana".as_bytes().to_vec(), 301);
span_map.insert("pear".as_bytes().to_vec(), 302);
tempdir::TempDir::new("vocab_test")
.and_then(|dir| {
let path = dir.path().join("vocab.tiktoken");
save_base64_span_map_path(&span_map, &path).expect("Failed to save vocab");
let loaded_vocab = load_base64_span_map_path(&path).expect("Failed to load vocab");
assert_eq!(&loaded_vocab, &span_map);
Ok(())
})
.unwrap();
}
}