1use rlx_gguf::{GgufFile, MetaValue};
7use serde::Deserialize;
8use std::path::Path;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
11#[serde(rename_all = "lowercase")]
12pub enum Llama32RopeType {
13 #[default]
14 Default,
15 #[serde(rename = "llama3")]
16 Llama3,
17}
18
19#[derive(Debug, Clone, Deserialize)]
20pub struct Llama32RopeScaling {
21 pub factor: f32,
22 #[serde(default = "default_low_freq_factor")]
23 pub low_freq_factor: f32,
24 #[serde(default = "default_high_freq_factor")]
25 pub high_freq_factor: f32,
26 pub original_max_position_embeddings: usize,
27 #[serde(default)]
28 pub rope_type: Llama32RopeType,
29}
30
31fn default_low_freq_factor() -> f32 {
32 1.0
33}
34fn default_high_freq_factor() -> f32 {
35 4.0
36}
37
38#[derive(Debug, Clone, Deserialize)]
39pub struct Llama32Config {
40 pub vocab_size: usize,
41 pub hidden_size: usize,
42 pub intermediate_size: usize,
43 pub num_hidden_layers: usize,
44 pub num_attention_heads: usize,
45 pub num_key_value_heads: usize,
46 pub max_position_embeddings: usize,
47
48 #[serde(default = "default_rms_norm_eps")]
49 pub rms_norm_eps: f64,
50 #[serde(default = "default_rope_theta")]
51 pub rope_theta: f64,
52 #[serde(default = "default_hidden_act")]
53 pub hidden_act: String,
54 #[serde(default)]
55 pub tie_word_embeddings: bool,
56 #[serde(default)]
57 pub attention_bias: bool,
58 #[serde(default)]
60 pub head_dim: Option<usize>,
61 #[serde(default)]
62 pub rope_scaling: Option<Llama32RopeScaling>,
63}
64
65fn default_rms_norm_eps() -> f64 {
66 1e-5
67}
68fn default_rope_theta() -> f64 {
69 500_000.0
70}
71fn default_hidden_act() -> String {
72 "silu".into()
73}
74
75impl Llama32Config {
76 pub fn from_file(path: &Path) -> anyhow::Result<Self> {
77 let data = std::fs::read_to_string(path)?;
78 Ok(serde_json::from_str(&data)?)
79 }
80
81 pub fn from_gguf(raw: &GgufFile) -> anyhow::Result<Self> {
82 llama32_cfg_from_gguf(raw)
83 }
84
85 pub fn head_dim(&self) -> usize {
86 self.head_dim
87 .unwrap_or(self.hidden_size / self.num_attention_heads)
88 }
89
90 pub fn kv_group_size(&self) -> usize {
91 self.num_attention_heads / self.num_key_value_heads
92 }
93
94 pub fn q_proj_dim(&self) -> usize {
95 self.num_attention_heads * self.head_dim()
96 }
97
98 pub fn kv_proj_dim(&self) -> usize {
99 self.num_key_value_heads * self.head_dim()
100 }
101
102 #[cfg(test)]
103 pub(crate) fn tiny_test() -> Self {
104 Self {
105 vocab_size: 32,
106 hidden_size: 16,
107 intermediate_size: 32,
108 num_hidden_layers: 2,
109 num_attention_heads: 4,
110 num_key_value_heads: 2,
111 max_position_embeddings: 16,
112 rms_norm_eps: 1e-5,
113 rope_theta: 500_000.0,
114 hidden_act: "silu".into(),
115 tie_word_embeddings: false,
116 attention_bias: false,
117 head_dim: None,
118 rope_scaling: None,
119 }
120 }
121}
122
123pub fn llama32_cfg_from_gguf(raw: &GgufFile) -> anyhow::Result<Llama32Config> {
124 let arch_prefix = raw
125 .metadata
126 .get("general.architecture")
127 .and_then(MetaValue::as_str)
128 .unwrap_or("llama");
129 let get_meta = |k: &str| -> Option<&MetaValue> {
130 raw.metadata.get(k).or_else(|| {
131 let suffix = k.strip_prefix("llama.")?;
132 if arch_prefix == "llama" {
133 None
134 } else {
135 let arch_key = format!("{arch_prefix}.{suffix}");
136 raw.metadata.get(&arch_key)
137 }
138 })
139 };
140 let get_u32 = |k: &str| -> anyhow::Result<u32> {
141 get_meta(k)
142 .and_then(MetaValue::as_u32)
143 .ok_or_else(|| anyhow::anyhow!("missing GGUF metadata key: {k}"))
144 };
145 let get_f32 = |k: &str| -> Option<f32> {
146 get_meta(k).and_then(|v| match v {
147 MetaValue::F32(x) => Some(*x),
148 _ => None,
149 })
150 };
151 let get_bool = |k: &str| -> Option<bool> {
152 get_meta(k).and_then(|v| match v {
153 MetaValue::Bool(b) => Some(*b),
154 _ => None,
155 })
156 };
157
158 let hidden_size = get_u32("llama.embedding_length")? as usize;
159 let num_attention_heads = get_u32("llama.attention.head_count")? as usize;
160 let head_dim = get_u32("llama.attention.key_length")
161 .ok()
162 .or_else(|| get_u32("llama.rope.dimension_count").ok())
163 .map(|v| v as usize);
164
165 let rope_scaling = match get_meta("llama.rope.scaling.type").and_then(MetaValue::as_str) {
166 Some("none") | None => {
167 None
169 }
170 Some("linear") | Some("yarn") | Some("longrope") => {
171 let factor = get_f32("llama.rope.scaling.factor")
172 .or_else(|| get_f32("llama.rope.scale_linear"))
173 .unwrap_or(1.0);
174 let original = get_u32("llama.rope.scaling.original_context_length")
175 .map(|v| v as usize)
176 .unwrap_or(8192);
177 Some(Llama32RopeScaling {
178 factor,
179 low_freq_factor: 1.0,
180 high_freq_factor: 4.0,
181 original_max_position_embeddings: original,
182 rope_type: Llama32RopeType::Llama3,
183 })
184 }
185 other => {
186 return Err(anyhow::anyhow!(
187 "unsupported llama.rope.scaling.type: {other:?}"
188 ));
189 }
190 };
191
192 Ok(Llama32Config {
193 vocab_size: get_u32("llama.vocab_size").unwrap_or(128_256) as usize,
194 hidden_size,
195 intermediate_size: get_u32("llama.feed_forward_length")? as usize,
196 num_hidden_layers: get_u32("llama.block_count")? as usize,
197 num_attention_heads,
198 num_key_value_heads: get_u32("llama.attention.head_count_kv")? as usize,
199 max_position_embeddings: get_u32("llama.context_length").unwrap_or(8192) as usize,
200 rms_norm_eps: get_f32("llama.attention.layer_norm_rms_epsilon").unwrap_or(1e-5) as f64,
201 rope_theta: get_f32("llama.rope.freq_base").unwrap_or(500_000.0) as f64,
202 hidden_act: "silu".into(),
203 tie_word_embeddings: get_bool("llama.tie_word_embeddings").unwrap_or(true),
204 attention_bias: false,
205 head_dim,
206 rope_scaling,
207 })
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213
214 #[test]
215 fn parse_llama32_1b_like() {
216 let json = r#"{
217 "vocab_size": 128256,
218 "hidden_size": 2048,
219 "intermediate_size": 8192,
220 "num_hidden_layers": 16,
221 "num_attention_heads": 32,
222 "num_key_value_heads": 8,
223 "max_position_embeddings": 131072,
224 "rope_theta": 500000.0,
225 "rms_norm_eps": 1e-05,
226 "tie_word_embeddings": true,
227 "rope_scaling": {
228 "factor": 32.0,
229 "high_freq_factor": 4.0,
230 "low_freq_factor": 1.0,
231 "original_max_position_embeddings": 8192,
232 "rope_type": "llama3"
233 }
234 }"#;
235 let cfg: Llama32Config = serde_json::from_str(json).unwrap();
236 assert_eq!(cfg.head_dim(), 64);
237 assert_eq!(cfg.kv_group_size(), 4);
238 assert!(cfg.rope_scaling.is_some());
239 }
240}