use serde::{Deserialize, Serialize};
use trustformers_core::{errors::Result, traits::Config};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GPTNeoXConfig {
pub vocab_size: usize,
pub hidden_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub intermediate_size: usize,
pub max_position_embeddings: usize,
pub layer_norm_eps: f32,
pub hidden_act: String,
pub rotary_emb_base: f32,
pub rotary_pct: f32,
pub use_parallel_residual: bool,
pub tie_word_embeddings: bool,
pub initializer_range: f32,
pub bos_token_id: Option<u32>,
pub eos_token_id: Option<u32>,
}
impl Default for GPTNeoXConfig {
fn default() -> Self {
Self {
vocab_size: 50432,
hidden_size: 2048,
num_hidden_layers: 16,
num_attention_heads: 16,
intermediate_size: 8192,
max_position_embeddings: 2048,
layer_norm_eps: 1e-5,
hidden_act: "gelu".to_string(),
rotary_emb_base: 10000.0,
rotary_pct: 1.0,
use_parallel_residual: false,
tie_word_embeddings: false,
initializer_range: 0.02,
bos_token_id: Some(0),
eos_token_id: Some(2),
}
}
}
impl Config for GPTNeoXConfig {
fn validate(&self) -> Result<()> {
if self.hidden_size == 0 {
return Err(
trustformers_core::errors::TrustformersError::invalid_config(
"hidden_size must be greater than 0".to_string(),
),
);
}
if self.num_hidden_layers == 0 {
return Err(
trustformers_core::errors::TrustformersError::invalid_config(
"num_hidden_layers must be greater than 0".to_string(),
),
);
}
if self.num_attention_heads == 0 {
return Err(
trustformers_core::errors::TrustformersError::invalid_config(
"num_attention_heads must be greater than 0".to_string(),
),
);
}
if !self.hidden_size.is_multiple_of(self.num_attention_heads) {
return Err(
trustformers_core::errors::TrustformersError::invalid_config(format!(
"hidden_size must be divisible by num_attention_heads ({})",
self.num_attention_heads
)),
);
}
if self.rotary_pct < 0.0 || self.rotary_pct > 1.0 {
return Err(
trustformers_core::errors::TrustformersError::invalid_config(
"rotary_pct must be between 0.0 and 1.0".to_string(),
),
);
}
Ok(())
}
fn architecture(&self) -> &'static str {
"gpt_neox"
}
}
impl GPTNeoXConfig {
pub fn rinna_japanese_3_6b() -> Self {
Self {
vocab_size: 32000,
hidden_size: 2816,
num_hidden_layers: 36,
num_attention_heads: 22,
intermediate_size: 11264,
max_position_embeddings: 2048,
layer_norm_eps: 1e-5,
hidden_act: "gelu".to_string(),
rotary_emb_base: 10000.0,
rotary_pct: 1.0,
use_parallel_residual: false,
tie_word_embeddings: false,
initializer_range: 0.02,
bos_token_id: Some(2),
eos_token_id: Some(3),
}
}
pub fn pythia_160m() -> Self {
Self {
vocab_size: 50304,
hidden_size: 768,
num_hidden_layers: 12,
num_attention_heads: 12,
intermediate_size: 3072,
max_position_embeddings: 2048,
layer_norm_eps: 1e-5,
hidden_act: "gelu".to_string(),
rotary_emb_base: 10000.0,
rotary_pct: 0.25,
use_parallel_residual: true,
tie_word_embeddings: false,
initializer_range: 0.02,
bos_token_id: Some(0),
eos_token_id: Some(0),
}
}
}