use crate::tensor::Tensor;
use crate::config::EncoderConfig;
use super::scalenorm::ScaleNorm;
use super::rotary::RotaryEmbedding;
use super::attention::Attention;
use super::feedforward::FeedForward;
use super::residual::Residual;
#[derive(Debug, Clone)]
pub enum LayerBlock {
Attn(Attention),
FF(FeedForward),
}
#[derive(Debug, Clone)]
pub struct TransformerLayer {
pub pre_norm: ScaleNorm,
pub block: LayerBlock,
pub residual: Residual,
}
#[derive(Debug, Clone)]
pub struct XTransformerEncoder {
pub layers: Vec<TransformerLayer>,
pub final_norm: ScaleNorm,
pub rotary: Option<RotaryEmbedding>,
pub dim: usize,
pub config: EncoderConfig,
}
impl XTransformerEncoder {
pub fn new(dim: usize, config: &EncoderConfig) -> Self {
let mut layers = Vec::new();
let depth = config.depth;
for _ in 0..depth {
layers.push(TransformerLayer {
pre_norm: ScaleNorm::new(dim),
block: LayerBlock::Attn(Attention::new(dim, config.heads)),
residual: Residual::new(dim, config.scale_residual),
});
layers.push(TransformerLayer {
pre_norm: ScaleNorm::new(dim),
block: LayerBlock::FF(FeedForward::new(dim, config.ff_mult)),
residual: Residual::new(dim, config.scale_residual),
});
}
let rotary = if config.rotary_pos_emb {
let rot_dim = config.rotary_emb_dim(dim);
Some(RotaryEmbedding::new(rot_dim))
} else {
None
};
Self {
layers,
final_norm: ScaleNorm::new(dim),
rotary,
dim,
config: config.clone(),
}
}
pub fn forward(&self, x: &Tensor) -> Tensor {
let n = x.shape[1];
let rotary_freqs = self.rotary.as_ref().map(|r| r.forward(n));
let mut x = x.clone();
for layer in &self.layers {
let inner_residual = x.clone();
x = layer.pre_norm.forward(&x);
match &layer.block {
LayerBlock::Attn(attn) => {
x = attn.forward(&x, rotary_freqs.as_ref());
}
LayerBlock::FF(ff) => {
x = ff.forward(&x);
}
}
x = layer.residual.forward(&x, &inner_residual);
}
self.final_norm.forward(&x)
}
}