#[cfg(feature = "std")]
use std::{
io::BufRead,
path::Path,
};
#[cfg(feature = "std")]
use crate::support::resources::ResourceLoader;
use crate::{
TokenType,
prelude::*,
pretrained::openai::{
OA_CL100K_BASE_PATTERN,
OA_O200K_BASE_PATTERN,
OA_P50K_BASE_PATTERN,
OA_R50K_BASE_PATTERN,
resources::{
OA_CL100K_BASE_TIKTOKEN_RESOURCE,
OA_O200K_BASE_TIKTOKEN_RESOURCE,
OA_P50K_BASE_TIKTOKEN_RESOURCE,
OA_R50K_BASE_TIKTOKEN_RESOURCE,
},
specials::{
oa_cl100k_base_special_tokens,
oa_o200k_base_special_tokens,
oa_o200k_harmony_special_tokens,
oa_p50k_base_special_tokens,
oa_p50k_edit_special_tokens,
oa_r50k_base_special_tokens,
},
},
spanners::TextSpanningConfig,
support::{
regex::RegexPattern,
resources::ConstKeyedResource,
},
vocab::utility::factories::ConstVocabularyFactory,
};
#[cfg(all(feature = "std", feature = "datagym"))]
mod datagym {
use crate::{
TokenType,
UnifiedTokenVocab,
VocabDescription,
pretrained::factory::BuiltinPretrainedVocabHook,
support::resources::ResourceLoader,
};
#[cfg(all(feature = "std", feature = "datagym"))]
pub fn load_gpt2_vocab<T: TokenType>(
loader: &mut dyn ResourceLoader
) -> crate::WCResult<UnifiedTokenVocab<T>> {
use std::io::BufReader;
use crate::{
pretrained::openai::{
oa_r50k_base_spanning_config,
resources::{
OA_GPT2_ENCODER_JSON_KEYED_RESOURCE,
OA_GPT2_VOCAB_BPE_KEYED_RESOURCE,
},
},
vocab::{
SpanMapVocab,
io::read_datagym_vocab,
},
};
let vocab_path = loader.load_resource_path(&OA_GPT2_VOCAB_BPE_KEYED_RESOURCE.into())?;
let mut vocab_reader = BufReader::new(std::fs::File::open(vocab_path)?);
let encoder_path =
loader.load_resource_path(&OA_GPT2_ENCODER_JSON_KEYED_RESOURCE.into())?;
let mut encoder_reader = BufReader::new(std::fs::File::open(encoder_path)?);
let span_map = read_datagym_vocab(&mut vocab_reader, &mut encoder_reader, false)?;
UnifiedTokenVocab::from_span_vocab(
oa_r50k_base_spanning_config(),
SpanMapVocab::from_span_map(span_map).to_token_type()?,
)
}
#[cfg(all(feature = "std", feature = "datagym"))]
inventory::submit! {
BuiltinPretrainedVocabHook::new(
"openai:gpt2",
|id| VocabDescription::new(
id,
&["openai", "gpt2"],
"GPT-2 `gpt2` vocabulary",
),
|_, loader| {
super::load_gpt2_vocab::<u32>(loader)
},
)
}
}
#[cfg(all(feature = "std", feature = "datagym"))]
pub use self::datagym::*;
const OA_KEY: &str = "openai";
#[cfg(feature = "std")]
mod builtin_loaders {
use super::*;
use crate::{
UnifiedTokenVocab,
VocabDescription,
WCResult,
pretrained::{
factory::BuiltinPretrainedVocabHook,
openai::resources::{
OA_CL100K_BASE_TIKTOKEN_KEYED_RESOURCE,
OA_O200K_BASE_TIKTOKEN_KEYED_RESOURCE,
OA_P50K_BASE_TIKTOKEN_KEYED_RESOURCE,
OA_R50K_BASE_TIKTOKEN_KEYED_RESOURCE,
},
},
support::resources::KeyedResource,
vocab::io::load_base64_unified_vocab_path,
};
fn load_tiktoken_vocab(
loader: &mut dyn ResourceLoader,
resource: &KeyedResource,
spanning: TextSpanningConfig<u32>,
) -> WCResult<UnifiedTokenVocab<u32>> {
let path = loader.load_resource_path(resource)?;
load_base64_unified_vocab_path::<u32>(path, spanning)
}
inventory::submit! {
BuiltinPretrainedVocabHook::new(
"openai:r50k_base",
|id| VocabDescription::new(
id,
&["openai", "r50k_base"],
"GPT-2 `p50k_base` vocabulary; remote",
),
|_, loader| {
load_tiktoken_vocab(
loader,
&OA_R50K_BASE_TIKTOKEN_KEYED_RESOURCE.into(),
TextSpanningConfig::<u32>::from(OA_R50K_BASE_PATTERN.to_pattern())
.with_special_words(oa_r50k_base_special_tokens::<u32>()),
)
}
)
}
inventory::submit! {
BuiltinPretrainedVocabHook::new(
"openai:p50k_base",
|id| VocabDescription::new(
id,
&["openai", "p50k_base"],
"GPT-2 `p50k_base` vocabulary; remote",
),
|_, loader| {
load_tiktoken_vocab(
loader,
&OA_P50K_BASE_TIKTOKEN_KEYED_RESOURCE.into(),
TextSpanningConfig::<u32>::from(OA_P50K_BASE_PATTERN.to_pattern())
.with_special_words(oa_p50k_base_special_tokens::<u32>()),
)
}
)
}
inventory::submit! {
BuiltinPretrainedVocabHook::new(
"openai:p50k_edit",
|id| VocabDescription::new(
id,
&["openai", "p50k_base"],
"GPT-2 `p50k_base` vocabulary; remote",
),
|_, loader| {
load_tiktoken_vocab(
loader,
&OA_P50K_BASE_TIKTOKEN_KEYED_RESOURCE.into(),
TextSpanningConfig::<u32>::from(OA_P50K_BASE_PATTERN.to_pattern())
.with_special_words(oa_p50k_edit_special_tokens::<u32>()),
)
}
)
}
inventory::submit! {
BuiltinPretrainedVocabHook::new(
"openai:cl100k_base",
|id| VocabDescription::new(
id,
&["openai", "cl100k_base"],
"GPT-3 `cl100k_base` vocabulary; remote",
),
|_, loader| {
load_tiktoken_vocab(
loader,
&OA_CL100K_BASE_TIKTOKEN_KEYED_RESOURCE.into(),
TextSpanningConfig::<u32>::from(OA_CL100K_BASE_PATTERN.to_pattern())
.with_special_words(oa_cl100k_base_special_tokens::<u32>()),
)
}
)
}
inventory::submit! {
BuiltinPretrainedVocabHook::new(
"openai:o200k_base",
|id| VocabDescription::new(
id,
&["openai", "o200k_base"],
"GPT-3 `o200k_base` vocabulary; remote",
),
|_, loader| {
load_tiktoken_vocab(
loader,
&OA_O200K_BASE_TIKTOKEN_KEYED_RESOURCE.into(),
TextSpanningConfig::<u32>::from(OA_O200K_BASE_PATTERN.to_pattern())
.with_special_words(oa_o200k_base_special_tokens::<u32>()),
)
}
)
}
inventory::submit! {
BuiltinPretrainedVocabHook::new(
"openai:o200k_harmony",
|id| VocabDescription::new(
id,
&["openai", "o200k_harmony"],
"GPT-3 `o200k_harmony` vocabulary; remote",
),
|_, loader| {
load_tiktoken_vocab(
loader,
&OA_O200K_BASE_TIKTOKEN_KEYED_RESOURCE.into(),
TextSpanningConfig::<u32>::from(OA_O200K_BASE_PATTERN.to_pattern())
.with_special_words(oa_o200k_harmony_special_tokens::<u32>()),
)
}
)
}
}
#[derive(Clone, Copy, Debug, PartialEq, strum::EnumString, strum::EnumIter, strum::Display)]
#[non_exhaustive]
pub enum OATokenizer {
#[strum(serialize = "r50k_base")]
R50kBase,
#[strum(serialize = "p50k_base")]
P50kBase,
#[strum(serialize = "p50k_edit")]
P50kEdit,
#[strum(serialize = "cl100k_base")]
Cl100kBase,
#[strum(serialize = "o200k_base")]
O200kBase,
#[strum(serialize = "o200k_harmony")]
O200kHarmony,
}
impl OATokenizer {
pub fn factory(&self) -> &ConstVocabularyFactory {
use OATokenizer::*;
match self {
R50kBase => &OA_R50K_BASE_VOCAB_FACTORY,
P50kBase => &OA_P50K_BASE_VOCAB_FACTORY,
P50kEdit => &OA_P50K_EDIT_VOCAB_FACTORY,
Cl100kBase => &OA_CL100K_BASE_VOCAB_FACTORY,
O200kBase => &OA_O200K_BASE_VOCAB_FACTORY,
O200kHarmony => &OA_O200K_HARMONY_VOCAB_FACTORY,
}
}
pub fn pattern(&self) -> RegexPattern {
self.factory().pattern()
}
pub fn special_tokens<T: TokenType>(&self) -> Vec<(String, T)> {
self.factory().special_tokens()
}
pub fn spanning_config<T: TokenType>(&self) -> TextSpanningConfig<T> {
self.factory().spanning_config()
}
}
#[cfg(feature = "std")]
impl OATokenizer {
pub fn load_vocab<T: TokenType>(
&self,
loader: &mut dyn ResourceLoader,
) -> crate::WCResult<crate::UnifiedTokenVocab<T>> {
self.factory().load_vocab(loader)
}
pub fn load_path<T: TokenType>(
&self,
path: impl AsRef<Path>,
) -> crate::WCResult<crate::UnifiedTokenVocab<T>> {
self.factory().load_vocab_path(path)
}
pub fn read_vocab<T: TokenType>(
&self,
reader: &mut dyn BufRead,
) -> crate::WCResult<crate::UnifiedTokenVocab<T>> {
self.factory().read_vocab(reader)
}
}
pub const OA_R50K_BASE_VOCAB_FACTORY: ConstVocabularyFactory = ConstVocabularyFactory {
name: "r50k_base",
resource: ConstKeyedResource {
key: &[OA_KEY, "r50k_base"],
resource: OA_R50K_BASE_TIKTOKEN_RESOURCE,
},
pattern: OA_R50K_BASE_PATTERN,
special_builder: &oa_r50k_base_special_tokens,
};
pub const OA_P50K_BASE_VOCAB_FACTORY: ConstVocabularyFactory = ConstVocabularyFactory {
name: "p50k_base",
resource: ConstKeyedResource {
key: &[OA_KEY, "p50k_base"],
resource: OA_P50K_BASE_TIKTOKEN_RESOURCE,
},
pattern: OA_P50K_BASE_PATTERN,
special_builder: &oa_p50k_base_special_tokens,
};
pub const OA_P50K_EDIT_VOCAB_FACTORY: ConstVocabularyFactory = ConstVocabularyFactory {
name: "p50k_edit",
resource: OA_P50K_BASE_VOCAB_FACTORY.resource,
pattern: OA_P50K_BASE_VOCAB_FACTORY.pattern,
special_builder: &oa_p50k_edit_special_tokens,
};
pub const OA_CL100K_BASE_VOCAB_FACTORY: ConstVocabularyFactory = ConstVocabularyFactory {
name: "cl100k_base",
resource: ConstKeyedResource {
key: &[OA_KEY, "cl100k_base"],
resource: OA_CL100K_BASE_TIKTOKEN_RESOURCE,
},
pattern: OA_CL100K_BASE_PATTERN,
special_builder: &oa_cl100k_base_special_tokens,
};
pub const OA_O200K_BASE_VOCAB_FACTORY: ConstVocabularyFactory = ConstVocabularyFactory {
name: "o200k_base",
resource: ConstKeyedResource {
key: &[OA_KEY, "o200k_base"],
resource: OA_O200K_BASE_TIKTOKEN_RESOURCE,
},
pattern: OA_O200K_BASE_PATTERN,
special_builder: &oa_o200k_base_special_tokens,
};
pub const OA_O200K_HARMONY_VOCAB_FACTORY: ConstVocabularyFactory = ConstVocabularyFactory {
name: "o200k_harmony",
resource: OA_O200K_BASE_VOCAB_FACTORY.resource,
pattern: OA_O200K_BASE_VOCAB_FACTORY.pattern,
special_builder: &oa_o200k_harmony_special_tokens,
};
#[cfg(test)]
mod test {
#[test]
fn test_oa_tokenizer() {
use core::str::FromStr;
use super::*;
assert_eq!(OATokenizer::R50kBase.to_string(), "r50k_base");
assert_eq!(OATokenizer::P50kBase.to_string(), "p50k_base");
assert_eq!(OATokenizer::P50kEdit.to_string(), "p50k_edit");
assert_eq!(OATokenizer::Cl100kBase.to_string(), "cl100k_base");
assert_eq!(OATokenizer::O200kBase.to_string(), "o200k_base");
assert_eq!(OATokenizer::O200kHarmony.to_string(), "o200k_harmony");
assert_eq!(
OATokenizer::from_str("r50k_base").unwrap(),
OATokenizer::R50kBase
);
assert_eq!(
OATokenizer::from_str("p50k_base").unwrap(),
OATokenizer::P50kBase
);
assert_eq!(
OATokenizer::from_str("p50k_edit").unwrap(),
OATokenizer::P50kEdit
);
assert_eq!(
OATokenizer::from_str("cl100k_base").unwrap(),
OATokenizer::Cl100kBase
);
assert_eq!(
OATokenizer::from_str("o200k_base").unwrap(),
OATokenizer::O200kBase
);
assert_eq!(
OATokenizer::from_str("o200k_harmony").unwrap(),
OATokenizer::O200kHarmony
);
}
#[test]
#[cfg(all(feature = "std", feature = "datagym", feature = "download"))]
fn test_load_gpt2_vocab() {
use crate::{
TokenEncoder,
TokenEncoderOptions,
UnifiedTokenVocab,
alloc::sync::Arc,
encoders::testing::common_encoder_tests,
};
let mut disk_cache: crate::disk_cache::WordchipperDiskCache = Default::default();
let vocab: Arc<UnifiedTokenVocab<u32>> = crate::vocab::io::load_gpt2_vocab(&mut disk_cache)
.unwrap()
.into();
let encoder: Arc<dyn TokenEncoder<u32>> =
TokenEncoderOptions::default().build(vocab.clone());
common_encoder_tests(vocab, encoder);
}
}