use axonml_autograd::Variable;
use axonml_nn::{Dropout, Linear, Module, Parameter};
use axonml_tensor::Tensor;
use axonml_tensor::creation::{ones, zeros};
use crate::attention::{CausalSelfAttention, MultiHeadSelfAttention};
#[derive(Debug)]
pub struct LayerNorm {
pub weight: Parameter,
pub bias: Parameter,
pub eps: f32,
pub dim: usize,
}
impl LayerNorm {
pub fn new(dim: usize, eps: f32) -> Self {
Self {
weight: Parameter::new(ones::<f32>(&[dim]), true),
bias: Parameter::new(zeros::<f32>(&[dim]), true),
eps,
dim,
}
}
}
impl Module for LayerNorm {
fn forward(&self, input: &Variable) -> Variable {
let mean = input.mean_dim(-1, true);
let variance = input.var_dim(-1, true);
let x_normalized = input.sub(&mean).div(&variance.add_scalar(self.eps).sqrt());
let weight_var = Variable::from_tensor_with_grad(
self.weight.data().clone(),
self.weight.requires_grad(),
);
let bias_var =
Variable::from_tensor_with_grad(self.bias.data().clone(), self.bias.requires_grad());
x_normalized.mul(&weight_var).add(&bias_var)
}
fn parameters(&self) -> Vec<Parameter> {
vec![self.weight.clone(), self.bias.clone()]
}
}
#[derive(Debug)]
pub struct FeedForward {
pub fc1: Linear,
pub fc2: Linear,
pub dropout: Dropout,
pub activation: String,
}
impl FeedForward {
pub fn new(
hidden_size: usize,
intermediate_size: usize,
dropout: f32,
activation: &str,
) -> Self {
Self {
fc1: Linear::new(hidden_size, intermediate_size),
fc2: Linear::new(intermediate_size, hidden_size),
dropout: Dropout::new(dropout),
activation: activation.to_string(),
}
}
fn activate(&self, x: &Variable) -> Variable {
match self.activation.as_str() {
"gelu" => x.gelu(),
"relu" => x.relu(),
"silu" | "swish" => x.silu(),
"tanh" => x.tanh(),
_ => x.gelu(), }
}
}
impl Module for FeedForward {
fn forward(&self, input: &Variable) -> Variable {
let x = self.fc1.forward(input);
let x = self.activate(&x);
let x = self.dropout.forward(&x);
self.fc2.forward(&x)
}
fn parameters(&self) -> Vec<Parameter> {
let mut params = Vec::new();
params.extend(self.fc1.parameters());
params.extend(self.fc2.parameters());
params
}
fn train(&mut self) {
self.dropout.train();
}
fn eval(&mut self) {
self.dropout.eval();
}
}
#[derive(Debug)]
pub struct TransformerEncoderBlock {
pub attention: MultiHeadSelfAttention,
pub ln1: LayerNorm,
pub ffn: FeedForward,
pub ln2: LayerNorm,
pub dropout: Dropout,
pub pre_norm: bool,
}
impl TransformerEncoderBlock {
pub fn new(
hidden_size: usize,
num_heads: usize,
intermediate_size: usize,
dropout: f32,
layer_norm_eps: f32,
activation: &str,
pre_norm: bool,
) -> Self {
Self {
attention: MultiHeadSelfAttention::new(hidden_size, num_heads, dropout),
ln1: LayerNorm::new(hidden_size, layer_norm_eps),
ffn: FeedForward::new(hidden_size, intermediate_size, dropout, activation),
ln2: LayerNorm::new(hidden_size, layer_norm_eps),
dropout: Dropout::new(dropout),
pre_norm,
}
}
pub fn forward_with_mask(
&self,
hidden_states: &Variable,
attention_mask: Option<&Tensor<f32>>,
) -> Variable {
if self.pre_norm {
let residual = hidden_states.clone();
let x = self.ln1.forward(hidden_states);
let x = self.attention.forward_with_mask(&x, attention_mask);
let x = self.dropout.forward(&x);
let x = x.add(&residual);
let residual = x.clone();
let x = self.ln2.forward(&x);
let x = self.ffn.forward(&x);
let x = self.dropout.forward(&x);
x.add(&residual)
} else {
let residual = hidden_states.clone();
let x = self
.attention
.forward_with_mask(hidden_states, attention_mask);
let x = self.dropout.forward(&x);
let x = self.ln1.forward(&x.add(&residual));
let residual = x.clone();
let x = self.ffn.forward(&x);
let x = self.dropout.forward(&x);
self.ln2.forward(&x.add(&residual))
}
}
}
impl Module for TransformerEncoderBlock {
fn forward(&self, input: &Variable) -> Variable {
self.forward_with_mask(input, None)
}
fn parameters(&self) -> Vec<Parameter> {
let mut params = Vec::new();
params.extend(self.attention.parameters());
params.extend(self.ln1.parameters());
params.extend(self.ffn.parameters());
params.extend(self.ln2.parameters());
params
}
fn train(&mut self) {
self.attention.train();
self.ffn.train();
self.dropout.train();
}
fn eval(&mut self) {
self.attention.eval();
self.ffn.eval();
self.dropout.eval();
}
}
#[derive(Debug)]
pub struct TransformerDecoderBlock {
pub attention: CausalSelfAttention,
pub ln1: LayerNorm,
pub ffn: FeedForward,
pub ln2: LayerNorm,
}
impl TransformerDecoderBlock {
pub fn new(
n_embd: usize,
n_head: usize,
max_seq_len: usize,
dropout: f32,
layer_norm_eps: f32,
activation: &str,
) -> Self {
Self {
attention: CausalSelfAttention::new(n_embd, n_head, max_seq_len, dropout),
ln1: LayerNorm::new(n_embd, layer_norm_eps),
ffn: FeedForward::new(n_embd, 4 * n_embd, dropout, activation),
ln2: LayerNorm::new(n_embd, layer_norm_eps),
}
}
}
impl Module for TransformerDecoderBlock {
fn forward(&self, input: &Variable) -> Variable {
let x = input.clone();
let residual = x.clone();
let x = self.ln1.forward(&x);
let x = self.attention.forward(&x);
let x = x.add(&residual);
let residual = x.clone();
let x = self.ln2.forward(&x);
let x = self.ffn.forward(&x);
x.add(&residual)
}
fn parameters(&self) -> Vec<Parameter> {
let mut params = Vec::new();
params.extend(self.attention.parameters());
params.extend(self.ln1.parameters());
params.extend(self.ffn.parameters());
params.extend(self.ln2.parameters());
params
}
fn train(&mut self) {
self.attention.train();
self.ffn.train();
}
fn eval(&mut self) {
self.attention.eval();
self.ffn.eval();
}
}
#[derive(Debug)]
pub struct TransformerEncoder {
pub layers: Vec<TransformerEncoderBlock>,
}
impl TransformerEncoder {
pub fn new(
num_layers: usize,
hidden_size: usize,
num_heads: usize,
intermediate_size: usize,
dropout: f32,
layer_norm_eps: f32,
activation: &str,
pre_norm: bool,
) -> Self {
let layers = (0..num_layers)
.map(|_| {
TransformerEncoderBlock::new(
hidden_size,
num_heads,
intermediate_size,
dropout,
layer_norm_eps,
activation,
pre_norm,
)
})
.collect();
Self { layers }
}
pub fn forward_with_mask(
&self,
hidden_states: &Variable,
attention_mask: Option<&Tensor<f32>>,
) -> Variable {
let mut output = hidden_states.clone();
for layer in &self.layers {
output = layer.forward_with_mask(&output, attention_mask);
}
output
}
}
impl Module for TransformerEncoder {
fn forward(&self, input: &Variable) -> Variable {
self.forward_with_mask(input, None)
}
fn parameters(&self) -> Vec<Parameter> {
self.layers.iter().flat_map(|l| l.parameters()).collect()
}
fn train(&mut self) {
for layer in &mut self.layers {
layer.train();
}
}
fn eval(&mut self) {
for layer in &mut self.layers {
layer.eval();
}
}
}
#[derive(Debug)]
pub struct TransformerDecoder {
pub layers: Vec<TransformerDecoderBlock>,
pub ln_f: LayerNorm,
}
impl TransformerDecoder {
pub fn new(
num_layers: usize,
n_embd: usize,
n_head: usize,
max_seq_len: usize,
dropout: f32,
layer_norm_eps: f32,
activation: &str,
) -> Self {
let layers = (0..num_layers)
.map(|_| {
TransformerDecoderBlock::new(
n_embd,
n_head,
max_seq_len,
dropout,
layer_norm_eps,
activation,
)
})
.collect();
Self {
layers,
ln_f: LayerNorm::new(n_embd, layer_norm_eps),
}
}
}
impl Module for TransformerDecoder {
fn forward(&self, input: &Variable) -> Variable {
let mut output = input.clone();
for layer in &self.layers {
output = layer.forward(&output);
}
self.ln_f.forward(&output)
}
fn parameters(&self) -> Vec<Parameter> {
let mut params: Vec<Parameter> = self.layers.iter().flat_map(|l| l.parameters()).collect();
params.extend(self.ln_f.parameters());
params
}
fn train(&mut self) {
for layer in &mut self.layers {
layer.train();
}
}
fn eval(&mut self) {
for layer in &mut self.layers {
layer.eval();
}
}
}
#[derive(Debug)]
pub enum TransformerBlock {
Encoder(TransformerEncoderBlock),
Decoder(TransformerDecoderBlock),
}
impl Module for TransformerBlock {
fn forward(&self, input: &Variable) -> Variable {
match self {
TransformerBlock::Encoder(block) => block.forward(input),
TransformerBlock::Decoder(block) => block.forward(input),
}
}
fn parameters(&self) -> Vec<Parameter> {
match self {
TransformerBlock::Encoder(block) => block.parameters(),
TransformerBlock::Decoder(block) => block.parameters(),
}
}
fn train(&mut self) {
match self {
TransformerBlock::Encoder(block) => block.train(),
TransformerBlock::Decoder(block) => block.train(),
}
}
fn eval(&mut self) {
match self {
TransformerBlock::Encoder(block) => block.eval(),
TransformerBlock::Decoder(block) => block.eval(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_layer_norm() {
let ln = LayerNorm::new(64, 1e-5);
let input = Variable::new(Tensor::randn(&[2, 8, 64]), false);
let output = ln.forward(&input);
assert_eq!(output.data().shape(), &[2, 8, 64]);
}
#[test]
fn test_feed_forward() {
let ffn = FeedForward::new(64, 256, 0.0, "gelu");
let input = Variable::new(Tensor::randn(&[2, 8, 64]), false);
let output = ffn.forward(&input);
assert_eq!(output.data().shape(), &[2, 8, 64]);
}
#[test]
fn test_encoder_block() {
let block = TransformerEncoderBlock::new(64, 4, 256, 0.0, 1e-5, "gelu", false);
let input = Variable::new(Tensor::randn(&[2, 8, 64]), false);
let output = block.forward(&input);
assert_eq!(output.data().shape(), &[2, 8, 64]);
}
#[test]
fn test_decoder_block() {
let block = TransformerDecoderBlock::new(64, 4, 128, 0.0, 1e-5, "gelu");
let input = Variable::new(Tensor::randn(&[2, 8, 64]), false);
let output = block.forward(&input);
assert_eq!(output.data().shape(), &[2, 8, 64]);
}
#[test]
fn test_transformer_encoder() {
let encoder = TransformerEncoder::new(2, 64, 4, 256, 0.0, 1e-5, "gelu", false);
let input = Variable::new(Tensor::randn(&[2, 8, 64]), false);
let output = encoder.forward(&input);
assert_eq!(output.data().shape(), &[2, 8, 64]);
}
#[test]
fn test_transformer_decoder() {
let decoder = TransformerDecoder::new(2, 64, 4, 128, 0.0, 1e-5, "gelu");
let input = Variable::new(Tensor::randn(&[2, 8, 64]), false);
let output = decoder.forward(&input);
assert_eq!(output.data().shape(), &[2, 8, 64]);
}
}