use core::f32;
#[allow(unused_imports)]
use num_traits::Float as _;
use crate::{
Backend, TensorMetadata,
element::ElementConversion,
tensor::{BoolTensor, FloatTensor},
};
pub fn naive_attention<B: Backend>(
query: FloatTensor<B>,
key: FloatTensor<B>,
value: FloatTensor<B>,
mask: Option<BoolTensor<B>>,
) -> FloatTensor<B> {
let query_shape = query.shape().dims::<4>();
let sqrt_d = (*query_shape.last().unwrap() as f32).sqrt().elem();
let transposed_key = B::float_transpose(key);
let qk = B::float_matmul(query, transposed_key);
let attention_scores = B::float_div_scalar(qk, sqrt_d);
let attention_scores = if let Some(mask) = mask {
B::float_mask_fill(attention_scores, mask, f32::NEG_INFINITY.elem())
} else {
attention_scores
};
let max_per_dim = B::float_max_dim(attention_scores.clone(), 3);
let minus_max = B::float_sub(attention_scores, max_per_dim);
let numerator = B::float_exp(minus_max);
let sum_exp = B::float_sum_dim(numerator.clone(), 3);
let softmax = B::float_div(numerator, sum_exp);
B::float_matmul(softmax, value)
}