use crate::errors::{Result, TrustformersError};
use crate::layers::Linear;
use crate::tensor::Tensor;
use crate::traits::Layer;
use scirs2_core::ndarray::{
s, Array1, Array2, ArrayBase, ArrayD, Axis, Dimension, IxDyn, OwnedRepr, Zip,
};
#[cfg(not(target_os = "macos"))]
use scirs2_core::simd_ops::SimdUnifiedOps;
const MIN_SIZE_FOR_BLAS: usize = 32;
#[cfg(target_os = "macos")]
#[inline]
fn blas_sgemm(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
use oxiblas_blas::level3::gemm;
use oxiblas_matrix::{MatMut, MatRef};
let a_t = MatRef::new(a.as_ptr(), k, m, k);
let b_t = MatRef::new(b.as_ptr(), n, k, n);
let c_t = MatMut::new(c.as_mut_ptr(), n, m, n);
gemm(1.0, b_t, a_t, 0.0, c_t);
}
#[cfg(not(target_os = "macos"))]
#[inline]
fn blas_sgemm(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
let a_arr = Array2::from_shape_vec((m, k), a.to_vec()).expect("BLAS input shape mismatch");
let b_arr = Array2::from_shape_vec((k, n), b.to_vec()).expect("BLAS input shape mismatch");
let mut c_arr = Array2::from_shape_vec((m, n), c.to_vec()).expect("BLAS output shape mismatch");
f32::simd_gemm(1.0, &a_arr.view(), &b_arr.view(), 0.0, &mut c_arr);
if let Some(slice) = c_arr.as_slice() {
c.copy_from_slice(slice);
} else {
for (i, &val) in c_arr.iter().enumerate() {
c[i] = val;
}
}
}
#[derive(Debug, Clone)]
pub struct FlashAttention {
num_heads: usize,
hidden_size: usize,
head_dim: usize,
query: Linear,
key: Linear,
value: Linear,
out_proj: Linear,
#[allow(dead_code)]
dropout_prob: f32,
block_size: usize,
causal: bool,
use_flash_attention_2: bool,
}
impl FlashAttention {
pub fn new(
hidden_size: usize,
num_heads: usize,
dropout_prob: f32,
bias: bool,
block_size: Option<usize>,
causal: bool,
) -> Result<Self> {
Self::new_with_version(
hidden_size,
num_heads,
dropout_prob,
bias,
block_size,
causal,
true,
)
}
pub fn new_with_version(
hidden_size: usize,
num_heads: usize,
dropout_prob: f32,
bias: bool,
block_size: Option<usize>,
causal: bool,
use_flash_attention_2: bool,
) -> Result<Self> {
if !hidden_size.is_multiple_of(num_heads) {
return Err(TrustformersError::invalid_config(format!(
"hidden_size {} must be divisible by num_heads {}",
hidden_size, num_heads
)));
}
let head_dim = hidden_size / num_heads;
let block_size = block_size.unwrap_or(64);
Ok(Self {
num_heads,
hidden_size,
head_dim,
query: Linear::new(hidden_size, hidden_size, bias),
key: Linear::new(hidden_size, hidden_size, bias),
value: Linear::new(hidden_size, hidden_size, bias),
out_proj: Linear::new(hidden_size, hidden_size, bias),
dropout_prob,
block_size,
causal,
use_flash_attention_2,
})
}
fn split_heads(&self, tensor: &Tensor) -> Result<Tensor> {
let shape = tensor.shape();
if shape.len() != 3 {
return Err(TrustformersError::tensor_op_error(
&format!(
"Input tensor must have 3 dimensions for split_heads, got {}",
shape.len()
),
"FlashAttention::split_heads",
));
}
match tensor {
Tensor::F32(arr) => {
let batch_size = shape[0];
let seq_len = shape[1];
let reshaped = arr
.to_shape(IxDyn(&[batch_size, seq_len, self.num_heads, self.head_dim]))
.map_err(|_| {
TrustformersError::shape_error("Failed to reshape in split_heads".into())
})?
.to_owned();
let transposed = reshaped.permuted_axes(vec![0, 2, 1, 3]);
Ok(Tensor::F32(transposed))
},
_ => Err(TrustformersError::tensor_op_error(
"Unsupported tensor type",
"FlashAttention::split_heads",
)),
}
}
fn merge_heads(&self, tensor: &Tensor) -> Result<Tensor> {
let shape = tensor.shape();
if shape.len() != 4 {
return Err(TrustformersError::tensor_op_error(
"Input tensor must have 4 dimensions",
"FlashAttention::merge_heads",
));
}
match tensor {
Tensor::F32(arr) => {
let batch_size = shape[0];
let seq_len = shape[2];
let transposed = arr.clone().permuted_axes(vec![0, 2, 1, 3]);
let merged = transposed
.to_shape(IxDyn(&[batch_size, seq_len, self.hidden_size]))
.map_err(|_| {
TrustformersError::shape_error("Failed to reshape in merge_heads".into())
})?
.to_owned();
Ok(Tensor::F32(merged))
},
_ => Err(TrustformersError::tensor_op_error(
"Unsupported tensor type",
"FlashAttention::merge_heads",
)),
}
}
fn flash_attention_forward(
&self,
q: &Tensor,
k: &Tensor,
v: &Tensor,
_mask: Option<&Tensor>,
) -> Result<Tensor> {
if self.use_flash_attention_2 {
self.flash_attention_2_forward(q, k, v, _mask)
} else {
self.flash_attention_1_forward(q, k, v, _mask)
}
}
fn flash_attention_1_forward(
&self,
q: &Tensor,
k: &Tensor,
v: &Tensor,
_mask: Option<&Tensor>,
) -> Result<Tensor> {
let q_shape = q.shape();
let batch_size = q_shape[0];
let num_heads = q_shape[1];
let seq_len = q_shape[2];
let head_dim = q_shape[3];
let scale = 1.0 / (head_dim as f32).sqrt();
match (q, k, v) {
(Tensor::F32(q_arr), Tensor::F32(k_arr), Tensor::F32(v_arr)) => {
let mut output = ArrayD::zeros(IxDyn(&[batch_size, num_heads, seq_len, head_dim]));
let mut l = ArrayD::zeros(IxDyn(&[batch_size, num_heads, seq_len])); let mut m =
ArrayD::from_elem(IxDyn(&[batch_size, num_heads, seq_len]), f32::NEG_INFINITY);
let num_blocks = seq_len.div_ceil(self.block_size);
for i in 0..num_blocks {
let q_start = i * self.block_size;
let q_end = (q_start + self.block_size).min(seq_len);
let q_block_size = q_end - q_start;
let q_block = q_arr.slice(s![.., .., q_start..q_end, ..]).to_owned();
let mut o_block =
ArrayD::zeros(IxDyn(&[batch_size, num_heads, q_block_size, head_dim]));
let mut l_block = ArrayD::zeros(IxDyn(&[batch_size, num_heads, q_block_size]));
let mut m_block = ArrayD::from_elem(
IxDyn(&[batch_size, num_heads, q_block_size]),
f32::NEG_INFINITY,
);
for j in 0..num_blocks {
let k_start = j * self.block_size;
let k_end = (k_start + self.block_size).min(seq_len);
if self.causal && k_start >= q_end {
break;
}
let k_block = k_arr.slice(s![.., .., k_start..k_end, ..]).to_owned();
let v_block = v_arr.slice(s![.., .., k_start..k_end, ..]).to_owned();
let k_transposed = k_block.permuted_axes([0, 1, 3, 2]);
let scores = self.batched_matmul_slices(&q_block, &k_transposed)?;
let mut scores = scores.mapv(|x| x * scale);
if self.causal {
for b in 0..batch_size {
for h in 0..num_heads {
for qi in 0..q_block_size {
for ki in 0..(k_end - k_start) {
let global_qi = q_start + qi;
let global_ki = k_start + ki;
if global_qi < global_ki {
scores[[b, h, qi, ki]] = f32::NEG_INFINITY;
}
}
}
}
}
}
let m_new = scores.map_axis(Axis(3), |row| {
row.fold(f32::NEG_INFINITY, |acc, &x| acc.max(x))
});
let m_prev = m_block.clone();
let m_combined = Zip::from(&m_block)
.and(&m_new)
.map_collect(|&m_old, &m_curr| m_old.max(m_curr));
let scores_shape = scores.shape();
let mut m_combined_expanded = ArrayD::zeros(IxDyn(scores_shape));
for b in 0..batch_size {
for h in 0..num_heads {
for qi in 0..q_block_size {
let m_val = m_combined[[b, h, qi]];
for ki in 0..(k_end - k_start) {
m_combined_expanded[[b, h, qi, ki]] = m_val;
}
}
}
}
let exp_scores = Zip::from(&scores)
.and(&m_combined_expanded)
.map_collect(|&score, &m_max| (score - m_max).exp());
let exp_prev = Zip::from(&m_prev)
.and(&m_combined)
.map_collect(|&m_old, &m_new| (m_old - m_new).exp());
let l_new = exp_scores.sum_axis(Axis(3));
let l_prev_scaled =
Zip::from(&l_block).and(&exp_prev).map_collect(|&l, &exp| l * exp);
l_block = l_prev_scaled + l_new;
let mut o_prev_scaled = o_block.clone();
for b in 0..batch_size {
for h in 0..num_heads {
for qi in 0..q_block_size {
let exp_val = exp_prev[[b, h, qi]];
for d in 0..head_dim {
o_prev_scaled[[b, h, qi, d]] *= exp_val;
}
}
}
}
let attn_v = self.batched_matmul_slices(&exp_scores, &v_block)?;
o_block = o_prev_scaled + attn_v;
m_block = m_combined;
}
let l_inv = l_block.mapv(|x: f32| if x > 0.0 { 1.0 / x } else { 0.0 });
for b in 0..batch_size {
for h in 0..num_heads {
for qi in 0..q_block_size {
let l_val = l_inv[[b, h, qi]];
for d in 0..head_dim {
o_block[[b, h, qi, d]] *= l_val;
}
}
}
}
output.slice_mut(s![.., .., q_start..q_end, ..]).assign(&o_block);
l.slice_mut(s![.., .., q_start..q_end]).assign(&l_block);
m.slice_mut(s![.., .., q_start..q_end]).assign(&m_block);
}
Ok(Tensor::F32(output))
},
_ => Err(TrustformersError::tensor_op_error(
"Unsupported tensor types for flash attention",
"FlashAttention::flash_attention_forward",
)),
}
}
fn batched_matmul_slices<D1, D2>(
&self,
a: &ArrayBase<OwnedRepr<f32>, D1>,
b: &ArrayBase<OwnedRepr<f32>, D2>,
) -> Result<ArrayD<f32>>
where
D1: Dimension,
D2: Dimension,
{
let a_dyn = a.view().into_dyn().to_owned();
let b_dyn = b.view().into_dyn().to_owned();
self.batched_matmul_4d(&a_dyn, &b_dyn)
}
fn batched_matmul_4d(&self, a: &ArrayD<f32>, b: &ArrayD<f32>) -> Result<ArrayD<f32>> {
let a_shape = a.shape();
let b_shape = b.shape();
if a_shape.len() != 4 || b_shape.len() != 4 {
return Err(TrustformersError::tensor_op_error(
"Both tensors must be 4D",
"FlashAttention::batched_matmul_4d",
));
}
let batch = a_shape[0];
let heads = a_shape[1];
let m = a_shape[2];
let k = a_shape[3];
let n = b_shape[3];
if a_shape[0] != b_shape[0] || a_shape[1] != b_shape[1] || k != b_shape[2] {
return Err(TrustformersError::tensor_op_error(
"Shape mismatch in batched matmul",
"FlashAttention::batched_matmul_4d",
));
}
let mut result = ArrayD::zeros(IxDyn(&[batch, heads, m, n]));
for b_idx in 0..batch {
for h_idx in 0..heads {
let a_slice = a.index_axis(Axis(0), b_idx);
let a_mat = a_slice.index_axis(Axis(0), h_idx);
let b_slice = b.index_axis(Axis(0), b_idx);
let b_mat = b_slice.index_axis(Axis(0), h_idx);
let a_data: Vec<f32> = a_mat.iter().cloned().collect();
let b_data: Vec<f32> = b_mat.iter().cloned().collect();
let product =
if m >= MIN_SIZE_FOR_BLAS && k >= MIN_SIZE_FOR_BLAS && n >= MIN_SIZE_FOR_BLAS {
let mut result_vec = vec![0.0f32; m * n];
blas_sgemm(&a_data, &b_data, &mut result_vec, m, k, n);
Array2::from_shape_vec((m, n), result_vec)
.map_err(|e| TrustformersError::shape_error(e.to_string()))?
} else {
let a_2d = Array2::from_shape_vec((m, k), a_data)
.map_err(|e| TrustformersError::shape_error(e.to_string()))?;
let b_2d = Array2::from_shape_vec((k, n), b_data)
.map_err(|e| TrustformersError::shape_error(e.to_string()))?;
a_2d.dot(&b_2d)
};
result
.index_axis_mut(Axis(0), b_idx)
.index_axis_mut(Axis(0), h_idx)
.assign(&product);
}
}
Ok(result)
}
fn flash_attention_2_forward(
&self,
q: &Tensor,
k: &Tensor,
v: &Tensor,
_mask: Option<&Tensor>,
) -> Result<Tensor> {
let q_shape = q.shape();
let batch_size = q_shape[0];
let num_heads = q_shape[1];
let seq_len = q_shape[2];
let head_dim = q_shape[3];
let scale = 1.0 / (head_dim as f32).sqrt();
match (q, k, v) {
(Tensor::F32(q_arr), Tensor::F32(k_arr), Tensor::F32(v_arr)) => {
let mut output = ArrayD::zeros(IxDyn(&[batch_size, num_heads, seq_len, head_dim]));
let num_blocks = seq_len.div_ceil(self.block_size);
for b in 0..batch_size {
for h in 0..num_heads {
let q_batch = q_arr.index_axis(Axis(0), b);
let k_batch = k_arr.index_axis(Axis(0), b);
let v_batch = v_arr.index_axis(Axis(0), b);
let q_head = q_batch.index_axis(Axis(0), h);
let k_head = k_batch.index_axis(Axis(0), h);
let v_head = v_batch.index_axis(Axis(0), h);
for i in 0..num_blocks {
let q_start = i * self.block_size;
let q_end = (q_start + self.block_size).min(seq_len);
let q_block_size = q_end - q_start;
let q_block = q_head.slice(s![q_start..q_end, ..]).to_owned();
let mut o_block = Array2::<f32>::zeros((q_block_size, head_dim));
let mut l_block = Array1::<f32>::zeros(q_block_size);
let mut m_block =
Array1::<f32>::from_elem(q_block_size, f32::NEG_INFINITY);
for j in 0..num_blocks {
let k_start = j * self.block_size;
let k_end = (k_start + self.block_size).min(seq_len);
let k_block_size = k_end - k_start;
if self.causal && k_start >= q_end {
break;
}
let k_block = k_head.slice(s![k_start..k_end, ..]).to_owned();
let v_block = v_head.slice(s![k_start..k_end, ..]).to_owned();
let mut scores = Array2::<f32>::zeros((q_block_size, k_block_size));
for qi in 0..q_block_size {
for ki in 0..k_block_size {
let mut dot_product = 0.0;
for d in 0..head_dim {
dot_product += q_block[[qi, d]] * k_block[[ki, d]];
}
scores[[qi, ki]] = dot_product * scale;
}
}
if self.causal {
for qi in 0..q_block_size {
for ki in 0..k_block_size {
let global_qi = q_start + qi;
let global_ki = k_start + ki;
if global_qi < global_ki {
scores[[qi, ki]] = f32::NEG_INFINITY;
}
}
}
}
let m_new =
scores.fold_axis(Axis(1), f32::NEG_INFINITY, |&acc, &x| {
acc.max(x)
});
let m_prev = m_block.clone();
let m_combined = Array1::<f32>::from_shape_fn(q_block_size, |i| {
m_block[i].max(m_new[i])
});
let mut exp_scores =
Array2::<f32>::zeros((q_block_size, k_block_size));
for qi in 0..q_block_size {
for ki in 0..k_block_size {
exp_scores[[qi, ki]] =
(scores[[qi, ki]] - m_combined[qi]).exp();
}
}
let exp_prev = Array1::<f32>::from_shape_fn(q_block_size, |i| {
(m_prev[i] - m_combined[i]).exp()
});
let l_new = exp_scores.sum_axis(Axis(1));
for qi in 0..q_block_size {
l_block[qi] = l_block[qi] * exp_prev[qi] + l_new[qi];
}
for qi in 0..q_block_size {
for d in 0..head_dim {
o_block[[qi, d]] *= exp_prev[qi];
}
}
for qi in 0..q_block_size {
for d in 0..head_dim {
let mut attn_v_val = 0.0;
for ki in 0..k_block_size {
attn_v_val += exp_scores[[qi, ki]] * v_block[[ki, d]];
}
o_block[[qi, d]] += attn_v_val;
}
}
m_block = m_combined;
}
for qi in 0..q_block_size {
let l_inv = if l_block[qi] > 0.0 { 1.0 / l_block[qi] } else { 0.0 };
for d in 0..head_dim {
o_block[[qi, d]] *= l_inv;
}
}
for qi in 0..q_block_size {
for d in 0..head_dim {
output[[b, h, q_start + qi, d]] = o_block[[qi, d]];
}
}
}
}
}
Ok(Tensor::F32(output))
},
_ => Err(TrustformersError::tensor_op_error(
"Unsupported tensor types for flash attention 2",
"FlashAttention::flash_attention_2_forward",
)),
}
}
}
#[derive(Debug, Clone)]
pub struct FlashAttentionInput {
pub hidden_states: Tensor,
pub attention_mask: Option<Tensor>,
}
impl Layer for FlashAttention {
type Input = FlashAttentionInput;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let hidden_states = input.hidden_states;
let was_2d = match &hidden_states {
Tensor::F32(arr) => arr.ndim() == 2,
_ => false,
};
let hidden_states = match &hidden_states {
Tensor::F32(arr) => {
if arr.ndim() == 2 {
let shape = arr.shape();
let expanded = arr
.view()
.into_shape_with_order(IxDyn(&[1, shape[0], shape[1]]))
.map_err(|e| {
TrustformersError::shape_error(format!(
"Failed to add batch dimension: {}",
e
))
})?;
Tensor::F32(expanded.to_owned())
} else {
hidden_states
}
},
_ => hidden_states,
};
let query_states = self.query.forward(hidden_states.clone())?;
let key_states = self.key.forward(hidden_states.clone())?;
let value_states = self.value.forward(hidden_states)?;
let query_states = self.split_heads(&query_states)?;
let key_states = self.split_heads(&key_states)?;
let value_states = self.split_heads(&value_states)?;
let context = self.flash_attention_forward(
&query_states,
&key_states,
&value_states,
input.attention_mask.as_ref(),
)?;
let context = self.merge_heads(&context)?;
let result = self.out_proj.forward(context)?;
if was_2d {
match &result {
Tensor::F32(arr) => {
if arr.shape()[0] == 1 {
let squeezed = arr.index_axis(Axis(0), 0).to_owned();
Ok(Tensor::F32(squeezed))
} else {
Ok(result)
}
},
_ => Ok(result),
}
} else {
Ok(result)
}
}
}
#[derive(Debug, Clone)]
pub struct MultiQueryAttention {
#[allow(dead_code)]
num_heads: usize,
#[allow(dead_code)]
hidden_size: usize,
#[allow(dead_code)]
head_dim: usize,
#[allow(dead_code)]
query: Linear,
#[allow(dead_code)]
key: Linear,
#[allow(dead_code)]
value: Linear,
#[allow(dead_code)]
out_proj: Linear,
#[allow(dead_code)]
dropout_prob: f32,
}
impl MultiQueryAttention {
pub fn new(
hidden_size: usize,
num_heads: usize,
dropout_prob: f32,
bias: bool,
) -> Result<Self> {
if !hidden_size.is_multiple_of(num_heads) {
return Err(TrustformersError::invalid_config(format!(
"hidden_size {} must be divisible by num_heads {}",
hidden_size, num_heads
)));
}
let head_dim = hidden_size / num_heads;
Ok(Self {
num_heads,
hidden_size,
head_dim,
query: Linear::new(hidden_size, hidden_size, bias),
key: Linear::new(hidden_size, head_dim, bias), value: Linear::new(hidden_size, head_dim, bias), out_proj: Linear::new(hidden_size, hidden_size, bias),
dropout_prob,
})
}
}
#[derive(Debug, Clone)]
pub struct GroupedQueryAttention {
#[allow(dead_code)]
num_query_heads: usize,
#[allow(dead_code)]
num_key_value_heads: usize,
#[allow(dead_code)]
hidden_size: usize,
#[allow(dead_code)]
head_dim: usize,
#[allow(dead_code)]
query: Linear,
#[allow(dead_code)]
key: Linear,
#[allow(dead_code)]
value: Linear,
#[allow(dead_code)]
out_proj: Linear,
#[allow(dead_code)]
dropout_prob: f32,
}
impl GroupedQueryAttention {
pub fn new(
hidden_size: usize,
num_query_heads: usize,
num_key_value_heads: usize,
dropout_prob: f32,
bias: bool,
) -> Result<Self> {
if !hidden_size.is_multiple_of(num_query_heads) {
return Err(TrustformersError::invalid_config(format!(
"hidden_size {} must be divisible by num_query_heads {}",
hidden_size, num_query_heads
)));
}
if !num_query_heads.is_multiple_of(num_key_value_heads) {
return Err(TrustformersError::invalid_config(format!(
"num_query_heads {} must be divisible by num_key_value_heads {}",
num_query_heads, num_key_value_heads
)));
}
let head_dim = hidden_size / num_query_heads;
let kv_hidden_size = num_key_value_heads * head_dim;
Ok(Self {
num_query_heads,
num_key_value_heads,
hidden_size,
head_dim,
query: Linear::new(hidden_size, hidden_size, bias),
key: Linear::new(hidden_size, kv_hidden_size, bias),
value: Linear::new(hidden_size, kv_hidden_size, bias),
out_proj: Linear::new(hidden_size, hidden_size, bias),
dropout_prob,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::Tensor;
#[test]
fn test_flash_attention_creation() {
let flash_attn = FlashAttention::new(768, 12, 0.1, true, Some(64), false);
assert!(flash_attn.is_ok());
let flash_attn = flash_attn.expect("Failed to create FlashAttention");
assert_eq!(flash_attn.num_heads, 12);
assert_eq!(flash_attn.hidden_size, 768);
assert_eq!(flash_attn.head_dim, 64);
assert_eq!(flash_attn.block_size, 64);
assert!(!flash_attn.causal);
}
#[test]
fn test_flash_attention_forward_pass() {
let flash_attn = FlashAttention::new(256, 8, 0.0, true, Some(32), false)
.expect("Failed to create FlashAttention");
let hidden_states = Tensor::randn(&[2, 128, 256]).expect("Failed to create random tensor");
let input = FlashAttentionInput {
hidden_states,
attention_mask: None,
};
let output = flash_attn.forward(input);
assert!(output.is_ok());
let output = output.expect("Forward pass failed");
assert_eq!(output.shape(), vec![2, 128, 256]);
}
#[test]
fn test_multi_query_attention_creation() {
let mqa = MultiQueryAttention::new(768, 12, 0.1, true);
assert!(mqa.is_ok());
let mqa = mqa.expect("Failed to create MultiQueryAttention");
assert_eq!(mqa.num_heads, 12);
assert_eq!(mqa.hidden_size, 768);
assert_eq!(mqa.head_dim, 64);
}
#[test]
fn test_grouped_query_attention_creation() {
let gqa = GroupedQueryAttention::new(768, 12, 4, 0.1, true);
assert!(gqa.is_ok());
let gqa = gqa.expect("Failed to create GroupedQueryAttention");
assert_eq!(gqa.num_query_heads, 12);
assert_eq!(gqa.num_key_value_heads, 4);
assert_eq!(gqa.hidden_size, 768);
assert_eq!(gqa.head_dim, 64);
}
#[test]
fn test_flash_attention_causal() {
let flash_attn = FlashAttention::new(256, 8, 0.0, true, Some(32), true)
.expect("Failed to create FlashAttention");
let hidden_states = Tensor::randn(&[1, 64, 256]).expect("Failed to create random tensor");
let input = FlashAttentionInput {
hidden_states,
attention_mask: None,
};
let output = flash_attn.forward(input);
assert!(output.is_ok());
let output = output.expect("Forward pass failed");
assert_eq!(output.shape(), vec![1, 64, 256]);
}
#[test]
fn test_flash_attention_deterministic() {
let flash_attn = FlashAttention::new(128, 4, 0.0, true, Some(16), false)
.expect("Failed to create FlashAttention");
let hidden_states = Tensor::ones(&[1, 32, 128]).expect("Failed to create ones tensor");
let input = FlashAttentionInput {
hidden_states: hidden_states.clone(),
attention_mask: None,
};
let output1 = flash_attn.forward(input.clone()).expect("Forward pass failed");
let output2 = flash_attn.forward(input).expect("Forward pass failed");
let data1 = output1.data().expect("Failed to get data");
let data2 = output2.data().expect("Failed to get data");
for (a, b) in data1.iter().zip(data2.iter()) {
assert!((a - b).abs() < 1e-6, "Outputs should be deterministic");
}
}
#[test]
fn test_flash_attention_2_creation() {
let flash_attn_2 =
FlashAttention::new_with_version(768, 12, 0.1, true, Some(64), false, true);
assert!(flash_attn_2.is_ok());
let flash_attn_2 = flash_attn_2.expect("Failed to create FlashAttention-2");
assert_eq!(flash_attn_2.num_heads, 12);
assert_eq!(flash_attn_2.hidden_size, 768);
assert_eq!(flash_attn_2.head_dim, 64);
assert_eq!(flash_attn_2.block_size, 64);
assert!(!flash_attn_2.causal);
assert!(flash_attn_2.use_flash_attention_2);
}
#[test]
fn test_flash_attention_2_forward_pass() {
let flash_attn_2 =
FlashAttention::new_with_version(256, 8, 0.0, true, Some(32), false, true)
.expect("Failed to create FlashAttention-2");
let hidden_states = Tensor::randn(&[2, 128, 256]).expect("Failed to create random tensor");
let input = FlashAttentionInput {
hidden_states,
attention_mask: None,
};
let output = flash_attn_2.forward(input);
assert!(output.is_ok());
let output = output.expect("Forward pass failed");
assert_eq!(output.shape(), vec![2, 128, 256]);
}
#[test]
fn test_flash_attention_2_vs_1_consistency() {
let flash_attn_1 =
FlashAttention::new_with_version(128, 4, 0.0, true, Some(16), false, false)
.expect("Failed to create FlashAttention-1");
let flash_attn_2 =
FlashAttention::new_with_version(128, 4, 0.0, true, Some(16), false, true)
.expect("Failed to create FlashAttention-2");
let hidden_states = Tensor::ones(&[1, 32, 128]).expect("Failed to create ones tensor");
let input = FlashAttentionInput {
hidden_states: hidden_states.clone(),
attention_mask: None,
};
let output1 = flash_attn_1.forward(input.clone()).expect("Forward pass failed");
let output2 = flash_attn_2.forward(input).expect("Forward pass failed");
let data1 = output1.data().expect("Failed to get data");
let data2 = output2.data().expect("Failed to get data");
let mut max_diff: f32 = 0.0;
for (a, b) in data1.iter().zip(data2.iter()) {
max_diff = max_diff.max((a - b).abs());
}
assert!(
max_diff < 1000.0,
"FlashAttention-2 output differs too much from FlashAttention-1: max_diff = {}",
max_diff
);
}
#[test]
fn test_flash_attention_2_causal() {
let flash_attn_2 =
FlashAttention::new_with_version(256, 8, 0.0, true, Some(32), true, true)
.expect("Failed to create FlashAttention-2");
let hidden_states = Tensor::randn(&[1, 64, 256]).expect("Failed to create random tensor");
let input = FlashAttentionInput {
hidden_states,
attention_mask: None,
};
let output = flash_attn_2.forward(input);
assert!(output.is_ok());
let output = output.expect("Forward pass failed");
assert_eq!(output.shape(), vec![1, 64, 256]);
}
}