Skip to main content

hermes_llm/
config.rs

1use serde::{Deserialize, Serialize};
2
3/// Activation function type
4#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
5pub enum ActivationType {
6    #[default]
7    SwiGLU,
8    GELU,
9}
10
11/// Normalization type
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
13pub enum NormType {
14    #[default]
15    RmsNorm,
16    LayerNorm,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct Config {
21    /// Vocabulary size
22    pub vocab_size: usize,
23    /// Maximum sequence length (context window)
24    pub max_seq_len: usize,
25    /// Embedding dimension
26    pub hidden_size: usize,
27    /// Number of transformer layers
28    pub num_layers: usize,
29    /// Number of attention heads
30    pub num_heads: usize,
31    /// Number of key-value heads (for GQA, defaults to num_heads)
32    #[serde(default)]
33    pub num_kv_heads: Option<usize>,
34    /// Intermediate size in FFN (typically 4x hidden_size)
35    pub intermediate_size: usize,
36    /// Dropout probability
37    pub dropout: f64,
38    /// Layer norm epsilon
39    pub layer_norm_eps: f64,
40    /// Whether to use bias in linear layers
41    pub use_bias: bool,
42    /// RoPE base frequency
43    pub rope_theta: f64,
44    /// Activation function (SwiGLU or GELU)
45    #[serde(default)]
46    pub activation: ActivationType,
47    /// Normalization type (RmsNorm or LayerNorm)
48    #[serde(default)]
49    pub norm_type: NormType,
50}
51
52impl Config {
53    /// GPT-2 Small configuration (124M parameters)
54    pub fn gpt2_small() -> Self {
55        Self {
56            vocab_size: 50257,
57            max_seq_len: 1024,
58            hidden_size: 768,
59            num_layers: 12,
60            num_heads: 12,
61            num_kv_heads: None,
62            intermediate_size: 3072,
63            dropout: 0.1,
64            layer_norm_eps: 1e-5,
65            use_bias: true,
66            rope_theta: 10000.0,
67            activation: ActivationType::GELU,
68            norm_type: NormType::LayerNorm,
69        }
70    }
71
72    /// GPT-2 Medium configuration (355M parameters)
73    pub fn gpt2_medium() -> Self {
74        Self {
75            vocab_size: 50257,
76            max_seq_len: 1024,
77            hidden_size: 1024,
78            num_layers: 24,
79            num_heads: 16,
80            num_kv_heads: None,
81            intermediate_size: 4096,
82            dropout: 0.1,
83            layer_norm_eps: 1e-5,
84            use_bias: true,
85            rope_theta: 10000.0,
86            activation: ActivationType::GELU,
87            norm_type: NormType::LayerNorm,
88        }
89    }
90
91    /// GPT-2 Large configuration (774M parameters)
92    pub fn gpt2_large() -> Self {
93        Self {
94            vocab_size: 50257,
95            max_seq_len: 1024,
96            hidden_size: 1280,
97            num_layers: 36,
98            num_heads: 20,
99            num_kv_heads: None,
100            intermediate_size: 5120,
101            dropout: 0.1,
102            layer_norm_eps: 1e-5,
103            use_bias: true,
104            rope_theta: 10000.0,
105            activation: ActivationType::GELU,
106            norm_type: NormType::LayerNorm,
107        }
108    }
109
110    /// Nano configuration (~500K params) - fastest for testing
111    pub fn nano() -> Self {
112        Self {
113            vocab_size: 1000,
114            max_seq_len: 128,
115            hidden_size: 64,
116            num_layers: 2,
117            num_heads: 2,
118            num_kv_heads: None,
119            intermediate_size: 256,
120            dropout: 0.1,
121            layer_norm_eps: 1e-5,
122            use_bias: false,
123            rope_theta: 10000.0,
124            activation: ActivationType::GELU,
125            norm_type: NormType::RmsNorm,
126        }
127    }
128
129    /// Tiny configuration for testing/debugging (~9M params)
130    pub fn tiny() -> Self {
131        Self {
132            vocab_size: 1000,
133            max_seq_len: 256,
134            hidden_size: 128,
135            num_layers: 4,
136            num_heads: 4,
137            num_kv_heads: None,
138            intermediate_size: 512,
139            dropout: 0.1,
140            layer_norm_eps: 1e-5,
141            use_bias: false,
142            rope_theta: 10000.0,
143            activation: ActivationType::GELU,
144            norm_type: NormType::RmsNorm,
145        }
146    }
147
148    /// LLaMA-style configuration (no bias, RMSNorm, SwiGLU)
149    pub fn llama_small() -> Self {
150        Self {
151            vocab_size: 32000,
152            max_seq_len: 2048,
153            hidden_size: 1024,
154            num_layers: 16,
155            num_heads: 16,
156            num_kv_heads: None,
157            intermediate_size: 2752,
158            dropout: 0.0,
159            layer_norm_eps: 1e-6,
160            use_bias: false,
161            rope_theta: 10000.0,
162            activation: ActivationType::SwiGLU,
163            norm_type: NormType::RmsNorm,
164        }
165    }
166
167    pub fn head_dim(&self) -> usize {
168        self.hidden_size / self.num_heads
169    }
170
171    pub fn from_json(path: &str) -> anyhow::Result<Self> {
172        let content = std::fs::read_to_string(path)?;
173        Ok(serde_json::from_str(&content)?)
174    }
175
176    pub fn save_json(&self, path: &str) -> anyhow::Result<()> {
177        let content = serde_json::to_string_pretty(self)?;
178        std::fs::write(path, content)?;
179        Ok(())
180    }
181}
182
183#[derive(Debug, Clone, Serialize, Deserialize)]
184pub struct TrainingConfig {
185    /// Learning rate
186    pub learning_rate: f64,
187    /// Weight decay for AdamW
188    pub weight_decay: f64,
189    /// Adam beta1
190    pub beta1: f64,
191    /// Adam beta2
192    pub beta2: f64,
193    /// Gradient clipping max norm
194    pub grad_clip: f64,
195    /// Batch size
196    pub batch_size: usize,
197    /// Number of training epochs
198    pub epochs: usize,
199    /// Warmup steps for learning rate scheduler
200    pub warmup_steps: usize,
201    /// Save checkpoint every N steps
202    pub save_every: usize,
203    /// Evaluate every N steps
204    pub eval_every: usize,
205    /// Log every N steps
206    pub log_every: usize,
207    /// Sequence length for training
208    pub seq_len: usize,
209    /// Gradient accumulation steps
210    pub gradient_accumulation_steps: usize,
211}
212
213impl Default for TrainingConfig {
214    fn default() -> Self {
215        Self {
216            learning_rate: 3e-4,
217            weight_decay: 0.1,
218            beta1: 0.9,
219            beta2: 0.95,
220            grad_clip: 1.0,
221            batch_size: 32,
222            epochs: 1,
223            warmup_steps: 1000,
224            save_every: 1000,
225            eval_every: 500,
226            log_every: 10,
227            seq_len: 512,
228            gradient_accumulation_steps: 1,
229        }
230    }
231}