god_graph/transformer/loader/
config.rs1use serde::{Deserialize, Serialize};
4
5pub trait ModelConfigTrait {
7 fn vocab_size(&self) -> usize;
9
10 fn hidden_size(&self) -> usize;
12
13 fn intermediate_size(&self) -> usize;
15
16 fn num_hidden_layers(&self) -> usize;
18
19 fn num_attention_heads(&self) -> usize;
21
22 fn num_key_value_heads(&self) -> Option<usize>;
24
25 fn max_position_embeddings(&self) -> usize;
27
28 fn rms_norm_eps(&self) -> f64;
30
31 fn rope_theta(&self) -> f64;
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct LlamaConfig {
38 pub vocab_size: usize,
40 pub hidden_size: usize,
42 pub intermediate_size: usize,
44 pub num_hidden_layers: usize,
46 pub num_attention_heads: usize,
48 pub num_key_value_heads: Option<usize>,
50 pub max_position_embeddings: usize,
52 pub rms_norm_eps: f64,
54 pub rope_theta: f64,
56 pub tie_word_embeddings: bool,
58 pub attention_bias: bool,
60}
61
62impl LlamaConfig {
63 pub fn llama_7b() -> Self {
65 Self {
66 vocab_size: 32000,
67 hidden_size: 4096,
68 intermediate_size: 11008,
69 num_hidden_layers: 32,
70 num_attention_heads: 32,
71 num_key_value_heads: None, max_position_embeddings: 2048,
73 rms_norm_eps: 1e-6,
74 rope_theta: 10000.0,
75 tie_word_embeddings: false,
76 attention_bias: false,
77 }
78 }
79
80 pub fn llama_2_7b() -> Self {
82 let mut config = Self::llama_7b();
83 config.num_key_value_heads = Some(32); config.max_position_embeddings = 4096;
85 config
86 }
87
88 pub fn llama_3_8b() -> Self {
90 Self {
91 vocab_size: 128256,
92 hidden_size: 4096,
93 intermediate_size: 14336,
94 num_hidden_layers: 32,
95 num_attention_heads: 32,
96 num_key_value_heads: Some(8), max_position_embeddings: 8192,
98 rms_norm_eps: 1e-5,
99 rope_theta: 500000.0,
100 tie_word_embeddings: false,
101 attention_bias: false,
102 }
103 }
104
105 pub fn get_num_key_value_heads(&self) -> usize {
107 self.num_key_value_heads.unwrap_or(self.num_attention_heads)
108 }
109
110 pub fn head_dim(&self) -> usize {
112 self.hidden_size / self.num_attention_heads
113 }
114
115 pub fn q_per_kv(&self) -> usize {
117 self.num_attention_heads / self.get_num_key_value_heads()
118 }
119}
120
121impl ModelConfigTrait for LlamaConfig {
122 fn vocab_size(&self) -> usize {
123 self.vocab_size
124 }
125
126 fn hidden_size(&self) -> usize {
127 self.hidden_size
128 }
129
130 fn intermediate_size(&self) -> usize {
131 self.intermediate_size
132 }
133
134 fn num_hidden_layers(&self) -> usize {
135 self.num_hidden_layers
136 }
137
138 fn num_attention_heads(&self) -> usize {
139 self.num_attention_heads
140 }
141
142 fn num_key_value_heads(&self) -> Option<usize> {
143 self.num_key_value_heads
144 }
145
146 fn max_position_embeddings(&self) -> usize {
147 self.max_position_embeddings
148 }
149
150 fn rms_norm_eps(&self) -> f64 {
151 self.rms_norm_eps
152 }
153
154 fn rope_theta(&self) -> f64 {
155 self.rope_theta
156 }
157}
158
159#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct MistralConfig {
162 pub vocab_size: usize,
164 pub hidden_size: usize,
166 pub intermediate_size: usize,
168 pub num_hidden_layers: usize,
170 pub num_attention_heads: usize,
172 pub num_key_value_heads: usize,
174 pub max_position_embeddings: usize,
176 pub rms_norm_eps: f64,
178 pub rope_theta: f64,
180 pub sliding_window: Option<usize>,
182 pub tie_word_embeddings: bool,
184 pub attention_bias: bool,
186}
187
188impl MistralConfig {
189 pub fn mistral_7b() -> Self {
191 Self {
192 vocab_size: 32000,
193 hidden_size: 4096,
194 intermediate_size: 14336,
195 num_hidden_layers: 32,
196 num_attention_heads: 32,
197 num_key_value_heads: 8, max_position_embeddings: 2048,
199 rms_norm_eps: 1e-5,
200 rope_theta: 10000.0,
201 sliding_window: Some(4096),
202 tie_word_embeddings: false,
203 attention_bias: false,
204 }
205 }
206
207 pub fn head_dim(&self) -> usize {
209 self.hidden_size / self.num_attention_heads
210 }
211
212 pub fn q_per_kv(&self) -> usize {
214 self.num_attention_heads / self.num_key_value_heads
215 }
216}
217
218impl ModelConfigTrait for MistralConfig {
219 fn vocab_size(&self) -> usize {
220 self.vocab_size
221 }
222
223 fn hidden_size(&self) -> usize {
224 self.hidden_size
225 }
226
227 fn intermediate_size(&self) -> usize {
228 self.intermediate_size
229 }
230
231 fn num_hidden_layers(&self) -> usize {
232 self.num_hidden_layers
233 }
234
235 fn num_attention_heads(&self) -> usize {
236 self.num_attention_heads
237 }
238
239 fn num_key_value_heads(&self) -> Option<usize> {
240 Some(self.num_key_value_heads)
241 }
242
243 fn max_position_embeddings(&self) -> usize {
244 self.max_position_embeddings
245 }
246
247 fn rms_norm_eps(&self) -> f64 {
248 self.rms_norm_eps
249 }
250
251 fn rope_theta(&self) -> f64 {
252 self.rope_theta
253 }
254}
255
256#[derive(Debug, Clone)]
258pub enum ModelConfig {
259 Llama(LlamaConfig),
261 Mistral(MistralConfig),
263}
264
265impl ModelConfig {
266 pub fn from_file<P: AsRef<std::path::Path>>(path: P) -> Result<Self, Box<dyn std::error::Error>> {
268 let file = std::fs::File::open(path)?;
269 let reader = std::io::BufReader::new(file);
270 let value: serde_json::Value = serde_json::from_reader(reader)?;
271
272 if value.get("sliding_window").is_some() {
274 let config: MistralConfig = serde_json::from_value(value)?;
276 Ok(ModelConfig::Mistral(config))
277 } else {
278 let config: LlamaConfig = serde_json::from_value(value)?;
280 Ok(ModelConfig::Llama(config))
281 }
282 }
283
284 pub fn as_llama(&self) -> Option<&LlamaConfig> {
286 match self {
287 ModelConfig::Llama(config) => Some(config),
288 _ => None,
289 }
290 }
291
292 pub fn as_mistral(&self) -> Option<&MistralConfig> {
294 match self {
295 ModelConfig::Mistral(config) => Some(config),
296 _ => None,
297 }
298 }
299}
300
301impl ModelConfigTrait for ModelConfig {
302 fn vocab_size(&self) -> usize {
303 match self {
304 ModelConfig::Llama(c) => c.vocab_size(),
305 ModelConfig::Mistral(c) => c.vocab_size(),
306 }
307 }
308
309 fn hidden_size(&self) -> usize {
310 match self {
311 ModelConfig::Llama(c) => c.hidden_size(),
312 ModelConfig::Mistral(c) => c.hidden_size(),
313 }
314 }
315
316 fn intermediate_size(&self) -> usize {
317 match self {
318 ModelConfig::Llama(c) => c.intermediate_size(),
319 ModelConfig::Mistral(c) => c.intermediate_size(),
320 }
321 }
322
323 fn num_hidden_layers(&self) -> usize {
324 match self {
325 ModelConfig::Llama(c) => c.num_hidden_layers(),
326 ModelConfig::Mistral(c) => c.num_hidden_layers(),
327 }
328 }
329
330 fn num_attention_heads(&self) -> usize {
331 match self {
332 ModelConfig::Llama(c) => c.num_attention_heads(),
333 ModelConfig::Mistral(c) => c.num_attention_heads(),
334 }
335 }
336
337 fn num_key_value_heads(&self) -> Option<usize> {
338 match self {
339 ModelConfig::Llama(c) => c.num_key_value_heads(),
340 ModelConfig::Mistral(c) => c.num_key_value_heads(),
341 }
342 }
343
344 fn max_position_embeddings(&self) -> usize {
345 match self {
346 ModelConfig::Llama(c) => c.max_position_embeddings(),
347 ModelConfig::Mistral(c) => c.max_position_embeddings(),
348 }
349 }
350
351 fn rms_norm_eps(&self) -> f64 {
352 match self {
353 ModelConfig::Llama(c) => c.rms_norm_eps(),
354 ModelConfig::Mistral(c) => c.rms_norm_eps(),
355 }
356 }
357
358 fn rope_theta(&self) -> f64 {
359 match self {
360 ModelConfig::Llama(c) => c.rope_theta(),
361 ModelConfig::Mistral(c) => c.rope_theta(),
362 }
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369
370 #[test]
371 fn test_llama_7b_config() {
372 let config = LlamaConfig::llama_7b();
373
374 assert_eq!(config.vocab_size, 32000);
375 assert_eq!(config.hidden_size, 4096);
376 assert_eq!(config.intermediate_size, 11008);
377 assert_eq!(config.num_hidden_layers, 32);
378 assert_eq!(config.num_attention_heads, 32);
379 assert_eq!(config.head_dim(), 128);
380 }
381
382 #[test]
383 fn test_llama_2_7b_config() {
384 let config = LlamaConfig::llama_2_7b();
385
386 assert_eq!(config.num_key_value_heads, Some(32));
387 assert_eq!(config.max_position_embeddings, 4096);
388 }
389
390 #[test]
391 fn test_llama_3_8b_config() {
392 let config = LlamaConfig::llama_3_8b();
393
394 assert_eq!(config.vocab_size, 128256);
395 assert_eq!(config.hidden_size, 4096);
396 assert_eq!(config.intermediate_size, 14336);
397 assert_eq!(config.num_attention_heads, 32);
398 assert_eq!(config.num_key_value_heads, Some(8));
399 assert_eq!(config.q_per_kv(), 4);
400 assert_eq!(config.max_position_embeddings, 8192);
401 }
402
403 #[test]
404 fn test_mistral_7b_config() {
405 let config = MistralConfig::mistral_7b();
406
407 assert_eq!(config.vocab_size, 32000);
408 assert_eq!(config.hidden_size, 4096);
409 assert_eq!(config.num_key_value_heads, 8);
410 assert_eq!(config.sliding_window, Some(4096));
411 assert_eq!(config.q_per_kv(), 4);
412 }
413}