use rlx_gguf::GgufFile;
use std::sync::{Mutex, OnceLock};
use crate::weight_loader::{gguf_to_hf_name, hf_to_gguf_name};
pub trait GgufTensorNameResolver: Send + Sync {
fn matches_arch(&self, arch: &str) -> bool;
fn resolve(&self, file: &GgufFile, requested_key: &str) -> Option<String>;
}
pub struct LlamaFamilyGgufResolver;
impl GgufTensorNameResolver for LlamaFamilyGgufResolver {
fn matches_arch(&self, arch: &str) -> bool {
matches!(
arch,
"llama"
| "llama4"
| "qwen3"
| "qwen2"
| "qwen35"
| "qwen35moe"
| "qwen36"
| "gemma"
| "gemma2"
| "mistral"
)
}
fn resolve(&self, file: &GgufFile, key: &str) -> Option<String> {
if file.tensors.contains_key(key) {
return Some(key.to_string());
}
if let Some(g) = hf_to_gguf_name(key) {
if file.tensors.contains_key(&g) {
return Some(g);
}
}
if let Some(h) = gguf_to_hf_name(key) {
if file.tensors.contains_key(&h) {
return Some(h);
}
}
None
}
}
pub struct PrefixStripGgufResolver;
pub type PassThroughGgufResolver = PrefixStripGgufResolver;
impl GgufTensorNameResolver for PrefixStripGgufResolver {
fn matches_arch(&self, _arch: &str) -> bool {
true
}
fn resolve(&self, file: &GgufFile, key: &str) -> Option<String> {
let mut k = key.to_string();
for prefix in [
"model.diffusion_model.",
"diffusion_model.",
"transformer.",
"model.",
] {
if let Some(rest) = k.strip_prefix(prefix) {
k = rest.to_string();
break;
}
}
if file.tensors.contains_key(&k) {
return Some(k);
}
if file.tensors.contains_key(key) {
return Some(key.to_string());
}
None
}
}
pub struct Gemma2GgufResolver;
impl GgufTensorNameResolver for Gemma2GgufResolver {
fn matches_arch(&self, arch: &str) -> bool {
matches!(
arch,
"gemma2" | "gemma3" | "gemma3n" | "gemma4" | "gemma4moe"
)
}
fn resolve(&self, file: &GgufFile, key: &str) -> Option<String> {
if file.tensors.contains_key(key) {
return Some(key.to_string());
}
if let Some(rest) = key.strip_prefix("model.layers.") {
if let Some((idx, tail)) = rest.split_once('.') {
let gguf_tail = match tail {
"post_attention_layernorm.weight" => Some("post_attention_norm.weight"),
"pre_feedforward_layernorm.weight" => Some("ffn_norm.weight"),
"post_feedforward_layernorm.weight" => Some("post_ffw_norm.weight"),
_ => None,
};
if let Some(t) = gguf_tail {
let g = format!("blk.{idx}.{t}");
if file.tensors.contains_key(&g) {
return Some(g);
}
}
}
}
LlamaFamilyGgufResolver.resolve(file, key)
}
}
pub struct Qwen35NativeGgufResolver;
impl GgufTensorNameResolver for Qwen35NativeGgufResolver {
fn matches_arch(&self, arch: &str) -> bool {
matches!(arch, "qwen35" | "qwen35moe" | "qwen36")
}
fn resolve(&self, file: &GgufFile, key: &str) -> Option<String> {
if file.tensors.contains_key(key) {
return Some(key.to_string());
}
LlamaFamilyGgufResolver.resolve(file, key)
}
}
static CUSTOM_RESOLVERS: Mutex<Vec<Box<dyn GgufTensorNameResolver>>> = Mutex::new(Vec::new());
static BUILTIN_RESOLVERS: OnceLock<Vec<Box<dyn GgufTensorNameResolver>>> = OnceLock::new();
fn builtin_resolvers() -> &'static Vec<Box<dyn GgufTensorNameResolver>> {
BUILTIN_RESOLVERS.get_or_init(|| {
vec![
Box::new(Qwen35NativeGgufResolver),
Box::new(Gemma2GgufResolver),
Box::new(LlamaFamilyGgufResolver),
Box::new(PrefixStripGgufResolver),
]
})
}
pub fn ensure_builtin_resolvers() {
let _ = builtin_resolvers();
}
pub fn register_gguf_tensor_resolver(resolver: Box<dyn GgufTensorNameResolver>) {
CUSTOM_RESOLVERS
.lock()
.expect("gguf resolver registry lock")
.push(resolver);
}
pub fn resolve_gguf_tensor_name(
file: &GgufFile,
arch: &str,
requested_key: &str,
) -> Option<String> {
for r in builtin_resolvers().iter() {
if r.matches_arch(arch) {
if let Some(name) = r.resolve(file, requested_key) {
return Some(name);
}
}
}
let custom = CUSTOM_RESOLVERS
.lock()
.expect("gguf resolver registry lock");
for r in custom.iter() {
if r.matches_arch(arch) {
if let Some(name) = r.resolve(file, requested_key) {
return Some(name);
}
}
}
if file.tensors.contains_key(requested_key) {
return Some(requested_key.to_string());
}
if let Some(g) = hf_to_gguf_name(requested_key) {
if file.tensors.contains_key(&g) {
return Some(g);
}
}
if let Some(h) = gguf_to_hf_name(requested_key) {
if file.tensors.contains_key(&h) {
return Some(h);
}
}
None
}