use tokenizers::{
ModelWrapper::BPE,
PreTokenizerWrapper,
PreTokenizerWrapper::{
ByteLevel,
Sequence,
Split,
},
pre_tokenizers::split::SplitPattern,
tokenizer::Tokenizer,
};
use crate::{
LabeledVocab,
UnifiedTokenVocab,
VocabDescription,
VocabIndex,
VocabQuery,
WCError,
WCHashMap,
WCHashSet,
WCResult,
alloc::sync::Arc,
prelude::*,
pretrained::{
factory::{
VocabProvider,
VocabProviderInventoryHook,
},
openai::OA_GPT2_PATTERN,
},
spanners::TextSpanningConfig,
support::{
regex::RegexPattern,
resources::ResourceLoader,
},
vocab::{
ByteMapVocab,
SpanMapVocab,
SpanTokenMap,
},
};
fn extract_pattern(pt: Option<&PreTokenizerWrapper>) -> Result<RegexPattern, WCError> {
fn split_regex(s: &tokenizers::pre_tokenizers::split::Split) -> Result<RegexPattern, WCError> {
match &s.pattern {
SplitPattern::Regex(r) => Ok(r.clone().into()),
_ => Err(WCError::External("Split without Regex pattern".into())),
}
}
match pt {
Some(Split(s)) => split_regex(s),
Some(ByteLevel(bl)) if bl.use_regex => Ok(OA_GPT2_PATTERN.into()),
Some(ByteLevel(_)) => Err(WCError::External(
"ByteLevel with use_regex=false has no splitting regex".into(),
)),
Some(Sequence(seq)) => {
let mut found = None;
for sub in seq.as_ref() {
match &sub {
Split(s) => {
if found.is_some() {
return Err(WCError::External("Sequence has multiple Splits".into()));
}
found = Some(split_regex(s)?);
}
ByteLevel(_) => {} _ => return Err(WCError::External("unsupported member in Sequence".into())),
}
}
found.ok_or_else(|| WCError::External("Sequence has no Split regex".into()))
}
Some(_) => Err(WCError::External("unsupported pre-tokenizer".into())),
None => Err(WCError::External("no pre-tokenizer".into())),
}
}
fn bytes_char() -> WCHashMap<u8, char> {
let mut bs: Vec<u8> = vec![];
bs.extend(b'!'..=b'~');
bs.extend(b'\xA1'..=b'\xAC');
bs.extend(b'\xAE'..=b'\xFF');
let mut cs: Vec<u32> = bs.iter().map(|i| *i as u32).collect();
let mut n = 0;
for b in 0..=255u8 {
if !bs.contains(&b) {
bs.push(b);
cs.push(u32::pow(2, 8) + n);
n += 1;
}
}
bs.into_iter()
.zip(cs)
.map(|(f, t)| (f, unsafe { std::char::from_u32_unchecked(t) }))
.collect()
}
pub fn vocab_from_hf_tokenizer(tok: &Tokenizer) -> WCResult<Arc<UnifiedTokenVocab<u32>>> {
type T = u32;
let pattern = extract_pattern(tok.get_pre_tokenizer())?;
let mut span_config: TextSpanningConfig<T> = TextSpanningConfig::from_pattern(pattern);
let BPE(bpe) = tok.get_model() else {
return Err(WCError::External(
"Tokenizer is not BPE compatible".to_string(),
));
};
if let Some(unk) = bpe.get_unk_token() {
return Err(WCError::External(format!("BPE has unk_token {unk:?}")));
}
let hf_vocab = bpe.get_vocab();
let mut special_tokens: WCHashSet<T> = Default::default();
let decoder = tok.get_added_tokens_decoder();
for (t, at) in decoder.iter() {
span_config.specials_mut().add_str_word(&at.content, *t);
special_tokens.insert(*t);
}
let b2c = bytes_char();
let c2b: WCHashMap<char, u8> = b2c.iter().map(|(&b, &c)| (c, b)).collect();
let mut span_map: SpanTokenMap<T> = SpanTokenMap::default();
for (s, id) in &hf_vocab {
if special_tokens.contains(id) {
continue;
} else {
let mut bytes = Vec::with_capacity(s.len());
for ch in s.chars() {
match c2b.get(&ch) {
Some(&b) => bytes.push(b),
None => {
return Err(WCError::External(format!(
"token {s:?} (id {id}) has non-byte-level codepoint {ch:?}"
)));
}
}
}
span_map.insert(bytes, *id);
}
}
if span_config.specials().len() != special_tokens.len() {
return Err(WCError::External(format!(
"hf vocab identifies {} special tokens, but only {} special tokens found in span_config",
special_tokens.len(),
span_config.specials().len()
)));
}
let byte_tokens: Vec<T> = (0u8..=255)
.map(|b| {
let key: String = std::iter::once(b2c[&b]).collect();
hf_vocab.get(&key).copied().ok_or(b)
})
.collect::<Result<Vec<_>, _>>()
.map_err(|b| WCError::External(format!("missing byte token for 0x{b:02x}")))?;
let byte_map = ByteMapVocab::<T>::from_byte_to_token(&byte_tokens);
let span_vocab = SpanMapVocab::<T>::new(byte_map, span_map)?;
let expected_len = span_vocab.len() + span_config.specials().len();
let vocab: Arc<UnifiedTokenVocab<T>> =
Arc::new(UnifiedTokenVocab::from_span_vocab(span_config, span_vocab)?);
if vocab.len() + vocab.special_vocab().len() != expected_len {
return Err(WCError::External(format!(
"Expected {} tokens, got {}",
expected_len,
vocab.len()
)));
}
Ok(vocab)
}
pub struct HFVocabProvider {}
inventory::submit! {
VocabProviderInventoryHook::new(|| Arc::new(HFVocabProvider{}))
}
impl VocabProvider for HFVocabProvider {
fn name(&self) -> String {
"hf".to_string()
}
fn description(&self) -> String {
"HuggingFace vocabularies".to_string()
}
fn list_vocabs(&self) -> Vec<VocabDescription> {
vec![]
}
fn load_vocab(
&self,
query: &VocabQuery,
_loader: &mut dyn ResourceLoader,
) -> WCResult<LabeledVocab<u32>> {
if let Some(schema) = query.schema()
&& schema != "hf"
{
return Err(WCError::ResourceNotFound(query.to_string()));
}
match Tokenizer::from_pretrained(query.clone().with_schema(None).to_string(), None) {
Ok(tok) => {
let vocab = vocab_from_hf_tokenizer(&tok)?;
let mut context = vec!["hf"];
if query.path().is_some() {
context.push(query.path().unwrap());
}
context.push(query.name());
let id = query.clone().with_schema(Some("hf"));
let context = id.to_context();
let descr: VocabDescription =
VocabDescription::new(id, &context, "Model loaded from hf");
Ok(LabeledVocab::new(descr, vocab))
}
Err(_) => Err(WCError::ResourceNotFound(query.to_string())),
}
}
}