use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct XLSTMConfig {
pub vocab_size: usize,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_layers: usize,
pub num_heads: usize,
pub max_sequence_length: usize,
pub dropout: f32,
pub layer_norm_epsilon: f64,
pub block_config: XLSTMBlockConfig,
pub initial_forget_gate_bias: f32,
pub use_pre_ln: bool,
pub use_post_ln: bool,
pub exponential_gating: ExponentialGatingConfig,
}
impl Default for XLSTMConfig {
fn default() -> Self {
Self {
vocab_size: 32000,
hidden_size: 768,
intermediate_size: 3072,
num_layers: 12,
num_heads: 12,
max_sequence_length: 2048,
dropout: 0.1,
layer_norm_epsilon: 1e-5,
block_config: XLSTMBlockConfig::default(),
initial_forget_gate_bias: 3.0, use_pre_ln: true,
use_post_ln: false,
exponential_gating: ExponentialGatingConfig::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct XLSTMBlockConfig {
pub block_type: XLSTMBlockType,
pub slstm_blocks: usize,
pub mlstm_blocks: usize,
pub block_pattern: Vec<XLSTMBlockType>,
}
impl Default for XLSTMBlockConfig {
fn default() -> Self {
Self {
block_type: XLSTMBlockType::Mixed,
slstm_blocks: 4,
mlstm_blocks: 8,
block_pattern: vec![
XLSTMBlockType::SLstm,
XLSTMBlockType::SLstm,
XLSTMBlockType::MLstm,
XLSTMBlockType::MLstm,
XLSTMBlockType::SLstm,
XLSTMBlockType::MLstm,
],
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum XLSTMBlockType {
SLstm,
MLstm,
Mixed,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExponentialGatingConfig {
pub enabled: bool,
pub min_gate_value: f32,
pub max_gate_value: f32,
pub temperature: f32,
}
impl Default for ExponentialGatingConfig {
fn default() -> Self {
Self {
enabled: true,
min_gate_value: 1e-6,
max_gate_value: 10.0,
temperature: 1.0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SLstmConfig {
pub hidden_size: usize,
pub use_exponential_gating: bool,
pub use_memory_mixing: bool,
pub dropout: f32,
}
impl Default for SLstmConfig {
fn default() -> Self {
Self {
hidden_size: 768,
use_exponential_gating: true,
use_memory_mixing: true,
dropout: 0.1,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MLstmConfig {
pub hidden_size: usize,
pub num_heads: usize,
pub head_dim: usize,
pub use_causal_mask: bool,
pub use_exponential_gating: bool,
pub dropout: f32,
pub memory_dimension: usize,
}
impl Default for MLstmConfig {
fn default() -> Self {
Self {
hidden_size: 768,
num_heads: 12,
head_dim: 64, use_causal_mask: true,
use_exponential_gating: true,
dropout: 0.1,
memory_dimension: 64,
}
}
}
impl MLstmConfig {
pub fn new(hidden_size: usize, num_heads: usize) -> Self {
assert!(
hidden_size.is_multiple_of(num_heads),
"Hidden size must be divisible by number of heads"
);
Self {
hidden_size,
num_heads,
head_dim: hidden_size / num_heads,
..Default::default()
}
}
}
impl XLSTMConfig {
pub fn small() -> Self {
Self {
vocab_size: 32000,
hidden_size: 512,
intermediate_size: 2048,
num_layers: 8,
num_heads: 8,
max_sequence_length: 1024,
..Default::default()
}
}
pub fn base() -> Self {
Self::default()
}
pub fn large() -> Self {
Self {
vocab_size: 50000,
hidden_size: 1024,
intermediate_size: 4096,
num_layers: 24,
num_heads: 16,
max_sequence_length: 4096,
..Default::default()
}
}
pub fn xlstm_7b() -> Self {
Self {
vocab_size: 50000,
hidden_size: 4096,
intermediate_size: 16384,
num_layers: 32,
num_heads: 32,
max_sequence_length: 8192,
block_config: XLSTMBlockConfig {
block_type: XLSTMBlockType::Mixed,
slstm_blocks: 12,
mlstm_blocks: 20,
block_pattern: (0..32)
.map(
|i| {
if i % 3 == 0 {
XLSTMBlockType::SLstm
} else {
XLSTMBlockType::MLstm
}
},
)
.collect(),
},
..Default::default()
}
}
}