use crate::error::{Error, Result};
use crate::format::Gguf;
use crate::format::gguf::value::GgufValue;
use splintr::{SentencePieceTokenizer, Tokenize, TokenizeError, WordPieceTokenizer};
pub struct GgufTokenizer {
inner: Box<dyn Tokenize>,
}
impl GgufTokenizer {
pub fn from_gguf(gguf: &Gguf) -> Result<Self> {
let metadata = gguf.metadata();
let model_type = metadata
.get_string("tokenizer.ggml.model")
.unwrap_or("llama");
let tokens = extract_tokens(metadata)?;
match model_type {
"bert" => Self::build_wordpiece(metadata, tokens),
_ => Self::build_sentencepiece(metadata, tokens),
}
}
fn build_wordpiece(
metadata: &crate::format::GgufMetadata,
tokens: Vec<String>,
) -> Result<Self> {
let unk_token_id = find_special_token_id(&tokens, metadata, "[UNK]", 0);
let do_lower_case = tokens.iter().any(|t| t == "the") && !tokens.iter().any(|t| t == "The");
let inner = WordPieceTokenizer::new(tokens, unk_token_id, 200, do_lower_case);
Ok(Self {
inner: Box::new(inner),
})
}
fn build_sentencepiece(
metadata: &crate::format::GgufMetadata,
tokens: Vec<String>,
) -> Result<Self> {
let scores = if let Some(scores_array) = metadata.get_array("tokenizer.ggml.scores") {
let mut out = Vec::with_capacity(scores_array.len());
for (i, v) in scores_array.iter().enumerate() {
let score = v.as_f32().ok_or_else(|| Error::ModelError {
reason: format!("tokenizer.ggml.scores[{i}] is not an f32"),
})?;
out.push(score);
}
out
} else {
vec![]
};
let bos_token_id = metadata.get_u32("tokenizer.ggml.bos_token_id");
let eos_token_id = metadata.get_u32("tokenizer.ggml.eos_token_id").unwrap_or(2);
let inner = SentencePieceTokenizer::new(tokens, scores, bos_token_id, eos_token_id)
.map_err(|e| Error::ModelError {
reason: format!("Failed to create SentencePiece tokenizer: {}", e),
})?;
Ok(Self {
inner: Box::new(inner),
})
}
pub fn encode(&self, text: &str) -> Vec<u32> {
self.inner.encode(text)
}
pub fn decode(&self, ids: &[u32]) -> Result<String> {
self.inner.decode(ids).map_err(|e| Error::ModelError {
reason: format!("Decode error: {}", e),
})
}
pub fn vocab_size(&self) -> usize {
self.inner.vocab_size()
}
}
impl Tokenize for GgufTokenizer {
fn encode(&self, text: &str) -> Vec<u32> {
self.inner.encode(text)
}
fn decode(&self, ids: &[u32]) -> std::result::Result<String, TokenizeError> {
self.inner.decode(ids)
}
fn vocab_size(&self) -> usize {
self.inner.vocab_size()
}
}
fn extract_tokens(metadata: &crate::format::GgufMetadata) -> Result<Vec<String>> {
let tokens_array =
metadata
.get_array("tokenizer.ggml.tokens")
.ok_or_else(|| Error::ModelError {
reason: "GGUF missing tokenizer.ggml.tokens".into(),
})?;
let mut tokens = Vec::with_capacity(tokens_array.len());
for (id, value) in tokens_array.iter().enumerate() {
match value {
GgufValue::String(s) => tokens.push(s.clone()),
_ => {
return Err(Error::ModelError {
reason: format!("tokenizer.ggml.tokens[{}] is not a string", id),
});
}
}
}
Ok(tokens)
}
fn find_special_token_id(
tokens: &[String],
metadata: &crate::format::GgufMetadata,
token_str: &str,
default: u32,
) -> u32 {
for (id, t) in tokens.iter().enumerate() {
if t == token_str {
return id as u32;
}
}
let key = match token_str {
"[UNK]" => "tokenizer.ggml.unknown_token_id",
"[PAD]" => "tokenizer.ggml.padding_token_id",
"[CLS]" => "tokenizer.ggml.cls_token_id",
"[SEP]" => "tokenizer.ggml.sep_token_id",
_ => return default,
};
metadata.get_u32(key).unwrap_or(default)
}