Skip to main content

rlx_models_core/
gguf_config.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Read HuggingFace-shaped config fields from GGUF metadata (`{arch}.*` keys).
17
18use anyhow::{Context, Result, bail};
19use rlx_gguf::{GgufFile, MetaValue};
20
21use crate::config::{BertConfig, NomicBertConfig};
22use crate::gguf_support::gguf_architecture_str;
23
24fn arch_prefix(raw: &GgufFile) -> &str {
25    gguf_architecture_str(raw).unwrap_or("bert")
26}
27
28fn get_meta<'a>(raw: &'a GgufFile, _arch: &str, key: &str) -> Option<&'a MetaValue> {
29    raw.metadata.get(key)
30}
31
32fn meta_u32(raw: &GgufFile, arch: &str, key: &str) -> Result<u32> {
33    get_meta(raw, arch, key)
34        .and_then(MetaValue::as_u32)
35        .with_context(|| format!("missing or invalid GGUF metadata key: {key}"))
36}
37
38fn meta_u32_or(raw: &GgufFile, arch: &str, key: &str, default: u32) -> u32 {
39    get_meta(raw, arch, key)
40        .and_then(MetaValue::as_u32)
41        .unwrap_or(default)
42}
43
44fn meta_f64(raw: &GgufFile, arch: &str, key: &str, default: f64) -> f64 {
45    get_meta(raw, arch, key).map_or(default, |v| match v {
46        MetaValue::F32(x) => *x as f64,
47        MetaValue::F64(x) => *x,
48        _ => default,
49    })
50}
51
52/// GGUF `general.architecture` values for embedding models (validate in `rlx-embed`, not the loader).
53pub const EMBED_GGUF_ARCHES: &[&str] = &["bert", "modern-bert", "nomic-bert", "nomic-bert-moe"];
54
55/// GGUF architectures for FLUX denoiser checkpoints (validate in `rlx-flux2`, not the loader).
56pub const FLUX_GGUF_ARCHES: &[&str] = &["flux"];
57
58/// DINOv2 ViT (e.g. dinov2.cpp / community converters); F32 drain via [`PrefixStripGgufResolver`].
59pub const DINOV2_GGUF_ARCHES: &[&str] = &["dinov2"];
60
61/// SAM v1 ViT-H and MobileSAM GGUF (`sam`, `mobile-sam`).
62pub const SAM_GGUF_ARCHES: &[&str] = &["sam", "mobile-sam"];
63
64/// SAM 2 Hiera checkpoints converted to GGUF (`sam2` tag).
65pub const SAM2_GGUF_ARCHES: &[&str] = &["sam2"];
66
67/// SAM 3 (e.g. rob-laz/sam3-gguf, sam3.cpp).
68pub const SAM3_GGUF_ARCHES: &[&str] = &["sam3"];
69
70/// V-JEPA 2 (no common Hub GGUF yet; validate when present).
71pub const VJEPA2_GGUF_ARCHES: &[&str] = &["vjepa2", "vjepa"];
72
73/// Wav2Vec2-BERT / classic Wav2Vec2 GGUF (`w2v-bert` converters; ASR repos often use `wav2vec2`).
74pub const W2V_BERT_GGUF_ARCHES: &[&str] = &["w2v-bert", "wav2vec2", "wav2vec"];
75
76pub fn is_flux_gguf_arch(arch: &str) -> bool {
77    FLUX_GGUF_ARCHES.contains(&arch)
78}
79
80pub fn is_embed_gguf_arch(arch: &str) -> bool {
81    EMBED_GGUF_ARCHES.contains(&arch)
82}
83
84pub fn is_dinov2_gguf_arch(arch: &str) -> bool {
85    DINOV2_GGUF_ARCHES.contains(&arch)
86}
87
88pub fn is_sam_gguf_arch(arch: &str) -> bool {
89    SAM_GGUF_ARCHES.contains(&arch)
90}
91
92pub fn is_sam2_gguf_arch(arch: &str) -> bool {
93    SAM2_GGUF_ARCHES.contains(&arch)
94}
95
96pub fn is_sam3_gguf_arch(arch: &str) -> bool {
97    SAM3_GGUF_ARCHES.contains(&arch)
98}
99
100pub fn is_vjepa2_gguf_arch(arch: &str) -> bool {
101    VJEPA2_GGUF_ARCHES.contains(&arch)
102}
103
104pub fn is_w2v_bert_gguf_arch(arch: &str) -> bool {
105    W2V_BERT_GGUF_ARCHES.contains(&arch)
106}
107
108/// BERT vs NomicBERT discriminator from GGUF metadata.
109#[derive(Debug, Clone, Copy, PartialEq, Eq)]
110pub enum EmbedGgufKind {
111    Bert,
112    NomicBert,
113}
114
115/// Suggested runner / crate for a GGUF architecture tag (for CLI and errors).
116pub fn gguf_runner_hint(arch: &str) -> &'static str {
117    if is_embed_gguf_arch(arch) {
118        return "rlx-embed (`RlxEmbed::from_weights`)";
119    }
120    if is_flux_gguf_arch(arch) {
121        return "rlx-flux2 denoiser (`Flux2Runner::builder().weights`) — VAE/TE stay safetensors";
122    }
123    if is_dinov2_gguf_arch(arch) {
124        return "rlx-dinov2 (`DinoV2Runner::builder().weights`)";
125    }
126    if is_sam3_gguf_arch(arch) {
127        return "rlx-sam3 (`Sam3::from_checkpoint_on`)";
128    }
129    if is_sam2_gguf_arch(arch) {
130        return "rlx-sam2 (`Sam2::from_safetensors_on`)";
131    }
132    if is_sam_gguf_arch(arch) {
133        return "rlx-sam (`Sam::from_safetensors_on`) — MobileSAM uses `mobile-sam` arch";
134    }
135    if is_vjepa2_gguf_arch(arch) {
136        return "rlx-vjepa2 (`Vjepa2Runner::builder().weights`)";
137    }
138    if is_w2v_bert_gguf_arch(arch) {
139        return "rlx-wav2vec2-bert (`Wav2Vec2BertRunner::builder().weights`; keep config.json beside GGUF)";
140    }
141    if let Some(fam) = crate::gguf_support::gguf_family_for_arch(arch) {
142        return match fam {
143            crate::gguf_support::GgufModelFamily::Qwen3 => {
144                "rlx-qwen3 (use `--packed` for large K-quant GGUF)"
145            }
146            crate::gguf_support::GgufModelFamily::Qwen35 => "rlx-qwen35 (`--packed` recommended)",
147            crate::gguf_support::GgufModelFamily::Llama32 => {
148                "rlx-llama32 (`--packed` for large K-quant GGUF)"
149            }
150            crate::gguf_support::GgufModelFamily::Gemma => {
151                "rlx-gemma (`--packed` for large K-quant GGUF)"
152            }
153            crate::gguf_support::GgufModelFamily::Lfm => "rlx-lfm (`LfmRunner::builder().weights`)",
154        };
155    }
156    "unknown — register a custom GgufTensorNameResolver or WeightFormatRegistration"
157}
158
159/// Estimated RAM if every tensor is dequantized to F32 vs kept packed on disk.
160#[derive(Debug, Clone, Copy)]
161pub struct GgufMemoryFootprint {
162    pub f32_bytes: u64,
163    pub packed_file_bytes: u64,
164}
165
166pub fn gguf_memory_footprint(raw: &GgufFile) -> GgufMemoryFootprint {
167    let mut f32_bytes = 0u64;
168    let mut packed_file_bytes = 0u64;
169    for t in raw.tensors.values() {
170        let n = t.n_elements() as u64;
171        f32_bytes += n * 4;
172        packed_file_bytes += raw.tensor_bytes(t).map(|b| b.len() as u64).unwrap_or(n * 4);
173    }
174    GgufMemoryFootprint {
175        f32_bytes,
176        packed_file_bytes,
177    }
178}
179
180/// Read `{arch}.{suffix}` as u32 when present.
181pub fn gguf_meta_u32(raw: &GgufFile, arch: &str, suffix: &str) -> Option<u32> {
182    let key = format!("{arch}.{suffix}");
183    raw.metadata.get(&key).and_then(MetaValue::as_u32)
184}
185
186pub fn embed_gguf_kind(raw: &GgufFile) -> Result<EmbedGgufKind> {
187    let arch = arch_prefix(raw);
188    if !is_embed_gguf_arch(arch) {
189        bail!(
190            "GGUF architecture {arch:?} is not supported for embeddings; \
191             expected one of: {}",
192            EMBED_GGUF_ARCHES.join(", ")
193        );
194    }
195    if matches!(arch, "nomic-bert" | "nomic-bert-moe") {
196        Ok(EmbedGgufKind::NomicBert)
197    } else {
198        Ok(EmbedGgufKind::Bert)
199    }
200}
201
202impl BertConfig {
203    pub fn from_gguf(raw: &GgufFile) -> Result<Self> {
204        let arch = arch_prefix(raw);
205        if !matches!(arch, "bert" | "modern-bert") {
206            bail!("BertConfig::from_gguf expected bert/modern-bert, got {arch}");
207        }
208        let hidden_size = meta_u32(raw, arch, &format!("{arch}.embedding_length"))? as usize;
209        let num_attention_heads =
210            meta_u32(raw, arch, &format!("{arch}.attention.head_count"))? as usize;
211        Ok(Self {
212            vocab_size: meta_u32(raw, arch, &format!("{arch}.vocab_size"))? as usize,
213            hidden_size,
214            num_hidden_layers: meta_u32(raw, arch, &format!("{arch}.block_count"))? as usize,
215            num_attention_heads,
216            intermediate_size: meta_u32(raw, arch, &format!("{arch}.feed_forward_length"))?
217                as usize,
218            max_position_embeddings: meta_u32_or(raw, arch, &format!("{arch}.context_length"), 512)
219                as usize,
220            type_vocab_size: meta_u32_or(raw, arch, "tokenizer.ggml.token_type_count", 2) as usize,
221            layer_norm_eps: meta_f64(
222                raw,
223                arch,
224                &format!("{arch}.attention.layer_norm_epsilon"),
225                1e-12,
226            ),
227            hidden_act: "gelu".into(),
228        })
229    }
230}
231
232impl NomicBertConfig {
233    pub fn from_gguf(raw: &GgufFile) -> Result<Self> {
234        let arch = arch_prefix(raw);
235        if !matches!(arch, "nomic-bert" | "nomic-bert-moe") {
236            bail!("NomicBertConfig::from_gguf expected nomic-bert, got {arch}");
237        }
238        let hidden_size = meta_u32(raw, arch, &format!("{arch}.embedding_length"))? as usize;
239        let num_attention_heads =
240            meta_u32(raw, arch, &format!("{arch}.attention.head_count"))? as usize;
241        let head_dim = meta_u32_or(
242            raw,
243            arch,
244            &format!("{arch}.attention.key_length"),
245            (hidden_size / num_attention_heads.max(1)) as u32,
246        ) as usize;
247        Ok(Self {
248            vocab_size: meta_u32(raw, arch, &format!("{arch}.vocab_size"))? as usize,
249            hidden_size,
250            num_hidden_layers: meta_u32(raw, arch, &format!("{arch}.block_count"))? as usize,
251            num_attention_heads,
252            intermediate_size: meta_u32(raw, arch, &format!("{arch}.feed_forward_length"))?
253                as usize,
254            max_position_embeddings: meta_u32_or(raw, arch, &format!("{arch}.context_length"), 8192)
255                as usize,
256            type_vocab_size: meta_u32_or(raw, arch, "tokenizer.ggml.token_type_count", 2) as usize,
257            layer_norm_eps: meta_f64(
258                raw,
259                arch,
260                &format!("{arch}.attention.layer_norm_epsilon"),
261                1e-12,
262            ),
263            head_dim,
264            rotary_emb_base: meta_f64(raw, arch, &format!("{arch}.rope.freq_base"), 1000.0),
265        })
266    }
267}