use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OptConfig {
pub vocab_size: usize,
pub hidden_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub ffn_dim: usize,
pub max_position_embeddings: usize,
pub word_embed_proj_dim: usize,
pub layer_norm_eps: f64,
pub dropout: f64,
pub do_layer_norm_before: bool,
pub activation_function: String,
pub use_cache: bool,
pub bos_token_id: u32,
pub eos_token_id: u32,
pub pad_token_id: Option<u32>,
}
impl Default for OptConfig {
fn default() -> Self {
Self {
vocab_size: 50272,
hidden_size: 768,
num_hidden_layers: 12,
num_attention_heads: 12,
ffn_dim: 3072,
max_position_embeddings: 2048,
word_embed_proj_dim: 768,
layer_norm_eps: 1e-5,
dropout: 0.0,
do_layer_norm_before: true,
activation_function: "relu".to_string(),
use_cache: true,
bos_token_id: 2,
eos_token_id: 2,
pad_token_id: Some(1),
}
}
}
impl OptConfig {
pub fn opt_125m() -> Self {
Self::default()
}
pub fn opt_350m() -> Self {
Self {
vocab_size: 50272,
hidden_size: 1024,
num_hidden_layers: 24,
num_attention_heads: 16,
ffn_dim: 4096,
max_position_embeddings: 2048,
word_embed_proj_dim: 512, layer_norm_eps: 1e-5,
dropout: 0.0,
do_layer_norm_before: false, activation_function: "relu".to_string(),
use_cache: true,
bos_token_id: 2,
eos_token_id: 2,
pad_token_id: Some(1),
}
}
pub fn opt_6_7b() -> Self {
Self {
vocab_size: 50272,
hidden_size: 4096,
num_hidden_layers: 32,
num_attention_heads: 32,
ffn_dim: 16384,
max_position_embeddings: 2048,
word_embed_proj_dim: 4096,
layer_norm_eps: 1e-5,
dropout: 0.0,
do_layer_norm_before: true,
activation_function: "relu".to_string(),
use_cache: true,
bos_token_id: 2,
eos_token_id: 2,
pad_token_id: Some(1),
}
}
pub fn head_dim(&self) -> usize {
self.hidden_size.checked_div(self.num_attention_heads).unwrap_or(0)
}
pub fn validate(&self) -> Result<(), String> {
if self.vocab_size == 0 {
return Err("vocab_size must be > 0".to_string());
}
if self.hidden_size == 0 {
return Err("hidden_size must be > 0".to_string());
}
if self.num_hidden_layers == 0 {
return Err("num_hidden_layers must be > 0".to_string());
}
if self.num_attention_heads == 0 {
return Err("num_attention_heads must be > 0".to_string());
}
if self.ffn_dim == 0 {
return Err("ffn_dim must be > 0".to_string());
}
if !self.hidden_size.is_multiple_of(self.num_attention_heads) {
return Err("hidden_size must be divisible by num_attention_heads".to_string());
}
if self.word_embed_proj_dim == 0 {
return Err("word_embed_proj_dim must be > 0".to_string());
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_opt_config_default() {
let cfg = OptConfig::default();
assert_eq!(cfg.vocab_size, 50272);
assert_eq!(cfg.hidden_size, 768);
assert_eq!(cfg.num_hidden_layers, 12);
assert_eq!(cfg.num_attention_heads, 12);
assert!(cfg.validate().is_ok());
}
#[test]
fn test_opt_config_125m() {
let cfg = OptConfig::opt_125m();
assert_eq!(cfg.hidden_size, 768);
assert_eq!(cfg.ffn_dim, 3072); assert_eq!(cfg.word_embed_proj_dim, 768);
assert!(cfg.validate().is_ok());
}
#[test]
fn test_opt_config_350m() {
let cfg = OptConfig::opt_350m();
assert_eq!(cfg.hidden_size, 1024);
assert_eq!(cfg.num_hidden_layers, 24);
assert_eq!(cfg.num_attention_heads, 16);
assert_eq!(cfg.ffn_dim, 4096);
assert_eq!(cfg.word_embed_proj_dim, 512);
assert!(cfg.validate().is_ok());
}
#[test]
fn test_opt_config_6_7b() {
let cfg = OptConfig::opt_6_7b();
assert_eq!(cfg.hidden_size, 4096);
assert_eq!(cfg.num_hidden_layers, 32);
assert_eq!(cfg.ffn_dim, 16384);
assert!(cfg.validate().is_ok());
}
#[test]
fn test_opt_head_dim() {
let cfg = OptConfig::opt_125m();
assert_eq!(cfg.head_dim(), 64);
let cfg6b = OptConfig::opt_6_7b();
assert_eq!(cfg6b.head_dim(), 128); }
#[test]
fn test_opt_embed_proj_dim() {
let cfg = OptConfig::opt_350m();
assert_ne!(cfg.word_embed_proj_dim, cfg.hidden_size);
let cfg_125m = OptConfig::opt_125m();
assert_eq!(cfg_125m.word_embed_proj_dim, cfg_125m.hidden_size);
}
#[test]
fn test_opt_config_validation_invalid() {
let cfg = OptConfig {
num_attention_heads: 7,
..OptConfig::default()
}; assert!(cfg.validate().is_err());
let cfg2 = OptConfig {
num_attention_heads: 12,
hidden_size: 0,
..OptConfig::default()
};
assert!(cfg2.validate().is_err());
}
}