use scirs2_core::ndarray::{s, Array1, Array2, Array3, Array4, ArrayView1, ArrayView2, ArrayView3};
use scirs2_core::numeric::Float;
use std::fmt::Debug;
use crate::error::{NeuralError, Result};
type FeedForwardReturn<F> = (Array3<F>, Array2<F>, Array1<F>, Array2<F>, Array1<F>);
pub fn scaled_dot_product_attention<F>(
query: &ArrayView3<F>,
key: &ArrayView3<F>,
value: &ArrayView3<F>,
mask: Option<&ArrayView3<F>>,
) -> Result<(Array3<F>, Array3<F>)>
where
F: Float + Debug,
{
let batch_size = query.shape()[0];
let seq_len_q = query.shape()[1];
let d_k = query.shape()[2];
let seq_len_k = key.shape()[1];
let d_v = value.shape()[2];
if key.shape()[0] != batch_size || key.shape()[2] != d_k {
return Err(NeuralError::ShapeMismatch(format!(
"Key shape mismatch in scaled_dot_product_attention: query shape {:?}, key shape {:?}",
query.shape(),
key.shape()
)));
}
if value.shape()[0] != batch_size || value.shape()[1] != seq_len_k {
return Err(NeuralError::ShapeMismatch(
format!("Value shape mismatch in scaled_dot_product_attention: key shape {:?}, value shape {:?}",
key.shape(), value.shape())
));
}
if let Some(m) = mask {
if m.shape()[0] != batch_size || m.shape()[1] != seq_len_q || m.shape()[2] != seq_len_k {
return Err(NeuralError::ShapeMismatch(format!(
"Mask shape mismatch in scaled_dot_product_attention: expected {:?}, got {:?}",
[batch_size, seq_len_q, seq_len_k],
m.shape()
)));
}
}
let scale = F::one() / F::from(d_k).expect("Failed to convert to float").sqrt();
let mut attention_scores = Array3::<F>::zeros((batch_size, seq_len_q, seq_len_k));
for b in 0..batch_size {
for i in 0..seq_len_q {
for j in 0..seq_len_k {
let mut sum = F::zero();
for k in 0..d_k {
sum = sum + query[[b, i, k]] * key[[b, j, k]];
}
attention_scores[[b, i, j]] = sum * scale;
}
}
}
if let Some(m) = mask {
for b in 0..batch_size {
for i in 0..seq_len_q {
for j in 0..seq_len_k {
if m[[b, i, j]] == F::zero() {
attention_scores[[b, i, j]] = F::from(-1e9).expect("Failed to convert constant to float");
}
}
}
}
}
let mut attention_weights = Array3::<F>::zeros((batch_size, seq_len_q, seq_len_k));
for b in 0..batch_size {
for i in 0..seq_len_q {
let mut max_val = attention_scores[[b, i, 0]];
for j in 1..seq_len_k {
if attention_scores[[b, i, j]] > max_val {
max_val = attention_scores[[b, i, j]];
}
}
let mut sum_exp = F::zero();
for j in 0..seq_len_k {
let exp_val = (attention_scores[[b, i, j]] - max_val).exp();
attention_weights[[b, i, j]] = exp_val;
sum_exp = sum_exp + exp_val;
}
for j in 0..seq_len_k {
attention_weights[[b, i, j]] = attention_weights[[b, i, j]] / sum_exp;
}
}
}
let mut output = Array3::<F>::zeros((batch_size, seq_len_q, d_v));
for b in 0..batch_size {
for i in 0..seq_len_q {
for j in 0..d_v {
let mut sum = F::zero();
for k in 0..seq_len_k {
sum = sum + attention_weights[[b, i, k]] * value[[b, k, j]];
}
output[[b, i, j]] = sum;
}
}
}
Ok((output, attention_weights))
}
#[allow(clippy::too_many_arguments)]
pub fn multi_head_attention<F>(
query: &ArrayView3<F>,
key: &ArrayView3<F>,
value: &ArrayView3<F>,
wq: &ArrayView2<F>,
wk: &ArrayView2<F>,
wv: &ArrayView2<F>,
wo: &ArrayView2<F>,
num_heads: usize,
mask: Option<&ArrayView3<F>>,
) -> Result<Array3<F>>
where
F: Float + Debug,
{
let batch_size = query.shape()[0];
let seq_len_q = query.shape()[1];
let seq_len_k = key.shape()[1];
let d_model = query.shape()[2];
if key.shape()[0] != batch_size || key.shape()[2] != d_model {
return Err(NeuralError::ShapeMismatch(format!(
"Key shape mismatch in multi_head_attention: query shape {:?}, key shape {:?}",
query.shape(),
key.shape()
)));
}
if value.shape()[0] != batch_size
|| value.shape()[1] != seq_len_k
|| value.shape()[2] != d_model
{
return Err(NeuralError::ShapeMismatch(format!(
"Value shape mismatch in multi_head_attention: key shape {:?}, value shape {:?}",
key.shape(),
value.shape()
)));
}
if wq.shape() != [d_model, d_model]
|| wk.shape() != [d_model, d_model]
|| wv.shape() != [d_model, d_model]
|| wo.shape() != [d_model, d_model]
{
return Err(NeuralError::ShapeMismatch(
"Weight matrix shape mismatch in multi_head_attention".to_string(),
));
}
if d_model % num_heads != 0 {
return Err(NeuralError::ShapeMismatch(format!(
"d_model ({}) must be divisible by num_heads ({})",
d_model, num_heads
)));
}
let depth = d_model / num_heads;
let mut q_proj = Array3::<F>::zeros((batch_size, seq_len_q, d_model));
let mut k_proj = Array3::<F>::zeros((batch_size, seq_len_k, d_model));
let mut v_proj = Array3::<F>::zeros((batch_size, seq_len_k, d_model));
for b in 0..batch_size {
for i in 0..seq_len_q {
for j in 0..d_model {
let mut sum = F::zero();
for k in 0..d_model {
sum = sum + query[[b, i, k]] * wq[[k, j]];
}
q_proj[[b, i, j]] = sum;
}
}
for i in 0..seq_len_k {
for j in 0..d_model {
let mut sum_k = F::zero();
let mut sum_v = F::zero();
for k in 0..d_model {
sum_k = sum_k + key[[b, i, k]] * wk[[k, j]];
sum_v = sum_v + value[[b, i, k]] * wv[[k, j]];
}
k_proj[[b, i, j]] = sum_k;
v_proj[[b, i, j]] = sum_v;
}
}
}
let mut q_split = Array4::<F>::zeros((batch_size, seq_len_q, num_heads, depth));
let mut k_split = Array4::<F>::zeros((batch_size, seq_len_k, num_heads, depth));
let mut v_split = Array4::<F>::zeros((batch_size, seq_len_k, num_heads, depth));
for b in 0..batch_size {
for i in 0..seq_len_q {
for h in 0..num_heads {
for d in 0..depth {
q_split[[b, i, h, d]] = q_proj[[b, i, h * depth + d]];
}
}
}
for i in 0..seq_len_k {
for h in 0..num_heads {
for d in 0..depth {
k_split[[b, i, h, d]] = k_proj[[b, i, h * depth + d]];
v_split[[b, i, h, d]] = v_proj[[b, i, h * depth + d]];
}
}
}
}
let mut q_heads = Array4::<F>::zeros((batch_size, num_heads, seq_len_q, depth));
let mut k_heads = Array4::<F>::zeros((batch_size, num_heads, seq_len_k, depth));
let mut v_heads = Array4::<F>::zeros((batch_size, num_heads, seq_len_k, depth));
for b in 0..batch_size {
for h in 0..num_heads {
for i in 0..seq_len_q {
for d in 0..depth {
q_heads[[b, h, i, d]] = q_split[[b, i, h, d]];
}
}
for i in 0..seq_len_k {
for d in 0..depth {
k_heads[[b, h, i, d]] = k_split[[b, i, h, d]];
v_heads[[b, h, i, d]] = v_split[[b, i, h, d]];
}
}
}
}
let mut head_outputs = Array4::<F>::zeros((batch_size, num_heads, seq_len_q, depth));
for h in 0..num_heads {
let mut q_head = Array3::<F>::zeros((batch_size, seq_len_q, depth));
let mut k_head = Array3::<F>::zeros((batch_size, seq_len_k, depth));
let mut v_head = Array3::<F>::zeros((batch_size, seq_len_k, depth));
for b in 0..batch_size {
for i in 0..seq_len_q {
for d in 0..depth {
q_head[[b, i, d]] = q_heads[[b, h, i, d]];
}
}
for i in 0..seq_len_k {
for d in 0..depth {
k_head[[b, i, d]] = k_heads[[b, h, i, d]];
v_head[[b, i, d]] = v_heads[[b, h, i, d]];
}
}
}
let (head_output, _) =
scaled_dot_product_attention(&q_head.view(), &k_head.view(), &v_head.view(), mask)?;
for b in 0..batch_size {
for i in 0..seq_len_q {
for d in 0..depth {
head_outputs[[b, h, i, d]] = head_output[[b, i, d]];
}
}
}
}
let mut concat_heads = Array3::<F>::zeros((batch_size, seq_len_q, d_model));
for b in 0..batch_size {
for i in 0..seq_len_q {
for h in 0..num_heads {
for d in 0..depth {
concat_heads[[b, i, h * depth + d]] = head_outputs[[b, h, i, d]];
}
}
}
}
let mut output = Array3::<F>::zeros((batch_size, seq_len_q, d_model));
for b in 0..batch_size {
for i in 0..seq_len_q {
for j in 0..d_model {
let mut sum = F::zero();
for k in 0..d_model {
sum = sum + concat_heads[[b, i, k]] * wo[[k, j]];
}
output[[b, i, j]] = sum;
}
}
}
Ok(output)
}
pub fn positional_encoding<F: Float + Debug>(
seq_len: usize,
d_model: usize,
max_seq_len: Option<usize>,
) -> Result<Array2<F>> {
let max_len = max_seq_len.unwrap_or(seq_len);
if max_len < seq_len {
return Err(NeuralError::InvalidArgument(format!(
"max_seq_len ({}) must be at least as large as seq_len ({})",
max_len, seq_len
)));
}
if d_model % 2 != 0 {
return Err(NeuralError::InvalidArgument(format!(
"d_model ({}) must be even",
d_model
)));
}
let mut pos_encoding = Array2::<F>::zeros((max_len, d_model));
for pos in 0..max_len {
for i in 0..d_model / 2 {
let div_term =
(F::from(2 * i).expect("Failed to convert to float") / F::from(d_model).expect("Failed to convert to float")).exp() * F::from(1e4).expect("Failed to convert constant to float");
let pos_f = F::from(pos).expect("Failed to convert to float");
pos_encoding[[pos, 2 * i]] = (pos_f / div_term).sin();
pos_encoding[[pos, 2 * i + 1]] = (pos_f / div_term).cos();
}
}
let result = pos_encoding.slice(s![0..seq_len, ..]).to_owned();
Ok(result)
}
pub fn transformer_ffn<F>(
x: &ArrayView3<F>,
w1: &ArrayView2<F>,
b1: &ArrayView1<F>,
w2: &ArrayView2<F>,
b2: &ArrayView1<F>,
) -> Result<Array3<F>>
where
F: Float + Debug,
{
let batch_size = x.shape()[0];
let seq_len = x.shape()[1];
let d_model = x.shape()[2];
let d_ff = w1.shape()[1];
if w1.shape()[0] != d_model {
return Err(NeuralError::ShapeMismatch(format!(
"w1 shape mismatch in transformer_ffn: x shape {:?}, w1 shape {:?}",
x.shape(),
w1.shape()
)));
}
if b1.shape()[0] != d_ff {
return Err(NeuralError::ShapeMismatch(format!(
"b1 shape mismatch in transformer_ffn: b1 shape {:?}, expected [{:?}]",
b1.shape(),
d_ff
)));
}
if w2.shape()[0] != d_ff || w2.shape()[1] != d_model {
return Err(NeuralError::ShapeMismatch(format!(
"w2 shape mismatch in transformer_ffn: w2 shape {:?}, expected [{:?}, {:?}]",
w2.shape(),
d_ff,
d_model
)));
}
if b2.shape()[0] != d_model {
return Err(NeuralError::ShapeMismatch(format!(
"b2 shape mismatch in transformer_ffn: b2 shape {:?}, expected [{:?}]",
b2.shape(),
d_model
)));
}
let mut ffn_mid = Array3::<F>::zeros((batch_size, seq_len, d_ff));
for b in 0..batch_size {
for s in 0..seq_len {
for j in 0..d_ff {
let mut sum = b1[j];
for k in 0..d_model {
sum = sum + x[[b, s, k]] * w1[[k, j]];
}
ffn_mid[[b, s, j]] = if sum > F::zero() { sum } else { F::zero() };
}
}
}
let mut output = Array3::<F>::zeros((batch_size, seq_len, d_model));
for b in 0..batch_size {
for s in 0..seq_len {
for j in 0..d_model {
let mut sum = b2[j];
for k in 0..d_ff {
sum = sum + ffn_mid[[b, s, k]] * w2[[k, j]];
}
output[[b, s, j]] = sum;
}
}
}
Ok(output)
}
pub fn transformer_ffn_backward<F>(
dout: &ArrayView3<F>,
x: &ArrayView3<F>,
w1: &ArrayView2<F>,
b1: &ArrayView1<F>,
w2: &ArrayView2<F>,
b2: &ArrayView1<F>,
) -> Result<FeedForwardReturn<F>>
where
F: Float + Debug,
{
let batch_size = x.shape()[0];
let seq_len = x.shape()[1];
let d_model = x.shape()[2];
let d_ff = w1.shape()[1];
if dout.shape() != x.shape() {
return Err(NeuralError::ShapeMismatch(format!(
"dout shape mismatch in transformer_ffn_backward: dout shape {:?}, x shape {:?}",
dout.shape(),
x.shape()
)));
}
let mut ffn_pre_relu = Array3::<F>::zeros((batch_size, seq_len, d_ff));
for b in 0..batch_size {
for s in 0..seq_len {
for j in 0..d_ff {
let mut sum = b1[j];
for k in 0..d_model {
sum = sum + x[[b, s, k]] * w1[[k, j]];
}
ffn_pre_relu[[b, s, j]] = sum;
}
}
}
let mut ffn_mid = Array3::<F>::zeros((batch_size, seq_len, d_ff));
for b in 0..batch_size {
for s in 0..seq_len {
for j in 0..d_ff {
ffn_mid[[b, s, j]] = if ffn_pre_relu[[b, s, j]] > F::zero() {
ffn_pre_relu[[b, s, j]]
} else {
F::zero()
};
}
}
}
let mut dx = Array3::<F>::zeros(x.raw_dim());
let mut dw1 = Array2::<F>::zeros(w1.raw_dim());
let mut db1 = Array1::<F>::zeros(b1.raw_dim());
let mut dw2 = Array2::<F>::zeros(w2.raw_dim());
let mut db2 = Array1::<F>::zeros(b2.raw_dim());
let mut dffn_mid = Array3::<F>::zeros(ffn_mid.raw_dim());
for b in 0..batch_size {
for s in 0..seq_len {
for j in 0..d_model {
db2[j] = db2[j] + dout[[b, s, j]];
for k in 0..d_ff {
dw2[[k, j]] = dw2[[k, j]] + ffn_mid[[b, s, k]] * dout[[b, s, j]];
dffn_mid[[b, s, k]] = dffn_mid[[b, s, k]] + w2[[k, j]] * dout[[b, s, j]];
}
}
}
}
let mut dffn_pre_relu = Array3::<F>::zeros(ffn_pre_relu.raw_dim());
for b in 0..batch_size {
for s in 0..seq_len {
for j in 0..d_ff {
if ffn_pre_relu[[b, s, j]] > F::zero() {
dffn_pre_relu[[b, s, j]] = dffn_mid[[b, s, j]];
}
}
}
}
for b in 0..batch_size {
for s in 0..seq_len {
for j in 0..d_ff {
db1[j] = db1[j] + dffn_pre_relu[[b, s, j]];
for k in 0..d_model {
dw1[[k, j]] = dw1[[k, j]] + x[[b, s, k]] * dffn_pre_relu[[b, s, j]];
dx[[b, s, k]] = dx[[b, s, k]] + w1[[k, j]] * dffn_pre_relu[[b, s, j]];
}
}
}
}
Ok((dx, dw1, db1, dw2, db2))
}