use serde::{Deserialize, Serialize};
use trustformers_core::traits::Config;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StableLMConfig {
pub vocab_size: usize,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_key_value_heads: Option<usize>, pub hidden_act: String,
pub max_position_embeddings: usize,
pub initializer_range: f32,
pub rms_norm_eps: f32,
pub use_cache: bool,
pub pad_token_id: Option<u32>,
pub bos_token_id: u32,
pub eos_token_id: u32,
pub rope_theta: f32, pub rope_scaling: Option<RopeScaling>,
pub attention_bias: bool,
pub mlp_bias: bool,
pub partial_rotary_factor: f32, pub model_type: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RopeScaling {
pub scaling_type: String, pub scaling_factor: f32,
}
impl Default for StableLMConfig {
fn default() -> Self {
Self {
vocab_size: 50432,
hidden_size: 2560,
intermediate_size: 6912,
num_hidden_layers: 32,
num_attention_heads: 32,
num_key_value_heads: None, hidden_act: "silu".to_string(),
max_position_embeddings: 4096,
initializer_range: 0.02,
rms_norm_eps: 1e-5,
use_cache: true,
pad_token_id: Some(0),
bos_token_id: 100257,
eos_token_id: 100257,
rope_theta: 10000.0,
rope_scaling: None,
attention_bias: false,
mlp_bias: false,
partial_rotary_factor: 0.25, model_type: "stablelm".to_string(),
}
}
}
impl StableLMConfig {
pub fn stablelm_3b() -> Self {
Self {
vocab_size: 50432,
hidden_size: 2560,
intermediate_size: 6912,
num_hidden_layers: 32,
num_attention_heads: 32,
max_position_embeddings: 4096,
partial_rotary_factor: 0.25,
..Default::default()
}
}
pub fn stablelm_7b() -> Self {
Self {
vocab_size: 50432,
hidden_size: 4096,
intermediate_size: 11008,
num_hidden_layers: 32,
num_attention_heads: 32,
max_position_embeddings: 4096,
partial_rotary_factor: 0.25,
..Default::default()
}
}
pub fn stablelm_zephyr_3b() -> Self {
let mut config = Self::stablelm_3b();
config.model_type = "stablelm-zephyr".to_string();
config
}
pub fn stablelm_code_3b() -> Self {
let mut config = Self::stablelm_3b();
config.model_type = "stablelm-code".to_string();
config.vocab_size = 49152; config
}
pub fn stablelm_2_1_6b() -> Self {
Self {
vocab_size: 100352,
hidden_size: 2048,
intermediate_size: 5504,
num_hidden_layers: 24,
num_attention_heads: 32,
num_key_value_heads: Some(4), max_position_embeddings: 4096,
rope_theta: 10000.0,
partial_rotary_factor: 0.25,
model_type: "stablelm-2".to_string(),
..Default::default()
}
}
pub fn stablelm_2_12b() -> Self {
Self {
vocab_size: 100352,
hidden_size: 5120,
intermediate_size: 13824,
num_hidden_layers: 40,
num_attention_heads: 40,
num_key_value_heads: Some(8), max_position_embeddings: 4096,
rope_theta: 10000.0,
partial_rotary_factor: 0.25,
model_type: "stablelm-2".to_string(),
..Default::default()
}
}
pub fn from_pretrained_name(model_name: &str) -> Option<Self> {
match model_name {
"stabilityai/stablelm-3b-4e1t" => Some(Self::stablelm_3b()),
"stabilityai/stablelm-base-alpha-3b" => Some(Self::stablelm_3b()),
"stabilityai/stablelm-base-alpha-7b" => Some(Self::stablelm_7b()),
"stabilityai/stablelm-zephyr-3b" => Some(Self::stablelm_zephyr_3b()),
"stabilityai/stable-code-3b" => Some(Self::stablelm_code_3b()),
"stabilityai/stablelm-2-1_6b" => Some(Self::stablelm_2_1_6b()),
"stabilityai/stablelm-2-12b" => Some(Self::stablelm_2_12b()),
_ => None,
}
}
pub fn validate(&self) -> Result<(), String> {
if self.vocab_size == 0 {
return Err("vocab_size must be > 0".to_string());
}
if self.hidden_size == 0 {
return Err("hidden_size must be > 0".to_string());
}
if self.num_attention_heads == 0 {
return Err("num_attention_heads must be > 0".to_string());
}
if !self.hidden_size.is_multiple_of(self.num_attention_heads) {
return Err("hidden_size must be divisible by num_attention_heads".to_string());
}
if let Some(num_kv_heads) = self.num_key_value_heads {
if num_kv_heads == 0 {
return Err("num_key_value_heads must be > 0 when specified".to_string());
}
if !self.num_attention_heads.is_multiple_of(num_kv_heads) {
return Err(
"num_attention_heads must be divisible by num_key_value_heads".to_string(),
);
}
}
if self.partial_rotary_factor < 0.0 || self.partial_rotary_factor > 1.0 {
return Err("partial_rotary_factor must be between 0.0 and 1.0".to_string());
}
if !["silu", "gelu", "relu", "swiglu"].contains(&self.hidden_act.as_str()) {
return Err("Unsupported activation function".to_string());
}
Ok(())
}
pub fn estimate_parameters(&self) -> usize {
let embedding_params = self.vocab_size * self.hidden_size;
let layer_params = self.num_hidden_layers
* (
4 * self.hidden_size * self.hidden_size +
2 * self.hidden_size * self.intermediate_size +
2 * self.hidden_size
);
let final_norm_params = self.hidden_size;
let lm_head_params = self.vocab_size * self.hidden_size;
embedding_params + layer_params + final_norm_params + lm_head_params
}
}
impl Config for StableLMConfig {
fn architecture(&self) -> &'static str {
"stablelm"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = StableLMConfig::default();
assert_eq!(config.model_type, "stablelm");
assert_eq!(config.hidden_act, "silu");
assert_eq!(config.partial_rotary_factor, 0.25);
assert!(config.validate().is_ok());
}
#[test]
fn test_predefined_configs() {
let config_3b = StableLMConfig::stablelm_3b();
assert_eq!(config_3b.hidden_size, 2560);
assert_eq!(config_3b.num_hidden_layers, 32);
assert!(config_3b.validate().is_ok());
let config_7b = StableLMConfig::stablelm_7b();
assert_eq!(config_7b.hidden_size, 4096);
assert_eq!(config_7b.num_hidden_layers, 32);
assert!(config_7b.validate().is_ok());
let config_2_12b = StableLMConfig::stablelm_2_12b();
assert_eq!(config_2_12b.hidden_size, 5120);
assert_eq!(config_2_12b.num_key_value_heads, Some(8));
assert!(config_2_12b.validate().is_ok());
}
#[test]
fn test_from_pretrained_name() {
let config = StableLMConfig::from_pretrained_name("stabilityai/stablelm-3b-4e1t");
assert!(config.is_some());
if let Some(cfg) = config {
assert_eq!(cfg.hidden_size, 2560);
}
let config = StableLMConfig::from_pretrained_name("nonexistent/model");
assert!(config.is_none());
}
#[test]
fn test_config_validation() {
let mut config = StableLMConfig::default();
assert!(config.validate().is_ok());
config.vocab_size = 0;
assert!(config.validate().is_err());
config = StableLMConfig::default();
config.hidden_size = 100;
config.num_attention_heads = 3; assert!(config.validate().is_err());
config = StableLMConfig::default();
config.partial_rotary_factor = 1.5; assert!(config.validate().is_err());
}
#[test]
fn test_parameter_estimation() {
let config = StableLMConfig::stablelm_3b();
let params = config.estimate_parameters();
assert!(
params > 1_000_000_000 && params < 10_000_000_000,
"Expected ~3B params, got {}",
params
);
}
#[test]
fn test_config_serialization() -> Result<(), Box<dyn std::error::Error>> {
let config = StableLMConfig::stablelm_3b();
let json = serde_json::to_string(&config)?;
let deserialized: StableLMConfig = serde_json::from_str(&json)?;
assert_eq!(config.hidden_size, deserialized.hidden_size);
assert_eq!(
config.partial_rotary_factor,
deserialized.partial_rotary_factor
);
assert_eq!(config.model_type, deserialized.model_type);
Ok(())
}
}