use std::cell::Cell;
use std::collections::HashMap;
use flodl::nn::{Dropout, GELU, LayerNorm, Linear, Module, NamedInputModule, Parameter};
use flodl::{scaled_dot_product_attention, Device, Result, Variable};
#[cfg(test)]
use flodl::{DType, Tensor, TensorOptions};
use crate::path::prefix_params;
#[derive(Debug, Clone, Copy)]
pub struct LayerNaming {
pub query: &'static str,
pub key: &'static str,
pub value: &'static str,
pub attn_output: &'static str,
pub attn_layer_norm: &'static str,
pub ffn_up: &'static str,
pub ffn_down: &'static str,
pub ffn_layer_norm: &'static str,
}
impl LayerNaming {
pub const BERT: Self = Self {
query: "attention.self.query",
key: "attention.self.key",
value: "attention.self.value",
attn_output: "attention.output.dense",
attn_layer_norm: "attention.output.LayerNorm",
ffn_up: "intermediate.dense",
ffn_down: "output.dense",
ffn_layer_norm: "output.LayerNorm",
};
pub const DISTILBERT: Self = Self {
query: "attention.q_lin",
key: "attention.k_lin",
value: "attention.v_lin",
attn_output: "attention.out_lin",
attn_layer_norm: "sa_layer_norm",
ffn_up: "ffn.lin1",
ffn_down: "ffn.lin2",
ffn_layer_norm: "output_layer_norm",
};
}
#[derive(Debug, Clone, Copy)]
pub struct TransformerLayerConfig {
pub hidden_size: i64,
pub num_attention_heads: i64,
pub intermediate_size: i64,
pub hidden_dropout_prob: f64,
pub attention_probs_dropout_prob: f64,
pub layer_norm_eps: f64,
}
pub struct TransformerLayer {
query: Linear,
key: Linear,
value: Linear,
attn_output: Linear,
attn_layer_norm: LayerNorm,
attn_out_dropout: Dropout,
attn_dropout_prob: f64,
training: Cell<bool>,
num_heads: i64,
head_dim: i64,
ffn_up: Linear,
activation: GELU,
ffn_down: Linear,
ffn_layer_norm: LayerNorm,
ffn_dropout: Dropout,
naming: LayerNaming,
}
impl TransformerLayer {
pub fn on_device(
config: &TransformerLayerConfig,
naming: LayerNaming,
device: Device,
) -> Result<Self> {
assert!(
config.hidden_size % config.num_attention_heads == 0,
"hidden_size ({}) must be divisible by num_attention_heads ({})",
config.hidden_size, config.num_attention_heads,
);
let head_dim = config.hidden_size / config.num_attention_heads;
Ok(TransformerLayer {
query: Linear::on_device(config.hidden_size, config.hidden_size, device)?,
key: Linear::on_device(config.hidden_size, config.hidden_size, device)?,
value: Linear::on_device(config.hidden_size, config.hidden_size, device)?,
attn_output: Linear::on_device(config.hidden_size, config.hidden_size, device)?,
attn_layer_norm: LayerNorm::on_device_with_eps(
config.hidden_size, config.layer_norm_eps, device,
)?,
attn_out_dropout: Dropout::new(config.hidden_dropout_prob),
attn_dropout_prob: config.attention_probs_dropout_prob,
training: Cell::new(true),
num_heads: config.num_attention_heads,
head_dim,
ffn_up: Linear::on_device(config.hidden_size, config.intermediate_size, device)?,
activation: GELU::new(),
ffn_down: Linear::on_device(config.intermediate_size, config.hidden_size, device)?,
ffn_layer_norm: LayerNorm::on_device_with_eps(
config.hidden_size, config.layer_norm_eps, device,
)?,
ffn_dropout: Dropout::new(config.hidden_dropout_prob),
naming,
})
}
fn forward_impl(
&self,
input: &Variable,
attention_mask: Option<&Variable>,
) -> Result<Variable> {
let shape = input.shape();
let batch = shape[0];
let seq = shape[1];
let q = self.query.forward(input)?
.reshape(&[batch, seq, self.num_heads, self.head_dim])?
.transpose(1, 2)?;
let k = self.key.forward(input)?
.reshape(&[batch, seq, self.num_heads, self.head_dim])?
.transpose(1, 2)?;
let v = self.value.forward(input)?
.reshape(&[batch, seq, self.num_heads, self.head_dim])?
.transpose(1, 2)?;
let dropout_p = if self.training.get() { self.attn_dropout_prob } else { 0.0 };
let mask_data = attention_mask.map(|m| m.data());
let context = scaled_dot_product_attention(
&q, &k, &v,
mask_data.as_ref(),
dropout_p,
false,
None,
)?;
let attn_flat = context.transpose(1, 2)?
.reshape(&[batch, seq, self.num_heads * self.head_dim])?;
let attn_proj = self.attn_output.forward(&attn_flat)?;
let attn_dropped = self.attn_out_dropout.forward(&attn_proj)?;
let residual1 = self.attn_layer_norm.forward(&attn_dropped.add(input)?)?;
let ffn_hidden = self.activation.forward(&self.ffn_up.forward(&residual1)?)?;
let ffn_out = self.ffn_down.forward(&ffn_hidden)?;
let ffn_dropped = self.ffn_dropout.forward(&ffn_out)?;
self.ffn_layer_norm.forward(&ffn_dropped.add(&residual1)?)
}
}
impl Module for TransformerLayer {
fn name(&self) -> &str { "transformer_layer" }
fn forward(&self, input: &Variable) -> Result<Variable> {
self.forward_impl(input, None)
}
fn parameters(&self) -> Vec<Parameter> {
let n = self.naming;
let mut out = Vec::new();
out.extend(prefix_params(n.query, self.query.parameters()));
out.extend(prefix_params(n.key, self.key.parameters()));
out.extend(prefix_params(n.value, self.value.parameters()));
out.extend(prefix_params(n.attn_output, self.attn_output.parameters()));
out.extend(prefix_params(n.attn_layer_norm, self.attn_layer_norm.parameters()));
out.extend(prefix_params(n.ffn_up, self.ffn_up.parameters()));
out.extend(prefix_params(n.ffn_down, self.ffn_down.parameters()));
out.extend(prefix_params(n.ffn_layer_norm, self.ffn_layer_norm.parameters()));
out
}
fn as_named_input(&self) -> Option<&dyn NamedInputModule> { Some(self) }
fn set_training(&self, training: bool) {
self.training.set(training);
self.attn_out_dropout.set_training(training);
self.ffn_dropout.set_training(training);
}
}
impl NamedInputModule for TransformerLayer {
fn forward_named(
&self,
input: &Variable,
refs: &HashMap<String, Variable>,
) -> Result<Variable> {
self.forward_impl(input, refs.get("attention_mask"))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn mini_config() -> TransformerLayerConfig {
TransformerLayerConfig {
hidden_size: 8,
num_attention_heads: 2,
intermediate_size: 16,
hidden_dropout_prob: 0.0,
attention_probs_dropout_prob: 0.0,
layer_norm_eps: 1e-12,
}
}
#[test]
fn bert_naming_emits_bert_suffixes() {
let layer = TransformerLayer::on_device(
&mini_config(), LayerNaming::BERT, Device::CPU,
).unwrap();
let names: Vec<String> = layer.parameters().into_iter().map(|p| p.name).collect();
assert!(names.iter().any(|n| n == "attention.self.query.weight"), "got: {names:?}");
assert!(names.iter().any(|n| n == "attention.self.query.bias"), "got: {names:?}");
assert!(names.iter().any(|n| n == "attention.output.dense.weight"), "got: {names:?}");
assert!(names.iter().any(|n| n == "attention.output.LayerNorm.weight"),"got: {names:?}");
assert!(names.iter().any(|n| n == "intermediate.dense.weight"), "got: {names:?}");
assert!(names.iter().any(|n| n == "output.dense.weight"), "got: {names:?}");
assert!(names.iter().any(|n| n == "output.LayerNorm.weight"), "got: {names:?}");
}
#[test]
fn distilbert_naming_emits_distilbert_suffixes() {
let layer = TransformerLayer::on_device(
&mini_config(), LayerNaming::DISTILBERT, Device::CPU,
).unwrap();
let names: Vec<String> = layer.parameters().into_iter().map(|p| p.name).collect();
assert!(names.iter().any(|n| n == "attention.q_lin.weight"), "got: {names:?}");
assert!(names.iter().any(|n| n == "attention.k_lin.weight"), "got: {names:?}");
assert!(names.iter().any(|n| n == "attention.v_lin.weight"), "got: {names:?}");
assert!(names.iter().any(|n| n == "attention.out_lin.weight"), "got: {names:?}");
assert!(names.iter().any(|n| n == "sa_layer_norm.weight"), "got: {names:?}");
assert!(names.iter().any(|n| n == "ffn.lin1.weight"), "got: {names:?}");
assert!(names.iter().any(|n| n == "ffn.lin2.weight"), "got: {names:?}");
assert!(names.iter().any(|n| n == "output_layer_norm.weight"), "got: {names:?}");
}
#[test]
fn parameter_count_identical_across_namings() {
let bert = TransformerLayer::on_device(
&mini_config(), LayerNaming::BERT, Device::CPU,
).unwrap();
let distil = TransformerLayer::on_device(
&mini_config(), LayerNaming::DISTILBERT, Device::CPU,
).unwrap();
assert_eq!(bert.parameters().len(), distil.parameters().len());
assert_eq!(bert.parameters().len(), 16);
}
#[test]
fn forward_runs_end_to_end() {
let layer = TransformerLayer::on_device(
&mini_config(), LayerNaming::BERT, Device::CPU,
).unwrap();
layer.set_training(false);
let x = Variable::new(
Tensor::zeros(
&[2, 4, 8],
TensorOptions { dtype: DType::Float32, device: Device::CPU },
).unwrap(),
false,
);
let out = layer.forward(&x).unwrap();
assert_eq!(out.data().shape(), vec![2, 4, 8]);
}
#[test]
fn zero_additive_mask_matches_unmasked() {
let cfg = mini_config();
let dev = Device::CPU;
let layer = TransformerLayer::on_device(&cfg, LayerNaming::BERT, dev).unwrap();
layer.set_training(false);
let batch = 1;
let seq = 3;
let hidden = cfg.hidden_size;
let x_data: Vec<f32> = (0..(batch * seq * hidden) as usize)
.map(|i| (i as f32) * 0.01)
.collect();
let x = Variable::new(
Tensor::from_f32(&x_data, &[batch, seq, hidden], dev).unwrap(),
false,
);
let zero_mask = Variable::new(
Tensor::zeros(
&[batch, 1, 1, seq],
TensorOptions { dtype: DType::Float32, device: dev },
).unwrap(),
false,
);
let mut refs = HashMap::new();
refs.insert("attention_mask".to_string(), zero_mask);
let unmasked = layer.forward(&x).unwrap();
let with_zero = layer.forward_named(&x, &refs).unwrap();
let a: Vec<f32> = unmasked.data().to_f32_vec().unwrap();
let b: Vec<f32> = with_zero.data().to_f32_vec().unwrap();
assert_eq!(a.len(), b.len());
for (i, (x, y)) in a.iter().zip(b.iter()).enumerate() {
assert!(
(x - y).abs() < 1e-6,
"diverged at {i}: unmasked={x}, zero-masked={y}",
);
}
}
#[test]
fn padding_mask_changes_output() {
use crate::models::bert::build_extended_attention_mask;
let cfg = mini_config();
let dev = Device::CPU;
let layer = TransformerLayer::on_device(&cfg, LayerNaming::BERT, dev).unwrap();
layer.set_training(false);
let batch = 1;
let seq = 4;
let hidden = cfg.hidden_size;
let x_data: Vec<f32> = (0..(batch * seq * hidden) as usize)
.map(|i| ((i as f32) * 0.017).sin())
.collect();
let x = Variable::new(
Tensor::from_f32(&x_data, &[batch, seq, hidden], dev).unwrap(),
false,
);
let raw = Tensor::from_f32(&[1.0, 1.0, 1.0, 0.0], &[batch, seq], dev).unwrap();
let additive = build_extended_attention_mask(&raw).unwrap();
let mut refs = HashMap::new();
refs.insert("attention_mask".to_string(), Variable::new(additive, false));
let unmasked = layer.forward(&x).unwrap();
let masked = layer.forward_named(&x, &refs).unwrap();
let a: Vec<f32> = unmasked.data().to_f32_vec().unwrap();
let b: Vec<f32> = masked.data().to_f32_vec().unwrap();
let max_diff = a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).fold(0.0_f32, f32::max);
assert!(
max_diff > 1e-4,
"masking a position must change attention output; max_diff={max_diff}",
);
}
}