use crate::local_model::metadata::tokenizer::{GgmlTokenizerMetadata, GgmlTokenizerModel};
use anyhow::Context;
use tokenizers::{
AddedToken, DecoderWrapper, NormalizerWrapper, Tokenizer,
decoders::{
self, byte_fallback::ByteFallback, byte_level::ByteLevel, fuse::Fuse, strip::Strip,
},
models::unigram::Unigram,
normalizers::{self, Prepend, Replace},
};
struct GgufTokenizerProps<'a> {
model: GgmlTokenizerModel,
tokens: &'a Vec<String>,
_added_tokens: Option<Vec<String>>,
scores: Option<Vec<f32>>,
_merges: Option<Vec<String>>,
unk: Option<u32>,
eos: u32,
bos: u32,
}
impl<'a> GgufTokenizerProps<'a> {
fn new(ggml: &'a GgmlTokenizerMetadata) -> crate::Result<Self> {
let props = Self {
model: ggml.model.clone(),
tokens: &ggml.tokens,
_added_tokens: ggml.added_tokens.clone(),
scores: ggml.scores.clone(),
_merges: ggml.merges.clone(),
unk: ggml.unknown_token_id,
eos: ggml.eos_token_id,
bos: ggml.bos_token_id,
};
Ok(props)
}
}
pub fn convert_gguf_to_hf_tokenizer(ggml: &GgmlTokenizerMetadata) -> crate::Result<Tokenizer> {
let props = GgufTokenizerProps::new(ggml)?;
let tokenizer = match props.model {
GgmlTokenizerModel::Llama | GgmlTokenizerModel::Replit | GgmlTokenizerModel::Gpt2 => {
unigram_tokenizer(&props)?
}
_ => {
anyhow::bail!("Tokenizer model `{:?}` not supported.", props.model);
}
};
Ok(tokenizer)
}
fn add_special_tokens(
p: &GgufTokenizerProps,
tokenizer: &mut Tokenizer,
bos: u32,
eos: u32,
unk: Option<u32>,
) -> crate::Result<()> {
for token_id in [Some(bos), Some(eos), unk].into_iter().flatten() {
let token = get_token(p, token_id)?;
tokenizer.add_special_tokens(&[AddedToken::from(token.to_string(), true)]);
}
Ok(())
}
fn get_token(p: &GgufTokenizerProps, token_id: u32) -> crate::Result<String> {
p.tokens
.get(token_id as usize)
.map(ToString::to_string)
.with_context(|| format!("Token not found for ID: {}", token_id))
}
fn unigram_tokenizer(p: &GgufTokenizerProps) -> crate::Result<Tokenizer> {
let GgufTokenizerProps { unk, eos, bos, .. } = *p;
let unk = unk.unwrap_or(0);
let model = {
let vocab: Vec<(String, f64)> = if let Some(s) = p.scores.as_ref() {
let scores = s.iter().cloned().map(|f_32| f_32 as f64);
p.tokens.iter().cloned().zip(scores).collect()
} else {
p.tokens.iter().cloned().map(|s| (s, -10000.0)).collect()
};
Unigram::from(vocab, Some(unk as usize), true).map_err(anyhow::Error::msg)?
};
let decoder = Decoder::Sequence(vec![
Decoder::Replace("▁", " "),
Decoder::ByteFallback,
Decoder::Fuse,
Decoder::Strip(' ', 1, 0),
]);
let normalizer = Normalizer::Sequence(vec![
Normalizer::Prepend("▁"),
Normalizer::Replace(" ", "▁"),
]);
let mut tokenizer = Tokenizer::new(model);
let d = DecoderWrapper::try_from(decoder)?;
tokenizer.with_decoder(Some(d));
let n = NormalizerWrapper::try_from(normalizer)?;
tokenizer.with_normalizer(Some(n));
add_special_tokens(p, &mut tokenizer, bos, eos, Some(unk))?;
Ok(tokenizer)
}
#[allow(dead_code)]
enum Decoder<'a> {
ByteFallback,
Fuse,
Replace(&'a str, &'a str),
Strip(char, usize, usize),
Sequence(Vec<Self>),
ByteLevel(bool, bool, bool),
}
impl TryFrom<Decoder<'_>> for DecoderWrapper {
type Error = anyhow::Error;
fn try_from(variant: Decoder) -> crate::Result<Self, Self::Error> {
let value: DecoderWrapper = match variant {
Decoder::ByteFallback => ByteFallback::default().into(),
Decoder::Fuse => Fuse::default().into(),
Decoder::Replace(pattern, content) => Replace::new(pattern, content)
.map_err(anyhow::Error::msg)?
.into(),
Decoder::Strip(content, start, stop) => Strip::new(content, start, stop).into(),
Decoder::Sequence(decoders) => {
let seq = decoders
.into_iter()
.map(DecoderWrapper::try_from)
.collect::<crate::Result<Vec<DecoderWrapper>>>()?;
decoders::sequence::Sequence::new(seq).into()
}
Decoder::ByteLevel(add_prefix_space, trim_offsets, use_regex) => {
ByteLevel::new(add_prefix_space, trim_offsets, use_regex).into()
}
};
Ok(value)
}
}
enum Normalizer<'a> {
Prepend(&'a str),
Replace(&'a str, &'a str),
Sequence(Vec<Self>),
}
impl TryFrom<Normalizer<'_>> for NormalizerWrapper {
type Error = anyhow::Error;
fn try_from(variant: Normalizer) -> crate::Result<Self, Self::Error> {
let value: NormalizerWrapper = match variant {
Normalizer::Prepend(prepend) => Prepend::new(prepend.to_owned()).into(),
Normalizer::Replace(pattern, content) => Replace::new(pattern, content)
.map_err(anyhow::Error::msg)?
.into(),
Normalizer::Sequence(decoders) => {
let seq = decoders
.into_iter()
.map(NormalizerWrapper::try_from)
.collect::<crate::Result<Vec<NormalizerWrapper>>>()?;
normalizers::Sequence::new(seq).into()
}
};
Ok(value)
}
}