use anyhow::{Context, Result, anyhow};
use rlx_gguf::MetaValue;
use std::collections::HashSet;
use std::path::Path;
fn compute_mtp_layer_threshold(file: &rlx_gguf::GgufFile) -> Option<u32> {
let arch = file
.metadata
.get("general.architecture")
.and_then(MetaValue::as_str)?;
let block_count = file
.metadata
.get(&format!("{arch}.block_count"))
.and_then(MetaValue::as_u32)?;
let nextn = file
.metadata
.get(&format!("{arch}.nextn_predict_layers"))
.and_then(MetaValue::as_u32)?;
if nextn == 0 {
return None;
}
Some(block_count.saturating_sub(nextn))
}
use crate::gguf_resolve::resolve_gguf_tensor_name;
use crate::gguf_support::gguf_architecture_str;
use crate::weight_map::PackedWeightTensor;
use crate::weight_map::WeightMap;
use rlx_ir::quant::QuantScheme;
pub fn hf_to_gguf_name(hf: &str) -> Option<String> {
match hf {
"model.embed_tokens.weight" => return Some("token_embd.weight".into()),
"model.norm.weight" => return Some("output_norm.weight".into()),
"lm_head.weight" => return Some("output.weight".into()),
_ => {}
}
let rest = hf.strip_prefix("model.layers.")?;
let dot = rest.find('.')?;
let (idx_str, tail_with_dot) = rest.split_at(dot);
let tail = &tail_with_dot[1..]; let idx: usize = idx_str.parse().ok()?;
let gguf_tail = match tail {
"input_layernorm.weight" => "attn_norm.weight",
"post_attention_layernorm.weight" => "ffn_norm.weight",
"self_attn.q_proj.weight" => "attn_q.weight",
"self_attn.k_proj.weight" => "attn_k.weight",
"self_attn.v_proj.weight" => "attn_v.weight",
"self_attn.o_proj.weight" => "attn_output.weight",
"self_attn.q_proj.bias" => "attn_q.bias",
"self_attn.k_proj.bias" => "attn_k.bias",
"self_attn.v_proj.bias" => "attn_v.bias",
"self_attn.q_norm.weight" => "attn_q_norm.weight",
"self_attn.k_norm.weight" => "attn_k_norm.weight",
"mlp.gate_proj.weight" => "ffn_gate.weight",
"mlp.up_proj.weight" => "ffn_up.weight",
"mlp.down_proj.weight" => "ffn_down.weight",
_ => return None,
};
Some(format!("blk.{idx}.{gguf_tail}"))
}
pub fn gguf_to_hf_name(gguf: &str) -> Option<String> {
match gguf {
"token_embd.weight" => return Some("model.embed_tokens.weight".into()),
"output_norm.weight" => return Some("model.norm.weight".into()),
"output.weight" => return Some("lm_head.weight".into()),
_ => {}
}
let rest = gguf.strip_prefix("blk.")?;
let dot = rest.find('.')?;
let (idx_str, tail_with_dot) = rest.split_at(dot);
let tail = &tail_with_dot[1..];
let idx: usize = idx_str.parse().ok()?;
let hf_tail = match tail {
"attn_norm.weight" => "input_layernorm.weight",
"ffn_norm.weight" => "post_attention_layernorm.weight",
"attn_q.weight" => "self_attn.q_proj.weight",
"attn_k.weight" => "self_attn.k_proj.weight",
"attn_v.weight" => "self_attn.v_proj.weight",
"attn_output.weight" => "self_attn.o_proj.weight",
"attn_q.bias" => "self_attn.q_proj.bias",
"attn_k.bias" => "self_attn.k_proj.bias",
"attn_v.bias" => "self_attn.v_proj.bias",
"attn_q_norm.weight" => "self_attn.q_norm.weight",
"attn_k_norm.weight" => "self_attn.k_norm.weight",
"ffn_gate.weight" => "mlp.gate_proj.weight",
"ffn_up.weight" => "mlp.up_proj.weight",
"ffn_down.weight" => "mlp.down_proj.weight",
_ => return None,
};
Some(format!("model.layers.{idx}.{hf_tail}"))
}
pub fn gguf_to_hf_name_for_arch(gguf: &str, arch: &str) -> Option<String> {
if matches!(
arch,
"gemma2" | "gemma3" | "gemma3n" | "gemma4" | "gemma4moe"
) {
match gguf {
"token_embd.weight" => return Some("model.embed_tokens.weight".into()),
"output_norm.weight" => return Some("model.norm.weight".into()),
"output.weight" => return Some("lm_head.weight".into()),
_ => {}
}
let rest = gguf.strip_prefix("blk.")?;
let dot = rest.find('.')?;
let (idx_str, tail_with_dot) = rest.split_at(dot);
let tail = &tail_with_dot[1..];
let idx: usize = idx_str.parse().ok()?;
let hf_tail = match tail {
"attn_norm.weight" => "input_layernorm.weight",
"post_attention_norm.weight" => "post_attention_layernorm.weight",
"ffn_norm.weight" => "pre_feedforward_layernorm.weight",
"post_ffw_norm.weight" => "post_feedforward_layernorm.weight",
"attn_q.weight" => "self_attn.q_proj.weight",
"attn_k.weight" => "self_attn.k_proj.weight",
"attn_v.weight" => "self_attn.v_proj.weight",
"attn_output.weight" => "self_attn.o_proj.weight",
"ffn_gate.weight" => "mlp.gate_proj.weight",
"ffn_up.weight" => "mlp.up_proj.weight",
"ffn_down.weight" => "mlp.down_proj.weight",
_ => return None,
};
return Some(format!("model.layers.{idx}.{hf_tail}"));
}
gguf_to_hf_name(gguf)
}
fn is_gemma_norm_weight(name: &str) -> bool {
if name == "output_norm.weight" || name == "model.norm.weight" {
return true;
}
if let Some(rest) = name
.strip_prefix("blk.")
.and_then(|r| r.split_once('.').map(|x| x.1))
{
return matches!(
rest,
"attn_norm.weight"
| "post_attention_norm.weight"
| "ffn_norm.weight"
| "post_ffw_norm.weight"
);
}
if let Some(rest) = name
.strip_prefix("model.layers.")
.and_then(|r| r.split_once('.').map(|x| x.1))
{
return matches!(
rest,
"input_layernorm.weight"
| "post_attention_layernorm.weight"
| "pre_feedforward_layernorm.weight"
| "post_feedforward_layernorm.weight"
);
}
false
}
pub fn is_mtp_weight(name: &str) -> bool {
name.contains("mtp_") || name.contains(".mtp") || name.starts_with("mtp")
}
pub fn ggml_type_to_quant_scheme(dtype: rlx_gguf::GgmlType) -> Option<QuantScheme> {
use rlx_gguf::GgmlType;
match dtype {
GgmlType::Q2K => Some(QuantScheme::GgufQ2K),
GgmlType::Q3K => Some(QuantScheme::GgufQ3K),
GgmlType::Q4K => Some(QuantScheme::GgufQ4K),
GgmlType::Q5K => Some(QuantScheme::GgufQ5K),
GgmlType::Q6K => Some(QuantScheme::GgufQ6K),
GgmlType::Q8K => Some(QuantScheme::GgufQ8K),
GgmlType::Q4_0 => Some(QuantScheme::GgufQ4_0),
GgmlType::Q8_0 => Some(QuantScheme::GgufQ8_0),
_ => None,
}
}
pub trait WeightLoader: Send {
fn format_id(&self) -> &'static str {
"unknown"
}
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)>;
fn take_transposed(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)>;
fn take_packed(&mut self, key: &str) -> Result<Option<crate::weight_map::PackedWeightTensor>> {
let _ = key;
Ok(None)
}
fn tensor_bytes_borrowed(&self, key: &str) -> Option<&[u8]> {
let _ = key;
None
}
fn remaining_keys(&self) -> Vec<String>;
fn arch_hint(&self) -> Option<&str> {
None
}
}
impl WeightLoader for WeightMap {
fn format_id(&self) -> &'static str {
"safetensors"
}
fn len(&self) -> usize {
Self::len(self)
}
fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
Self::take(self, key)
}
fn take_transposed(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
Self::take_transposed(self, key)
}
fn remaining_keys(&self) -> Vec<String> {
self.keys().map(|s| s.to_string()).collect()
}
}
pub struct HfTranslatingLoader<L: WeightLoader> {
inner: L,
}
impl<L: WeightLoader> HfTranslatingLoader<L> {
pub fn new(inner: L) -> Self {
Self { inner }
}
pub fn into_inner(self) -> L {
self.inner
}
pub fn inner(&self) -> &L {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut L {
&mut self.inner
}
}
impl<L: WeightLoader> WeightLoader for HfTranslatingLoader<L> {
fn format_id(&self) -> &'static str {
self.inner.format_id()
}
fn len(&self) -> usize {
self.inner.len()
}
fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
match self.inner.take(key) {
Ok(v) => Ok(v),
Err(_) => {
if let Some(hf) = gguf_to_hf_name(key) {
return self.inner.take(&hf);
}
self.inner.take(key)
}
}
}
fn take_transposed(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
match self.inner.take_transposed(key) {
Ok(v) => Ok(v),
Err(_) => {
if let Some(hf) = gguf_to_hf_name(key) {
return self.inner.take_transposed(&hf);
}
self.inner.take_transposed(key)
}
}
}
fn take_packed(&mut self, key: &str) -> Result<Option<crate::weight_map::PackedWeightTensor>> {
self.inner.take_packed(key)
}
fn tensor_bytes_borrowed(&self, key: &str) -> Option<&[u8]> {
self.inner.tensor_bytes_borrowed(key)
}
fn remaining_keys(&self) -> Vec<String> {
self.inner.remaining_keys()
}
}
pub fn load_from_path(path: &str) -> Result<Box<dyn WeightLoader>> {
crate::weight_registry::open_weight_loader(Path::new(path))
}
pub struct GgufLoader {
file: rlx_gguf::GgufFile,
arch: String,
taken: HashSet<String>,
include_mtp: bool,
mtp_layer_threshold: Option<u32>,
}
impl GgufLoader {
pub fn from_file(path: &str) -> Result<Self> {
let file = crate::gguf_support::load_gguf_file(std::path::Path::new(path))?;
let arch = gguf_architecture_str(&file)
.unwrap_or("unknown")
.to_string();
let mtp_layer_threshold = compute_mtp_layer_threshold(&file);
Ok(Self {
file,
arch,
taken: HashSet::new(),
include_mtp: false,
mtp_layer_threshold,
})
}
pub fn architecture(&self) -> &str {
&self.arch
}
pub fn mtp_layer_threshold(&self) -> Option<u32> {
self.mtp_layer_threshold
}
pub fn file(&self) -> &rlx_gguf::GgufFile {
&self.file
}
pub fn tensor_bytes_borrowed(&self, key: &str) -> Option<&[u8]> {
let real = self.resolve(key).ok()?;
let t = self.file.get(&real)?;
self.file.tensor_bytes(t).ok()
}
pub fn take_packed_metadata(
&mut self,
key: &str,
) -> Result<Option<(rlx_ir::quant::QuantScheme, Vec<usize>)>> {
let real = self.resolve(key)?;
if self.taken.contains(&real) {
return Err(anyhow!("weight already taken: {key} (→ {real})"));
}
if !self.include_mtp && self.is_mtp_tensor(&real) {
return Err(anyhow!(
"refusing to take MTP weight `{real}` without include_mtp(true)"
));
}
let t = self
.file
.get(&real)
.ok_or_else(|| anyhow!("tensor missing: {real}"))?;
let Some(scheme) = ggml_type_to_quant_scheme(t.dtype) else {
return Ok(None);
};
let mut shape = t.shape.clone();
shape.reverse();
self.taken.insert(real);
Ok(Some((scheme, shape)))
}
pub fn is_mtp_tensor(&self, name: &str) -> bool {
if is_mtp_weight(name) {
return true;
}
if let Some(thresh) = self.mtp_layer_threshold {
if let Some(rest) = name.strip_prefix("blk.") {
if let Some(dot) = rest.find('.') {
if let Ok(idx) = rest[..dot].parse::<u32>() {
if idx >= thresh {
return true;
}
}
}
}
}
false
}
pub fn include_mtp(&mut self, include: bool) -> &mut Self {
self.include_mtp = include;
self
}
pub fn take_packed(&mut self, key: &str) -> Result<Option<PackedWeightTensor>> {
let real = self.resolve(key)?;
if self.taken.contains(&real) {
return Err(anyhow!("weight already taken: {key} (→ {real})"));
}
if !self.include_mtp && self.is_mtp_tensor(&real) {
return Err(anyhow!(
"refusing to take MTP weight `{real}` without include_mtp(true)"
));
}
let t = self
.file
.get(&real)
.ok_or_else(|| anyhow!("tensor missing: {real}"))?;
let Some(scheme) = ggml_type_to_quant_scheme(t.dtype) else {
return Ok(None);
};
let bytes = self
.file
.tensor_bytes(t)
.with_context(|| format!("read packed bytes for {real}"))?
.to_vec();
let mut shape = t.shape.clone();
shape.reverse();
self.taken.insert(real);
Ok(Some((bytes, scheme, shape)))
}
pub fn take_mtp(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
if !self.is_mtp_tensor(key) {
return Err(anyhow!("not an MTP weight under this file's scheme: {key}"));
}
if !self.file.tensors.contains_key(key) {
return Err(anyhow!("MTP weight not found in GGUF: {key}"));
}
if self.taken.contains(key) {
return Err(anyhow!("MTP weight already taken: {key}"));
}
let (data, raw_shape) = self.file.dequant_f32(key)?;
self.taken.insert(key.to_string());
let mut shape = raw_shape;
shape.reverse();
Ok((data, shape))
}
}
impl GgufLoader {
fn resolve(&self, key: &str) -> Result<String> {
resolve_gguf_tensor_name(&self.file, &self.arch, key)
.ok_or_else(|| anyhow!("weight not found in GGUF (arch={}): {key}", self.arch))
}
}
impl WeightLoader for GgufLoader {
fn format_id(&self) -> &'static str {
"gguf"
}
fn arch_hint(&self) -> Option<&str> {
Some(&self.arch)
}
fn take_packed(&mut self, key: &str) -> Result<Option<crate::weight_map::PackedWeightTensor>> {
self.take_packed(key)
}
fn tensor_bytes_borrowed(&self, key: &str) -> Option<&[u8]> {
GgufLoader::tensor_bytes_borrowed(self, key)
}
fn len(&self) -> usize {
self.file
.tensors
.keys()
.filter(|k| !self.taken.contains(*k) && (self.include_mtp || !self.is_mtp_tensor(k)))
.count()
}
fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
let real = self.resolve(key)?;
if self.taken.contains(&real) {
return Err(anyhow!("weight already taken: {key} (→ {real})"));
}
if !self.include_mtp && self.is_mtp_tensor(&real) {
return Err(anyhow!(
"refusing to take MTP weight `{real}` without include_mtp(true); \
use loader.take_mtp(...) for explicit MTP grabs or \
loader.include_mtp(true) to include them in drains"
));
}
let (mut data, raw_shape) = self.file.dequant_f32(&real)?;
self.taken.insert(real.clone());
if matches!(
self.arch.as_str(),
"gemma" | "gemma2" | "gemma3" | "gemma3n" | "gemma4"
) && is_gemma_norm_weight(&real)
{
for v in data.iter_mut() {
*v -= 1.0;
}
}
let mut shape = raw_shape;
shape.reverse();
Ok((data, shape))
}
fn take_transposed(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
let (data, shape) = self.take(key)?;
if shape.len() != 2 {
return Err(anyhow!("transpose requires 2D, got {shape:?}"));
}
let (rows, cols) = (shape[0], shape[1]);
let mut t = vec![0f32; data.len()];
for i in 0..rows {
for j in 0..cols {
t[j * rows + i] = data[i * cols + j];
}
}
Ok((t, vec![cols, rows]))
}
fn remaining_keys(&self) -> Vec<String> {
self.file
.tensors
.keys()
.filter(|k| {
!self.taken.contains(k.as_str()) && (self.include_mtp || !self.is_mtp_tensor(k))
})
.cloned()
.collect()
}
}
impl GgufLoader {
pub fn mtp_keys(&self) -> Vec<String> {
self.file
.tensors
.keys()
.filter(|k| self.is_mtp_tensor(k))
.cloned()
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn unknown_extension_errors() {
let r = load_from_path("/tmp/no-such-thing.bin");
match r {
Err(e) => assert!(e.to_string().contains("unsupported")),
Ok(_) => panic!("expected error"),
}
}
#[test]
fn hf_to_gguf_top_level() {
assert_eq!(
hf_to_gguf_name("model.embed_tokens.weight").as_deref(),
Some("token_embd.weight")
);
assert_eq!(
hf_to_gguf_name("model.norm.weight").as_deref(),
Some("output_norm.weight")
);
assert_eq!(
hf_to_gguf_name("lm_head.weight").as_deref(),
Some("output.weight")
);
}
#[test]
fn hf_to_gguf_per_layer() {
let cases = [
(
"model.layers.0.self_attn.q_proj.weight",
"blk.0.attn_q.weight",
),
(
"model.layers.7.self_attn.o_proj.weight",
"blk.7.attn_output.weight",
),
(
"model.layers.3.mlp.gate_proj.weight",
"blk.3.ffn_gate.weight",
),
(
"model.layers.12.mlp.down_proj.weight",
"blk.12.ffn_down.weight",
),
(
"model.layers.4.input_layernorm.weight",
"blk.4.attn_norm.weight",
),
(
"model.layers.4.post_attention_layernorm.weight",
"blk.4.ffn_norm.weight",
),
(
"model.layers.0.self_attn.q_norm.weight",
"blk.0.attn_q_norm.weight",
),
];
for (hf, gguf) in cases {
assert_eq!(
hf_to_gguf_name(hf).as_deref(),
Some(gguf),
"mismatch for {hf}"
);
}
}
#[test]
fn hf_to_gguf_unknown_returns_none() {
assert!(hf_to_gguf_name("model.layers.0.some_new_thing.weight").is_none());
assert!(hf_to_gguf_name("model.layers.foo.input_layernorm.weight").is_none());
}
#[test]
fn mtp_detection() {
assert!(is_mtp_weight("mtp_blk.0.attn_q.weight"));
assert!(is_mtp_weight("output_mtp_0.weight"));
assert!(is_mtp_weight("model.layers.0.mtp_head.weight"));
assert!(!is_mtp_weight("blk.0.attn_q.weight"));
assert!(!is_mtp_weight("output.weight"));
}
#[test]
fn ggml_q4_0_maps_to_packed_scheme() {
use rlx_gguf::GgmlType;
assert_eq!(
ggml_type_to_quant_scheme(GgmlType::Q4_0),
Some(rlx_ir::quant::QuantScheme::GgufQ4_0)
);
assert_eq!(
ggml_type_to_quant_scheme(GgmlType::Q8_0),
Some(rlx_ir::quant::QuantScheme::GgufQ8_0)
);
}
#[test]
fn gguf_loader_threshold_based_mtp_detection() {
let mut buf: Vec<u8> = Vec::new();
buf.extend_from_slice(&rlx_gguf::GGUF_MAGIC.to_le_bytes());
buf.extend_from_slice(&3u32.to_le_bytes());
buf.extend_from_slice(&3u64.to_le_bytes()); buf.extend_from_slice(&3u64.to_le_bytes()); let write_string = |buf: &mut Vec<u8>, k: &str, v: &str| {
buf.extend_from_slice(&(k.len() as u64).to_le_bytes());
buf.extend_from_slice(k.as_bytes());
buf.extend_from_slice(&8u32.to_le_bytes());
buf.extend_from_slice(&(v.len() as u64).to_le_bytes());
buf.extend_from_slice(v.as_bytes());
};
let write_u32 = |buf: &mut Vec<u8>, k: &str, v: u32| {
buf.extend_from_slice(&(k.len() as u64).to_le_bytes());
buf.extend_from_slice(k.as_bytes());
buf.extend_from_slice(&4u32.to_le_bytes()); buf.extend_from_slice(&v.to_le_bytes());
};
write_string(&mut buf, "general.architecture", "qwen35");
write_u32(&mut buf, "qwen35.block_count", 25);
write_u32(&mut buf, "qwen35.nextn_predict_layers", 1);
let write_tensor = |buf: &mut Vec<u8>, name: &str, shape: &[usize], off: u64| {
buf.extend_from_slice(&(name.len() as u64).to_le_bytes());
buf.extend_from_slice(name.as_bytes());
buf.extend_from_slice(&(shape.len() as u32).to_le_bytes());
for &d in shape {
buf.extend_from_slice(&(d as u64).to_le_bytes());
}
buf.extend_from_slice(&0u32.to_le_bytes()); buf.extend_from_slice(&off.to_le_bytes());
};
write_tensor(&mut buf, "blk.0.attn_q.weight", &[4, 4], 0);
write_tensor(&mut buf, "blk.24.attn_q.weight", &[4, 4], 64);
write_tensor(&mut buf, "token_embd.weight", &[4, 4], 128);
while !buf
.len()
.is_multiple_of(rlx_gguf::DEFAULT_ALIGNMENT as usize)
{
buf.push(0);
}
for _ in 0..(3 * 16) {
buf.extend_from_slice(&0.5f32.to_le_bytes());
}
let path = std::env::temp_dir().join("rlx_mtp_threshold_test.gguf");
std::fs::write(&path, &buf).unwrap();
let loader = GgufLoader::from_file(path.to_str().unwrap()).unwrap();
assert_eq!(loader.mtp_layer_threshold(), Some(24));
assert!(!loader.is_mtp_tensor("blk.0.attn_q.weight"));
assert!(loader.is_mtp_tensor("blk.24.attn_q.weight"));
assert!(!loader.is_mtp_tensor("token_embd.weight"));
let mtp = loader.mtp_keys();
assert_eq!(mtp, vec!["blk.24.attn_q.weight".to_string()]);
std::fs::remove_file(&path).ok();
}
#[test]
fn gguf_loader_resolves_hf_names_and_skips_mtp() {
let mut tensors = Vec::new();
let mut data: Vec<f32> = Vec::new();
let t1: Vec<f32> = (0..12).map(|x| x as f32).collect();
tensors.push(("token_embd.weight", vec![3usize, 4], data.len()));
data.extend_from_slice(&t1);
let t2: Vec<f32> = (100..116).map(|x| x as f32).collect();
tensors.push(("blk.0.attn_q.weight", vec![4usize, 4], data.len()));
data.extend_from_slice(&t2);
let t3: Vec<f32> = vec![0.5f32; 8];
tensors.push(("output_mtp_0.weight", vec![2usize, 4], data.len()));
data.extend_from_slice(&t3);
let mut buf: Vec<u8> = Vec::new();
buf.extend_from_slice(&rlx_gguf::GGUF_MAGIC.to_le_bytes());
buf.extend_from_slice(&3u32.to_le_bytes()); buf.extend_from_slice(&(tensors.len() as u64).to_le_bytes());
buf.extend_from_slice(&0u64.to_le_bytes());
for (name, shape, _) in &tensors {
buf.extend_from_slice(&(name.len() as u64).to_le_bytes());
buf.extend_from_slice(name.as_bytes());
buf.extend_from_slice(&(shape.len() as u32).to_le_bytes());
for &d in shape {
buf.extend_from_slice(&(d as u64).to_le_bytes());
}
buf.extend_from_slice(&0u32.to_le_bytes()); buf.extend_from_slice(&0u64.to_le_bytes());
}
while !buf
.len()
.is_multiple_of(rlx_gguf::DEFAULT_ALIGNMENT as usize)
{
buf.push(0);
}
let data_start = buf.len();
for v in &data {
buf.extend_from_slice(&v.to_le_bytes());
}
let header = (4 + 4 + 8 + 8) as usize; let mut cursor = header;
for (name, shape, byte_off) in &tensors {
let name_len_bytes = 8;
let name_bytes = name.len();
let n_dims_bytes = 4;
let dims_bytes = shape.len() * 8;
let dtype_bytes = 4;
let off_bytes = 8;
let info_size =
name_len_bytes + name_bytes + n_dims_bytes + dims_bytes + dtype_bytes + off_bytes;
let off_field_at = cursor + info_size - off_bytes;
let final_off = (*byte_off * 4) as u64; for i in 0..8 {
buf[off_field_at + i] = (final_off >> (i * 8)) as u8;
}
cursor += info_size;
}
let _ = data_start;
let path = std::env::temp_dir().join("rlx_test_qwen3_mini.gguf");
std::fs::write(&path, &buf).unwrap();
let mut loader = GgufLoader::from_file(path.to_str().unwrap()).unwrap();
assert_eq!(loader.len(), 2);
let (out, shape) = loader
.take("model.embed_tokens.weight")
.expect("hf-named token_embd should resolve");
assert_eq!(shape, vec![4, 3]);
assert_eq!(&out, &t1);
let (out, shape) = loader
.take("model.layers.0.self_attn.q_proj.weight")
.expect("hf-named attn_q should resolve");
assert_eq!(shape, vec![4, 4]);
assert_eq!(&out, &t2);
assert_eq!(loader.remaining_keys(), Vec::<String>::new());
assert_eq!(loader.mtp_keys(), vec!["output_mtp_0.weight".to_string()]);
let mut loader2 = GgufLoader::from_file(path.to_str().unwrap()).unwrap();
loader2.include_mtp(true);
let visible: std::collections::HashSet<String> =
loader2.remaining_keys().into_iter().collect();
assert!(visible.contains("token_embd.weight"));
assert!(visible.contains("blk.0.attn_q.weight"));
assert!(
visible.contains("output_mtp_0.weight"),
"MTP weight should be visible with include_mtp(true)"
);
let (mtp_data, mtp_shape) = loader2.take_mtp("output_mtp_0.weight").unwrap();
assert_eq!(mtp_shape, vec![4, 2]);
assert_eq!(mtp_data, t3);
let mut loader3 = GgufLoader::from_file(path.to_str().unwrap()).unwrap();
let err = loader3.take("output_mtp_0.weight").unwrap_err();
let msg = format!("{err:#}");
assert!(
msg.contains("include_mtp(true)"),
"expected MTP guard error, got: {msg}"
);
std::fs::remove_file(&path).ok();
}
#[test]
fn missing_gguf_file_errors() {
let r = load_from_path("/tmp/no-such-thing-rlx-gguf-test.gguf");
assert!(r.is_err());
}
}