use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum ActivationType {
#[default]
SwiGLU,
GELU,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum NormType {
#[default]
RmsNorm,
LayerNorm,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
pub vocab_size: usize,
pub max_seq_len: usize,
pub hidden_size: usize,
pub num_layers: usize,
pub num_heads: usize,
#[serde(default)]
pub num_kv_heads: Option<usize>,
pub intermediate_size: usize,
pub dropout: f64,
pub layer_norm_eps: f64,
pub use_bias: bool,
pub rope_theta: f64,
#[serde(default)]
pub activation: ActivationType,
#[serde(default)]
pub norm_type: NormType,
}
impl Config {
pub fn gpt2_small() -> Self {
Self {
vocab_size: 50257,
max_seq_len: 1024,
hidden_size: 768,
num_layers: 12,
num_heads: 12,
num_kv_heads: None,
intermediate_size: 3072,
dropout: 0.1,
layer_norm_eps: 1e-5,
use_bias: true,
rope_theta: 10000.0,
activation: ActivationType::GELU,
norm_type: NormType::LayerNorm,
}
}
pub fn gpt2_medium() -> Self {
Self {
vocab_size: 50257,
max_seq_len: 1024,
hidden_size: 1024,
num_layers: 24,
num_heads: 16,
num_kv_heads: None,
intermediate_size: 4096,
dropout: 0.1,
layer_norm_eps: 1e-5,
use_bias: true,
rope_theta: 10000.0,
activation: ActivationType::GELU,
norm_type: NormType::LayerNorm,
}
}
pub fn gpt2_large() -> Self {
Self {
vocab_size: 50257,
max_seq_len: 1024,
hidden_size: 1280,
num_layers: 36,
num_heads: 20,
num_kv_heads: None,
intermediate_size: 5120,
dropout: 0.1,
layer_norm_eps: 1e-5,
use_bias: true,
rope_theta: 10000.0,
activation: ActivationType::GELU,
norm_type: NormType::LayerNorm,
}
}
pub fn nano() -> Self {
Self {
vocab_size: 1000,
max_seq_len: 128,
hidden_size: 64,
num_layers: 2,
num_heads: 2,
num_kv_heads: None,
intermediate_size: 256,
dropout: 0.1,
layer_norm_eps: 1e-5,
use_bias: false,
rope_theta: 10000.0,
activation: ActivationType::GELU,
norm_type: NormType::RmsNorm,
}
}
pub fn tiny() -> Self {
Self {
vocab_size: 1000,
max_seq_len: 256,
hidden_size: 128,
num_layers: 4,
num_heads: 4,
num_kv_heads: None,
intermediate_size: 512,
dropout: 0.1,
layer_norm_eps: 1e-5,
use_bias: false,
rope_theta: 10000.0,
activation: ActivationType::GELU,
norm_type: NormType::RmsNorm,
}
}
pub fn llama_small() -> Self {
Self {
vocab_size: 32000,
max_seq_len: 2048,
hidden_size: 1024,
num_layers: 16,
num_heads: 16,
num_kv_heads: None,
intermediate_size: 2752,
dropout: 0.0,
layer_norm_eps: 1e-6,
use_bias: false,
rope_theta: 10000.0,
activation: ActivationType::SwiGLU,
norm_type: NormType::RmsNorm,
}
}
pub fn head_dim(&self) -> usize {
self.hidden_size / self.num_heads
}
pub fn from_json(path: &str) -> anyhow::Result<Self> {
let content = std::fs::read_to_string(path)?;
Ok(serde_json::from_str(&content)?)
}
pub fn save_json(&self, path: &str) -> anyhow::Result<()> {
let content = serde_json::to_string_pretty(self)?;
std::fs::write(path, content)?;
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingConfig {
pub learning_rate: f64,
pub weight_decay: f64,
pub beta1: f64,
pub beta2: f64,
pub grad_clip: f64,
pub batch_size: usize,
pub epochs: usize,
pub warmup_steps: usize,
pub save_every: usize,
pub eval_every: usize,
pub log_every: usize,
pub seq_len: usize,
pub gradient_accumulation_steps: usize,
}
impl Default for TrainingConfig {
fn default() -> Self {
Self {
learning_rate: 3e-4,
weight_decay: 0.1,
beta1: 0.9,
beta2: 0.95,
grad_clip: 1.0,
batch_size: 32,
epochs: 1,
warmup_steps: 1000,
save_every: 1000,
eval_every: 500,
log_every: 10,
seq_len: 512,
gradient_accumulation_steps: 1,
}
}
}