use std::sync::atomic::Ordering;
use crate::utils::gguf_metadata::ContentMetadata;
use crate::DEBUG;
use ahash::AHashMap;
use anyhow::Result;
use candle_core::quantized::gguf_file::Value;
use itertools::Itertools;
use tokenizers::pre_tokenizers::{
sequence::Sequence,
split::{Split, SplitPattern},
PreTokenizerWrapper,
};
use tokenizers::tokenizer::normalizer::SplitDelimiterBehavior;
use tokenizers::{
decoders::{
self, byte_fallback::ByteFallback, byte_level::ByteLevel, fuse::Fuse, strip::Strip,
},
models::{bpe::BpeBuilder, unigram::Unigram},
normalizers::{self, Prepend, Replace},
processors, AddedToken, DecoderWrapper, ModelWrapper, NormalizerWrapper, Tokenizer,
};
use tracing::info;
use super::Content;
pub(crate) struct GgufTokenizerConversion {
pub tokenizer: Tokenizer,
pub bos: Option<String>,
pub eos: Option<String>,
pub unk: Option<String>,
}
struct PropsGGUF {
model: String,
tokens: Vec<String>,
added_tokens: Option<Vec<String>>,
scores: Option<Vec<f32>>,
merges: Option<Vec<String>>,
unk: Option<u32>,
bos: Option<u32>,
eos: u32,
}
impl TryFrom<ContentMetadata<'_>> for PropsGGUF {
type Error = anyhow::Error;
fn try_from(c: ContentMetadata) -> Result<Self, Self::Error> {
let required = ["model", "tokens", "eos_token_id"];
c.has_required_keys(&required)?;
let props = Self {
model: c.get_value("model")?,
tokens: c.get_value("tokens")?,
added_tokens: c.get_value("added_tokens").ok(),
scores: c.get_value("scores").ok(),
merges: c.get_value("merges").ok(),
unk: c.get_value("unknown_token_id").ok(),
eos: c.get_value("eos_token_id")?,
bos: c.get_value("bos_token_id").ok(),
};
Ok(props)
}
}
pub fn convert_gguf_to_hf_tokenizer<R: std::io::Seek + std::io::Read>(
content: &Content<'_, R>,
) -> Result<GgufTokenizerConversion> {
let metadata = ContentMetadata {
path_prefix: "tokenizer.ggml",
metadata: content.get_metadata(),
};
let md_get = |s: &str| match metadata.metadata.get(s) {
None => candle_core::bail!("cannot find {s} in metadata"),
Some(v) => Ok(v),
};
let mut token_types = Vec::<i32>::new();
if metadata.metadata.contains_key("tokenizer.ggml.token_type") {
let vtypes: &Vec<Value> = md_get("tokenizer.ggml.token_type")
.unwrap()
.to_vec()
.unwrap();
let v: Vec<i32> = vtypes.iter().map(|v| v.to_i32().unwrap()).collect();
token_types.extend(v);
}
let props = PropsGGUF::try_from(metadata)?;
let (mut tokenizer, kind) = match props.model.as_str() {
"llama" | "replit" => unigram_tokenizer(&props)?,
"gpt2" => bpe_tokenizer(&props)?,
other => {
anyhow::bail!("Tokenizer model `{other}` not supported.");
}
};
let mut num_special_tokens = 0;
#[allow(clippy::needless_range_loop)]
if token_types.len() == props.tokens.len() {
for i in 0..props.tokens.len() {
if token_types[i] != 1i32 {
let tk = props.tokens[i].clone();
tokenizer.add_special_tokens(&[AddedToken::from(tk.to_string(), true)]);
num_special_tokens += 1;
}
}
}
info!(
"GGUF tokenizer model is `{model}`, kind: `{kind:?}`, num tokens: {}, num special tokens {}, num added tokens: {}, num merges: {}, num scores: {}",
tokenizer.get_vocab_size(true),
num_special_tokens,
props.added_tokens.as_ref().map(|x| x.len()).unwrap_or(0),
props.merges.as_ref().map(|x| x.len()).unwrap_or(0),
props.scores.as_ref().map(|x| x.len()).unwrap_or(0),
model = props.model,
);
if DEBUG.load(Ordering::Relaxed) {
info!("Tokenizer: {tokenizer:?}");
}
let unk = match props.unk {
Some(u) => Some(props.tokens[u as usize].clone()),
_ => None,
};
let bos = match props.bos {
Some(b) => Some(props.tokens[b as usize].clone()),
None => None,
};
Ok(GgufTokenizerConversion {
tokenizer,
bos,
eos: Some(props.tokens[props.eos as usize].clone()),
unk,
})
}
#[derive(Debug)]
enum TokenizerKind {
Unigram,
Bpe,
}
fn unigram_tokenizer(p: &PropsGGUF) -> Result<(Tokenizer, TokenizerKind)> {
let PropsGGUF { unk, eos, bos, .. } = *p;
let unk = unk.unwrap_or(0);
let model = {
let vocab: Vec<(String, f64)> = {
let Some(s) = p.scores.as_ref() else {
anyhow::bail!(
"`llama` unigram tokenizer is missing required metadata `tokenizer.ggml.scores`"
);
};
let scores = s.iter().cloned().map(|f_32| f_32 as f64);
p.tokens.iter().cloned().zip(scores).collect()
};
Unigram::from(vocab, Some(unk as usize), true).map_err(anyhow::Error::msg)?
};
let decoder = Decoder::Sequence(vec![
Decoder::Replace("▁", " "),
Decoder::ByteFallback,
Decoder::Fuse,
Decoder::Strip(' ', 1, 0),
]);
let normalizer = Normalizer::Sequence(vec![
Normalizer::Prepend("▁"),
Normalizer::Replace(" ", "▁"),
]);
let mut tokenizer: Tokenizer = TokenizerX::new(
ModelWrapper::Unigram(model),
Some(decoder),
Some(normalizer),
)?;
for v in [bos, Some(eos), Some(unk)].iter().flatten() {
let tk = p.tokens[*v as usize].clone();
tokenizer.add_special_tokens(&[AddedToken::from(tk.to_string(), true)]);
}
Ok((tokenizer, TokenizerKind::Unigram))
}
fn bpe_tokenizer(p: &PropsGGUF) -> Result<(Tokenizer, TokenizerKind)> {
let merges = p
.merges
.as_ref()
.ok_or(anyhow::Error::msg("BPE tokenizer must include merges"))?
.iter()
.map(|merge| {
let split: (&str, &str) = merge
.splitn(2, ' ')
.collect_tuple()
.expect("Failed to convert split into 2-tuple");
(split.0.to_string(), split.1.to_string())
})
.collect::<Vec<_>>();
let mut vocab = AHashMap::new();
for (i, token) in p.tokens.iter().enumerate() {
#[allow(clippy::cast_possible_truncation)]
vocab.insert(token.clone(), i as u32);
}
let PropsGGUF { bos, eos, unk, .. } = *p;
let mut bpe = BpeBuilder::new().vocab_and_merges(vocab, merges);
if let Some(unk) = unk {
bpe = bpe.unk_token(p.tokens[unk as usize].to_string());
};
let bpe = bpe.build().map_err(anyhow::Error::msg)?;
let mut tokenizer = TokenizerX::new(
ModelWrapper::BPE(bpe),
Some(Decoder::ByteLevel(true, true, true)),
None,
)?;
let split = Split::new(
SplitPattern::Regex("(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+".to_string()),
SplitDelimiterBehavior::Isolated,
false,
).unwrap();
let pre_tokenizer = Sequence::new(vec![
PreTokenizerWrapper::Split(split),
PreTokenizerWrapper::ByteLevel(ByteLevel::new(false, false, false)),
]);
tokenizer.with_pre_tokenizer(Some(pre_tokenizer));
tokenizer.with_decoder(Some(decoders::byte_level::ByteLevel::new(
false, false, false,
)));
tokenizer.with_post_processor(Some(processors::byte_level::ByteLevel::new(
false, false, false,
)));
for v in [bos, Some(eos), unk].iter().flatten() {
let tk = p.tokens[*v as usize].clone();
tokenizer.add_special_tokens(&[AddedToken::from(tk.to_string(), true)]);
}
Ok((tokenizer, TokenizerKind::Bpe))
}
struct TokenizerX;
impl TokenizerX {
#[allow(clippy::new_ret_no_self)]
fn new<'a>(
model: ModelWrapper,
decoder: Option<Decoder<'a>>,
normalizer: Option<Normalizer<'a>>,
) -> Result<Tokenizer> {
let mut tokenizer = Tokenizer::new(model);
if let Some(decoder) = decoder {
let d = DecoderWrapper::try_from(decoder)?;
tokenizer.with_decoder(Some(d));
}
if let Some(normalizer) = normalizer {
let n: NormalizerWrapper = NormalizerWrapper::try_from(normalizer)?;
tokenizer.with_normalizer(Some(n));
}
Ok(tokenizer)
}
}
enum Decoder<'a> {
ByteFallback,
Fuse,
Replace(&'a str, &'a str),
Strip(char, usize, usize),
Sequence(Vec<Self>),
ByteLevel(bool, bool, bool),
}
impl TryFrom<Decoder<'_>> for DecoderWrapper {
type Error = anyhow::Error;
fn try_from(variant: Decoder) -> Result<Self, Self::Error> {
let value: DecoderWrapper = match variant {
Decoder::ByteFallback => ByteFallback::default().into(),
Decoder::Fuse => Fuse::default().into(),
Decoder::Replace(pattern, content) => Replace::new(pattern, content)
.map_err(anyhow::Error::msg)?
.into(),
Decoder::Strip(content, start, stop) => Strip::new(content, start, stop).into(),
Decoder::Sequence(decoders) => {
let seq = decoders
.into_iter()
.map(DecoderWrapper::try_from)
.collect::<Result<Vec<DecoderWrapper>>>()?;
decoders::sequence::Sequence::new(seq).into()
}
Decoder::ByteLevel(add_prefix_space, trim_offsets, use_regex) => {
ByteLevel::new(add_prefix_space, trim_offsets, use_regex).into()
}
};
Ok(value)
}
}
enum Normalizer<'a> {
Prepend(&'a str),
Replace(&'a str, &'a str),
Sequence(Vec<Self>),
}
impl TryFrom<Normalizer<'_>> for NormalizerWrapper {
type Error = anyhow::Error;
fn try_from(variant: Normalizer) -> Result<Self, Self::Error> {
let value: NormalizerWrapper = match variant {
Normalizer::Prepend(prepend) => Prepend::new(prepend.to_owned()).into(),
Normalizer::Replace(pattern, content) => Replace::new(pattern, content)
.map_err(anyhow::Error::msg)?
.into(),
Normalizer::Sequence(decoders) => {
let seq = decoders
.into_iter()
.map(NormalizerWrapper::try_from)
.collect::<Result<Vec<NormalizerWrapper>>>()?;
normalizers::Sequence::new(seq).into()
}
};
Ok(value)
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
use tokenizers::Tokenizer;
#[allow(dead_code)]
#[derive(Debug)]
enum TokenizerType {
Llama,
Replit,
Gpt2,
Rwkv,
}
fn get_gguf_tokenizer(tokenizer: TokenizerType) -> Result<Tokenizer> {
match tokenizer {
TokenizerType::Llama => {
let api = ApiBuilder::new().with_progress(true).build().unwrap();
let api = api.repo(Repo::with_revision(
"EricB/mistralrs_tests".to_string(),
RepoType::Model,
"main".to_string(),
));
let filename = api.get("llama_gguf_tokenizer.json").unwrap();
let tokenizer = Tokenizer::from_file(filename).expect("Valid tokenizer");
Ok(tokenizer)
}
TokenizerType::Gpt2 => {
let api = ApiBuilder::new().with_progress(true).build().unwrap();
let api = api.repo(Repo::with_revision(
"EricB/mistralrs_tests".to_string(),
RepoType::Model,
"main".to_string(),
));
let filename = api.get("gpt2_gguf_tokenizer.json").unwrap();
let tokenizer = Tokenizer::from_file(filename).expect("Valid tokenizer");
Ok(tokenizer)
}
other => anyhow::bail!("Cannot get testing HF tokenizer for type {other:?}"),
}
}
fn get_hf_tokenizer(tokenizer: TokenizerType) -> Result<Tokenizer> {
match tokenizer {
TokenizerType::Llama => {
let api = ApiBuilder::new().with_progress(true).build().unwrap();
let api = api.repo(Repo::with_revision(
"EricB/mistralrs_tests".to_string(),
RepoType::Model,
"main".to_string(),
));
let tokenizer_filename = api.get("tokenizer.json").unwrap();
Ok(Tokenizer::from_file(tokenizer_filename).unwrap())
}
TokenizerType::Gpt2 => {
let api = ApiBuilder::new().with_progress(true).build().unwrap();
let api = api.repo(Repo::with_revision(
"EricB/mistralrs_tests".to_string(),
RepoType::Model,
"main".to_string(),
));
let tokenizer_filename = api.get("tokenizer_gpt2.json").unwrap();
Ok(Tokenizer::from_file(tokenizer_filename).unwrap())
}
other => anyhow::bail!("Cannot get testing HF tokenizer for type {other:?}"),
}
}
fn get_test_passage() -> String {
let passage = "Hello, world! \n🚀 (normal) 😶🌫️ (compound emoji, zwj sequence) ✅ (emoji as single token)\n你好世界!\nNǐ hǎo shìjiè!";
passage.to_owned()
}
fn codec_roundtrip(
tokenizer: &Tokenizer,
passage: &str,
add_special_tokens: bool,
) -> Result<String> {
let tokenized = tokenizer
.encode_fast(passage, add_special_tokens)
.map_err(anyhow::Error::msg)?;
decode(tokenizer, tokenized.get_ids(), !add_special_tokens)
}
fn decode(
tokenizer: &Tokenizer,
token_ids: &[u32],
skip_special_tokens: bool,
) -> Result<String> {
tokenizer
.decode(token_ids, skip_special_tokens)
.map_err(anyhow::Error::msg)
}
#[test]
fn test_encode_decode_llama() -> Result<()> {
use rand::rng;
use rand::seq::SliceRandom;
let passage = get_test_passage();
let hf_tokenizer = get_hf_tokenizer(TokenizerType::Llama)?;
let gguf_tokenizer = get_gguf_tokenizer(TokenizerType::Llama)?;
let hf_decoded = codec_roundtrip(&hf_tokenizer, passage.as_str(), false)?;
let gguf_decoded = codec_roundtrip(&gguf_tokenizer, passage.as_str(), false)?;
assert_eq!(hf_decoded, gguf_decoded);
assert_eq!(passage, gguf_decoded);
#[allow(clippy::cast_possible_truncation)]
let mut tokens = (0..hf_tokenizer.get_vocab_size(false) as u32).collect::<Vec<_>>();
tokens.shuffle(&mut rng());
let hf_decoded = decode(&hf_tokenizer, &tokens, false)?;
let gguf_decoded = decode(&gguf_tokenizer, &tokens, false)?;
assert_eq!(hf_decoded, gguf_decoded);
let hf_decoded = decode(&hf_tokenizer, &tokens, true)?;
let gguf_decoded = decode(&gguf_tokenizer, &tokens, true)?;
assert_eq!(hf_decoded, gguf_decoded);
Ok(())
}
#[test]
fn test_encode_decode_gpt2() -> Result<()> {
use rand::rng;
use rand::seq::SliceRandom;
let passage = get_test_passage();
let hf_tokenizer = get_hf_tokenizer(TokenizerType::Gpt2)?;
let gguf_tokenizer = get_gguf_tokenizer(TokenizerType::Gpt2)?;
let hf_decoded = codec_roundtrip(&hf_tokenizer, passage.as_str(), false)?;
let gguf_decoded = codec_roundtrip(&gguf_tokenizer, passage.as_str(), false)?;
assert_eq!(hf_decoded, gguf_decoded);
assert_eq!(passage, gguf_decoded);
#[allow(clippy::cast_possible_truncation)]
let mut tokens = (0..hf_tokenizer.get_vocab_size(false) as u32).collect::<Vec<_>>();
tokens.shuffle(&mut rng());
let hf_decoded = decode(&hf_tokenizer, &tokens, false)?;
let gguf_decoded = decode(&gguf_tokenizer, &tokens, false)?;
assert_eq!(hf_decoded, gguf_decoded);
let hf_decoded = decode(&hf_tokenizer, &tokens, true)?;
let gguf_decoded = decode(&gguf_tokenizer, &tokens, true)?;
assert_eq!(hf_decoded, gguf_decoded);
Ok(())
}
}