oxidized_transformers/layers/transformer/
layer.rsuse candle_core::{ModuleT, Tensor};
use candle_nn::VarBuilder;
use snafu::{ResultExt, Snafu};
use crate::architectures::{BuildDecoderLayer, DecoderLayer};
use crate::architectures::{BuildEncoderLayer, EncoderLayer};
use crate::error::BoxedError;
use crate::kv_cache::LayerKeyValueCache;
use crate::layers::attention::{Attention, AttentionMask, BuildAttention, SelfAttentionConfig};
use crate::layers::build_module::BuildModule;
use crate::layers::feedforward::PointwiseFeedForwardConfig;
use crate::layers::identity::Identity;
#[derive(Debug)]
pub struct TransformerLayerConfig {
attn_residual_layer_norm: Box<dyn BuildModule>,
attention: SelfAttentionConfig,
feedforward: PointwiseFeedForwardConfig,
ffn_residual_layer_norm: Box<dyn BuildModule>,
parallel_attn_dropout: Box<dyn BuildModule>,
use_parallel_attention: bool,
}
impl TransformerLayerConfig {
fn build_layer(&self, vb: VarBuilder) -> Result<TransformerLayer, TransformerLayerError> {
Ok(TransformerLayer {
attn_residual_layer_norm: self
.attn_residual_layer_norm
.build(vb.push_prefix("attn_residual_layer_norm"))
.context(CreateLayerNormSnafu)?,
ffn: self
.feedforward
.build(vb.push_prefix("ffn"))
.context(BuildPointwiseFeedForwardSnafu)?,
ffn_residual_layer_norm: self
.ffn_residual_layer_norm
.build(vb.push_prefix("ffn_residual_layer_norm"))
.context(CreateLayerNormSnafu)?,
mha: self
.attention
.build(vb.push_prefix("attention"))
.context(BuildAttentionSnafu)?,
parallel_attention_dropout: self
.parallel_attn_dropout
.build(vb.push_prefix("parallel_attention_dropout"))
.context(BuildParallelAttentionDropoutSnafu)?,
use_parallel_attention: self.use_parallel_attention,
})
}
pub fn attn_residual_layer_norm(
mut self,
attn_residual_layer_norm: Box<dyn BuildModule>,
) -> Self {
self.attn_residual_layer_norm = attn_residual_layer_norm;
self
}
pub fn attention(mut self, attention: SelfAttentionConfig) -> Self {
self.attention = attention;
self
}
pub fn feedforward(mut self, feedforward: PointwiseFeedForwardConfig) -> Self {
self.feedforward = feedforward;
self
}
pub fn ffn_residual_layer_norm(
mut self,
ffn_residual_layer_norm: Box<dyn BuildModule>,
) -> Self {
self.ffn_residual_layer_norm = ffn_residual_layer_norm;
self
}
pub fn parallel_attn_dropout(mut self, parallel_attn_dropout: Box<dyn BuildModule>) -> Self {
self.parallel_attn_dropout = parallel_attn_dropout;
self
}
pub fn use_parallel_attention(mut self, use_parallel_attention: bool) -> Self {
self.use_parallel_attention = use_parallel_attention;
self
}
}
impl Default for TransformerLayerConfig {
fn default() -> Self {
Self {
attn_residual_layer_norm: Box::new(Identity),
attention: SelfAttentionConfig::default(),
feedforward: PointwiseFeedForwardConfig::default(),
ffn_residual_layer_norm: Box::new(Identity),
parallel_attn_dropout: Box::new(Identity),
use_parallel_attention: false,
}
}
}
impl BuildDecoderLayer for TransformerLayerConfig {
type Cache = LayerKeyValueCache;
fn build_decoder_layer(
&self,
vb: VarBuilder,
) -> Result<Box<dyn DecoderLayer<Cache = Self::Cache>>, BoxedError> {
Ok(Box::new(TransformerDecoderLayer {
inner: self.build_layer(vb)?,
}))
}
}
impl BuildEncoderLayer for TransformerLayerConfig {
fn build_encoder_layer(&self, vb: VarBuilder) -> Result<Box<dyn EncoderLayer>, BoxedError> {
Ok(Box::new(TransformerEncoderLayer {
inner: self.build_layer(vb)?,
}))
}
}
#[derive(Debug, Snafu)]
pub enum TransformerLayerError {
#[snafu(display("Cannot build attention layer"))]
BuildAttention { source: BoxedError },
#[snafu(display("Cannot build parallel attention dropout"))]
BuildParallelAttentionDropout { source: BoxedError },
#[snafu(display("Cannot build pointwise feed-forward layer"))]
BuildPointwiseFeedForward { source: BoxedError },
#[snafu(display("Cannot create layer norm"))]
CreateLayerNorm { source: BoxedError },
#[snafu(display("Cannot apply point-wise feed-forward layer"))]
FeedForward { source: candle_core::Error },
#[snafu(display("Cannot apply parallel attention"))]
ParallelAttention { source: candle_core::Error },
#[snafu(display("Cannot apply residual connection"))]
Residual { source: candle_core::Error },
#[snafu(display("Cannot apply self-attention"))]
SelfAttention { source: BoxedError },
}
struct TransformerLayer {
attn_residual_layer_norm: Box<dyn ModuleT>,
ffn_residual_layer_norm: Box<dyn ModuleT>,
mha: Box<dyn Attention>,
parallel_attention_dropout: Box<dyn ModuleT>,
ffn: Box<dyn ModuleT>,
use_parallel_attention: bool,
}
impl TransformerLayer {
#[allow(clippy::too_many_arguments)]
fn forward(
&self,
input: &Tensor,
attention_mask: &AttentionMask,
cache: &mut LayerKeyValueCache,
positions: Option<&Tensor>,
train: bool,
use_causal_mask: bool,
) -> Result<Tensor, TransformerLayerError> {
let mut residual = input.clone();
let attn_out = self
.mha
.forward_t(
input,
attention_mask,
cache,
positions,
train,
use_causal_mask,
)
.context(SelfAttentionSnafu)?;
let ffn_in = if self.use_parallel_attention {
input
} else {
residual = (residual + &attn_out)
.and_then(|xs| self.attn_residual_layer_norm.forward_t(&xs, train))
.context(ResidualSnafu)?;
&residual
};
let ffn_out = self
.ffn
.forward_t(ffn_in, train)
.context(FeedForwardSnafu)?;
let output = if self.use_parallel_attention {
(attn_out + ffn_out)
.and_then(|xs| self.parallel_attention_dropout.forward_t(&xs, train))
.context(ParallelAttentionSnafu)?
} else {
ffn_out
};
let output = (residual + output)
.and_then(|xs| self.ffn_residual_layer_norm.forward_t(&xs, train))
.context(ResidualSnafu)?;
Ok(output)
}
}
pub struct TransformerDecoderLayer {
inner: TransformerLayer,
}
impl DecoderLayer for TransformerDecoderLayer {
type Cache = LayerKeyValueCache;
fn forward_t(
&self,
input: &Tensor,
attention_mask: &AttentionMask,
cache: &mut Self::Cache,
positions: Option<&Tensor>,
train: bool,
) -> Result<Tensor, BoxedError> {
Ok(self
.inner
.forward(input, attention_mask, cache, positions, train, true)?)
}
}
pub struct TransformerEncoderLayer {
inner: TransformerLayer,
}
impl EncoderLayer for TransformerEncoderLayer {
fn forward_t(
&self,
input: &Tensor,
attention_mask: &AttentionMask,
positions: Option<&Tensor>,
train: bool,
) -> Result<Tensor, BoxedError> {
self.inner
.forward(
input,
attention_mask,
&mut LayerKeyValueCache::no_cache(),
positions,
train,
false,
)
.boxed()
}
}