#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct TRMConfig {
pub hidden_size: usize,
pub h_cycles: usize,
pub l_cycles: usize,
pub l_layers: usize,
pub num_heads: usize,
pub expansion: f32,
pub pos_encodings: String,
pub mlp_t: bool,
pub halt_max_steps: usize,
pub dropout: f32,
pub vocab_size: usize,
pub num_outputs: usize,
}
impl Default for TRMConfig {
fn default() -> Self {
Self {
hidden_size: 256,
h_cycles: 3,
l_cycles: 6,
l_layers: 2,
num_heads: 8,
expansion: 4.0,
pos_encodings: "rope".to_string(),
mlp_t: false,
halt_max_steps: 10,
dropout: 0.0,
vocab_size: 50257, num_outputs: 50257,
}
}
}
impl TRMConfig {
pub fn validate(&self) -> crate::Result<()> {
if self.hidden_size == 0 {
return Err(crate::TRMError::Config(
"hidden_size must be > 0".to_string(),
));
}
if self.hidden_size % self.num_heads != 0 {
return Err(crate::TRMError::Config(
"hidden_size must be divisible by num_heads".to_string(),
));
}
if self.h_cycles == 0 || self.l_cycles == 0 {
return Err(crate::TRMError::Config(
"h_cycles and l_cycles must be > 0".to_string(),
));
}
if !["rope", "learned", "none"].contains(&self.pos_encodings.as_str()) {
return Err(crate::TRMError::Config(format!(
"Invalid pos_encodings: {}. Must be 'rope', 'learned', or 'none'",
self.pos_encodings
)));
}
Ok(())
}
pub fn head_dim(&self) -> usize {
self.hidden_size / self.num_heads
}
pub fn ffn_hidden_size(&self) -> usize {
(self.hidden_size as f32 * self.expansion) as usize
}
}