use crate::error::NumRs2Error;
use scirs2_core::ndarray::{
s, Array, Array1, Array2, Array3, ArrayView, ArrayView1, ArrayView2, ArrayView3, Axis,
ScalarOperand,
};
use scirs2_core::numeric::Float;
use scirs2_core::random::*;
use scirs2_core::simd_ops::SimdUnifiedOps;
pub type TransformerResult<T> = Result<T, NumRs2Error>;
#[derive(Debug, Clone)]
pub struct MultiHeadAttention<T>
where
T: Float + SimdUnifiedOps,
{
pub d_model: usize,
pub num_heads: usize,
pub d_k: usize,
pub w_q: Array2<T>,
pub w_k: Array2<T>,
pub w_v: Array2<T>,
pub w_o: Array2<T>,
pub dropout_p: T,
}
impl<T> MultiHeadAttention<T>
where
T: Float + SimdUnifiedOps + ScalarOperand,
{
pub fn new(d_model: usize, num_heads: usize, dropout: f64) -> TransformerResult<Self> {
if d_model == 0 {
return Err(NumRs2Error::InvalidOperation(
"d_model must be greater than 0".to_string(),
));
}
if num_heads == 0 {
return Err(NumRs2Error::InvalidOperation(
"num_heads must be greater than 0".to_string(),
));
}
if !d_model.is_multiple_of(num_heads) {
return Err(NumRs2Error::DimensionMismatch(format!(
"d_model ({}) must be divisible by num_heads ({})",
d_model, num_heads
)));
}
if !(0.0..1.0).contains(&dropout) {
return Err(NumRs2Error::InvalidOperation(format!(
"dropout must be in [0, 1), got {}",
dropout
)));
}
let d_k = d_model / num_heads;
let scale = T::from(1.0).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert constant".to_string())
})? / T::from(d_model)
.ok_or_else(|| NumRs2Error::ConversionError("Failed to convert dimension".to_string()))?
.sqrt();
let mut rng = thread_rng();
let w_q = Array2::from_shape_fn((d_model, d_model), |_| {
let u: f64 = rng.random();
T::from(u * 2.0 - 1.0).expect("Conversion should succeed for f64 to Float") * scale
});
let w_k = Array2::from_shape_fn((d_model, d_model), |_| {
let u: f64 = rng.random();
T::from(u * 2.0 - 1.0).expect("Conversion should succeed for f64 to Float") * scale
});
let w_v = Array2::from_shape_fn((d_model, d_model), |_| {
let u: f64 = rng.random();
T::from(u * 2.0 - 1.0).expect("Conversion should succeed for f64 to Float") * scale
});
let w_o = Array2::from_shape_fn((d_model, d_model), |_| {
let u: f64 = rng.random();
T::from(u * 2.0 - 1.0).expect("Conversion should succeed for f64 to Float") * scale
});
let dropout_p = T::from(dropout).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert dropout probability".to_string())
})?;
Ok(Self {
d_model,
num_heads,
d_k,
w_q,
w_k,
w_v,
w_o,
dropout_p,
})
}
pub fn forward(
&self,
x: &ArrayView2<T>,
mask: Option<&ArrayView2<T>>,
training: bool,
) -> TransformerResult<Array2<T>> {
let (seq_len, input_dim) = x.dim();
if input_dim != self.d_model {
return Err(NumRs2Error::DimensionMismatch(format!(
"Input dimension {} does not match d_model {}",
input_dim, self.d_model
)));
}
let q = x.dot(&self.w_q); let k = x.dot(&self.w_k); let v = x.dot(&self.w_v);
let q_heads = self.split_heads(&q.view())?;
let k_heads = self.split_heads(&k.view())?;
let v_heads = self.split_heads(&v.view())?;
let mut attended = Array3::zeros((self.num_heads, seq_len, self.d_k));
for h in 0..self.num_heads {
let q_h = q_heads.slice(s![h, .., ..]);
let k_h = k_heads.slice(s![h, .., ..]);
let v_h = v_heads.slice(s![h, .., ..]);
let attn_output = self.scaled_dot_product_attention(
&q_h.to_owned().view(),
&k_h.to_owned().view(),
&v_h.to_owned().view(),
mask,
training,
)?;
attended
.slice_mut(s![h, .., ..])
.assign(&attn_output.view());
}
let concat = self.combine_heads(&attended.view())?;
let output = concat.dot(&self.w_o);
Ok(output)
}
fn split_heads(&self, x: &ArrayView2<T>) -> TransformerResult<Array3<T>> {
let (seq_len, d_model) = x.dim();
if d_model != self.d_model {
return Err(NumRs2Error::DimensionMismatch(
"Input dimension mismatch".to_string(),
));
}
let mut result = Array3::zeros((self.num_heads, seq_len, self.d_k));
for h in 0..self.num_heads {
for i in 0..seq_len {
for j in 0..self.d_k {
result[[h, i, j]] = x[[i, h * self.d_k + j]];
}
}
}
Ok(result)
}
fn combine_heads(&self, x: &ArrayView3<T>) -> TransformerResult<Array2<T>> {
let (num_heads, seq_len, d_k) = x.dim();
if num_heads != self.num_heads || d_k != self.d_k {
return Err(NumRs2Error::DimensionMismatch(
"Head dimensions mismatch".to_string(),
));
}
let mut result = Array2::zeros((seq_len, self.d_model));
for h in 0..self.num_heads {
for i in 0..seq_len {
for j in 0..self.d_k {
result[[i, h * self.d_k + j]] = x[[h, i, j]];
}
}
}
Ok(result)
}
fn scaled_dot_product_attention(
&self,
q: &ArrayView2<T>,
k: &ArrayView2<T>,
v: &ArrayView2<T>,
mask: Option<&ArrayView2<T>>,
training: bool,
) -> TransformerResult<Array2<T>> {
let d_k_float = T::from(self.d_k)
.ok_or_else(|| NumRs2Error::ConversionError("Failed to convert d_k".to_string()))?;
let scale = T::one() / d_k_float.sqrt();
let mut scores = q.dot(&k.t()) * scale;
if let Some(m) = mask {
let neg_inf = T::from(-1e9).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert mask value".to_string())
})?;
let zero = T::zero();
for i in 0..scores.nrows() {
for j in 0..scores.ncols() {
if m[[i, j]] == zero {
scores[[i, j]] = neg_inf;
}
}
}
}
let attention_weights = self.softmax_2d(&scores.view(), 1)?;
let weights = if training && self.dropout_p > T::zero() {
self.apply_dropout(&attention_weights.view())?
} else {
attention_weights
};
let output = weights.dot(v);
Ok(output)
}
fn softmax_2d(&self, x: &ArrayView2<T>, axis: usize) -> TransformerResult<Array2<T>> {
let shape = x.dim();
let mut result = Array2::zeros(shape);
if axis == 0 {
for j in 0..shape.1 {
let col = x.column(j);
let max_val = col
.iter()
.fold(T::neg_infinity(), |a, &b| if a > b { a } else { b });
let mut sum = T::zero();
let mut exp_vals = Array1::zeros(shape.0);
for i in 0..shape.0 {
exp_vals[i] = (col[i] - max_val).exp();
sum = sum + exp_vals[i];
}
for i in 0..shape.0 {
result[[i, j]] = exp_vals[i] / sum;
}
}
} else {
for i in 0..shape.0 {
let row = x.row(i);
let max_val = row
.iter()
.fold(T::neg_infinity(), |a, &b| if a > b { a } else { b });
let mut sum = T::zero();
let mut exp_vals = Array1::zeros(shape.1);
for j in 0..shape.1 {
exp_vals[j] = (row[j] - max_val).exp();
sum = sum + exp_vals[j];
}
for j in 0..shape.1 {
result[[i, j]] = exp_vals[j] / sum;
}
}
}
Ok(result)
}
fn apply_dropout(&self, x: &ArrayView2<T>) -> TransformerResult<Array2<T>> {
let mut rng = thread_rng();
let threshold = self
.dropout_p
.to_f64()
.ok_or_else(|| NumRs2Error::ConversionError("Failed to convert dropout".to_string()))?;
let scale = T::one() / (T::one() - self.dropout_p);
let mask = Array2::from_shape_fn(x.raw_dim(), |_| {
if rng.random::<f64>() > threshold {
scale
} else {
T::zero()
}
});
Ok(x * &mask)
}
}
#[derive(Debug, Clone)]
pub struct PositionalEncoding<T>
where
T: Float,
{
pub max_len: usize,
pub d_model: usize,
pub encoding_type: PositionalEncodingType,
pub encodings: Array2<T>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PositionalEncodingType {
Sinusoidal,
Learned,
}
impl<T> PositionalEncoding<T>
where
T: Float + SimdUnifiedOps,
{
pub fn new(
max_len: usize,
d_model: usize,
encoding_type: PositionalEncodingType,
) -> TransformerResult<Self> {
if max_len == 0 {
return Err(NumRs2Error::InvalidOperation(
"max_len must be greater than 0".to_string(),
));
}
if d_model == 0 {
return Err(NumRs2Error::InvalidOperation(
"d_model must be greater than 0".to_string(),
));
}
if encoding_type == PositionalEncodingType::Sinusoidal && !d_model.is_multiple_of(2) {
return Err(NumRs2Error::InvalidOperation(
"d_model must be even for sinusoidal encoding".to_string(),
));
}
let encodings = match encoding_type {
PositionalEncodingType::Sinusoidal => Self::create_sinusoidal(max_len, d_model)?,
PositionalEncodingType::Learned => Self::create_learned(max_len, d_model)?,
};
Ok(Self {
max_len,
d_model,
encoding_type,
encodings,
})
}
fn create_sinusoidal(max_len: usize, d_model: usize) -> TransformerResult<Array2<T>> {
let mut pe = Array2::zeros((max_len, d_model));
let two = T::from(2.0).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert constant".to_string())
})?;
let ten_thousand = T::from(10000.0).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert constant".to_string())
})?;
for pos in 0..max_len {
let pos_t = T::from(pos).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert position".to_string())
})?;
for i in 0..(d_model / 2) {
let i_t = T::from(i).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert index".to_string())
})?;
let d_model_t = T::from(d_model).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert dimension".to_string())
})?;
let div_term = two * i_t / d_model_t;
let angle = pos_t / ten_thousand.powf(div_term);
pe[[pos, 2 * i]] = angle.sin();
pe[[pos, 2 * i + 1]] = angle.cos();
}
}
Ok(pe)
}
fn create_learned(max_len: usize, d_model: usize) -> TransformerResult<Array2<T>> {
let mut rng = thread_rng();
let scale = T::from(0.02)
.ok_or_else(|| NumRs2Error::ConversionError("Failed to convert scale".to_string()))?;
let pe = Array2::from_shape_fn((max_len, d_model), |_| {
let u: f64 = rng.random();
T::from(u * 2.0 - 1.0).expect("Conversion should succeed for f64 to Float") * scale
});
Ok(pe)
}
pub fn forward(&self, x: &ArrayView2<T>) -> TransformerResult<Array2<T>>
where
T: ScalarOperand,
{
let (seq_len, d_model) = x.dim();
if d_model != self.d_model {
return Err(NumRs2Error::DimensionMismatch(format!(
"Input dimension {} does not match d_model {}",
d_model, self.d_model
)));
}
if seq_len > self.max_len {
return Err(NumRs2Error::InvalidOperation(format!(
"Sequence length {} exceeds maximum {}",
seq_len, self.max_len
)));
}
let pe_slice = self.encodings.slice(s![0..seq_len, ..]);
Ok(x + &pe_slice)
}
}
#[derive(Debug, Clone)]
pub struct PositionwiseFeedForward<T>
where
T: Float + SimdUnifiedOps,
{
pub d_model: usize,
pub d_ff: usize,
pub w1: Array2<T>,
pub b1: Array1<T>,
pub w2: Array2<T>,
pub b2: Array1<T>,
pub dropout_p: T,
}
impl<T> PositionwiseFeedForward<T>
where
T: Float + SimdUnifiedOps + ScalarOperand,
{
pub fn new(d_model: usize, d_ff: usize, dropout: f64) -> TransformerResult<Self> {
if d_model == 0 {
return Err(NumRs2Error::InvalidOperation(
"d_model must be greater than 0".to_string(),
));
}
if d_ff == 0 {
return Err(NumRs2Error::InvalidOperation(
"d_ff must be greater than 0".to_string(),
));
}
if !(0.0..1.0).contains(&dropout) {
return Err(NumRs2Error::InvalidOperation(format!(
"dropout must be in [0, 1), got {}",
dropout
)));
}
let mut rng = thread_rng();
let scale1 = T::from(1.0).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert constant".to_string())
})? / T::from(d_model)
.ok_or_else(|| NumRs2Error::ConversionError("Failed to convert dimension".to_string()))?
.sqrt();
let scale2 = T::from(1.0).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert constant".to_string())
})? / T::from(d_ff)
.ok_or_else(|| NumRs2Error::ConversionError("Failed to convert dimension".to_string()))?
.sqrt();
let w1 = Array2::from_shape_fn((d_model, d_ff), |_| {
let u: f64 = rng.random();
T::from(u * 2.0 - 1.0).expect("Conversion should succeed for f64 to Float") * scale1
});
let b1 = Array1::zeros(d_ff);
let w2 = Array2::from_shape_fn((d_ff, d_model), |_| {
let u: f64 = rng.random();
T::from(u * 2.0 - 1.0).expect("Conversion should succeed for f64 to Float") * scale2
});
let b2 = Array1::zeros(d_model);
let dropout_p = T::from(dropout).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert dropout probability".to_string())
})?;
Ok(Self {
d_model,
d_ff,
w1,
b1,
w2,
b2,
dropout_p,
})
}
pub fn forward(&self, x: &ArrayView2<T>, training: bool) -> TransformerResult<Array2<T>> {
let (seq_len, d_model) = x.dim();
if d_model != self.d_model {
return Err(NumRs2Error::DimensionMismatch(format!(
"Input dimension {} does not match d_model {}",
d_model, self.d_model
)));
}
let mut hidden = x.dot(&self.w1);
for i in 0..seq_len {
for j in 0..self.d_ff {
hidden[[i, j]] = hidden[[i, j]] + self.b1[j];
}
}
hidden = self.gelu(&hidden.view())?;
if training && self.dropout_p > T::zero() {
hidden = self.apply_dropout(&hidden.view())?;
}
let mut output = hidden.dot(&self.w2);
for i in 0..seq_len {
for j in 0..self.d_model {
output[[i, j]] = output[[i, j]] + self.b2[j];
}
}
Ok(output)
}
fn gelu(&self, x: &ArrayView2<T>) -> TransformerResult<Array2<T>> {
let sqrt_2_over_pi = T::from(0.7978845608028654).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert constant".to_string())
})?;
let coeff = T::from(0.044715).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert constant".to_string())
})?;
let half = T::from(0.5).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert constant".to_string())
})?;
let one = T::one();
let result = x.mapv(|val| {
let x_cubed = val * val * val;
let inner = sqrt_2_over_pi * (val + coeff * x_cubed);
half * val * (one + inner.tanh())
});
Ok(result)
}
fn apply_dropout(&self, x: &ArrayView2<T>) -> TransformerResult<Array2<T>> {
let mut rng = thread_rng();
let threshold = self
.dropout_p
.to_f64()
.ok_or_else(|| NumRs2Error::ConversionError("Failed to convert dropout".to_string()))?;
let scale = T::one() / (T::one() - self.dropout_p);
let mask = Array2::from_shape_fn(x.raw_dim(), |_| {
if rng.random::<f64>() > threshold {
scale
} else {
T::zero()
}
});
Ok(x * &mask)
}
}
#[derive(Debug, Clone)]
pub struct TransformerEncoderLayer<T>
where
T: Float + SimdUnifiedOps,
{
pub attention: MultiHeadAttention<T>,
pub feed_forward: PositionwiseFeedForward<T>,
pub norm1_gamma: Array1<T>,
pub norm1_beta: Array1<T>,
pub norm2_gamma: Array1<T>,
pub norm2_beta: Array1<T>,
pub norm_eps: T,
}
impl<T> TransformerEncoderLayer<T>
where
T: Float + SimdUnifiedOps + ScalarOperand,
{
pub fn new(
d_model: usize,
num_heads: usize,
d_ff: usize,
dropout: f64,
) -> TransformerResult<Self> {
let attention = MultiHeadAttention::new(d_model, num_heads, dropout)?;
let feed_forward = PositionwiseFeedForward::new(d_model, d_ff, dropout)?;
let norm1_gamma = Array1::ones(d_model);
let norm1_beta = Array1::zeros(d_model);
let norm2_gamma = Array1::ones(d_model);
let norm2_beta = Array1::zeros(d_model);
let norm_eps = T::from(1e-5)
.ok_or_else(|| NumRs2Error::ConversionError("Failed to convert epsilon".to_string()))?;
Ok(Self {
attention,
feed_forward,
norm1_gamma,
norm1_beta,
norm2_gamma,
norm2_beta,
norm_eps,
})
}
pub fn forward(
&self,
x: &ArrayView2<T>,
mask: Option<&ArrayView2<T>>,
training: bool,
) -> TransformerResult<Array2<T>> {
let attn_output = self.attention.forward(x, mask, training)?;
let attn_residual = x + &attn_output;
let norm1 = self.layer_norm(&attn_residual.view(), &self.norm1_gamma, &self.norm1_beta)?;
let ff_output = self.feed_forward.forward(&norm1.view(), training)?;
let ff_residual = &norm1 + &ff_output;
let output = self.layer_norm(&ff_residual.view(), &self.norm2_gamma, &self.norm2_beta)?;
Ok(output)
}
fn layer_norm(
&self,
x: &ArrayView2<T>,
gamma: &Array1<T>,
beta: &Array1<T>,
) -> TransformerResult<Array2<T>> {
let (seq_len, d_model) = x.dim();
let n_features = T::from(d_model).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert feature count".to_string())
})?;
let mut result = Array2::zeros((seq_len, d_model));
for i in 0..seq_len {
let row = x.row(i);
let mean = row.sum() / n_features;
let var = row.mapv(|v| (v - mean) * (v - mean)).sum() / n_features;
let std = (var + self.norm_eps).sqrt();
for j in 0..d_model {
result[[i, j]] = (x[[i, j]] - mean) / std * gamma[j] + beta[j];
}
}
Ok(result)
}
}
#[derive(Debug, Clone)]
pub struct TransformerDecoderLayer<T>
where
T: Float + SimdUnifiedOps,
{
pub self_attention: MultiHeadAttention<T>,
pub cross_attention: MultiHeadAttention<T>,
pub feed_forward: PositionwiseFeedForward<T>,
pub norm1_gamma: Array1<T>,
pub norm1_beta: Array1<T>,
pub norm2_gamma: Array1<T>,
pub norm2_beta: Array1<T>,
pub norm3_gamma: Array1<T>,
pub norm3_beta: Array1<T>,
pub norm_eps: T,
}
impl<T> TransformerDecoderLayer<T>
where
T: Float + SimdUnifiedOps + ScalarOperand,
{
pub fn new(
d_model: usize,
num_heads: usize,
d_ff: usize,
dropout: f64,
) -> TransformerResult<Self> {
let self_attention = MultiHeadAttention::new(d_model, num_heads, dropout)?;
let cross_attention = MultiHeadAttention::new(d_model, num_heads, dropout)?;
let feed_forward = PositionwiseFeedForward::new(d_model, d_ff, dropout)?;
let norm1_gamma = Array1::ones(d_model);
let norm1_beta = Array1::zeros(d_model);
let norm2_gamma = Array1::ones(d_model);
let norm2_beta = Array1::zeros(d_model);
let norm3_gamma = Array1::ones(d_model);
let norm3_beta = Array1::zeros(d_model);
let norm_eps = T::from(1e-5)
.ok_or_else(|| NumRs2Error::ConversionError("Failed to convert epsilon".to_string()))?;
Ok(Self {
self_attention,
cross_attention,
feed_forward,
norm1_gamma,
norm1_beta,
norm2_gamma,
norm2_beta,
norm3_gamma,
norm3_beta,
norm_eps,
})
}
pub fn forward(
&self,
x: &ArrayView2<T>,
encoder_output: &ArrayView2<T>,
tgt_mask: Option<&ArrayView2<T>>,
memory_mask: Option<&ArrayView2<T>>,
training: bool,
) -> TransformerResult<Array2<T>> {
let self_attn = self.self_attention.forward(x, tgt_mask, training)?;
let residual1 = x + &self_attn;
let norm1 = self.layer_norm(&residual1.view(), &self.norm1_gamma, &self.norm1_beta)?;
let cross_attn =
self.cross_attention_forward(&norm1.view(), encoder_output, memory_mask, training)?;
let residual2 = &norm1 + &cross_attn;
let norm2 = self.layer_norm(&residual2.view(), &self.norm2_gamma, &self.norm2_beta)?;
let ff_output = self.feed_forward.forward(&norm2.view(), training)?;
let residual3 = &norm2 + &ff_output;
let output = self.layer_norm(&residual3.view(), &self.norm3_gamma, &self.norm3_beta)?;
Ok(output)
}
fn cross_attention_forward(
&self,
query: &ArrayView2<T>,
key_value: &ArrayView2<T>,
mask: Option<&ArrayView2<T>>,
training: bool,
) -> TransformerResult<Array2<T>> {
self.cross_attention.forward(query, mask, training)
}
fn layer_norm(
&self,
x: &ArrayView2<T>,
gamma: &Array1<T>,
beta: &Array1<T>,
) -> TransformerResult<Array2<T>> {
let (seq_len, d_model) = x.dim();
let n_features = T::from(d_model).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert feature count".to_string())
})?;
let mut result = Array2::zeros((seq_len, d_model));
for i in 0..seq_len {
let row = x.row(i);
let mean = row.sum() / n_features;
let var = row.mapv(|v| (v - mean) * (v - mean)).sum() / n_features;
let std = (var + self.norm_eps).sqrt();
for j in 0..d_model {
result[[i, j]] = (x[[i, j]] - mean) / std * gamma[j] + beta[j];
}
}
Ok(result)
}
}
#[derive(Debug, Clone)]
pub struct TransformerEncoder<T>
where
T: Float + SimdUnifiedOps,
{
pub layers: Vec<TransformerEncoderLayer<T>>,
pub num_layers: usize,
}
impl<T> TransformerEncoder<T>
where
T: Float + SimdUnifiedOps + ScalarOperand,
{
pub fn new(
num_layers: usize,
d_model: usize,
num_heads: usize,
d_ff: usize,
dropout: f64,
) -> TransformerResult<Self> {
if num_layers == 0 {
return Err(NumRs2Error::InvalidOperation(
"num_layers must be greater than 0".to_string(),
));
}
let mut layers = Vec::with_capacity(num_layers);
for _ in 0..num_layers {
layers.push(TransformerEncoderLayer::new(
d_model, num_heads, d_ff, dropout,
)?);
}
Ok(Self { layers, num_layers })
}
pub fn forward(
&self,
x: &ArrayView2<T>,
mask: Option<&ArrayView2<T>>,
training: bool,
) -> TransformerResult<Array2<T>> {
let mut output = x.to_owned();
for layer in &self.layers {
output = layer.forward(&output.view(), mask, training)?;
}
Ok(output)
}
}
#[derive(Debug, Clone)]
pub struct TransformerDecoder<T>
where
T: Float + SimdUnifiedOps,
{
pub layers: Vec<TransformerDecoderLayer<T>>,
pub num_layers: usize,
}
impl<T> TransformerDecoder<T>
where
T: Float + SimdUnifiedOps + ScalarOperand,
{
pub fn new(
num_layers: usize,
d_model: usize,
num_heads: usize,
d_ff: usize,
dropout: f64,
) -> TransformerResult<Self> {
if num_layers == 0 {
return Err(NumRs2Error::InvalidOperation(
"num_layers must be greater than 0".to_string(),
));
}
let mut layers = Vec::with_capacity(num_layers);
for _ in 0..num_layers {
layers.push(TransformerDecoderLayer::new(
d_model, num_heads, d_ff, dropout,
)?);
}
Ok(Self { layers, num_layers })
}
pub fn forward(
&self,
x: &ArrayView2<T>,
encoder_output: &ArrayView2<T>,
tgt_mask: Option<&ArrayView2<T>>,
memory_mask: Option<&ArrayView2<T>>,
training: bool,
) -> TransformerResult<Array2<T>> {
let mut output = x.to_owned();
for layer in &self.layers {
output = layer.forward(
&output.view(),
encoder_output,
tgt_mask,
memory_mask,
training,
)?;
}
Ok(output)
}
}
pub fn create_causal_mask<T>(seq_len: usize) -> Array2<T>
where
T: Float,
{
let mut mask = Array2::zeros((seq_len, seq_len));
let one = T::one();
for i in 0..seq_len {
for j in 0..=i {
mask[[i, j]] = one;
}
}
mask
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
#[test]
fn test_multi_head_attention_creation() {
let mha = MultiHeadAttention::<f64>::new(512, 8, 0.1);
assert!(mha.is_ok());
let mha = mha.expect("MultiHeadAttention creation should succeed");
assert_eq!(mha.d_model, 512);
assert_eq!(mha.num_heads, 8);
assert_eq!(mha.d_k, 64);
}
#[test]
fn test_multi_head_attention_invalid_params() {
let result = MultiHeadAttention::<f64>::new(511, 8, 0.1);
assert!(result.is_err());
let result = MultiHeadAttention::<f64>::new(512, 8, 1.5);
assert!(result.is_err());
}
#[test]
fn test_multi_head_attention_forward() {
let mha = MultiHeadAttention::<f64>::new(64, 4, 0.0)
.expect("MultiHeadAttention creation should succeed");
let input = Array2::ones((10, 64)); let output = mha.forward(&input.view(), None, false);
assert!(output.is_ok());
let output = output.expect("Forward pass should succeed");
assert_eq!(output.dim(), (10, 64));
}
#[test]
fn test_positional_encoding_sinusoidal() {
let pe = PositionalEncoding::<f64>::new(100, 64, PositionalEncodingType::Sinusoidal);
assert!(pe.is_ok());
let pe = pe.expect("PositionalEncoding creation should succeed");
assert_eq!(pe.encodings.dim(), (100, 64));
for &val in pe.encodings.iter() {
assert!((-1.0..=1.0).contains(&val));
}
}
#[test]
fn test_positional_encoding_learned() {
let pe = PositionalEncoding::<f64>::new(100, 64, PositionalEncodingType::Learned);
assert!(pe.is_ok());
let pe = pe.expect("PositionalEncoding creation should succeed");
assert_eq!(pe.encodings.dim(), (100, 64));
}
#[test]
fn test_positional_encoding_forward() {
let pe = PositionalEncoding::<f64>::new(50, 64, PositionalEncodingType::Sinusoidal)
.expect("PositionalEncoding creation should succeed");
let input = Array2::zeros((20, 64));
let output = pe.forward(&input.view());
assert!(output.is_ok());
let output = output.expect("Forward pass should succeed");
assert_eq!(output.dim(), (20, 64));
for i in 0..20 {
for j in 0..64 {
assert_abs_diff_eq!(output[[i, j]], pe.encodings[[i, j]], epsilon = 1e-10);
}
}
}
#[test]
fn test_feedforward_network() -> Result<(), Box<dyn std::error::Error>> {
let d_model = 8;
let d_ff = 16;
let seq_len = 4;
let ffn = PositionwiseFeedForward::<f64>::new(d_model, d_ff, 0.0)?;
assert_eq!(ffn.d_model, d_model);
assert_eq!(ffn.d_ff, d_ff);
let input = Array2::ones((seq_len, d_model));
let output = ffn.forward(&input.view(), false)?;
assert_eq!(output.dim(), (seq_len, d_model));
for &val in output.iter() {
assert!(val.is_finite(), "Output should contain finite values");
}
Ok(())
}
#[test]
fn test_gelu_activation() {
let ffn =
PositionwiseFeedForward::<f64>::new(4, 16, 0.0).expect("FFN creation should succeed");
let input = Array2::from_shape_vec((2, 4), vec![0.0, 1.0, -1.0, 2.0, -2.0, 0.5, -0.5, 1.5])
.expect("Array creation should succeed");
let output = ffn.gelu(&input.view());
assert!(output.is_ok());
let output = output.expect("GELU should succeed");
assert_abs_diff_eq!(output[[0, 0]], 0.0, epsilon = 0.01);
assert!(output[[0, 1]] > output[[0, 0]]);
}
#[test]
fn test_transformer_encoder_layer() {
let layer = TransformerEncoderLayer::<f64>::new(16, 4, 64, 0.0);
assert!(layer.is_ok());
let layer = layer.expect("EncoderLayer creation should succeed");
let input = Array2::ones((5, 16));
let output = layer.forward(&input.view(), None, false);
assert!(output.is_ok());
let output = output.expect("Forward pass should succeed");
assert_eq!(output.dim(), (5, 16));
}
#[test]
fn test_transformer_decoder_layer() {
let layer = TransformerDecoderLayer::<f64>::new(16, 4, 64, 0.0);
assert!(layer.is_ok());
let layer = layer.expect("DecoderLayer creation should succeed");
let tgt_input = Array2::ones((5, 16));
let src_input = Array2::ones((8, 16));
let output = layer.forward(&tgt_input.view(), &src_input.view(), None, None, false);
assert!(output.is_ok());
let output = output.expect("Forward pass should succeed");
assert_eq!(output.dim(), (5, 16));
}
#[test]
fn test_transformer_encoder() {
let encoder = TransformerEncoder::<f64>::new(3, 16, 4, 64, 0.0);
assert!(encoder.is_ok());
let encoder = encoder.expect("Encoder creation should succeed");
assert_eq!(encoder.num_layers, 3);
let input = Array2::ones((5, 16));
let output = encoder.forward(&input.view(), None, false);
assert!(output.is_ok());
let output = output.expect("Forward pass should succeed");
assert_eq!(output.dim(), (5, 16));
}
#[test]
fn test_transformer_decoder() {
let decoder = TransformerDecoder::<f64>::new(3, 16, 4, 64, 0.0);
assert!(decoder.is_ok());
let decoder = decoder.expect("Decoder creation should succeed");
assert_eq!(decoder.num_layers, 3);
let tgt = Array2::ones((5, 16));
let memory = Array2::ones((8, 16));
let output = decoder.forward(&tgt.view(), &memory.view(), None, None, false);
assert!(output.is_ok());
let output = output.expect("Forward pass should succeed");
assert_eq!(output.dim(), (5, 16));
}
#[test]
fn test_causal_mask() {
let mask = create_causal_mask::<f64>(5);
assert_eq!(mask.dim(), (5, 5));
for i in 0..5 {
for j in 0..5 {
if j <= i {
assert_eq!(mask[[i, j]], 1.0);
} else {
assert_eq!(mask[[i, j]], 0.0);
}
}
}
}
#[test]
fn test_attention_weights_sum_to_one() {
let mha = MultiHeadAttention::<f64>::new(64, 4, 0.0)
.expect("MultiHeadAttention creation should succeed");
let q = Array2::ones((5, 64));
let k = Array2::ones((5, 64));
let v = Array2::ones((5, 64));
let output = mha.scaled_dot_product_attention(&q.view(), &k.view(), &v.view(), None, false);
assert!(output.is_ok());
let output = output.expect("Attention should succeed");
assert_eq!(output.dim(), (5, 64));
}
#[test]
fn test_layer_normalization_properties() {
let layer = TransformerEncoderLayer::<f64>::new(64, 4, 256, 0.0)
.expect("EncoderLayer creation should succeed");
let input = Array2::from_shape_fn((10, 64), |(i, j)| (i * 64 + j) as f64);
let gamma = Array1::ones(64);
let beta = Array1::zeros(64);
let normalized = layer.layer_norm(&input.view(), &gamma, &beta);
assert!(normalized.is_ok());
let normalized = normalized.expect("Layer norm should succeed");
for i in 0..10 {
let row = normalized.row(i);
let mean = row.sum() / 64.0;
let var = row.mapv(|v| (v - mean) * (v - mean)).sum() / 64.0;
assert_abs_diff_eq!(mean, 0.0, epsilon = 1e-5);
assert_abs_diff_eq!(var, 1.0, epsilon = 1e-4);
}
}
#[test]
fn test_masked_attention() {
let mha = MultiHeadAttention::<f64>::new(64, 4, 0.0)
.expect("MultiHeadAttention creation should succeed");
let input = Array2::ones((5, 64));
let mask = create_causal_mask::<f64>(5);
let output = mha.forward(&input.view(), Some(&mask.view()), false);
assert!(output.is_ok());
let output = output.expect("Masked attention should succeed");
assert_eq!(output.dim(), (5, 64));
}
}