1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
5pub enum ActivationType {
6 #[default]
7 SwiGLU,
8 GELU,
9}
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
13pub enum NormType {
14 #[default]
15 RmsNorm,
16 LayerNorm,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct Config {
21 pub vocab_size: usize,
23 pub max_seq_len: usize,
25 pub hidden_size: usize,
27 pub num_layers: usize,
29 pub num_heads: usize,
31 #[serde(default)]
33 pub num_kv_heads: Option<usize>,
34 pub intermediate_size: usize,
36 pub dropout: f64,
38 pub layer_norm_eps: f64,
40 pub use_bias: bool,
42 pub rope_theta: f64,
44 #[serde(default)]
46 pub activation: ActivationType,
47 #[serde(default)]
49 pub norm_type: NormType,
50}
51
52impl Config {
53 pub fn gpt2_small() -> Self {
55 Self {
56 vocab_size: 50257,
57 max_seq_len: 1024,
58 hidden_size: 768,
59 num_layers: 12,
60 num_heads: 12,
61 num_kv_heads: None,
62 intermediate_size: 3072,
63 dropout: 0.1,
64 layer_norm_eps: 1e-5,
65 use_bias: true,
66 rope_theta: 10000.0,
67 activation: ActivationType::GELU,
68 norm_type: NormType::LayerNorm,
69 }
70 }
71
72 pub fn gpt2_medium() -> Self {
74 Self {
75 vocab_size: 50257,
76 max_seq_len: 1024,
77 hidden_size: 1024,
78 num_layers: 24,
79 num_heads: 16,
80 num_kv_heads: None,
81 intermediate_size: 4096,
82 dropout: 0.1,
83 layer_norm_eps: 1e-5,
84 use_bias: true,
85 rope_theta: 10000.0,
86 activation: ActivationType::GELU,
87 norm_type: NormType::LayerNorm,
88 }
89 }
90
91 pub fn gpt2_large() -> Self {
93 Self {
94 vocab_size: 50257,
95 max_seq_len: 1024,
96 hidden_size: 1280,
97 num_layers: 36,
98 num_heads: 20,
99 num_kv_heads: None,
100 intermediate_size: 5120,
101 dropout: 0.1,
102 layer_norm_eps: 1e-5,
103 use_bias: true,
104 rope_theta: 10000.0,
105 activation: ActivationType::GELU,
106 norm_type: NormType::LayerNorm,
107 }
108 }
109
110 pub fn nano() -> Self {
112 Self {
113 vocab_size: 1000,
114 max_seq_len: 128,
115 hidden_size: 64,
116 num_layers: 2,
117 num_heads: 2,
118 num_kv_heads: None,
119 intermediate_size: 256,
120 dropout: 0.1,
121 layer_norm_eps: 1e-5,
122 use_bias: false,
123 rope_theta: 10000.0,
124 activation: ActivationType::GELU,
125 norm_type: NormType::RmsNorm,
126 }
127 }
128
129 pub fn tiny() -> Self {
131 Self {
132 vocab_size: 1000,
133 max_seq_len: 256,
134 hidden_size: 128,
135 num_layers: 4,
136 num_heads: 4,
137 num_kv_heads: None,
138 intermediate_size: 512,
139 dropout: 0.1,
140 layer_norm_eps: 1e-5,
141 use_bias: false,
142 rope_theta: 10000.0,
143 activation: ActivationType::GELU,
144 norm_type: NormType::RmsNorm,
145 }
146 }
147
148 pub fn llama_small() -> Self {
150 Self {
151 vocab_size: 32000,
152 max_seq_len: 2048,
153 hidden_size: 1024,
154 num_layers: 16,
155 num_heads: 16,
156 num_kv_heads: None,
157 intermediate_size: 2752,
158 dropout: 0.0,
159 layer_norm_eps: 1e-6,
160 use_bias: false,
161 rope_theta: 10000.0,
162 activation: ActivationType::SwiGLU,
163 norm_type: NormType::RmsNorm,
164 }
165 }
166
167 pub fn head_dim(&self) -> usize {
168 self.hidden_size / self.num_heads
169 }
170
171 pub fn from_json(path: &str) -> anyhow::Result<Self> {
172 let content = std::fs::read_to_string(path)?;
173 Ok(serde_json::from_str(&content)?)
174 }
175
176 pub fn save_json(&self, path: &str) -> anyhow::Result<()> {
177 let content = serde_json::to_string_pretty(self)?;
178 std::fs::write(path, content)?;
179 Ok(())
180 }
181}
182
183#[derive(Debug, Clone, Serialize, Deserialize)]
184pub struct TrainingConfig {
185 pub learning_rate: f64,
187 pub weight_decay: f64,
189 pub beta1: f64,
191 pub beta2: f64,
193 pub grad_clip: f64,
195 pub batch_size: usize,
197 pub epochs: usize,
199 pub warmup_steps: usize,
201 pub save_every: usize,
203 pub eval_every: usize,
205 pub log_every: usize,
207 pub seq_len: usize,
209 pub gradient_accumulation_steps: usize,
211}
212
213impl Default for TrainingConfig {
214 fn default() -> Self {
215 Self {
216 learning_rate: 3e-4,
217 weight_decay: 0.1,
218 beta1: 0.9,
219 beta2: 0.95,
220 grad_clip: 1.0,
221 batch_size: 32,
222 epochs: 1,
223 warmup_steps: 1000,
224 save_every: 1000,
225 eval_every: 500,
226 log_every: 10,
227 seq_len: 512,
228 gradient_accumulation_steps: 1,
229 }
230 }
231}