use burn::prelude::*;
use burn::module::{Param, ParamId};
use crate::config::EncoderConfig;
use super::scalenorm::ScaleNorm;
use super::attention::Attention;
use super::feedforward::FeedForward;
use super::residual::Residual;
#[derive(Module, Debug)]
pub struct XTransformerEncoder<B: Backend> {
pub attn_norms: Vec<ScaleNorm<B>>,
pub attns: Vec<Attention<B>>,
pub attn_residuals: Vec<Residual<B>>,
pub ff_norms: Vec<ScaleNorm<B>>,
pub ffs: Vec<FeedForward<B>>,
pub ff_residuals: Vec<Residual<B>>,
pub final_norm: ScaleNorm<B>,
pub rot_cos: Option<Param<Tensor<B, 2>>>,
pub rot_sin: Option<Param<Tensor<B, 2>>>,
pub rotary_dim: usize,
pub dim: usize,
}
impl<B: Backend> XTransformerEncoder<B> {
pub fn new(
dim: usize,
max_seq_len: usize,
config: &EncoderConfig,
device: &B::Device,
) -> Self {
let depth = config.depth;
let mut attn_norms = Vec::with_capacity(depth);
let mut attns = Vec::with_capacity(depth);
let mut attn_residuals = Vec::with_capacity(depth);
let mut ff_norms = Vec::with_capacity(depth);
let mut ffs = Vec::with_capacity(depth);
let mut ff_residuals = Vec::with_capacity(depth);
for _ in 0..depth {
attn_norms .push(ScaleNorm::new(dim, device));
attns .push(Attention::new(dim, config.heads, device));
attn_residuals.push(Residual::new(dim, config.scale_residual, device));
ff_norms .push(ScaleNorm::new(dim, device));
ffs .push(FeedForward::new(dim, config.ff_mult, device));
ff_residuals .push(Residual::new(dim, config.scale_residual, device));
}
let rotary_dim = if config.rotary_pos_emb {
config.rotary_emb_dim(dim)
} else {
0
};
let (rot_cos, rot_sin) = if rotary_dim > 0 {
let half = rotary_dim / 2;
let mut freq_data = vec![0.0f32; max_seq_len * half];
for pos in 0..max_seq_len {
for j in 0..half {
let inv_freq =
1.0 / 10000.0f64.powf(2.0 * j as f64 / rotary_dim as f64) as f32;
freq_data[pos * half + j] = pos as f32 * inv_freq;
}
}
let freqs = Tensor::<B, 2>::from_data(
TensorData::new(freq_data, [max_seq_len, half]),
device,
);
(
Some(Param::initialized(ParamId::new(), freqs.clone().cos())),
Some(Param::initialized(ParamId::new(), freqs.sin())),
)
} else {
(None, None)
};
Self {
attn_norms, attns, attn_residuals,
ff_norms, ffs, ff_residuals,
final_norm: ScaleNorm::new(dim, device),
rot_cos, rot_sin, rotary_dim, dim,
}
}
}
#[cfg(not(feature = "wgpu-kernels-metal"))]
impl<B: Backend> XTransformerEncoder<B> {
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
encoder_forward(self, x)
}
}
#[cfg(feature = "wgpu-kernels-metal")]
impl<B: Backend + crate::model_burn::FusedOps> XTransformerEncoder<B> {
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
encoder_forward(self, x)
}
}
#[cfg(not(feature = "wgpu-kernels-metal"))]
fn encoder_forward<B: Backend>(enc: &XTransformerEncoder<B>, x: Tensor<B, 3>) -> Tensor<B, 3> {
run_encoder(enc, x)
}
#[cfg(feature = "wgpu-kernels-metal")]
fn encoder_forward<B: Backend + crate::model_burn::FusedOps>(
enc: &XTransformerEncoder<B>,
x: Tensor<B, 3>,
) -> Tensor<B, 3> {
run_encoder(enc, x)
}
#[cfg(not(feature = "wgpu-kernels-metal"))]
fn run_encoder<B: Backend>(enc: &XTransformerEncoder<B>, x: Tensor<B, 3>) -> Tensor<B, 3> {
encoder_body(enc, x)
}
#[cfg(feature = "wgpu-kernels-metal")]
fn run_encoder<B: Backend + crate::model_burn::FusedOps>(
enc: &XTransformerEncoder<B>,
x: Tensor<B, 3>,
) -> Tensor<B, 3> {
encoder_body(enc, x)
}
macro_rules! define_encoder_body {
($($bound:tt)*) => {
fn encoder_body<B: Backend $($bound)*>(
enc: &XTransformerEncoder<B>,
x: Tensor<B, 3>,
) -> Tensor<B, 3> {
let n = x.dims()[1];
let depth = enc.attns.len();
let (rot_cos, rot_sin) = match (&enc.rot_cos, &enc.rot_sin) {
(Some(cos), Some(sin)) => {
let half = cos.val().dims()[1];
(
Some(cos.val().slice([0..n, 0..half])),
Some(sin.val().slice([0..n, 0..half])),
)
}
_ => (None, None),
};
let mut x = x;
for i in 0..depth {
let inner = x.clone();
x = enc.attn_norms[i].forward(x);
x = enc.attns[i].forward(x, rot_cos.as_ref(), rot_sin.as_ref());
x = enc.attn_residuals[i].forward(x, inner);
let inner = x.clone();
x = enc.ff_norms[i].forward(x);
x = enc.ffs[i].forward(x);
x = enc.ff_residuals[i].forward(x, inner);
}
enc.final_norm.forward(x)
}
};
}
#[cfg(not(feature = "wgpu-kernels-metal"))]
define_encoder_body!();
#[cfg(feature = "wgpu-kernels-metal")]
define_encoder_body!(+ crate::model_burn::FusedOps);