use scirs2_core::ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, ArrayView3};
use scirs2_core::numeric::{Float, NumAssignOps, Zero};
use std::ops::{Add, Div, Mul, Sub};
use super::utils::{attention, AttentionMask};
use crate::error::{check_dimensions, LinalgError, LinalgResult};
#[allow(dead_code)]
pub fn flash_attention<F>(
query: &ArrayView3<F>,
key: &ArrayView3<F>,
value: &ArrayView3<F>,
mask: Option<&AttentionMask>,
scale: F,
blocksize: usize,
) -> LinalgResult<Array3<F>>
where
F: Float + Add + Mul + Div + Sub + NumAssignOps + Zero + std::fmt::Debug,
{
let (batchsize, seq_len_q, d_model) = (query.shape()[0], query.shape()[1], query.shape()[2]);
let (batchsize_k, seq_len_k, d_model_k) = (key.shape()[0], key.shape()[1], key.shape()[2]);
let (batchsize_v, seq_len_v, d_model_v) =
(value.shape()[0], value.shape()[1], value.shape()[2]);
check_dimensions(
batchsize == batchsize_k,
format!("Batch sizes must match: {batchsize} != {batchsize_k}"),
)?;
check_dimensions(
batchsize == batchsize_v,
format!("Batch sizes must match: {batchsize} != {batchsize_v}"),
)?;
check_dimensions(
seq_len_k == seq_len_v,
format!("Key and value sequence lengths must match: {seq_len_k} != {seq_len_v}"),
)?;
check_dimensions(
d_model == d_model_k,
format!("Query and key dimensions must match: {d_model} != {d_model_k}"),
)?;
let blocksize_q = blocksize.min(seq_len_q);
let blocksize_k = blocksize.min(seq_len_k);
let mut output = Array3::<F>::zeros((batchsize, seq_len_q, d_model_v));
for b in 0..batchsize {
for q_start in (0..seq_len_q).step_by(blocksize_q) {
let q_end = (q_start + blocksize_q).min(seq_len_q);
let q_block = query.slice(scirs2_core::ndarray::s![b, q_start..q_end, ..]);
let mut m_block = Array1::<F>::from_elem(q_end - q_start, F::neg_infinity());
let mut l_block = Array1::<F>::zeros(q_end - q_start);
for k_start in (0..seq_len_k).step_by(blocksize_k) {
let k_end = (k_start + blocksize_k).min(seq_len_k);
let k_block = key.slice(scirs2_core::ndarray::s![b, k_start..k_end, ..]);
let v_block = value.slice(scirs2_core::ndarray::s![b, k_start..k_end, ..]);
let mut scores_block = Array2::<F>::zeros((q_end - q_start, k_end - k_start));
for i in 0..(q_end - q_start) {
for j in 0..(k_end - k_start) {
let mut dot_product = F::zero();
for k in 0..d_model {
dot_product += q_block[[i, k]] * k_block[[j, k]];
}
scores_block[[i, j]] = dot_product * scale;
}
}
if let Some(mask_ref) = mask {
match mask_ref {
AttentionMask::Causal => {
for i in 0..(q_end - q_start) {
let q_idx = q_start + i;
for j in 0..(k_end - k_start) {
let k_idx = k_start + j;
if k_idx > q_idx {
scores_block[[i, j]] = F::neg_infinity();
}
}
}
}
_ => {
return Err(LinalgError::NotImplementedError(
"Flash attention currently only supports causal masks".to_string(),
))
}
}
}
for i in 0..(q_end - q_start) {
let row = scores_block.slice(scirs2_core::ndarray::s![i, ..]);
let max_val =
row.fold(F::neg_infinity(), |max, &x| if x > max { x } else { max });
if max_val > m_block[i] {
let m_prev = m_block[i];
let m_new = max_val;
if m_prev != F::neg_infinity() {
l_block[i] *= (m_prev - m_new).exp();
}
if l_block[i] > F::zero() {
let scale_factor = if l_block[i] != F::zero() {
(m_prev - m_new).exp() / l_block[i]
} else {
F::zero()
};
for j in 0..d_model_v {
output[[b, q_start + i, j]] *= scale_factor;
}
}
m_block[i] = m_new;
}
let mut block_sum = F::zero();
let mut block_output = Array1::<F>::zeros(d_model_v);
for j in 0..(k_end - k_start) {
let exp_val = (scores_block[[i, j]] - m_block[i]).exp();
block_sum += exp_val;
for k in 0..d_model_v {
block_output[k] += exp_val * v_block[[j, k]];
}
}
l_block[i] += block_sum;
for j in 0..d_model_v {
output[[b, q_start + i, j]] += block_output[j];
}
}
}
for i in 0..(q_end - q_start) {
if l_block[i] > F::zero() {
for j in 0..d_model_v {
output[[b, q_start + i, j]] /= l_block[i];
}
}
}
}
}
Ok(output)
}
#[allow(dead_code)]
pub fn sparse_attention<F>(
query: &ArrayView3<F>,
key: &ArrayView3<F>,
value: &ArrayView3<F>,
pattern_mask: &ArrayView2<bool>,
scale: F,
) -> LinalgResult<Array3<F>>
where
F: Float + Add + Mul + Div + Sub + NumAssignOps + Zero + std::fmt::Debug,
{
let (batchsize, seq_len_q, d_model) = (query.shape()[0], query.shape()[1], query.shape()[2]);
let (_, seq_len_k, _) = (key.shape()[0], key.shape()[1], key.shape()[2]);
let (_, _, d_model_v) = (value.shape()[0], value.shape()[1], value.shape()[2]);
if pattern_mask.shape() != [seq_len_q, seq_len_k] {
return Err(LinalgError::DimensionError(format!(
"Pattern mask shape {:?} doesn't match query and key sequence lengths [{}, {}]",
pattern_mask.shape(),
seq_len_q,
seq_len_k
)));
}
let mut output = Array3::<F>::zeros((batchsize, seq_len_q, d_model_v));
for b in 0..batchsize {
let q_b = query.slice(scirs2_core::ndarray::s![b, .., ..]);
let k_b = key.slice(scirs2_core::ndarray::s![b, .., ..]);
let v_b = value.slice(scirs2_core::ndarray::s![b, .., ..]);
for i in 0..seq_len_q {
let q_i = q_b.slice(scirs2_core::ndarray::s![i, ..]);
let mut scores = Vec::new();
let mut indices = Vec::new();
for j in 0..seq_len_k {
if pattern_mask[[i, j]] {
let k_j = k_b.slice(scirs2_core::ndarray::s![j, ..]);
let mut dot_product = F::zero();
for k in 0..d_model {
dot_product += q_i[k] * k_j[k];
}
scores.push(dot_product * scale);
indices.push(j);
}
}
if scores.is_empty() {
continue;
}
let max_val = scores
.iter()
.fold(F::neg_infinity(), |max, &x| if x > max { x } else { max });
let mut exp_scores = Vec::with_capacity(scores.len());
let mut sum = F::zero();
for &score in &scores {
let exp_val = (score - max_val).exp();
exp_scores.push(exp_val);
sum += exp_val;
}
if sum > F::zero() {
for exp_score in &mut exp_scores {
*exp_score /= sum;
}
}
for j in 0..d_model_v {
let mut weighted_sum = F::zero();
for k in 0..indices.len() {
let v_idx = indices[k];
weighted_sum += exp_scores[k] * v_b[[v_idx, j]];
}
output[[b, i, j]] = weighted_sum;
}
}
}
Ok(output)
}
#[allow(dead_code)]
pub fn attention_with_alibi<F>(
query: &ArrayView3<F>,
key: &ArrayView3<F>,
value: &ArrayView3<F>,
slopes: &ArrayView1<F>,
scale: F,
causal: bool,
) -> LinalgResult<Array3<F>>
where
F: Float + Add + Mul + Div + Sub + NumAssignOps + Zero + std::fmt::Debug,
{
let (batchsize, seq_len_q, d_model) = (query.shape()[0], query.shape()[1], query.shape()[2]);
let (_, seq_len_k, _) = (key.shape()[0], key.shape()[1], key.shape()[2]);
let (_, _, d_model_v) = (value.shape()[0], value.shape()[1], value.shape()[2]);
let mut result = Array3::<F>::zeros((batchsize, seq_len_q, d_model_v));
for b in 0..batchsize {
let q_b = query.slice(scirs2_core::ndarray::s![b, .., ..]);
let k_b = key.slice(scirs2_core::ndarray::s![b, .., ..]);
let v_b = value.slice(scirs2_core::ndarray::s![b, .., ..]);
let mut scores = Array2::<F>::zeros((seq_len_q, seq_len_k));
for i in 0..seq_len_q {
for j in 0..seq_len_k {
let mut dot_product = F::zero();
for k in 0..d_model {
dot_product += q_b[[i, k]] * k_b[[j, k]];
}
scores[[i, j]] = dot_product * scale;
}
}
for i in 0..seq_len_q {
for j in 0..seq_len_k {
let pos_diff =
F::from((i as isize - j as isize).abs() as f64).ok_or_else(|| {
LinalgError::ValueError(
"Failed to convert position difference to target type".to_string(),
)
})?;
let slope = slopes[0]; scores[[i, j]] -= slope * pos_diff;
}
}
if causal {
for i in 0..seq_len_q {
for j in 0..seq_len_k {
if j > i {
scores[[i, j]] = F::neg_infinity();
}
}
}
}
for i in 0..seq_len_q {
let mut row = scores.slice_mut(scirs2_core::ndarray::s![i, ..]);
let max_val = row.fold(F::neg_infinity(), |max, &x| if x > max { x } else { max });
let mut sum = F::zero();
for j in 0..seq_len_k {
let exp_val = (row[j] - max_val).exp();
row[j] = exp_val;
sum += exp_val;
}
if sum > F::zero() {
for j in 0..seq_len_k {
row[j] /= sum;
}
}
}
let mut output = Array2::<F>::zeros((seq_len_q, d_model_v));
for i in 0..seq_len_q {
for j in 0..d_model_v {
let mut sum = F::zero();
for k in 0..seq_len_k {
sum += scores[[i, k]] * v_b[[k, j]];
}
output[[i, j]] = sum;
}
}
result
.slice_mut(scirs2_core::ndarray::s![b, .., ..])
.assign(&output);
}
Ok(result)
}
#[allow(dead_code)]
pub fn rotary_embedding<F>(x: &ArrayView3<F>, freqbase: F) -> LinalgResult<Array3<F>>
where
F: Float + Add + Mul + Div + Sub + NumAssignOps + Zero + std::fmt::Debug,
{
let (batchsize, seq_len, d_model) = (x.shape()[0], x.shape()[1], x.shape()[2]);
if d_model % 2 != 0 {
return Err(LinalgError::ValueError(
"Dimension must be even for rotary embeddings".to_string(),
));
}
let mut result = Array3::<F>::zeros((batchsize, seq_len, d_model));
if d_model % 2 != 0 {
return Err(LinalgError::ValueError(
"Model dimension must be even for rotary embeddings".to_string(),
));
}
let half_dim = d_model / 2;
let mut freqs = Vec::with_capacity(half_dim);
for i in 0..half_dim {
let exponent = F::from(2.0 * i as f64 / d_model as f64).ok_or_else(|| {
LinalgError::ValueError(
"Failed to convert frequency exponent to target type".to_string(),
)
})?;
let freq = F::one() / freqbase.powf(exponent);
freqs.push(freq);
}
for b in 0..batchsize {
for pos in 0..seq_len {
for (i, _) in freqs.iter().enumerate().take(half_dim) {
let i2 = 2 * i;
let x_i = x[[b, pos, i2]];
let x_i_plus_1 = x[[b, pos, i2 + 1]];
let pos_f = F::from(pos as f64).ok_or_else(|| {
LinalgError::ValueError("Failed to convert position to target type".to_string())
})?;
let theta = pos_f * freqs[i];
let cos_theta = theta.cos();
let sin_theta = theta.sin();
result[[b, pos, i2]] = x_i * cos_theta - x_i_plus_1 * sin_theta;
result[[b, pos, i2 + 1]] = x_i * sin_theta + x_i_plus_1 * cos_theta;
}
}
}
Ok(result)
}
#[allow(dead_code)]
pub fn linear_attention<F>(
query: &ArrayView3<F>,
key: &ArrayView3<F>,
value: &ArrayView3<F>,
scale: F,
) -> LinalgResult<Array3<F>>
where
F: Float + Add + Mul + Div + Sub + NumAssignOps + Zero + std::fmt::Debug,
{
let (batchsize, seq_len_q, d_model) = (query.shape()[0], query.shape()[1], query.shape()[2]);
let (_, seq_len_k, _) = (key.shape()[0], key.shape()[1], key.shape()[2]);
let (_, _, d_model_v) = (value.shape()[0], value.shape()[1], value.shape()[2]);
let mut result = Array3::<F>::zeros((batchsize, seq_len_q, d_model_v));
for b in 0..batchsize {
let mut q_prime = Array2::<F>::zeros((seq_len_q, d_model));
let mut k_prime = Array2::<F>::zeros((seq_len_k, d_model));
for i in 0..seq_len_q {
for j in 0..d_model {
let x = query[[b, i, j]];
q_prime[[i, j]] = if x > F::zero() {
x
} else {
(x.exp() - F::one()) + F::one()
};
}
}
for i in 0..seq_len_k {
for j in 0..d_model {
let x = key[[b, i, j]] * scale;
k_prime[[i, j]] = if x > F::zero() {
x
} else {
(x.exp() - F::one()) + F::one()
};
}
}
let mut kv = Array2::<F>::zeros((d_model, d_model_v));
for i in 0..d_model {
for j in 0..d_model_v {
let mut sum = F::zero();
for k in 0..seq_len_k {
sum += k_prime[[k, i]] * value[[b, k, j]];
}
kv[[i, j]] = sum;
}
}
let mut z = Array1::<F>::zeros(seq_len_q);
for i in 0..seq_len_q {
let mut sum = F::zero();
for j in 0..d_model {
let mut k_sum = F::zero();
for k in 0..seq_len_k {
k_sum += k_prime[[k, j]];
}
sum += q_prime[[i, j]] * k_sum;
}
z[i] = sum;
}
for i in 0..seq_len_q {
for j in 0..d_model_v {
let mut sum = F::zero();
for k in 0..d_model {
sum += q_prime[[i, k]] * kv[[k, j]];
}
if z[i] > F::zero() {
result[[b, i, j]] = if z[i] != F::zero() {
sum / z[i]
} else {
F::zero()
};
}
}
}
}
Ok(result)
}
#[allow(dead_code)]
pub fn relative_position_attention<F>(
query: &ArrayView3<F>,
key: &ArrayView3<F>,
value: &ArrayView3<F>,
rel_emb: &ArrayView2<F>,
scale: F,
) -> LinalgResult<Array3<F>>
where
F: Float + Add + Mul + Div + Sub + NumAssignOps + Zero + std::fmt::Debug,
{
let (batchsize, seq_len_q, d_model) = (query.shape()[0], query.shape()[1], query.shape()[2]);
let (_, seq_len_k, _) = (key.shape()[0], key.shape()[1], key.shape()[2]);
let (_, _, d_model_v) = (value.shape()[0], value.shape()[1], value.shape()[2]);
let expected_rel_emb_len = 2 * seq_len_k.max(seq_len_q) - 1;
if rel_emb.shape()[0] != expected_rel_emb_len || rel_emb.shape()[1] != d_model {
return Err(LinalgError::DimensionError(format!(
"Relative embedding shape should be [{}, {}], got {:?}",
expected_rel_emb_len,
d_model,
rel_emb.shape()
)));
}
let mut result = Array3::<F>::zeros((batchsize, seq_len_q, d_model_v));
for b in 0..batchsize {
let mut content_scores = Array2::<F>::zeros((seq_len_q, seq_len_k));
for i in 0..seq_len_q {
for j in 0..seq_len_k {
let mut dot_product = F::zero();
for k in 0..d_model {
dot_product += query[[b, i, k]] * key[[b, j, k]];
}
content_scores[[i, j]] = dot_product * scale;
}
}
let mut pos_scores = Array2::<F>::zeros((seq_len_q, seq_len_k));
for i in 0..seq_len_q {
for j in 0..seq_len_k {
let rel_pos = (seq_len_k - 1) + i - j; let mut dot_product = F::zero();
for k in 0..d_model {
dot_product += query[[b, i, k]] * rel_emb[[rel_pos, k]];
}
pos_scores[[i, j]] = dot_product * scale;
}
}
let mut combined_scores = Array2::<F>::zeros((seq_len_q, seq_len_k));
for i in 0..seq_len_q {
for j in 0..seq_len_k {
combined_scores[[i, j]] = content_scores[[i, j]] + pos_scores[[i, j]];
}
let mut row = combined_scores.slice_mut(scirs2_core::ndarray::s![i, ..]);
let max_val = row.fold(F::neg_infinity(), |max, &x| if x > max { x } else { max });
let mut sum = F::zero();
for j in 0..seq_len_k {
let exp_val = (row[j] - max_val).exp();
row[j] = exp_val;
sum += exp_val;
}
if sum > F::zero() {
for j in 0..seq_len_k {
row[j] /= sum;
}
}
}
let mut output = Array2::<F>::zeros((seq_len_q, d_model_v));
for i in 0..seq_len_q {
for j in 0..d_model_v {
let mut sum = F::zero();
for k in 0..seq_len_k {
sum += combined_scores[[i, k]] * value[[b, k, j]];
}
output[[i, j]] = sum;
}
}
result
.slice_mut(scirs2_core::ndarray::s![b, .., ..])
.assign(&output);
}
Ok(result)
}
#[allow(dead_code)]
pub fn attention_with_rpe<F>(
query: &ArrayView3<F>,
key: &ArrayView3<F>,
value: &ArrayView3<F>,
rel_emb: &ArrayView2<F>,
scale: F,
use_xpos: bool,
) -> LinalgResult<Array3<F>>
where
F: Float + Add + Mul + Div + Sub + NumAssignOps + Zero + std::fmt::Debug,
{
if use_xpos {
let (batchsize, seq_len_q, d_model) =
(query.shape()[0], query.shape()[1], query.shape()[2]);
let (_, seq_len_k, _) = (key.shape()[0], key.shape()[1], key.shape()[2]);
let (_, _, _d_model_v) = (value.shape()[0], value.shape()[1], value.shape()[2]);
let mut q_scaled = Array3::<F>::zeros((batchsize, seq_len_q, d_model));
let mut k_scaled = Array3::<F>::zeros((batchsize, seq_len_k, d_model));
for b in 0..batchsize {
for i in 0..seq_len_q {
let pos_i = F::from(i as f64 + 1.0).expect("Operation failed"); for j in 0..d_model {
let dim_factor = F::from(j as f64 / d_model as f64).expect("Operation failed");
let scale_factor = F::one() / pos_i.powf(dim_factor);
q_scaled[[b, i, j]] = query[[b, i, j]] * scale_factor;
}
}
for i in 0..seq_len_k {
let pos_i = F::from(i as f64 + 1.0).expect("Operation failed"); for j in 0..d_model {
let dim_factor = F::from(j as f64 / d_model as f64).expect("Operation failed");
let scale_factor = F::one() / pos_i.powf(dim_factor);
k_scaled[[b, i, j]] = key[[b, i, j]] * scale_factor;
}
}
}
return attention(&q_scaled.view(), &k_scaled.view(), value, None, scale);
}
relative_position_attention(query, key, value, rel_emb, scale)
}