use serde::{Deserialize, Serialize};
use trustformers_core::traits::Config;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FalconConfig {
pub vocab_size: usize,
pub hidden_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_kv_heads: Option<usize>, pub hidden_act: String,
pub max_position_embeddings: usize,
pub initializer_range: f32,
pub layer_norm_epsilon: f32,
pub use_cache: bool,
pub pad_token_id: Option<u32>,
pub bos_token_id: u32,
pub eos_token_id: u32,
pub apply_residual_connection_post_layernorm: bool,
pub hidden_dropout: f32,
pub attention_dropout: f32,
pub model_type: String,
pub parallel_attn: bool, pub bias: bool,
pub multi_query: bool, pub alibi: bool, pub new_decoder_architecture: bool, pub use_flash_attention: Option<bool>, }
impl Default for FalconConfig {
fn default() -> Self {
Self {
vocab_size: 65024,
hidden_size: 4544,
num_hidden_layers: 32,
num_attention_heads: 71,
num_kv_heads: Some(1), hidden_act: "gelu".to_string(),
max_position_embeddings: 2048,
initializer_range: 0.02,
layer_norm_epsilon: 1e-5,
use_cache: true,
pad_token_id: Some(0),
bos_token_id: 1,
eos_token_id: 2,
apply_residual_connection_post_layernorm: false,
hidden_dropout: 0.0,
attention_dropout: 0.0,
model_type: "falcon".to_string(),
parallel_attn: true,
bias: false,
multi_query: true,
alibi: false,
new_decoder_architecture: false,
use_flash_attention: None,
}
}
}
impl Config for FalconConfig {
fn validate(&self) -> trustformers_core::errors::Result<()> {
if !self.hidden_size.is_multiple_of(self.num_attention_heads) {
return Err(trustformers_core::errors::TrustformersError::config_error(
"hidden_size must be divisible by num_attention_heads",
"FalconConfig::validate",
));
}
if let Some(num_kv_heads) = self.num_kv_heads {
if !self.num_attention_heads.is_multiple_of(num_kv_heads) {
return Err(trustformers_core::errors::TrustformersError::config_error(
"num_attention_heads must be divisible by num_kv_heads",
"FalconConfig::validate",
));
}
}
if self.vocab_size == 0 {
return Err(trustformers_core::errors::TrustformersError::config_error(
"vocab_size must be greater than 0",
"FalconConfig::validate",
));
}
Ok(())
}
fn architecture(&self) -> &'static str {
"Falcon"
}
}
impl FalconConfig {
pub fn falcon_7b() -> Self {
Self {
vocab_size: 65024,
hidden_size: 4544,
num_hidden_layers: 32,
num_attention_heads: 71,
num_kv_heads: Some(1), max_position_embeddings: 2048,
model_type: "falcon-7b".to_string(),
new_decoder_architecture: false,
alibi: true, ..Self::default()
}
}
pub fn falcon_7b_instruct() -> Self {
Self {
model_type: "falcon-7b-instruct".to_string(),
..Self::falcon_7b()
}
}
pub fn falcon_40b() -> Self {
Self {
vocab_size: 65024,
hidden_size: 8192,
num_hidden_layers: 60,
num_attention_heads: 128,
num_kv_heads: Some(8), max_position_embeddings: 2048,
model_type: "falcon-40b".to_string(),
new_decoder_architecture: false,
alibi: true,
..Self::default()
}
}
pub fn falcon_40b_instruct() -> Self {
Self {
model_type: "falcon-40b-instruct".to_string(),
..Self::falcon_40b()
}
}
pub fn falcon_180b() -> Self {
Self {
vocab_size: 65024,
hidden_size: 14848,
num_hidden_layers: 80,
num_attention_heads: 232,
num_kv_heads: Some(8), max_position_embeddings: 2048,
model_type: "falcon-180b".to_string(),
new_decoder_architecture: true, alibi: false, ..Self::default()
}
}
pub fn falcon_180b_chat() -> Self {
Self {
model_type: "falcon-180b-chat".to_string(),
..Self::falcon_180b()
}
}
pub fn head_dim(&self) -> usize {
self.hidden_size / self.num_attention_heads
}
pub fn num_kv_heads(&self) -> usize {
self.num_kv_heads.unwrap_or(self.num_attention_heads)
}
pub fn num_query_groups(&self) -> usize {
self.num_attention_heads / self.num_kv_heads()
}
pub fn from_pretrained_name(name: &str) -> Option<Self> {
match name {
"tiiuae/falcon-7b" | "falcon-7b" => Some(Self::falcon_7b()),
"tiiuae/falcon-7b-instruct" | "falcon-7b-instruct" => Some(Self::falcon_7b_instruct()),
"tiiuae/falcon-40b" | "falcon-40b" => Some(Self::falcon_40b()),
"tiiuae/falcon-40b-instruct" | "falcon-40b-instruct" => {
Some(Self::falcon_40b_instruct())
},
"tiiuae/falcon-180b" | "falcon-180b" => Some(Self::falcon_180b()),
"tiiuae/falcon-180b-chat" | "falcon-180b-chat" => Some(Self::falcon_180b_chat()),
_ => None,
}
}
pub fn is_instruct_model(&self) -> bool {
self.model_type.contains("instruct") || self.model_type.contains("chat")
}
pub fn uses_alibi(&self) -> bool {
self.alibi
}
pub fn uses_new_architecture(&self) -> bool {
self.new_decoder_architecture
}
pub fn num_parameters(&self) -> usize {
let embedding_params = self.vocab_size * self.hidden_size;
let transformer_params = self.num_hidden_layers
* (
self.hidden_size * (self.hidden_size + 2 * self.num_kv_heads() * self.head_dim()) +
self.hidden_size * self.hidden_size * 2 * 8 / 3
);
let head_params = self.hidden_size * self.vocab_size;
embedding_params + transformer_params + head_params
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_falcon_config_validation() {
let config = FalconConfig::falcon_7b();
assert!(config.validate().is_ok());
let mut invalid_config = config.clone();
invalid_config.hidden_size = 4543; assert!(invalid_config.validate().is_err());
}
#[test]
fn test_falcon_config_presets() {
let falcon_7b = FalconConfig::falcon_7b();
assert_eq!(falcon_7b.hidden_size, 4544);
assert_eq!(falcon_7b.num_hidden_layers, 32);
assert_eq!(falcon_7b.num_attention_heads, 71);
assert_eq!(falcon_7b.num_kv_heads(), 1);
assert!(falcon_7b.uses_alibi());
assert!(!falcon_7b.uses_new_architecture());
let falcon_40b = FalconConfig::falcon_40b();
assert_eq!(falcon_40b.hidden_size, 8192);
assert_eq!(falcon_40b.num_hidden_layers, 60);
assert_eq!(falcon_40b.num_attention_heads, 128);
assert_eq!(falcon_40b.num_kv_heads(), 8);
let falcon_180b = FalconConfig::falcon_180b();
assert_eq!(falcon_180b.hidden_size, 14848);
assert_eq!(falcon_180b.num_hidden_layers, 80);
assert_eq!(falcon_180b.num_attention_heads, 232);
assert_eq!(falcon_180b.num_kv_heads(), 8);
assert!(!falcon_180b.uses_alibi());
assert!(falcon_180b.uses_new_architecture());
}
#[test]
fn test_falcon_config_from_pretrained() {
let config = FalconConfig::from_pretrained_name("tiiuae/falcon-7b");
assert!(config.is_some());
let config = config.expect("operation failed");
assert_eq!(config.model_type, "falcon-7b");
let config = FalconConfig::from_pretrained_name("tiiuae/falcon-180b-chat");
assert!(config.is_some());
let config = config.expect("operation failed");
assert!(config.is_instruct_model());
let config = FalconConfig::from_pretrained_name("unknown-model");
assert!(config.is_none());
}
#[test]
fn test_falcon_config_helpers() {
let config = FalconConfig::falcon_7b();
assert_eq!(config.head_dim(), 64); assert_eq!(config.num_kv_heads(), 1);
assert_eq!(config.num_query_groups(), 71);
let config_40b = FalconConfig::falcon_40b();
assert_eq!(config_40b.head_dim(), 64); assert_eq!(config_40b.num_kv_heads(), 8);
assert_eq!(config_40b.num_query_groups(), 16); }
#[test]
fn test_config_trait() {
let config = FalconConfig::falcon_7b();
assert_eq!(config.architecture(), "Falcon");
}
#[test]
fn test_parameter_estimation() {
let config = FalconConfig::falcon_7b();
let params = config.num_parameters();
assert!(
params > 2_000_000_000 && params < 15_000_000_000,
"Expected ~7B params, got {}",
params
);
let config_40b = FalconConfig::falcon_40b();
let params_40b = config_40b.num_parameters();
assert!(
params_40b > 15_000_000_000 && params_40b < 100_000_000_000,
"Expected ~40B params, got {}",
params_40b
);
}
#[test]
fn test_default_config_hidden_size() {
let config = FalconConfig::default();
assert_eq!(config.hidden_size, 4544);
}
#[test]
fn test_default_config_num_attention_heads() {
let config = FalconConfig::default();
assert_eq!(config.num_attention_heads, 71);
}
#[test]
fn test_default_config_num_hidden_layers() {
let config = FalconConfig::default();
assert_eq!(config.num_hidden_layers, 32);
}
#[test]
fn test_default_config_parallel_attn_true() {
let config = FalconConfig::default();
assert!(
config.parallel_attn,
"Default config must have parallel_attn=true"
);
}
#[test]
fn test_default_config_multi_query_true() {
let config = FalconConfig::default();
assert!(
config.multi_query,
"Default config must have multi_query=true"
);
}
#[test]
fn test_falcon_7b_uses_alibi_not_rotary() {
let config = FalconConfig::falcon_7b();
assert!(config.alibi, "Falcon-7B must use ALiBi positional encoding");
}
#[test]
fn test_falcon_180b_no_alibi() {
let config = FalconConfig::falcon_180b();
assert!(!config.alibi, "Falcon-180B must not use ALiBi");
}
#[test]
fn test_falcon_40b_old_decoder_architecture() {
let config = FalconConfig::falcon_40b();
assert!(!config.new_decoder_architecture);
}
#[test]
fn test_falcon_180b_new_decoder_architecture() {
let config = FalconConfig::falcon_180b();
assert!(config.new_decoder_architecture);
}
#[test]
fn test_falcon_7b_instruct_is_instruct() {
let config = FalconConfig::falcon_7b_instruct();
assert!(config.is_instruct_model());
}
#[test]
fn test_falcon_40b_instruct_is_instruct() {
let config = FalconConfig::falcon_40b_instruct();
assert!(config.is_instruct_model());
}
#[test]
fn test_falcon_7b_base_not_instruct() {
let config = FalconConfig::falcon_7b();
assert!(!config.is_instruct_model());
}
#[test]
fn test_validation_fails_zero_vocab_size() {
let config = FalconConfig {
vocab_size: 0,
..FalconConfig::default()
};
assert!(config.validate().is_err());
}
#[test]
fn test_validation_fails_kv_heads_not_divisor() {
let config = FalconConfig {
num_attention_heads: 8,
num_kv_heads: Some(3), hidden_size: 8 * 64, ..FalconConfig::default()
};
assert!(config.validate().is_err());
}
#[test]
fn test_all_preset_configs_validate() {
for config in [
FalconConfig::falcon_7b(),
FalconConfig::falcon_40b(),
FalconConfig::falcon_180b(),
] {
assert!(
config.validate().is_ok(),
"Config {:?} failed validation",
config.model_type
);
}
}
#[test]
fn test_from_pretrained_short_name_7b() {
let config = FalconConfig::from_pretrained_name("falcon-7b");
assert!(config.is_some());
assert_eq!(config.expect("falcon-7b config").num_kv_heads(), 1);
}
#[test]
fn test_from_pretrained_short_name_40b() {
let config = FalconConfig::from_pretrained_name("falcon-40b");
assert!(config.is_some());
}
#[test]
fn test_from_pretrained_unknown_returns_none() {
let config = FalconConfig::from_pretrained_name("completely-unknown");
assert!(config.is_none());
}
#[test]
fn test_num_query_groups_7b() {
let config = FalconConfig::falcon_7b();
assert_eq!(config.num_query_groups(), 71);
}
#[test]
fn test_num_query_groups_180b() {
let config = FalconConfig::falcon_180b();
assert_eq!(config.num_query_groups(), 29);
}
#[test]
fn test_config_serialization_roundtrip() {
let config = FalconConfig::falcon_7b();
let json = serde_json::to_string(&config).expect("serialize FalconConfig");
let restored: FalconConfig = serde_json::from_str(&json).expect("deserialize FalconConfig");
assert_eq!(config.hidden_size, restored.hidden_size);
assert_eq!(config.num_attention_heads, restored.num_attention_heads);
assert_eq!(config.alibi, restored.alibi);
}
#[test]
fn test_config_clone_equality() {
let config = FalconConfig::falcon_40b();
let cloned = config.clone();
assert_eq!(config.hidden_size, cloned.hidden_size);
assert_eq!(config.num_kv_heads, cloned.num_kv_heads);
}
}