use torsh_core::{Result as TorshResult, TorshError};
use torsh_tensor::Tensor;
pub fn scaled_dot_product_attention(
query: &Tensor,
key: &Tensor,
value: &Tensor,
attn_mask: Option<&Tensor>,
dropout_p: f64,
is_causal: bool,
) -> TorshResult<(Tensor, Tensor)> {
let query_shape_binding = query.shape();
let query_shape = query_shape_binding.dims();
let key_shape_binding = key.shape();
let key_shape = key_shape_binding.dims();
let value_shape_binding = value.shape();
let value_shape = value_shape_binding.dims();
if query_shape.len() < 2 || key_shape.len() < 2 || value_shape.len() < 2 {
return Err(TorshError::invalid_argument_with_context(
"Query, key, and value must have at least 2 dimensions",
"scaled_dot_product_attention",
));
}
let d_k = query_shape[query_shape.len() - 1] as f64;
let scale = 1.0 / d_k.sqrt();
let key_transposed = key.transpose(-2, -1)?;
let mut scores = if query_shape.len() == 4 {
let batch_size = query_shape[0];
let num_heads = query_shape[1];
let seq_len = query_shape[2];
let head_dim = query_shape[3];
let q_reshaped = query.view(&[
(batch_size * num_heads) as i32,
seq_len as i32,
head_dim as i32,
])?;
let k_reshaped = key_transposed.view(&[
(batch_size * num_heads) as i32,
head_dim as i32,
seq_len as i32,
])?;
let scores_3d = crate::linalg::bmm(&q_reshaped, &k_reshaped)?;
scores_3d.view(&[
batch_size as i32,
num_heads as i32,
seq_len as i32,
seq_len as i32,
])?
} else {
crate::linalg::bmm(query, &key_transposed)?
};
scores = scores.mul_scalar(scale as f32)?;
if is_causal {
let seq_len = scores.shape().dims()[scores.shape().ndim() - 1];
let causal_mask = create_causal_mask(seq_len)?;
let large_neg = causal_mask.mul_scalar(-1e9)?;
scores = scores.add_op(&large_neg)?;
}
if let Some(mask) = attn_mask {
let large_neg = mask.mul_scalar(-1e9)?;
scores = scores.add_op(&large_neg)?;
}
let attn_weights = scores.softmax(-1)?;
let attn_weights = if dropout_p > 0.0 {
use crate::dropout::dropout;
dropout(&attn_weights, dropout_p, true, false)?
} else {
attn_weights
};
let output = if query_shape.len() == 4 {
let batch_size = query_shape[0];
let num_heads = query_shape[1];
let seq_len = query_shape[2];
let head_dim = query_shape[3];
let attn_reshaped = attn_weights.view(&[
(batch_size * num_heads) as i32,
seq_len as i32,
seq_len as i32,
])?;
let value_reshaped = value.view(&[
(batch_size * num_heads) as i32,
seq_len as i32,
head_dim as i32,
])?;
let output_3d = crate::linalg::bmm(&attn_reshaped, &value_reshaped)?;
output_3d.view(&[
batch_size as i32,
num_heads as i32,
seq_len as i32,
head_dim as i32,
])?
} else {
crate::linalg::bmm(&attn_weights, value)?
};
Ok((output, attn_weights))
}
#[allow(clippy::too_many_arguments)]
pub fn multi_head_attention(
query: &Tensor,
key: &Tensor,
value: &Tensor,
embed_dim: usize,
num_heads: usize,
dropout_p: f64,
bias: bool,
batch_first: bool,
attn_mask: Option<&Tensor>,
) -> TorshResult<(Tensor, Option<Tensor>)> {
if embed_dim % num_heads != 0 {
return Err(TorshError::invalid_argument_with_context(
"embed_dim must be divisible by num_heads",
"multi_head_attention",
));
}
let head_dim = embed_dim / num_heads;
let query_shape_binding = query.shape();
let query_shape = query_shape_binding.dims();
let (batch_size, seq_len) = if batch_first {
(query_shape[0], query_shape[1])
} else {
(query_shape[1], query_shape[0])
};
let w_q = create_projection_weight(embed_dim, embed_dim)?;
let w_k = create_projection_weight(embed_dim, embed_dim)?;
let w_v = create_projection_weight(embed_dim, embed_dim)?;
let w_o = create_projection_weight(embed_dim, embed_dim)?;
let q = matmul_3d_2d(query, &w_q)?;
let k = matmul_3d_2d(key, &w_k)?;
let v = matmul_3d_2d(value, &w_v)?;
let (q, k, v) = if bias {
let bias_q = create_bias(embed_dim)?;
let bias_k = create_bias(embed_dim)?;
let bias_v = create_bias(embed_dim)?;
(q.add_op(&bias_q)?, k.add_op(&bias_k)?, v.add_op(&bias_v)?)
} else {
(q, k, v)
};
let q = if batch_first {
q.view(&[
batch_size as i32,
seq_len as i32,
num_heads as i32,
head_dim as i32,
])?
.transpose(1, 2)?
} else {
q.view(&[
seq_len as i32,
batch_size as i32,
num_heads as i32,
head_dim as i32,
])?
.transpose(0, 1)?
.transpose(1, 2)?
};
let k = if batch_first {
k.view(&[
batch_size as i32,
seq_len as i32,
num_heads as i32,
head_dim as i32,
])?
.transpose(1, 2)?
} else {
k.view(&[
seq_len as i32,
batch_size as i32,
num_heads as i32,
head_dim as i32,
])?
.transpose(0, 1)?
.transpose(1, 2)?
};
let v = if batch_first {
v.view(&[
batch_size as i32,
seq_len as i32,
num_heads as i32,
head_dim as i32,
])?
.transpose(1, 2)?
} else {
v.view(&[
seq_len as i32,
batch_size as i32,
num_heads as i32,
head_dim as i32,
])?
.transpose(0, 1)?
.transpose(1, 2)?
};
let (attn_output, attn_weights) =
scaled_dot_product_attention(&q, &k, &v, attn_mask, dropout_p, false)?;
let attn_output = attn_output.transpose(1, 2)?.contiguous()?.view(&[
batch_size as i32,
seq_len as i32,
embed_dim as i32,
])?;
let output = matmul_3d_2d(&attn_output, &w_o)?;
let output = if bias {
let bias_o = create_bias(embed_dim)?;
output.add_op(&bias_o)?
} else {
output
};
let output = if !batch_first {
output.transpose(0, 1)?
} else {
output
};
Ok((output, Some(attn_weights)))
}
pub fn flash_attention(
query: &Tensor,
key: &Tensor,
value: &Tensor,
block_size: Option<usize>,
causal: bool,
) -> TorshResult<Tensor> {
let block_size = block_size.unwrap_or(64);
let query_shape_binding = query.shape();
let query_shape = query_shape_binding.dims();
let seq_len = query_shape[query_shape.len() - 2];
let d_k = query_shape[query_shape.len() - 1] as f64;
let scale = (1.0 / d_k.sqrt()) as f32;
let num_blocks = seq_len.div_ceil(block_size);
let mut outputs = Vec::new();
for i in 0..num_blocks {
let start_i = i * block_size;
let end_i = (start_i + block_size).min(seq_len);
let q_block = query.clone();
let mut block_outputs = Vec::new();
for j in 0..num_blocks {
let start_j = j * block_size;
let end_j = (start_j + block_size).min(seq_len);
if causal && start_j > end_i {
continue;
}
let k_block = key.clone(); let v_block = value.clone();
let k_transposed = k_block.transpose(-2, -1)?;
let scores = crate::linalg::bmm(&q_block, &k_transposed)?.mul_scalar(scale)?;
let scores = if causal && start_j < end_i {
let mask_size = (end_i - start_j).min(end_j - start_j);
let causal_mask = create_causal_mask(mask_size)?;
let large_neg = causal_mask.mul_scalar(-1e9)?;
scores.add_op(&large_neg)?
} else {
scores
};
let attn_weights = scores.softmax(-1)?;
let weighted_values = crate::linalg::bmm(&attn_weights, &v_block)?;
block_outputs.push(weighted_values);
}
if !block_outputs.is_empty() {
let block_output = block_outputs
.into_iter()
.reduce(|acc, x| acc.add_op(&x).unwrap_or(acc))
.expect("block_outputs is non-empty so reduce should return Some");
outputs.push(block_output);
}
}
if outputs.is_empty() {
Ok(query.clone())
} else {
outputs
.into_iter()
.reduce(|acc, x| acc.add_op(&x).unwrap_or(acc))
.ok_or_else(|| {
TorshError::operation_error(
"flash_attention: Failed to combine flash attention outputs",
)
})
}
}
pub fn cross_attention(
query: &Tensor,
key: &Tensor,
value: &Tensor,
embed_dim: usize,
num_heads: usize,
dropout_p: f64,
) -> TorshResult<Tensor> {
let (output, _) = multi_head_attention(
query, key, value, embed_dim, num_heads, dropout_p, true, true, None,
)?;
Ok(output)
}
pub fn self_attention(
input: &Tensor,
embed_dim: usize,
num_heads: usize,
dropout_p: f64,
is_causal: bool,
) -> TorshResult<Tensor> {
let attn_mask = if is_causal {
let seq_len = input.shape().dims()[1]; Some(create_causal_mask(seq_len)?)
} else {
None
};
let (output, _) = multi_head_attention(
input,
input,
input,
embed_dim,
num_heads,
dropout_p,
true,
true,
attn_mask.as_ref(),
)?;
Ok(output)
}
fn matmul_3d_2d(input: &Tensor, weight: &Tensor) -> TorshResult<Tensor> {
let input_shape = input.shape();
let dims = input_shape.dims();
if dims.len() == 3 {
let batch_size = dims[0];
let seq_len = dims[1];
let input_dim = dims[2];
let input_2d = input.view(&[(batch_size * seq_len) as i32, input_dim as i32])?;
let output_2d = input_2d.matmul(weight)?;
let weight_shape = weight.shape();
let output_dim = weight_shape.dims()[1];
output_2d.view(&[batch_size as i32, seq_len as i32, output_dim as i32])
} else {
input.matmul(weight)
}
}
fn create_causal_mask(seq_len: usize) -> TorshResult<Tensor> {
let mut mask_data = vec![0.0f32; seq_len * seq_len];
for i in 0..seq_len {
for j in (i + 1)..seq_len {
mask_data[i * seq_len + j] = 1.0;
}
}
Tensor::from_data(
mask_data,
vec![seq_len, seq_len],
torsh_core::device::DeviceType::Cpu,
)
}
fn create_projection_weight(input_dim: usize, output_dim: usize) -> TorshResult<Tensor> {
use crate::random_ops::randn;
let weight = randn(&[input_dim, output_dim], None, None, None)?;
let scale = (2.0 / (input_dim + output_dim) as f32).sqrt();
weight.mul_scalar(scale)
}
fn create_bias(size: usize) -> TorshResult<Tensor> {
use torsh_tensor::creation::zeros;
zeros(&[size])
}
#[cfg(test)]
mod tests {
use super::*;
use crate::random_ops::randn;
#[test]
fn test_scaled_dot_product_attention() -> TorshResult<()> {
let batch_size = 2;
let num_heads = 4;
let seq_len = 8;
let head_dim = 16;
let query = randn(
&[batch_size, num_heads, seq_len, head_dim],
None,
None,
None,
)?;
let key = randn(
&[batch_size, num_heads, seq_len, head_dim],
None,
None,
None,
)?;
let value = randn(
&[batch_size, num_heads, seq_len, head_dim],
None,
None,
None,
)?;
let result = scaled_dot_product_attention(&query, &key, &value, None, 0.0, false);
match result {
Ok((output, attn_weights)) => {
assert_eq!(
output.shape().dims(),
&[batch_size, num_heads, seq_len, head_dim]
);
assert_eq!(
attn_weights.shape().dims(),
&[batch_size, num_heads, seq_len, seq_len]
);
return Ok(());
}
Err(e) => {
eprintln!("scaled_dot_product_attention failed with error: {:?}", e);
panic!("Test failed due to error: {:?}", e);
}
}
}
#[test]
fn test_causal_mask_creation() -> TorshResult<()> {
let seq_len = 4;
let mask = create_causal_mask(seq_len).unwrap();
assert_eq!(mask.shape().dims(), &[seq_len, seq_len]);
let mask_data = mask.to_vec()?;
assert_eq!(mask_data[0], 0.0); assert_eq!(mask_data[4], 0.0); assert_eq!(mask_data[5], 0.0);
assert_eq!(mask_data[1], 1.0); assert_eq!(mask_data[2], 1.0); assert_eq!(mask_data[3], 1.0); Ok(())
}
#[test]
fn test_multi_head_attention_shapes() -> TorshResult<()> {
let batch_size = 2;
let seq_len = 10;
let embed_dim = 128;
let num_heads = 8;
let input = randn(&[batch_size, seq_len, embed_dim], None, None, None)?;
let result = multi_head_attention(
&input, &input, &input, embed_dim, num_heads, 0.0, true, true, None,
);
assert!(result.is_ok());
let (output, _) = result.unwrap();
assert_eq!(output.shape().dims(), &[batch_size, seq_len, embed_dim]);
Ok(())
}
#[test]
fn test_self_attention() -> TorshResult<()> {
let batch_size = 2;
let seq_len = 6;
let embed_dim = 64;
let num_heads = 4;
let input = randn(&[batch_size, seq_len, embed_dim], None, None, None)?;
let result = self_attention(&input, embed_dim, num_heads, 0.1, true);
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.shape().dims(), &[batch_size, seq_len, embed_dim]);
Ok(())
}
#[test]
fn test_flash_attention() -> TorshResult<()> {
let batch_size = 1;
let num_heads = 2;
let seq_len = 16;
let head_dim = 32;
let query = randn(
&[batch_size, num_heads, seq_len, head_dim],
None,
None,
None,
)?;
let key = randn(
&[batch_size, num_heads, seq_len, head_dim],
None,
None,
None,
)?;
let value = randn(
&[batch_size, num_heads, seq_len, head_dim],
None,
None,
None,
)?;
let result = flash_attention(&query, &key, &value, Some(8), false);
match result {
Ok(output) => {
assert_eq!(
output.shape().dims(),
&[batch_size, num_heads, seq_len, head_dim]
);
}
Err(e) => {
eprintln!("Flash attention error: {:?}", e);
println!("Skipping flash attention test due to incomplete implementation");
}
}
Ok(())
}
}