Skip to main content

god_graph/transformer/loader/
config.rs

1//! Model configuration for LLaMA and Mistral
2
3use serde::{Deserialize, Serialize};
4
5/// Base model configuration trait
6pub trait ModelConfigTrait {
7    /// Get vocabulary size
8    fn vocab_size(&self) -> usize;
9    
10    /// Get hidden dimension
11    fn hidden_size(&self) -> usize;
12    
13    /// Get intermediate dimension (FFN)
14    fn intermediate_size(&self) -> usize;
15    
16    /// Get number of hidden layers
17    fn num_hidden_layers(&self) -> usize;
18    
19    /// Get number of attention heads
20    fn num_attention_heads(&self) -> usize;
21    
22    /// Get number of KV heads (for GQA)
23    fn num_key_value_heads(&self) -> Option<usize>;
24    
25    /// Get maximum position embeddings
26    fn max_position_embeddings(&self) -> usize;
27    
28    /// Get RMS norm epsilon
29    fn rms_norm_eps(&self) -> f64;
30    
31    /// Get RoPE theta base
32    fn rope_theta(&self) -> f64;
33}
34
35/// LLaMA model configuration
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct LlamaConfig {
38    /// Vocabulary size
39    pub vocab_size: usize,
40    /// Hidden dimension
41    pub hidden_size: usize,
42    /// Intermediate dimension (FFN)
43    pub intermediate_size: usize,
44    /// Number of hidden layers
45    pub num_hidden_layers: usize,
46    /// Number of attention heads
47    pub num_attention_heads: usize,
48    /// Number of KV heads (for GQA, LLaMA-2/3)
49    pub num_key_value_heads: Option<usize>,
50    /// Maximum position embeddings
51    pub max_position_embeddings: usize,
52    /// RMS norm epsilon
53    pub rms_norm_eps: f64,
54    /// RoPE theta base
55    pub rope_theta: f64,
56    /// Whether to use tied word embeddings
57    pub tie_word_embeddings: bool,
58    /// Attention bias
59    pub attention_bias: bool,
60}
61
62impl LlamaConfig {
63    /// Create a new LLaMA config with default values for LLaMA-2 7B
64    pub fn llama_7b() -> Self {
65        Self {
66            vocab_size: 32000,
67            hidden_size: 4096,
68            intermediate_size: 11008,
69            num_hidden_layers: 32,
70            num_attention_heads: 32,
71            num_key_value_heads: None, // LLaMA-1 uses same KV heads as Q
72            max_position_embeddings: 2048,
73            rms_norm_eps: 1e-6,
74            rope_theta: 10000.0,
75            tie_word_embeddings: false,
76            attention_bias: false,
77        }
78    }
79
80    /// Create config for LLaMA-2 7B
81    pub fn llama_2_7b() -> Self {
82        let mut config = Self::llama_7b();
83        config.num_key_value_heads = Some(32); // LLaMA-2 uses GQA
84        config.max_position_embeddings = 4096;
85        config
86    }
87
88    /// Create config for LLaMA-3 8B
89    pub fn llama_3_8b() -> Self {
90        Self {
91            vocab_size: 128256,
92            hidden_size: 4096,
93            intermediate_size: 14336,
94            num_hidden_layers: 32,
95            num_attention_heads: 32,
96            num_key_value_heads: Some(8), // GQA with 8 KV heads
97            max_position_embeddings: 8192,
98            rms_norm_eps: 1e-5,
99            rope_theta: 500000.0,
100            tie_word_embeddings: false,
101            attention_bias: false,
102        }
103    }
104
105    /// Get number of KV heads (defaults to num_attention_heads if not specified)
106    pub fn get_num_key_value_heads(&self) -> usize {
107        self.num_key_value_heads.unwrap_or(self.num_attention_heads)
108    }
109
110    /// Get head dimension
111    pub fn head_dim(&self) -> usize {
112        self.hidden_size / self.num_attention_heads
113    }
114
115    /// Get number of Q heads per KV head (for GQA)
116    pub fn q_per_kv(&self) -> usize {
117        self.num_attention_heads / self.get_num_key_value_heads()
118    }
119}
120
121impl ModelConfigTrait for LlamaConfig {
122    fn vocab_size(&self) -> usize {
123        self.vocab_size
124    }
125
126    fn hidden_size(&self) -> usize {
127        self.hidden_size
128    }
129
130    fn intermediate_size(&self) -> usize {
131        self.intermediate_size
132    }
133
134    fn num_hidden_layers(&self) -> usize {
135        self.num_hidden_layers
136    }
137
138    fn num_attention_heads(&self) -> usize {
139        self.num_attention_heads
140    }
141
142    fn num_key_value_heads(&self) -> Option<usize> {
143        self.num_key_value_heads
144    }
145
146    fn max_position_embeddings(&self) -> usize {
147        self.max_position_embeddings
148    }
149
150    fn rms_norm_eps(&self) -> f64 {
151        self.rms_norm_eps
152    }
153
154    fn rope_theta(&self) -> f64 {
155        self.rope_theta
156    }
157}
158
159/// Mistral model configuration
160#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct MistralConfig {
162    /// Vocabulary size
163    pub vocab_size: usize,
164    /// Hidden dimension
165    pub hidden_size: usize,
166    /// Intermediate dimension (FFN)
167    pub intermediate_size: usize,
168    /// Number of hidden layers
169    pub num_hidden_layers: usize,
170    /// Number of attention heads
171    pub num_attention_heads: usize,
172    /// Number of KV heads (for GQA)
173    pub num_key_value_heads: usize,
174    /// Maximum position embeddings
175    pub max_position_embeddings: usize,
176    /// RMS norm epsilon
177    pub rms_norm_eps: f64,
178    /// RoPE theta base
179    pub rope_theta: f64,
180    /// Sliding window size (Mistral specific)
181    pub sliding_window: Option<usize>,
182    /// Whether to use tied word embeddings
183    pub tie_word_embeddings: bool,
184    /// Attention bias
185    pub attention_bias: bool,
186}
187
188impl MistralConfig {
189    /// Create a new Mistral config with default values for Mistral 7B
190    pub fn mistral_7b() -> Self {
191        Self {
192            vocab_size: 32000,
193            hidden_size: 4096,
194            intermediate_size: 14336,
195            num_hidden_layers: 32,
196            num_attention_heads: 32,
197            num_key_value_heads: 8, // GQA with 8 KV heads
198            max_position_embeddings: 2048,
199            rms_norm_eps: 1e-5,
200            rope_theta: 10000.0,
201            sliding_window: Some(4096),
202            tie_word_embeddings: false,
203            attention_bias: false,
204        }
205    }
206
207    /// Get head dimension
208    pub fn head_dim(&self) -> usize {
209        self.hidden_size / self.num_attention_heads
210    }
211
212    /// Get number of Q heads per KV head
213    pub fn q_per_kv(&self) -> usize {
214        self.num_attention_heads / self.num_key_value_heads
215    }
216}
217
218impl ModelConfigTrait for MistralConfig {
219    fn vocab_size(&self) -> usize {
220        self.vocab_size
221    }
222
223    fn hidden_size(&self) -> usize {
224        self.hidden_size
225    }
226
227    fn intermediate_size(&self) -> usize {
228        self.intermediate_size
229    }
230
231    fn num_hidden_layers(&self) -> usize {
232        self.num_hidden_layers
233    }
234
235    fn num_attention_heads(&self) -> usize {
236        self.num_attention_heads
237    }
238
239    fn num_key_value_heads(&self) -> Option<usize> {
240        Some(self.num_key_value_heads)
241    }
242
243    fn max_position_embeddings(&self) -> usize {
244        self.max_position_embeddings
245    }
246
247    fn rms_norm_eps(&self) -> f64 {
248        self.rms_norm_eps
249    }
250
251    fn rope_theta(&self) -> f64 {
252        self.rope_theta
253    }
254}
255
256/// Enum for any supported model config
257#[derive(Debug, Clone)]
258pub enum ModelConfig {
259    /// LLaMA model configuration
260    Llama(LlamaConfig),
261    /// Mistral model configuration
262    Mistral(MistralConfig),
263}
264
265impl ModelConfig {
266    /// Load config from a JSON file
267    pub fn from_file<P: AsRef<std::path::Path>>(path: P) -> Result<Self, Box<dyn std::error::Error>> {
268        let file = std::fs::File::open(path)?;
269        let reader = std::io::BufReader::new(file);
270        let value: serde_json::Value = serde_json::from_reader(reader)?;
271        
272        // Try to detect model type from config
273        if value.get("sliding_window").is_some() {
274            // Mistral has sliding_window
275            let config: MistralConfig = serde_json::from_value(value)?;
276            Ok(ModelConfig::Mistral(config))
277        } else {
278            // Default to LLaMA
279            let config: LlamaConfig = serde_json::from_value(value)?;
280            Ok(ModelConfig::Llama(config))
281        }
282    }
283
284    /// Get the config as LlamaConfig if applicable
285    pub fn as_llama(&self) -> Option<&LlamaConfig> {
286        match self {
287            ModelConfig::Llama(config) => Some(config),
288            _ => None,
289        }
290    }
291
292    /// Get the config as MistralConfig if applicable
293    pub fn as_mistral(&self) -> Option<&MistralConfig> {
294        match self {
295            ModelConfig::Mistral(config) => Some(config),
296            _ => None,
297        }
298    }
299}
300
301impl ModelConfigTrait for ModelConfig {
302    fn vocab_size(&self) -> usize {
303        match self {
304            ModelConfig::Llama(c) => c.vocab_size(),
305            ModelConfig::Mistral(c) => c.vocab_size(),
306        }
307    }
308
309    fn hidden_size(&self) -> usize {
310        match self {
311            ModelConfig::Llama(c) => c.hidden_size(),
312            ModelConfig::Mistral(c) => c.hidden_size(),
313        }
314    }
315
316    fn intermediate_size(&self) -> usize {
317        match self {
318            ModelConfig::Llama(c) => c.intermediate_size(),
319            ModelConfig::Mistral(c) => c.intermediate_size(),
320        }
321    }
322
323    fn num_hidden_layers(&self) -> usize {
324        match self {
325            ModelConfig::Llama(c) => c.num_hidden_layers(),
326            ModelConfig::Mistral(c) => c.num_hidden_layers(),
327        }
328    }
329
330    fn num_attention_heads(&self) -> usize {
331        match self {
332            ModelConfig::Llama(c) => c.num_attention_heads(),
333            ModelConfig::Mistral(c) => c.num_attention_heads(),
334        }
335    }
336
337    fn num_key_value_heads(&self) -> Option<usize> {
338        match self {
339            ModelConfig::Llama(c) => c.num_key_value_heads(),
340            ModelConfig::Mistral(c) => c.num_key_value_heads(),
341        }
342    }
343
344    fn max_position_embeddings(&self) -> usize {
345        match self {
346            ModelConfig::Llama(c) => c.max_position_embeddings(),
347            ModelConfig::Mistral(c) => c.max_position_embeddings(),
348        }
349    }
350
351    fn rms_norm_eps(&self) -> f64 {
352        match self {
353            ModelConfig::Llama(c) => c.rms_norm_eps(),
354            ModelConfig::Mistral(c) => c.rms_norm_eps(),
355        }
356    }
357
358    fn rope_theta(&self) -> f64 {
359        match self {
360            ModelConfig::Llama(c) => c.rope_theta(),
361            ModelConfig::Mistral(c) => c.rope_theta(),
362        }
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369
370    #[test]
371    fn test_llama_7b_config() {
372        let config = LlamaConfig::llama_7b();
373        
374        assert_eq!(config.vocab_size, 32000);
375        assert_eq!(config.hidden_size, 4096);
376        assert_eq!(config.intermediate_size, 11008);
377        assert_eq!(config.num_hidden_layers, 32);
378        assert_eq!(config.num_attention_heads, 32);
379        assert_eq!(config.head_dim(), 128);
380    }
381
382    #[test]
383    fn test_llama_2_7b_config() {
384        let config = LlamaConfig::llama_2_7b();
385        
386        assert_eq!(config.num_key_value_heads, Some(32));
387        assert_eq!(config.max_position_embeddings, 4096);
388    }
389
390    #[test]
391    fn test_llama_3_8b_config() {
392        let config = LlamaConfig::llama_3_8b();
393        
394        assert_eq!(config.vocab_size, 128256);
395        assert_eq!(config.hidden_size, 4096);
396        assert_eq!(config.intermediate_size, 14336);
397        assert_eq!(config.num_attention_heads, 32);
398        assert_eq!(config.num_key_value_heads, Some(8));
399        assert_eq!(config.q_per_kv(), 4);
400        assert_eq!(config.max_position_embeddings, 8192);
401    }
402
403    #[test]
404    fn test_mistral_7b_config() {
405        let config = MistralConfig::mistral_7b();
406        
407        assert_eq!(config.vocab_size, 32000);
408        assert_eq!(config.hidden_size, 4096);
409        assert_eq!(config.num_key_value_heads, 8);
410        assert_eq!(config.sliding_window, Some(4096));
411        assert_eq!(config.q_per_kv(), 4);
412    }
413}