Skip to main content

rlx_llama32/
config.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// LLaMA-3.2 configuration — HF `config.json` and GGUF `llama.*` metadata.
5
6use rlx_gguf::{GgufFile, MetaValue};
7use serde::Deserialize;
8use std::path::Path;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
11#[serde(rename_all = "lowercase")]
12pub enum Llama32RopeType {
13    #[default]
14    Default,
15    #[serde(rename = "llama3")]
16    Llama3,
17}
18
19#[derive(Debug, Clone, Deserialize)]
20pub struct Llama32RopeScaling {
21    pub factor: f32,
22    #[serde(default = "default_low_freq_factor")]
23    pub low_freq_factor: f32,
24    #[serde(default = "default_high_freq_factor")]
25    pub high_freq_factor: f32,
26    pub original_max_position_embeddings: usize,
27    #[serde(default)]
28    pub rope_type: Llama32RopeType,
29}
30
31fn default_low_freq_factor() -> f32 {
32    1.0
33}
34fn default_high_freq_factor() -> f32 {
35    4.0
36}
37
38#[derive(Debug, Clone, Deserialize)]
39pub struct Llama32Config {
40    pub vocab_size: usize,
41    pub hidden_size: usize,
42    pub intermediate_size: usize,
43    pub num_hidden_layers: usize,
44    pub num_attention_heads: usize,
45    pub num_key_value_heads: usize,
46    pub max_position_embeddings: usize,
47
48    #[serde(default = "default_rms_norm_eps")]
49    pub rms_norm_eps: f64,
50    #[serde(default = "default_rope_theta")]
51    pub rope_theta: f64,
52    #[serde(default = "default_hidden_act")]
53    pub hidden_act: String,
54    #[serde(default)]
55    pub tie_word_embeddings: bool,
56    #[serde(default)]
57    pub attention_bias: bool,
58    /// Explicit head dim (Llama 3.x); when absent, derived from hidden/heads.
59    #[serde(default)]
60    pub head_dim: Option<usize>,
61    #[serde(default)]
62    pub rope_scaling: Option<Llama32RopeScaling>,
63}
64
65fn default_rms_norm_eps() -> f64 {
66    1e-5
67}
68fn default_rope_theta() -> f64 {
69    500_000.0
70}
71fn default_hidden_act() -> String {
72    "silu".into()
73}
74
75impl Llama32Config {
76    pub fn from_file(path: &Path) -> anyhow::Result<Self> {
77        let data = std::fs::read_to_string(path)?;
78        Ok(serde_json::from_str(&data)?)
79    }
80
81    pub fn from_gguf(raw: &GgufFile) -> anyhow::Result<Self> {
82        llama32_cfg_from_gguf(raw)
83    }
84
85    pub fn head_dim(&self) -> usize {
86        self.head_dim
87            .unwrap_or(self.hidden_size / self.num_attention_heads)
88    }
89
90    pub fn kv_group_size(&self) -> usize {
91        self.num_attention_heads / self.num_key_value_heads
92    }
93
94    pub fn q_proj_dim(&self) -> usize {
95        self.num_attention_heads * self.head_dim()
96    }
97
98    pub fn kv_proj_dim(&self) -> usize {
99        self.num_key_value_heads * self.head_dim()
100    }
101
102    #[cfg(test)]
103    pub(crate) fn tiny_test() -> Self {
104        Self {
105            vocab_size: 32,
106            hidden_size: 16,
107            intermediate_size: 32,
108            num_hidden_layers: 2,
109            num_attention_heads: 4,
110            num_key_value_heads: 2,
111            max_position_embeddings: 16,
112            rms_norm_eps: 1e-5,
113            rope_theta: 500_000.0,
114            hidden_act: "silu".into(),
115            tie_word_embeddings: false,
116            attention_bias: false,
117            head_dim: None,
118            rope_scaling: None,
119        }
120    }
121}
122
123pub fn llama32_cfg_from_gguf(raw: &GgufFile) -> anyhow::Result<Llama32Config> {
124    let arch_prefix = raw
125        .metadata
126        .get("general.architecture")
127        .and_then(MetaValue::as_str)
128        .unwrap_or("llama");
129    let get_meta = |k: &str| -> Option<&MetaValue> {
130        raw.metadata.get(k).or_else(|| {
131            let suffix = k.strip_prefix("llama.")?;
132            if arch_prefix == "llama" {
133                None
134            } else {
135                let arch_key = format!("{arch_prefix}.{suffix}");
136                raw.metadata.get(&arch_key)
137            }
138        })
139    };
140    let get_u32 = |k: &str| -> anyhow::Result<u32> {
141        get_meta(k)
142            .and_then(MetaValue::as_u32)
143            .ok_or_else(|| anyhow::anyhow!("missing GGUF metadata key: {k}"))
144    };
145    let get_f32 = |k: &str| -> Option<f32> {
146        get_meta(k).and_then(|v| match v {
147            MetaValue::F32(x) => Some(*x),
148            _ => None,
149        })
150    };
151    let get_bool = |k: &str| -> Option<bool> {
152        get_meta(k).and_then(|v| match v {
153            MetaValue::Bool(b) => Some(*b),
154            _ => None,
155        })
156    };
157
158    let hidden_size = get_u32("llama.embedding_length")? as usize;
159    let num_attention_heads = get_u32("llama.attention.head_count")? as usize;
160    let head_dim = get_u32("llama.attention.key_length")
161        .ok()
162        .or_else(|| get_u32("llama.rope.dimension_count").ok())
163        .map(|v| v as usize);
164
165    let rope_scaling = match get_meta("llama.rope.scaling.type").and_then(MetaValue::as_str) {
166        Some("none") | None => {
167            // Llama 3.x often bakes scaling into rope_freqs.weight; HF fields may be absent.
168            None
169        }
170        Some("linear") | Some("yarn") | Some("longrope") => {
171            let factor = get_f32("llama.rope.scaling.factor")
172                .or_else(|| get_f32("llama.rope.scale_linear"))
173                .unwrap_or(1.0);
174            let original = get_u32("llama.rope.scaling.original_context_length")
175                .map(|v| v as usize)
176                .unwrap_or(8192);
177            Some(Llama32RopeScaling {
178                factor,
179                low_freq_factor: 1.0,
180                high_freq_factor: 4.0,
181                original_max_position_embeddings: original,
182                rope_type: Llama32RopeType::Llama3,
183            })
184        }
185        other => {
186            return Err(anyhow::anyhow!(
187                "unsupported llama.rope.scaling.type: {other:?}"
188            ));
189        }
190    };
191
192    Ok(Llama32Config {
193        vocab_size: get_u32("llama.vocab_size").unwrap_or(128_256) as usize,
194        hidden_size,
195        intermediate_size: get_u32("llama.feed_forward_length")? as usize,
196        num_hidden_layers: get_u32("llama.block_count")? as usize,
197        num_attention_heads,
198        num_key_value_heads: get_u32("llama.attention.head_count_kv")? as usize,
199        max_position_embeddings: get_u32("llama.context_length").unwrap_or(8192) as usize,
200        rms_norm_eps: get_f32("llama.attention.layer_norm_rms_epsilon").unwrap_or(1e-5) as f64,
201        rope_theta: get_f32("llama.rope.freq_base").unwrap_or(500_000.0) as f64,
202        hidden_act: "silu".into(),
203        tie_word_embeddings: get_bool("llama.tie_word_embeddings").unwrap_or(true),
204        attention_bias: false,
205        head_dim,
206        rope_scaling,
207    })
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213
214    #[test]
215    fn parse_llama32_1b_like() {
216        let json = r#"{
217            "vocab_size": 128256,
218            "hidden_size": 2048,
219            "intermediate_size": 8192,
220            "num_hidden_layers": 16,
221            "num_attention_heads": 32,
222            "num_key_value_heads": 8,
223            "max_position_embeddings": 131072,
224            "rope_theta": 500000.0,
225            "rms_norm_eps": 1e-05,
226            "tie_word_embeddings": true,
227            "rope_scaling": {
228                "factor": 32.0,
229                "high_freq_factor": 4.0,
230                "low_freq_factor": 1.0,
231                "original_max_position_embeddings": 8192,
232                "rope_type": "llama3"
233            }
234        }"#;
235        let cfg: Llama32Config = serde_json::from_str(json).unwrap();
236        assert_eq!(cfg.head_dim(), 64);
237        assert_eq!(cfg.kv_group_size(), 4);
238        assert!(cfg.rope_scaling.is_some());
239    }
240}