use serde::{Deserialize, Serialize};
pub trait ModelConfigTrait {
fn vocab_size(&self) -> usize;
fn hidden_size(&self) -> usize;
fn intermediate_size(&self) -> usize;
fn num_hidden_layers(&self) -> usize;
fn num_attention_heads(&self) -> usize;
fn num_key_value_heads(&self) -> Option<usize>;
fn max_position_embeddings(&self) -> usize;
fn rms_norm_eps(&self) -> f64;
fn rope_theta(&self) -> f64;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlamaConfig {
pub vocab_size: usize,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_key_value_heads: Option<usize>,
pub max_position_embeddings: usize,
pub rms_norm_eps: f64,
pub rope_theta: f64,
pub tie_word_embeddings: bool,
pub attention_bias: bool,
}
impl LlamaConfig {
pub fn llama_7b() -> Self {
Self {
vocab_size: 32000,
hidden_size: 4096,
intermediate_size: 11008,
num_hidden_layers: 32,
num_attention_heads: 32,
num_key_value_heads: None, max_position_embeddings: 2048,
rms_norm_eps: 1e-6,
rope_theta: 10000.0,
tie_word_embeddings: false,
attention_bias: false,
}
}
pub fn llama_2_7b() -> Self {
let mut config = Self::llama_7b();
config.num_key_value_heads = Some(32); config.max_position_embeddings = 4096;
config
}
pub fn llama_3_8b() -> Self {
Self {
vocab_size: 128256,
hidden_size: 4096,
intermediate_size: 14336,
num_hidden_layers: 32,
num_attention_heads: 32,
num_key_value_heads: Some(8), max_position_embeddings: 8192,
rms_norm_eps: 1e-5,
rope_theta: 500000.0,
tie_word_embeddings: false,
attention_bias: false,
}
}
pub fn get_num_key_value_heads(&self) -> usize {
self.num_key_value_heads.unwrap_or(self.num_attention_heads)
}
pub fn head_dim(&self) -> usize {
self.hidden_size / self.num_attention_heads
}
pub fn q_per_kv(&self) -> usize {
self.num_attention_heads / self.get_num_key_value_heads()
}
}
impl ModelConfigTrait for LlamaConfig {
fn vocab_size(&self) -> usize {
self.vocab_size
}
fn hidden_size(&self) -> usize {
self.hidden_size
}
fn intermediate_size(&self) -> usize {
self.intermediate_size
}
fn num_hidden_layers(&self) -> usize {
self.num_hidden_layers
}
fn num_attention_heads(&self) -> usize {
self.num_attention_heads
}
fn num_key_value_heads(&self) -> Option<usize> {
self.num_key_value_heads
}
fn max_position_embeddings(&self) -> usize {
self.max_position_embeddings
}
fn rms_norm_eps(&self) -> f64 {
self.rms_norm_eps
}
fn rope_theta(&self) -> f64 {
self.rope_theta
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MistralConfig {
pub vocab_size: usize,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_key_value_heads: usize,
pub max_position_embeddings: usize,
pub rms_norm_eps: f64,
pub rope_theta: f64,
pub sliding_window: Option<usize>,
pub tie_word_embeddings: bool,
pub attention_bias: bool,
}
impl MistralConfig {
pub fn mistral_7b() -> Self {
Self {
vocab_size: 32000,
hidden_size: 4096,
intermediate_size: 14336,
num_hidden_layers: 32,
num_attention_heads: 32,
num_key_value_heads: 8, max_position_embeddings: 2048,
rms_norm_eps: 1e-5,
rope_theta: 10000.0,
sliding_window: Some(4096),
tie_word_embeddings: false,
attention_bias: false,
}
}
pub fn head_dim(&self) -> usize {
self.hidden_size / self.num_attention_heads
}
pub fn q_per_kv(&self) -> usize {
self.num_attention_heads / self.num_key_value_heads
}
}
impl ModelConfigTrait for MistralConfig {
fn vocab_size(&self) -> usize {
self.vocab_size
}
fn hidden_size(&self) -> usize {
self.hidden_size
}
fn intermediate_size(&self) -> usize {
self.intermediate_size
}
fn num_hidden_layers(&self) -> usize {
self.num_hidden_layers
}
fn num_attention_heads(&self) -> usize {
self.num_attention_heads
}
fn num_key_value_heads(&self) -> Option<usize> {
Some(self.num_key_value_heads)
}
fn max_position_embeddings(&self) -> usize {
self.max_position_embeddings
}
fn rms_norm_eps(&self) -> f64 {
self.rms_norm_eps
}
fn rope_theta(&self) -> f64 {
self.rope_theta
}
}
#[derive(Debug, Clone)]
pub enum ModelConfig {
Llama(LlamaConfig),
Mistral(MistralConfig),
}
impl ModelConfig {
pub fn from_file<P: AsRef<std::path::Path>>(path: P) -> Result<Self, Box<dyn std::error::Error>> {
let file = std::fs::File::open(path)?;
let reader = std::io::BufReader::new(file);
let value: serde_json::Value = serde_json::from_reader(reader)?;
if value.get("sliding_window").is_some() {
let config: MistralConfig = serde_json::from_value(value)?;
Ok(ModelConfig::Mistral(config))
} else {
let config: LlamaConfig = serde_json::from_value(value)?;
Ok(ModelConfig::Llama(config))
}
}
pub fn as_llama(&self) -> Option<&LlamaConfig> {
match self {
ModelConfig::Llama(config) => Some(config),
_ => None,
}
}
pub fn as_mistral(&self) -> Option<&MistralConfig> {
match self {
ModelConfig::Mistral(config) => Some(config),
_ => None,
}
}
}
impl ModelConfigTrait for ModelConfig {
fn vocab_size(&self) -> usize {
match self {
ModelConfig::Llama(c) => c.vocab_size(),
ModelConfig::Mistral(c) => c.vocab_size(),
}
}
fn hidden_size(&self) -> usize {
match self {
ModelConfig::Llama(c) => c.hidden_size(),
ModelConfig::Mistral(c) => c.hidden_size(),
}
}
fn intermediate_size(&self) -> usize {
match self {
ModelConfig::Llama(c) => c.intermediate_size(),
ModelConfig::Mistral(c) => c.intermediate_size(),
}
}
fn num_hidden_layers(&self) -> usize {
match self {
ModelConfig::Llama(c) => c.num_hidden_layers(),
ModelConfig::Mistral(c) => c.num_hidden_layers(),
}
}
fn num_attention_heads(&self) -> usize {
match self {
ModelConfig::Llama(c) => c.num_attention_heads(),
ModelConfig::Mistral(c) => c.num_attention_heads(),
}
}
fn num_key_value_heads(&self) -> Option<usize> {
match self {
ModelConfig::Llama(c) => c.num_key_value_heads(),
ModelConfig::Mistral(c) => c.num_key_value_heads(),
}
}
fn max_position_embeddings(&self) -> usize {
match self {
ModelConfig::Llama(c) => c.max_position_embeddings(),
ModelConfig::Mistral(c) => c.max_position_embeddings(),
}
}
fn rms_norm_eps(&self) -> f64 {
match self {
ModelConfig::Llama(c) => c.rms_norm_eps(),
ModelConfig::Mistral(c) => c.rms_norm_eps(),
}
}
fn rope_theta(&self) -> f64 {
match self {
ModelConfig::Llama(c) => c.rope_theta(),
ModelConfig::Mistral(c) => c.rope_theta(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_llama_7b_config() {
let config = LlamaConfig::llama_7b();
assert_eq!(config.vocab_size, 32000);
assert_eq!(config.hidden_size, 4096);
assert_eq!(config.intermediate_size, 11008);
assert_eq!(config.num_hidden_layers, 32);
assert_eq!(config.num_attention_heads, 32);
assert_eq!(config.head_dim(), 128);
}
#[test]
fn test_llama_2_7b_config() {
let config = LlamaConfig::llama_2_7b();
assert_eq!(config.num_key_value_heads, Some(32));
assert_eq!(config.max_position_embeddings, 4096);
}
#[test]
fn test_llama_3_8b_config() {
let config = LlamaConfig::llama_3_8b();
assert_eq!(config.vocab_size, 128256);
assert_eq!(config.hidden_size, 4096);
assert_eq!(config.intermediate_size, 14336);
assert_eq!(config.num_attention_heads, 32);
assert_eq!(config.num_key_value_heads, Some(8));
assert_eq!(config.q_per_kv(), 4);
assert_eq!(config.max_position_embeddings, 8192);
}
#[test]
fn test_mistral_7b_config() {
let config = MistralConfig::mistral_7b();
assert_eq!(config.vocab_size, 32000);
assert_eq!(config.hidden_size, 4096);
assert_eq!(config.num_key_value_heads, 8);
assert_eq!(config.sliding_window, Some(4096));
assert_eq!(config.q_per_kv(), 4);
}
}