use super::attention_gqa::{add_positional_encoding, slice_pe};
#[allow(clippy::wildcard_imports)]
use super::*;
impl TransformerDecoderLayer {
#[must_use]
pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize) -> Self {
Self {
self_attn: MultiHeadAttention::new(d_model, nhead),
cross_attn: MultiHeadAttention::new(d_model, nhead),
linear1: Linear::new(d_model, dim_feedforward),
linear2: Linear::new(dim_feedforward, d_model),
norm1: LayerNorm::new(&[d_model]),
norm2: LayerNorm::new(&[d_model]),
norm3: LayerNorm::new(&[d_model]),
dropout: Dropout::new(0.1),
dropout1: Dropout::new(0.1),
dropout2: Dropout::new(0.1),
dropout3: Dropout::new(0.1),
d_model,
training: true,
}
}
pub fn forward_with_memory(
&self,
tgt: &Tensor,
memory: &Tensor,
tgt_mask: Option<&Tensor>,
memory_mask: Option<&Tensor>,
) -> Tensor {
let tgt_norm = self.norm1.forward(tgt);
let (attn_out, _) = self.self_attn.forward_self(&tgt_norm, tgt_mask);
let attn_out = self.dropout1.forward(&attn_out);
let tgt = tgt.add(&attn_out);
let tgt_norm = self.norm2.forward(&tgt);
let (cross_out, _) = self
.cross_attn
.forward_qkv(&tgt_norm, memory, memory, memory_mask);
let cross_out = self.dropout2.forward(&cross_out);
let tgt = tgt.add(&cross_out);
let tgt_norm = self.norm3.forward(&tgt);
let ff_out = self.linear1.forward(&tgt_norm);
let ff_out = gelu(&ff_out);
let ff_out = self.dropout.forward(&ff_out);
let ff_out = self.linear2.forward(&ff_out);
let ff_out = self.dropout3.forward(&ff_out);
tgt.add(&ff_out)
}
}
impl Module for TransformerDecoderLayer {
fn forward(&self, input: &Tensor) -> Tensor {
self.forward_with_memory(input, input, None, None)
}
fn parameters(&self) -> Vec<&Tensor> {
let mut params = self.self_attn.parameters();
params.extend(self.cross_attn.parameters());
params.extend(self.linear1.parameters());
params.extend(self.linear2.parameters());
params.extend(self.norm1.parameters());
params.extend(self.norm2.parameters());
params.extend(self.norm3.parameters());
params
}
fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
let mut params = self.self_attn.parameters_mut();
params.extend(self.cross_attn.parameters_mut());
params.extend(self.linear1.parameters_mut());
params.extend(self.linear2.parameters_mut());
params.extend(self.norm1.parameters_mut());
params.extend(self.norm2.parameters_mut());
params.extend(self.norm3.parameters_mut());
params
}
fn train(&mut self) {
self.training = true;
self.self_attn.train();
self.cross_attn.train();
self.dropout.train();
self.dropout1.train();
self.dropout2.train();
self.dropout3.train();
}
fn eval(&mut self) {
self.training = false;
self.self_attn.eval();
self.cross_attn.eval();
self.dropout.eval();
self.dropout1.eval();
self.dropout2.eval();
self.dropout3.eval();
}
fn training(&self) -> bool {
self.training
}
}
impl std::fmt::Debug for TransformerDecoderLayer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TransformerDecoderLayer")
.field("d_model", &self.d_model)
.field("self_attn", &self.self_attn)
.field("cross_attn", &self.cross_attn)
.finish_non_exhaustive()
}
}
#[derive(Debug)]
pub struct PositionalEncoding {
d_model: usize,
max_len: usize,
dropout: Dropout,
pe: Tensor,
training: bool,
}
impl PositionalEncoding {
#[must_use]
pub fn new(d_model: usize, max_len: usize) -> Self {
let pe = compute_positional_encoding(d_model, max_len);
Self {
d_model,
max_len,
dropout: Dropout::new(0.1),
pe,
training: true,
}
}
pub fn with_dropout(mut self, dropout: f32) -> Self {
self.dropout = Dropout::new(dropout);
self
}
}
impl Module for PositionalEncoding {
fn forward(&self, input: &Tensor) -> Tensor {
let seq_len = input.shape()[1];
assert!(
seq_len <= self.max_len,
"Sequence length {seq_len} exceeds max_len {}",
self.max_len
);
let pe_slice = slice_pe(&self.pe, seq_len, self.d_model);
let output = add_positional_encoding(input, &pe_slice);
self.dropout.forward(&output)
}
fn train(&mut self) {
self.training = true;
self.dropout.train();
}
fn eval(&mut self) {
self.training = false;
self.dropout.eval();
}
fn training(&self) -> bool {
self.training
}
}
pub(crate) fn transpose_last_two(x: &Tensor) -> Tensor {
let shape = x.shape();
let ndim = shape.len();
if ndim < 2 {
return x.clone();
}
let last = shape[ndim - 1];
let second_last = shape[ndim - 2];
let mut new_shape = shape.to_vec();
new_shape[ndim - 2] = last;
new_shape[ndim - 1] = second_last;
let batch_size: usize = shape[..ndim - 2].iter().product();
let matrix_size = last * second_last;
let mut output = vec![0.0; x.data().len()];
const TILE: usize = 32;
let src = x.data();
for b in 0..batch_size {
let offset = b * matrix_size;
for i0 in (0..second_last).step_by(TILE) {
let i_end = (i0 + TILE).min(second_last);
for j0 in (0..last).step_by(TILE) {
let j_end = (j0 + TILE).min(last);
for i in i0..i_end {
let src_base = offset + i * last;
for j in j0..j_end {
output[offset + j * second_last + i] = src[src_base + j];
}
}
}
}
}
Tensor::from_vec(output, &new_shape)
}
#[allow(clippy::expect_used)]
pub(crate) fn matmul_batched(a: &Tensor, b: &Tensor) -> Tensor {
let a_shape = a.shape();
let b_shape = b.shape();
if a_shape.len() == 4 && b_shape.len() == 4 {
let (batch, heads, m, k1) = (a_shape[0], a_shape[1], a_shape[2], a_shape[3]);
let k2 = b_shape[2];
let n = b_shape[3];
assert_eq!(k1, k2, "Inner dimensions must match for matmul");
let output = Matrix::batched_matmul_4d(a.data(), b.data(), batch, heads, m, k1, n)
.expect("batched_matmul_4d failed: dimensions validated but operation failed");
Tensor::from_vec(output, &[batch, heads, m, n])
} else {
a.matmul(b)
}
}
pub(super) fn scale_tensor(x: &Tensor, scale: f32) -> Tensor {
x.mul_scalar(scale)
}
pub(super) fn add_mask(scores: &Tensor, mask: &Tensor) -> Tensor {
if scores.shape() == mask.shape() {
return scores.add(mask);
}
let data: Vec<f32> = scores
.data()
.iter()
.zip(mask.data().iter())
.map(|(&s, &m)| s + m)
.collect();
Tensor::from_vec(data, scores.shape())
}
pub(super) fn softmax_last_dim(x: &Tensor) -> Tensor {
crate::nn::functional::softmax(x, -1)
}
pub(super) fn apply_dropout(x: &Tensor, p: f32) -> Tensor {
crate::nn::functional::dropout(x, p, true)
}
pub(super) fn reshape_for_attention(
x: &Tensor,
batch: usize,
seq_len: usize,
num_heads: usize,
head_dim: usize,
) -> Tensor {
let mut output = vec![0.0; batch * num_heads * seq_len * head_dim];
for b in 0..batch {
for s in 0..seq_len {
for h in 0..num_heads {
for d in 0..head_dim {
let in_idx = b * seq_len * (num_heads * head_dim)
+ s * (num_heads * head_dim)
+ h * head_dim
+ d;
let out_idx = b * num_heads * seq_len * head_dim
+ h * seq_len * head_dim
+ s * head_dim
+ d;
output[out_idx] = x.data()[in_idx];
}
}
}
}
Tensor::from_vec(output, &[batch, num_heads, seq_len, head_dim])
}
pub(crate) fn reshape_from_attention(
x: &Tensor,
batch: usize,
seq_len: usize,
embed_dim: usize,
) -> Tensor {
let num_heads = x.shape()[1];
let head_dim = x.shape()[3];
let mut output = vec![0.0; batch * seq_len * embed_dim];
for b in 0..batch {
for s in 0..seq_len {
for h in 0..num_heads {
for d in 0..head_dim {
let in_idx = b * num_heads * seq_len * head_dim
+ h * seq_len * head_dim
+ s * head_dim
+ d;
let out_idx = b * seq_len * embed_dim + s * embed_dim + h * head_dim + d;
output[out_idx] = x.data()[in_idx];
}
}
}
}
Tensor::from_vec(output, &[batch, seq_len, embed_dim])
}
pub(super) fn gelu(x: &Tensor) -> Tensor {
crate::nn::functional::gelu(x)
}
fn compute_positional_encoding(d_model: usize, max_len: usize) -> Tensor {
let mut pe = vec![0.0; max_len * d_model];
for pos in 0..max_len {
for i in 0..d_model / 2 {
let angle = pos as f32 / 10000_f32.powf(2.0 * i as f32 / d_model as f32);
pe[pos * d_model + 2 * i] = angle.sin();
pe[pos * d_model + 2 * i + 1] = angle.cos();
}
}
Tensor::new(&pe, &[max_len, d_model])
}