use burn::backend::NdArray;
use burn::tensor::{Distribution, Tensor};
use burn_attention::FlashAttentionV3;
type Backend = NdArray;
fn main() {
println!("Flash Attention v3 - Basic Usage Example");
println!("=========================================\n");
let device = Default::default();
let batch_size = 2;
let num_heads = 8;
let seq_len = 128;
let head_dim = 64;
println!("Configuration:");
println!(" Batch size: {}", batch_size);
println!(" Number of heads: {}", num_heads);
println!(" Sequence length: {}", seq_len);
println!(" Head dimension: {}\n", head_dim);
let query = Tensor::<Backend, 4>::random(
[batch_size, num_heads, seq_len, head_dim],
Distribution::Normal(0.0, 1.0),
&device,
);
let key = Tensor::<Backend, 4>::random(
[batch_size, num_heads, seq_len, head_dim],
Distribution::Normal(0.0, 1.0),
&device,
);
let value = Tensor::<Backend, 4>::random(
[batch_size, num_heads, seq_len, head_dim],
Distribution::Normal(0.0, 1.0),
&device,
);
println!("Running standard attention...");
let output = FlashAttentionV3::forward(
query.clone(),
key.clone(),
value.clone(),
None,
false,
);
println!("Output shape: {:?}\n", output.dims());
println!("Running causal attention...");
let output_causal = FlashAttentionV3::forward(
query.clone(),
key.clone(),
value.clone(),
None,
true,
);
println!("Causal output shape: {:?}\n", output_causal.dims());
println!("Example completed successfully!");
}