Skip to main content

rlx_models_core/
gguf_resolve.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Pluggable GGUF tensor-name resolution per `general.architecture`.
17
18use rlx_gguf::GgufFile;
19use std::sync::{Mutex, OnceLock};
20
21use crate::weight_loader::{gguf_to_hf_name, hf_to_gguf_name};
22
23/// Resolve a builder-requested tensor name to the name stored in a GGUF file.
24pub trait GgufTensorNameResolver: Send + Sync {
25    fn matches_arch(&self, arch: &str) -> bool;
26    fn resolve(&self, file: &GgufFile, requested_key: &str) -> Option<String>;
27}
28
29/// HF `model.layers.N.*` ↔ GGUF `blk.N.*` (Llama, Qwen3, Qwen35, …).
30pub struct LlamaFamilyGgufResolver;
31
32impl GgufTensorNameResolver for LlamaFamilyGgufResolver {
33    fn matches_arch(&self, arch: &str) -> bool {
34        matches!(
35            arch,
36            "llama"
37                | "llama4"
38                | "qwen3"
39                | "qwen2"
40                | "qwen35"
41                | "qwen35moe"
42                | "qwen36"
43                | "gemma"
44                | "gemma2"
45                | "mistral"
46        )
47    }
48
49    fn resolve(&self, file: &GgufFile, key: &str) -> Option<String> {
50        if file.tensors.contains_key(key) {
51            return Some(key.to_string());
52        }
53        if let Some(g) = hf_to_gguf_name(key) {
54            if file.tensors.contains_key(&g) {
55                return Some(g);
56            }
57        }
58        if let Some(h) = gguf_to_hf_name(key) {
59            if file.tensors.contains_key(&h) {
60                return Some(h);
61            }
62        }
63        None
64    }
65}
66
67/// Strip common HF prefixes and match verbatim tensor names (architecture-agnostic fallback).
68pub struct PrefixStripGgufResolver;
69
70/// Alias for [`PrefixStripGgufResolver`] (older name).
71pub type PassThroughGgufResolver = PrefixStripGgufResolver;
72
73impl GgufTensorNameResolver for PrefixStripGgufResolver {
74    fn matches_arch(&self, _arch: &str) -> bool {
75        true
76    }
77
78    fn resolve(&self, file: &GgufFile, key: &str) -> Option<String> {
79        let mut k = key.to_string();
80        for prefix in [
81            "model.diffusion_model.",
82            "diffusion_model.",
83            "transformer.",
84            "model.",
85        ] {
86            if let Some(rest) = k.strip_prefix(prefix) {
87                k = rest.to_string();
88                break;
89            }
90        }
91        if file.tensors.contains_key(&k) {
92            return Some(k);
93        }
94        if file.tensors.contains_key(key) {
95            return Some(key.to_string());
96        }
97        None
98    }
99}
100
101/// Gemma 2/3/4: 4 RMSNorms per layer disagree with the Llama 2-norm convention.
102///
103/// Llama treats `post_attention_layernorm` as the pre-FFN norm and aliases it
104/// to `ffn_norm`. Gemma 2/3/4 (V2/V3/V4 layer styles) have a dedicated
105/// `post_attention_norm` between the attention output and the residual add,
106/// *and* a separate `ffn_norm` / `post_ffw_norm` pair around the MLP. Without
107/// this resolver, the Llama mapper would alias `post_attention_layernorm` to
108/// `ffn_norm`, collide with `pre_feedforward_layernorm`, and silently load
109/// the wrong tensor. The tail map is identical across these arches — only
110/// the GGUF arch tag (and runtime details like sliding-window stride) differ.
111pub struct Gemma2GgufResolver;
112
113impl GgufTensorNameResolver for Gemma2GgufResolver {
114    fn matches_arch(&self, arch: &str) -> bool {
115        matches!(
116            arch,
117            "gemma2" | "gemma3" | "gemma3n" | "gemma4" | "gemma4moe"
118        )
119    }
120
121    fn resolve(&self, file: &GgufFile, key: &str) -> Option<String> {
122        // Identity hit first — accept native GGUF names verbatim.
123        if file.tensors.contains_key(key) {
124            return Some(key.to_string());
125        }
126        // Handle the four-norm-per-layer scheme explicitly. The Llama mapper
127        // is wrong for `post_attention_layernorm` (it aliases to `ffn_norm`,
128        // which Gemma 2 reserves for the pre-FFN norm) and has no entry at
129        // all for the `pre_feedforward_layernorm`/`post_feedforward_layernorm`
130        // pair.
131        if let Some(rest) = key.strip_prefix("model.layers.") {
132            if let Some((idx, tail)) = rest.split_once('.') {
133                let gguf_tail = match tail {
134                    "post_attention_layernorm.weight" => Some("post_attention_norm.weight"),
135                    "pre_feedforward_layernorm.weight" => Some("ffn_norm.weight"),
136                    "post_feedforward_layernorm.weight" => Some("post_ffw_norm.weight"),
137                    _ => None,
138                };
139                if let Some(t) = gguf_tail {
140                    let g = format!("blk.{idx}.{t}");
141                    if file.tensors.contains_key(&g) {
142                        return Some(g);
143                    }
144                }
145            }
146        }
147        // Fall through to Llama-family mapping for everything else
148        // (input_layernorm, attn/mlp weights, lm_head, embeddings, …).
149        LlamaFamilyGgufResolver.resolve(file, key)
150    }
151}
152
153/// Qwen3.5 native `blk.N.*` names; also accept HF aliases via the Llama mapper.
154pub struct Qwen35NativeGgufResolver;
155
156impl GgufTensorNameResolver for Qwen35NativeGgufResolver {
157    fn matches_arch(&self, arch: &str) -> bool {
158        matches!(arch, "qwen35" | "qwen35moe" | "qwen36")
159    }
160
161    fn resolve(&self, file: &GgufFile, key: &str) -> Option<String> {
162        if file.tensors.contains_key(key) {
163            return Some(key.to_string());
164        }
165        LlamaFamilyGgufResolver.resolve(file, key)
166    }
167}
168
169static CUSTOM_RESOLVERS: Mutex<Vec<Box<dyn GgufTensorNameResolver>>> = Mutex::new(Vec::new());
170static BUILTIN_RESOLVERS: OnceLock<Vec<Box<dyn GgufTensorNameResolver>>> = OnceLock::new();
171
172fn builtin_resolvers() -> &'static Vec<Box<dyn GgufTensorNameResolver>> {
173    BUILTIN_RESOLVERS.get_or_init(|| {
174        vec![
175            Box::new(Qwen35NativeGgufResolver),
176            Box::new(Gemma2GgufResolver),
177            Box::new(LlamaFamilyGgufResolver),
178            Box::new(PrefixStripGgufResolver),
179        ]
180    })
181}
182
183/// Register built-in GGUF resolvers (idempotent). Called automatically on first resolve; \
184/// call from `main` if you register custom resolvers and need ordering guarantees.
185pub fn ensure_builtin_resolvers() {
186    let _ = builtin_resolvers();
187}
188
189/// Register a custom resolver (call before first GGUF load). Later registrations win
190/// among resolvers that match the same architecture.
191pub fn register_gguf_tensor_resolver(resolver: Box<dyn GgufTensorNameResolver>) {
192    CUSTOM_RESOLVERS
193        .lock()
194        .expect("gguf resolver registry lock")
195        .push(resolver);
196}
197
198/// Resolve `requested_key` against tensors in `file` using registered resolvers.
199pub fn resolve_gguf_tensor_name(
200    file: &GgufFile,
201    arch: &str,
202    requested_key: &str,
203) -> Option<String> {
204    for r in builtin_resolvers().iter() {
205        if r.matches_arch(arch) {
206            if let Some(name) = r.resolve(file, requested_key) {
207                return Some(name);
208            }
209        }
210    }
211    let custom = CUSTOM_RESOLVERS
212        .lock()
213        .expect("gguf resolver registry lock");
214    for r in custom.iter() {
215        if r.matches_arch(arch) {
216            if let Some(name) = r.resolve(file, requested_key) {
217                return Some(name);
218            }
219        }
220    }
221    if file.tensors.contains_key(requested_key) {
222        return Some(requested_key.to_string());
223    }
224    if let Some(g) = hf_to_gguf_name(requested_key) {
225        if file.tensors.contains_key(&g) {
226            return Some(g);
227        }
228    }
229    if let Some(h) = gguf_to_hf_name(requested_key) {
230        if file.tensors.contains_key(&h) {
231            return Some(h);
232        }
233    }
234    None
235}