use candle_core::{Result, Tensor, Device, DType};
use candle_nn::{VarBuilder, Module};
use crate::config::TRMConfig;
use crate::layers::{Attention, SwiGLU, CastedEmbedding, RMSNorm, RotaryEmbedding};
use crate::layers::normalization::rms_norm;
use crate::layers::activations::CastedLinear;
pub mod loader;
#[derive(Debug, Clone)]
pub struct InnerCarry {
pub z_h: Tensor,
pub z_l: Tensor,
}
impl InnerCarry {
pub fn new(z_h: Tensor, z_l: Tensor) -> Self {
Self { z_h, z_l }
}
pub fn empty(batch_size: usize, seq_len: usize, hidden_size: usize, dtype: DType, device: &Device) -> Result<Self> {
let z_h = Tensor::zeros((batch_size, seq_len, hidden_size), dtype, device)?;
let z_l = Tensor::zeros((batch_size, seq_len, hidden_size), dtype, device)?;
Ok(Self { z_h, z_l })
}
}
pub struct TransformerBlock {
config: TRMConfig,
self_attn: Option<Attention>,
mlp: SwiGLU,
norm_eps: f64,
}
impl TransformerBlock {
pub fn new(config: TRMConfig, vb: VarBuilder) -> Result<Self> {
let self_attn = if !config.mlp_t {
Some(Attention::new(
config.hidden_size,
config.head_dim(),
config.num_heads,
config.num_heads, false, vb.pp("self_attn"),
)?)
} else {
None
};
let mlp = SwiGLU::new(
config.hidden_size,
config.expansion,
vb.pp("mlp"),
)?;
Ok(Self {
config: config.clone(),
self_attn,
mlp,
norm_eps: 1e-5,
})
}
pub fn forward(
&self,
hidden_states: &Tensor,
cos_sin: Option<(&Tensor, &Tensor)>,
) -> Result<Tensor> {
let mut hidden_states = hidden_states.clone();
if let Some(ref attn) = self.self_attn {
let attn_out = attn.forward(&hidden_states, cos_sin)?;
hidden_states = rms_norm(&(hidden_states + attn_out)?, self.norm_eps)?;
}
let mlp_out = self.mlp.forward(&hidden_states)?;
hidden_states = rms_norm(&(hidden_states + mlp_out)?, self.norm_eps)?;
Ok(hidden_states)
}
}
pub struct ReasoningModule {
layers: Vec<TransformerBlock>,
}
impl ReasoningModule {
pub fn new(num_layers: usize, config: TRMConfig, vb: VarBuilder) -> Result<Self> {
let mut layers = Vec::new();
for i in 0..num_layers {
layers.push(TransformerBlock::new(
config.clone(),
vb.pp(&format!("layer_{}", i)),
)?);
}
Ok(Self { layers })
}
pub fn forward(
&self,
hidden_states: &Tensor,
input_injection: &Tensor,
cos_sin: Option<(&Tensor, &Tensor)>,
) -> Result<Tensor> {
let mut hidden_states = (hidden_states + input_injection)?;
for layer in &self.layers {
hidden_states = layer.forward(&hidden_states, cos_sin)?;
}
Ok(hidden_states)
}
}
pub struct TinyRecursiveModel {
config: TRMConfig,
embed_tokens: CastedEmbedding,
lm_head: CastedLinear,
embed_scale: f64,
rotary_emb: Option<RotaryEmbedding>,
l_level: ReasoningModule,
h_init: Tensor,
l_init: Tensor,
device: Device,
}
impl TinyRecursiveModel {
pub fn new(config: TRMConfig, vb: VarBuilder) -> crate::Result<Self> {
config.validate()?;
let device = vb.device().clone();
let dtype = vb.dtype();
let embed_scale = (config.hidden_size as f64).sqrt();
let embed_tokens = CastedEmbedding::new(
config.vocab_size,
config.hidden_size,
vb.pp("embed_tokens"),
dtype,
)?;
let lm_head = CastedLinear::new(
config.hidden_size,
config.num_outputs,
false,
vb.pp("lm_head"),
)?;
let rotary_emb = if config.pos_encodings == "rope" {
Some(RotaryEmbedding::new(
config.head_dim(),
2048, 10000.0,
&device,
)?)
} else {
None
};
let l_level = ReasoningModule::new(
config.l_layers,
config.clone(),
vb.pp("l_level"),
)?;
let h_init = vb.get(config.hidden_size, "h_init")?;
let l_init = vb.get(config.hidden_size, "l_init")?;
Ok(Self {
config,
embed_tokens,
lm_head,
embed_scale,
rotary_emb,
l_level,
h_init,
l_init,
device,
})
}
pub fn empty_carry(&self, batch_size: usize) -> Result<InnerCarry> {
InnerCarry::empty(
batch_size,
self.config.vocab_size, self.config.hidden_size,
DType::F32,
&self.device,
)
}
pub fn reset_carry(&self, reset_flag: &Tensor, carry: &InnerCarry) -> Result<InnerCarry> {
let reset_flag = reset_flag.unsqueeze(1)?.unsqueeze(1)?;
let batch_size = carry.z_h.dim(0)?;
let seq_len = carry.z_h.dim(1)?;
let h_init = self.h_init
.unsqueeze(0)?
.unsqueeze(0)?
.broadcast_as((batch_size, seq_len, self.config.hidden_size))?;
let l_init = self.l_init
.unsqueeze(0)?
.unsqueeze(0)?
.broadcast_as((batch_size, seq_len, self.config.hidden_size))?;
let z_h = reset_flag.where_cond(&h_init, &carry.z_h)?;
let z_l = reset_flag.where_cond(&l_init, &carry.z_l)?;
Ok(InnerCarry::new(z_h, z_l))
}
fn input_embeddings(&self, input: &Tensor) -> Result<Tensor> {
let embedding = self.embed_tokens.forward(input)?;
embedding.affine(self.embed_scale, 0.0)
}
pub fn forward(&self, carry: &InnerCarry, input: &Tensor) -> Result<(InnerCarry, Tensor)> {
let seq_len = input.dim(1)?;
let cos_sin = if let Some(ref rope) = self.rotary_emb {
let (cos, sin) = rope.forward_with_len(seq_len)?;
Some((cos, sin))
} else {
None
};
let input_embeddings = self.input_embeddings(input)?;
let mut z_h = carry.z_h.clone();
let mut z_l = carry.z_l.clone();
for _h_step in 0..self.config.h_cycles {
for _l_step in 0..self.config.l_cycles {
let injection = (&z_h + &input_embeddings)?;
z_l = self.l_level.forward(
&z_l,
&injection,
cos_sin.as_ref().map(|(c, s)| (c.as_ref(), s.as_ref())),
)?;
}
z_h = self.l_level.forward(
&z_h,
&z_l,
cos_sin.as_ref().map(|(c, s)| (c.as_ref(), s.as_ref())),
)?;
}
let logits = self.lm_head.forward(&z_h)?;
let new_carry = InnerCarry::new(z_h.clone(), z_l.clone());
Ok((new_carry, logits))
}
}
#[cfg(test)]
mod tests {
use super::*;
use candle_nn::VarMap;
#[test]
fn test_inner_carry_creation() -> Result<()> {
let device = Device::Cpu;
let carry = InnerCarry::empty(2, 16, 256, DType::F32, &device)?;
assert_eq!(carry.z_h.dims(), &[2, 16, 256]);
assert_eq!(carry.z_l.dims(), &[2, 16, 256]);
Ok(())
}
#[test]
fn test_transformer_block() -> Result<()> {
let device = Device::Cpu;
let varmap = VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
let mut config = TRMConfig::default();
config.hidden_size = 256;
config.num_heads = 8;
let block = TransformerBlock::new(config, vb)?;
let x = Tensor::randn(0f32, 1.0, (2, 16, 256), &device)?;
let out = block.forward(&x, None)?;
assert_eq!(out.dims(), &[2, 16, 256]);
Ok(())
}
#[test]
fn test_reasoning_module() -> Result<()> {
let device = Device::Cpu;
let varmap = VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
let mut config = TRMConfig::default();
config.hidden_size = 256;
config.num_heads = 8;
config.l_layers = 2;
let module = ReasoningModule::new(2, config, vb)?;
let hidden = Tensor::randn(0f32, 1.0, (2, 16, 256), &device)?;
let injection = Tensor::randn(0f32, 1.0, (2, 16, 256), &device)?;
let out = module.forward(&hidden, &injection, None)?;
assert_eq!(out.dims(), &[2, 16, 256]);
Ok(())
}
}