use super::activation::softmax_2d;
use super::NnResult;
use crate::error::NumRs2Error;
use scirs2_core::ndarray::{
Array, Array1, Array2, Array3, ArrayView, ArrayView1, ArrayView2, Axis, ScalarOperand,
};
use scirs2_core::numeric::Float;
use scirs2_core::simd_ops::SimdUnifiedOps;
pub fn scaled_dot_product_attention<T>(
query: &ArrayView2<T>,
key: &ArrayView2<T>,
value: &ArrayView2<T>,
mask: Option<&ArrayView2<T>>,
) -> NnResult<Array2<T>>
where
T: Float + SimdUnifiedOps + ScalarOperand,
{
let d_k = query.ncols();
if key.ncols() != d_k {
return Err(NumRs2Error::DimensionMismatch(format!(
"Query and key dimension mismatch: {} vs {}",
d_k,
key.ncols()
)));
}
if key.nrows() != value.nrows() {
return Err(NumRs2Error::DimensionMismatch(
"Key and value sequence length mismatch".to_string(),
));
}
let scale = T::from(1.0)
.ok_or_else(|| NumRs2Error::ConversionError("Failed to convert scale".to_string()))?
/ T::from(d_k)
.ok_or_else(|| NumRs2Error::ConversionError("Failed to convert dimension".to_string()))?
.sqrt();
let scores = query.dot(&key.t()) * scale;
let masked_scores = if let Some(m) = mask {
if m.shape() != scores.shape() {
return Err(NumRs2Error::DimensionMismatch(
"Mask shape mismatch with attention scores".to_string(),
));
}
let neg_inf = T::neg_infinity();
let zero = T::zero();
let one = T::one();
Array2::from_shape_fn(scores.raw_dim(), |(i, j)| {
if m[[i, j]] == zero {
neg_inf
} else {
scores[[i, j]]
}
})
} else {
scores
};
let attention_weights = softmax_2d(&masked_scores.view(), 1)?;
let output = attention_weights.dot(value);
Ok(output)
}
pub fn self_attention<T>(
x: &ArrayView2<T>,
w_q: &ArrayView2<T>,
w_k: &ArrayView2<T>,
w_v: &ArrayView2<T>,
) -> NnResult<Array2<T>>
where
T: Float + SimdUnifiedOps + ScalarOperand,
{
if x.ncols() != w_q.nrows() || x.ncols() != w_k.nrows() || x.ncols() != w_v.nrows() {
return Err(NumRs2Error::DimensionMismatch(
"Input dimension mismatch with projection matrices".to_string(),
));
}
let query = x.dot(w_q);
let key = x.dot(w_k);
let value = x.dot(w_v);
scaled_dot_product_attention(&query.view(), &key.view(), &value.view(), None)
}
pub fn embedding<T>(indices: &[usize], embedding_matrix: &ArrayView2<T>) -> NnResult<Array2<T>>
where
T: Float + SimdUnifiedOps,
{
let vocab_size = embedding_matrix.nrows();
let embedding_dim = embedding_matrix.ncols();
let mut output = Array2::zeros((indices.len(), embedding_dim));
for (i, &idx) in indices.iter().enumerate() {
if idx >= vocab_size {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"Index {} out of bounds for vocabulary size {}",
idx, vocab_size
)));
}
output.row_mut(i).assign(&embedding_matrix.row(idx));
}
Ok(output)
}
pub fn positional_encoding<T>(seq_len: usize, d_model: usize) -> NnResult<Array2<T>>
where
T: Float + SimdUnifiedOps,
{
if !d_model.is_multiple_of(2) {
return Err(NumRs2Error::InvalidOperation(
"Model dimension must be even for sinusoidal positional encoding".to_string(),
));
}
let mut pe = Array2::zeros((seq_len, d_model));
let two = T::from(2.0)
.ok_or_else(|| NumRs2Error::ConversionError("Failed to convert constant".to_string()))?;
let ten_thousand = T::from(10000.0)
.ok_or_else(|| NumRs2Error::ConversionError("Failed to convert constant".to_string()))?;
for pos in 0..seq_len {
let pos_t = T::from(pos).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert position".to_string())
})?;
for i in 0..(d_model / 2) {
let i_t = T::from(i).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert index".to_string())
})?;
let d_model_t = T::from(d_model).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert dimension".to_string())
})?;
let div_term = two * i_t / d_model_t;
let angle = pos_t / ten_thousand.powf(div_term);
pe[[pos, 2 * i]] = angle.sin();
pe[[pos, 2 * i + 1]] = angle.cos();
}
}
Ok(pe)
}
pub fn add_positional_encoding<T>(embeddings: &ArrayView2<T>) -> NnResult<Array2<T>>
where
T: Float + SimdUnifiedOps + ScalarOperand,
{
let (seq_len, d_model) = embeddings.dim();
let pe = positional_encoding(seq_len, d_model)?;
Ok(embeddings + &pe)
}
pub fn embedding_bag<T>(
indices: &[usize],
embedding_matrix: &ArrayView2<T>,
mode: &str,
) -> NnResult<Array1<T>>
where
T: Float + SimdUnifiedOps + ScalarOperand,
{
if indices.is_empty() {
return Err(NumRs2Error::InvalidOperation(
"Indices cannot be empty".to_string(),
));
}
let vocab_size = embedding_matrix.nrows();
let embedding_dim = embedding_matrix.ncols();
let mut result = Array1::zeros(embedding_dim);
match mode {
"sum" => {
for &idx in indices {
if idx >= vocab_size {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"Index {} out of bounds",
idx
)));
}
result = result + embedding_matrix.row(idx);
}
}
"mean" => {
for &idx in indices {
if idx >= vocab_size {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"Index {} out of bounds",
idx
)));
}
result = result + embedding_matrix.row(idx);
}
let count = T::from(indices.len()).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert count".to_string())
})?;
result = result / count;
}
"max" => {
result = Array1::from_elem(embedding_dim, T::neg_infinity());
for &idx in indices {
if idx >= vocab_size {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"Index {} out of bounds",
idx
)));
}
let emb = embedding_matrix.row(idx);
for j in 0..embedding_dim {
if emb[j] > result[j] {
result[j] = emb[j];
}
}
}
}
_ => {
return Err(NumRs2Error::InvalidOperation(format!(
"Unknown mode: {}. Must be 'sum', 'mean', or 'max'",
mode
)));
}
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::Array2;
#[test]
fn test_embedding() {
let embedding_matrix = Array2::from_shape_fn((10, 5), |(i, j)| (i * 10 + j) as f64);
let indices = vec![0, 2, 5];
let result =
embedding(&indices, &embedding_matrix.view()).expect("test: valid embedding params");
assert_eq!(result.dim(), (3, 5));
for j in 0..5 {
assert_abs_diff_eq!(result[[0, j]], j as f64, epsilon = 1e-6);
}
}
#[test]
fn test_positional_encoding() {
let pe = positional_encoding::<f64>(10, 8).expect("test: valid positional encoding params");
assert_eq!(pe.dim(), (10, 8));
for &val in pe.iter() {
assert!((-1.0..=1.0).contains(&val));
}
}
#[test]
fn test_embedding_bag_sum() {
let embedding_matrix = Array2::from_shape_fn((5, 3), |(_, _)| 1.0);
let indices = vec![0, 1, 2];
let result = embedding_bag(&indices, &embedding_matrix.view(), "sum")
.expect("test: valid embedding_bag params");
for &val in result.iter() {
assert_abs_diff_eq!(val, 3.0, epsilon = 1e-6);
}
}
#[test]
fn test_embedding_bag_mean() {
let embedding_matrix = Array2::from_shape_fn((5, 3), |(_, _)| 2.0);
let indices = vec![0, 1, 2, 3];
let result = embedding_bag(&indices, &embedding_matrix.view(), "mean")
.expect("test: valid embedding_bag params");
for &val in result.iter() {
assert_abs_diff_eq!(val, 2.0, epsilon = 1e-6);
}
}
}