mod cross_attention;
mod multi_head;
mod scaled_dot_product;
mod utils;
pub use cross_attention::{
attention_with_alibi, attention_with_rpe, flash_attention, linear_attention,
relative_position_attention, rotary_embedding, sparse_attention,
};
pub use multi_head::{grouped_query_attention, multi_head_attention};
pub use scaled_dot_product::{causal_attention, masked_attention, scaled_dot_product_attention};
pub use utils::{apply_mask, attention, AttentionConfig, AttentionMask};
use crate::error::LinalgResult;
use scirs2_core::ndarray::{Array3, ArrayView3};
use scirs2_core::numeric::{Float, NumAssignOps, Zero};
use std::ops::{Add, Div, Mul, Sub};
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_basic_attention() {
let query = array![[[1.0, 1.0], [1.0, 1.0]]]
.into_shape_with_order((1, 2, 2))
.expect("Operation failed");
let key = array![[[1.0, 1.0], [1.0, 1.0]]]
.into_shape_with_order((1, 2, 2))
.expect("Operation failed");
let value = array![[[5.0, 6.0], [7.0, 8.0]]]
.into_shape_with_order((1, 2, 2))
.expect("Operation failed");
let scale = 1.0 / (2.0_f64).sqrt();
let result = attention(&query.view(), &key.view(), &value.view(), None, scale)
.expect("Operation failed");
assert_eq!(result.shape(), &[1, 2, 2]);
let expected_first_pos = [(5.0 + 7.0) / 2.0, (6.0 + 8.0) / 2.0];
let expected_second_pos = [(5.0 + 7.0) / 2.0, (6.0 + 8.0) / 2.0];
assert!((result[[0, 0, 0]] - expected_first_pos[0]).abs() < 1e-5);
assert!((result[[0, 0, 1]] - expected_first_pos[1]).abs() < 1e-5);
assert!((result[[0, 1, 0]] - expected_second_pos[0]).abs() < 1e-5);
assert!((result[[0, 1, 1]] - expected_second_pos[1]).abs() < 1e-5);
}
#[test]
fn test_causal_attention() {
let query = array![[[1.0, 1.0], [1.0, 1.0]]]
.into_shape_with_order((1, 2, 2))
.expect("Operation failed");
let key = array![[[1.0, 1.0], [1.0, 1.0]]]
.into_shape_with_order((1, 2, 2))
.expect("Operation failed");
let value = array![[[1.0, 2.0], [3.0, 4.0]]]
.into_shape_with_order((1, 2, 2))
.expect("Operation failed");
let scale = 1.0 / (2.0_f64).sqrt();
let result = causal_attention(&query.view(), &key.view(), &value.view(), scale)
.expect("Operation failed");
assert!((result[[0, 0, 0]] - 1.0).abs() < 1e-6);
assert!((result[[0, 0, 1]] - 2.0).abs() < 1e-6);
assert!((result[[0, 1, 0]] - 2.0).abs() < 1e-6);
assert!((result[[0, 1, 1]] - 3.0).abs() < 1e-6);
}
}