use std::{
collections::{BTreeMap, BTreeSet, HashMap, HashSet},
path::PathBuf,
};
use crate::{
array::Array,
dtype::Dtype,
error::{
ArithmeticOverflowPayload, Error, InvariantViolationPayload, MissingKeyPayload,
OutOfRangePayload, ParsePayload, Result, UnknownEnumValuePayload,
},
io::GgufMetadata,
lm::load::{Config, Weights},
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u32)]
pub enum TokenType {
Normal = 1,
Unknown = 2,
Control = 3,
UserDefined = 4,
Unused = 5,
Byte = 6,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u32)]
pub enum GgmlFileType {
F16 = 1,
}
pub struct HfVocab {
added_tokens_list: Vec<(u32, String)>,
added_tokens_ids: BTreeSet<u32>,
#[allow(dead_code)]
specials: HashMap<String, u32>,
special_ids: HashSet<u32>,
vocab_size_base: u32,
vocab_size: u32,
reverse_base_vocab: Vec<Option<String>>,
bos_token_id: Option<u32>,
eos_token_id: Option<u32>,
unk_token_id: Option<u32>,
}
impl HfVocab {
pub fn from_tokenizer(tokenizer: &crate::tokenizer::Tokenizer) -> Result<Self> {
let hf = tokenizer.hf();
let vocab_size_base_usize = hf.get_vocab_size(false);
let vocab_size_base = u32::try_from(vocab_size_base_usize).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"HfVocab: tokenizer base vocab size",
"must fit in u32",
vocab_size_base_usize.to_string(),
))
})?;
let added_vocab = hf.get_added_vocabulary().get_vocab();
let mut added: Vec<(u32, String)> = added_vocab
.iter()
.filter(|&(_, &id)| id >= vocab_size_base)
.map(|(name, &id)| (id, name.clone()))
.collect();
added.sort_by_key(|(id, _)| *id);
let mut added_tokens_list: Vec<(u32, String)> = Vec::with_capacity(added.len());
let mut added_tokens_ids = BTreeSet::new();
for (id, name) in &added {
added_tokens_list.push((*id, name.clone()));
added_tokens_ids.insert(*id);
}
let mut specials: HashMap<String, u32> = HashMap::new();
let mut special_ids: HashSet<u32> = HashSet::new();
for (id, tok) in hf.get_added_tokens_decoder() {
if tok.special {
specials.insert(tok.content.clone(), id);
special_ids.insert(id);
}
}
for id in [
tokenizer.bos_token_id(),
tokenizer.eos_token_id(),
tokenizer.unk_token_id(),
tokenizer.pad_token_id(),
]
.into_iter()
.flatten()
{
special_ids.insert(id);
}
for id in tokenizer.additional_special_token_ids() {
special_ids.insert(id);
}
let full_vocab = hf.get_vocab(true);
let mut reverse_base_vocab: Vec<Option<String>> = vec![None; vocab_size_base_usize];
for (text, id) in &full_vocab {
if (*id as usize) < vocab_size_base_usize {
reverse_base_vocab[*id as usize] = Some(text.clone());
}
}
let added_u32 = u32::try_from(added_tokens_list.len()).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"HfVocab: added token count",
"must fit in u32",
added_tokens_list.len().to_string(),
))
})?;
let vocab_size = vocab_size_base.checked_add(added_u32).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"vocab_size_base + added",
"u32",
[
("vocab_size_base", u64::from(vocab_size_base)),
("added", u64::from(added_u32)),
],
))
})?;
Ok(HfVocab {
added_tokens_list,
added_tokens_ids,
specials,
special_ids,
vocab_size_base,
vocab_size,
reverse_base_vocab,
bos_token_id: tokenizer.bos_token_id(),
eos_token_id: tokenizer.eos_token_id(),
unk_token_id: tokenizer.unk_token_id(),
})
}
pub fn vocab_size(&self) -> u32 {
self.vocab_size
}
pub fn vocab_size_base(&self) -> u32 {
self.vocab_size_base
}
pub fn bos_token_id(&self) -> Option<u32> {
self.bos_token_id
}
pub fn eos_token_id(&self) -> Option<u32> {
self.eos_token_id
}
pub fn unk_token_id(&self) -> Option<u32> {
self.unk_token_id
}
fn get_token_type(&self, token_id: u32, token_text: &str) -> TokenType {
if is_byte_token(token_text) {
TokenType::Byte
} else if self.special_ids.contains(&token_id) {
TokenType::Control
} else {
TokenType::Normal
}
}
fn get_token_score(&self, _token_id: u32) -> f32 {
-1000.0
}
pub fn all_tokens(&self) -> Result<Vec<(String, f32, TokenType)>> {
let mut out = Vec::with_capacity(self.vocab_size as usize);
for id in 0..self.vocab_size_base {
if self.added_tokens_ids.contains(&id) {
continue;
}
let text = self.reverse_base_vocab[id as usize]
.as_deref()
.ok_or_else(|| {
Error::MissingKey(MissingKeyPayload::new(
"HfVocab: base vocab token",
id.to_string(),
))
})?;
let score = self.get_token_score(id);
let toktype = self.get_token_type(id, text);
out.push((text.to_owned(), score, toktype));
}
for (id, text) in &self.added_tokens_list {
let (toktype, score) = if self.special_ids.contains(id) {
(self.get_token_type(*id, ""), self.get_token_score(*id))
} else {
(TokenType::UserDefined, -1000.0)
};
out.push((text.clone(), score, toktype));
}
Ok(out)
}
pub fn has_newline_token(&self, tokenizer: &crate::tokenizer::Tokenizer) -> bool {
let vocab = tokenizer.hf().get_vocab(true);
vocab.contains_key("<0x0A>") || vocab.contains_key("\n")
}
}
fn is_byte_token(text: &str) -> bool {
text.len() == 6
&& text.starts_with("<0x")
&& text.ends_with('>')
&& text.as_bytes()[3].is_ascii_hexdigit()
&& text.as_bytes()[4].is_ascii_hexdigit()
}
pub fn translate_weight_names(name: &str) -> String {
let mut s = name.replace("model.layers.", "blk.");
s = s.replace("block_sparse_moe.gate", "ffn_gate_inp");
s = remap_moe_expert(&s, "w1", "ffn_gate");
s = remap_moe_expert(&s, "w2", "ffn_down");
s = remap_moe_expert(&s, "w3", "ffn_up");
s = s.replace("mlp.gate_proj", "ffn_gate");
s = s.replace("mlp.down_proj", "ffn_down");
s = s.replace("mlp.up_proj", "ffn_up");
s = s.replace("self_attn.q_proj", "attn_q");
s = s.replace("self_attn.k_proj", "attn_k");
s = s.replace("self_attn.v_proj", "attn_v");
s = s.replace("self_attn.o_proj", "attn_output");
s = s.replace("input_layernorm", "attn_norm");
s = s.replace("post_attention_layernorm", "ffn_norm");
s = s.replace("model.embed_tokens", "token_embd");
s = s.replace("model.norm", "output_norm");
s = s.replace("lm_head", "output");
s
}
fn remap_moe_expert(s: &str, wk: &str, replacement: &str) -> String {
let prefix = "block_sparse_moe.experts.";
let middle = format!(".{wk}.weight");
let mut out = String::with_capacity(s.len());
let mut rest = s;
while let Some(pos) = rest.find(prefix) {
out.push_str(&rest[..pos]);
let tail = &rest[pos + prefix.len()..];
let digit_end = tail
.as_bytes()
.iter()
.position(|b| !b.is_ascii_digit())
.unwrap_or(tail.len());
if digit_end == 0 || !tail[digit_end..].starts_with(&middle) {
out.push_str(prefix);
rest = tail;
continue;
}
let digits = &tail[..digit_end];
out.push_str(&format!("{replacement}.{digits}.weight"));
rest = &tail[digit_end + middle.len()..];
}
out.push_str(rest);
out
}
pub fn permute_weights(weights: &Array, n_head: i32, n_head_kv: Option<i32>) -> Result<Array> {
let effective = match n_head_kv {
Some(kv) if kv != n_head => kv,
_ => n_head,
};
if effective <= 0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"permute_weights: n_head",
"must be positive",
format!("{effective}"),
)));
}
let original_shape = weights.shape();
let original_shape_i32: Vec<i32> = original_shape
.iter()
.map(|&d| {
i32::try_from(d).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"permute_weights: shape dim",
"must fit in i32",
d.to_string(),
))
})
})
.collect::<Result<_>>()?;
if original_shape.is_empty() {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"permute_weights: weights rank",
"must be >= 1 (requires at least 1-D weights)",
)));
}
let d0 = original_shape_i32[0];
let twice = 2_i32.checked_mul(effective).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"permute_weights: 2 * n_head",
"i32",
[("two", 2u64), ("n_head", effective as u64)],
))
})?;
if d0 % twice != 0 {
return Err(Error::DivisibilityConstraint(
crate::error::DivisibilityConstraintPayload::new(
"permute_weights: leading dim must be divisible by 2 * n_head",
"leading_dim",
d0 as u64,
"2*n_head",
twice as u64,
),
));
}
let split = d0 / twice;
let mut reshape_dims: Vec<i32> = Vec::with_capacity(3 + original_shape_i32.len() - 1);
reshape_dims.push(effective);
reshape_dims.push(2);
reshape_dims.push(split);
reshape_dims.extend_from_slice(&original_shape_i32[1..]);
let reshaped = weights.reshape(&&reshape_dims[..])?;
let swapped = reshaped.swapaxes(1, 2)?;
swapped.reshape(&&original_shape_i32[..])
}
pub fn prepare_metadata(
config: &Config,
raw_config: &serde_json::Value,
vocab: &HfVocab,
) -> Result<HashMap<String, GgufMetadata>> {
let mut metadata: HashMap<String, GgufMetadata> = HashMap::new();
let get_u32 = |key: &str| -> Option<u32> {
raw_config
.get(key)
.and_then(|v| v.as_u64())
.and_then(|n| u32::try_from(n).ok())
};
let get_f32 = |key: &str| -> Option<f32> {
raw_config
.get(key)
.and_then(|v| v.as_f64())
.map(|f| f as f32)
};
metadata.insert(
"general.name".to_string(),
GgufMetadata::String("llama".to_string()),
);
if let Some(v) = get_u32("max_position_embeddings") {
metadata.insert(
"llama.context_length".to_string(),
GgufMetadata::Array(scalar_u32(v)?),
);
}
if let Some(v) = get_u32("hidden_size") {
metadata.insert(
"llama.embedding_length".to_string(),
GgufMetadata::Array(scalar_u32(v)?),
);
}
if let Some(v) = get_u32("num_hidden_layers") {
metadata.insert(
"llama.block_count".to_string(),
GgufMetadata::Array(scalar_u32(v)?),
);
}
if let Some(v) = get_u32("intermediate_size") {
metadata.insert(
"llama.feed_forward_length".to_string(),
GgufMetadata::Array(scalar_u32(v)?),
);
}
if let (Some(hidden), Some(heads)) = (get_u32("hidden_size"), get_u32("num_attention_heads"))
&& heads > 0
{
metadata.insert(
"llama.rope.dimension_count".to_string(),
GgufMetadata::Array(scalar_u32(hidden / heads)?),
);
}
if let Some(v) = get_u32("num_attention_heads") {
metadata.insert(
"llama.attention.head_count".to_string(),
GgufMetadata::Array(scalar_u32(v)?),
);
let kv = get_u32("num_key_value_heads").unwrap_or(v);
metadata.insert(
"llama.attention.head_count_kv".to_string(),
GgufMetadata::Array(scalar_u32(kv)?),
);
}
if let Some(v) = get_u32("num_local_experts") {
metadata.insert(
"llama.expert_count".to_string(),
GgufMetadata::Array(scalar_u32(v)?),
);
}
if let Some(v) = get_u32("num_experts_per_tok") {
metadata.insert(
"llama.expert_used_count".to_string(),
GgufMetadata::Array(scalar_u32(v)?),
);
}
if let Some(v) = get_f32("rms_norm_eps") {
metadata.insert(
"llama.attention.layer_norm_rms_epsilon".to_string(),
GgufMetadata::Array(scalar_f32(v)?),
);
}
if let Some(v) = get_f32("rope_theta") {
metadata.insert(
"llama.rope.freq_base".to_string(),
GgufMetadata::Array(scalar_f32(v)?),
);
}
if let Some(rope_scaling) = raw_config.get("rope_scaling").and_then(|v| v.as_object())
&& let Some(typ) = rope_scaling.get("type").and_then(|v| v.as_str())
&& typ == "linear"
{
metadata.insert(
"llama.rope.scaling.type".to_string(),
GgufMetadata::String("linear".to_string()),
);
if let Some(factor) = rope_scaling.get("factor").and_then(|v| v.as_f64()) {
metadata.insert(
"llama.rope.scaling.factor".to_string(),
GgufMetadata::Array(scalar_f32(factor as f32)?),
);
}
}
metadata.insert(
"general.file_type".to_string(),
GgufMetadata::Array(scalar_u32(GgmlFileType::F16 as u32)?),
);
metadata.insert(
"general.quantization_version".to_string(),
GgufMetadata::Array(scalar_u32(GgmlFileType::F16 as u32)?),
);
let name_or_path = raw_config
.get("_name_or_path")
.and_then(|v| v.as_str())
.unwrap_or("llama");
let base_name = name_or_path
.rsplit('/')
.next()
.unwrap_or("llama")
.to_owned();
metadata.insert("general.name".to_string(), GgufMetadata::String(base_name));
metadata.insert(
"general.architecture".to_string(),
GgufMetadata::String("llama".to_string()),
);
metadata.insert(
"general.alignment".to_string(),
GgufMetadata::Array(scalar_u32(32)?),
);
metadata.insert(
"tokenizer.ggml.model".to_string(),
GgufMetadata::String("llama".to_string()),
);
let triples = vocab.all_tokens()?;
if triples.len() as u32 != vocab.vocab_size {
return Err(Error::LengthMismatch(
crate::error::LengthMismatchPayload::new(
"prepare_metadata: emitted tokens vs vocab.vocab_size",
vocab.vocab_size as usize,
triples.len(),
),
));
}
let mut tokens = Vec::with_capacity(triples.len());
let mut scores = Vec::with_capacity(triples.len());
let mut toktypes = Vec::with_capacity(triples.len());
for (text, score, toktype) in triples {
tokens.push(text);
scores.push(score);
toktypes.push(toktype as u32);
}
metadata.insert(
"tokenizer.ggml.tokens".to_string(),
GgufMetadata::StringList(tokens),
);
metadata.insert(
"tokenizer.ggml.scores".to_string(),
GgufMetadata::Array(Array::from_slice::<f32>(&scores, &(scores.len(),))?),
);
metadata.insert(
"tokenizer.ggml.token_type".to_string(),
GgufMetadata::Array(Array::from_slice::<u32>(&toktypes, &(toktypes.len(),))?),
);
if let Some(id) = vocab.bos_token_id() {
metadata.insert(
"tokenizer.ggml.bos_token_id".to_string(),
GgufMetadata::Array(scalar_u32(id)?),
);
}
if let Some(id) = vocab.eos_token_id() {
metadata.insert(
"tokenizer.ggml.eos_token_id".to_string(),
GgufMetadata::Array(scalar_u32(id)?),
);
}
if let Some(id) = vocab.unk_token_id() {
metadata.insert(
"tokenizer.ggml.unknown_token_id".to_string(),
GgufMetadata::Array(scalar_u32(id)?),
);
}
let _ = config;
Ok(metadata)
}
fn scalar_u32(value: u32) -> Result<Array> {
Array::from_slice::<u32>(&[value], &(1_usize,))
}
fn scalar_f32(value: f32) -> Result<Array> {
Array::from_slice::<f32>(&[value], &(1_usize,))
}
const SUPPORTED_MODEL_TYPES: &[&str] = &["llama", "mistral", "mixtral"];
#[derive(Debug, Clone)]
pub struct ConvertToGgufArgs {
pub model_path: PathBuf,
pub gguf_path: PathBuf,
}
pub fn convert_to_gguf(args: &ConvertToGgufArgs) -> Result<()> {
let (config, raw_json) = crate::lm::load::load_config(&args.model_path)?;
if !SUPPORTED_MODEL_TYPES.contains(&config.model_type()) {
return Err(Error::UnknownEnumValue(UnknownEnumValuePayload::new(
"convert_to_gguf: model_type (LM-side GGUF exporter supported set)",
config.model_type().to_string(),
SUPPORTED_MODEL_TYPES,
)));
}
let raw_config: serde_json::Value = serde_json::from_str(&raw_json).map_err(|e| {
Error::Parse(ParsePayload::new(
"convert_to_gguf: cannot re-parse config.json",
"JSON",
e,
))
})?;
if config.quantization.is_some() || raw_config.get("quantization_config").is_some() {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"convert_to_gguf: checkpoint quantization",
"must be None (the GGUF LM export targets dense F16/F32 GGUF; dequantize first \
via lm::convert)",
)));
}
let tokenizer = crate::lm::load::load_tokenizer(&args.model_path, &config)?;
let num_attention_heads = config.num_attention_heads;
let num_key_value_heads = config.num_key_value_heads;
let weights = crate::lm::load::load_weights(&args.model_path)?;
let mut permuted: Weights = HashMap::with_capacity(weights.len());
for (key, val) in weights {
if key.contains("self_attn.q_proj.weight") {
permuted.insert(
key,
permute_weights(&val, num_attention_heads, Some(num_attention_heads))?,
);
} else if key.contains("self_attn.k_proj.weight") {
permuted.insert(
key,
permute_weights(&val, num_attention_heads, Some(num_key_value_heads))?,
);
} else {
permuted.insert(key, val);
}
}
let renamed: BTreeMap<String, Array> = permuted
.into_iter()
.map(|(k, v)| (translate_weight_names(&k), v))
.collect();
let vocab = HfVocab::from_tokenizer(&tokenizer)?;
let metadata = prepare_metadata(&config, &raw_config, &vocab)?;
let mut normalized: HashMap<String, Array> = HashMap::with_capacity(renamed.len());
for (key, val) in renamed {
let dt = val.dtype()?;
let out = if dt == Dtype::BF16 {
let f32_arr = val.astype(Dtype::F32)?;
f32_arr.astype(Dtype::F16)?
} else if key.contains("norm") {
val.astype(Dtype::F32)?
} else {
val
};
normalized.insert(key, out);
}
crate::io::save_gguf(&args.gguf_path, &normalized, &metadata)
}
#[cfg(test)]
mod tests;