use serde::{Deserialize, Serialize};
use trustformers_core::errors::invalid_config;
use trustformers_core::traits::Config;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HyenaConfig {
pub vocab_size: usize,
pub hidden_size: usize,
pub num_hidden_layers: usize,
pub intermediate_size: usize,
pub hidden_act: String,
pub hidden_dropout_prob: f32,
pub max_position_embeddings: usize,
pub initializer_range: f32,
pub layer_norm_eps: f32,
pub pad_token_id: u32,
pub order: usize, pub filter_order: usize, pub local_order: usize, pub outer_mixing: bool, pub conv_kernel_size: usize, pub use_positional_embeddings: bool, pub short_filter_order: usize, pub modulate: bool, pub w: f32, pub wd: f32, pub bias: bool, pub num_inner_mlps: usize, pub normalized: bool, pub use_flashfft: bool, }
impl Default for HyenaConfig {
fn default() -> Self {
Self {
vocab_size: 50257,
hidden_size: 768,
num_hidden_layers: 12,
intermediate_size: 3072,
hidden_act: "gelu".to_string(),
hidden_dropout_prob: 0.1,
max_position_embeddings: 32768, initializer_range: 0.02,
layer_norm_eps: 1e-5,
pad_token_id: 0,
order: 2,
filter_order: 64,
local_order: 3,
outer_mixing: true,
conv_kernel_size: 3,
use_positional_embeddings: false,
short_filter_order: 3,
modulate: true,
w: 1.0,
wd: 0.1,
bias: true,
num_inner_mlps: 2,
normalized: false,
use_flashfft: true,
}
}
}
impl Config for HyenaConfig {
fn validate(&self) -> trustformers_core::errors::Result<()> {
if self.order < 2 {
return Err(invalid_config(
"config_field",
"Hyena order must be at least 2".to_string(),
));
}
if self.filter_order == 0 {
return Err(invalid_config(
"config_field",
"filter_order must be greater than 0".to_string(),
));
}
if self.conv_kernel_size.is_multiple_of(2) {
return Err(invalid_config(
"config_field",
"conv_kernel_size should be odd for symmetric padding".to_string(),
));
}
Ok(())
}
fn architecture(&self) -> &'static str {
"Hyena"
}
}
impl HyenaConfig {
pub fn hyena_small() -> Self {
Self {
hidden_size: 768,
num_hidden_layers: 12,
intermediate_size: 3072,
max_position_embeddings: 32768,
..Self::default()
}
}
pub fn hyena_medium() -> Self {
Self {
hidden_size: 1024,
num_hidden_layers: 24,
intermediate_size: 4096,
max_position_embeddings: 65536,
filter_order: 128,
..Self::default()
}
}
pub fn hyena_large() -> Self {
Self {
hidden_size: 1280,
num_hidden_layers: 36,
intermediate_size: 5120,
max_position_embeddings: 131072,
filter_order: 256,
..Self::default()
}
}
pub fn hyena_dna() -> Self {
Self {
vocab_size: 12, hidden_size: 256,
num_hidden_layers: 8,
intermediate_size: 1024,
max_position_embeddings: 1048576, filter_order: 64,
order: 2,
use_positional_embeddings: false,
..Self::default()
}
}
pub fn hyena_long() -> Self {
Self {
max_position_embeddings: 262144, filter_order: 512,
use_flashfft: true,
..Self::default()
}
}
pub fn receptive_field(&self) -> usize {
self.filter_order * self.num_hidden_layers
}
pub fn memory_advantage(&self) -> f32 {
let seq_len = self.max_position_embeddings as f32;
let attention_memory = seq_len * seq_len;
let hyena_memory = seq_len * self.filter_order as f32;
attention_memory / hyena_memory
}
pub fn is_long_context_optimized(&self) -> bool {
self.max_position_embeddings >= 32768 && self.use_flashfft
}
}