burn_attention 0.1.0

Flash Attention v3 implementation for Burn framework
Documentation
use burn::backend::NdArray;
use burn::tensor::{Distribution, Tensor};
use burn_attention::FlashAttentionV3;

type Backend = NdArray;

fn main() {
    println!("Flash Attention v3 - Basic Usage Example");
    println!("=========================================\n");

    let device = Default::default();

    // Configuration
    let batch_size = 2;
    let num_heads = 8;
    let seq_len = 128;
    let head_dim = 64;

    println!("Configuration:");
    println!("  Batch size: {}", batch_size);
    println!("  Number of heads: {}", num_heads);
    println!("  Sequence length: {}", seq_len);
    println!("  Head dimension: {}\n", head_dim);

    // Create random input tensors
    let query = Tensor::<Backend, 4>::random(
        [batch_size, num_heads, seq_len, head_dim],
        Distribution::Normal(0.0, 1.0),
        &device,
    );
    let key = Tensor::<Backend, 4>::random(
        [batch_size, num_heads, seq_len, head_dim],
        Distribution::Normal(0.0, 1.0),
        &device,
    );
    let value = Tensor::<Backend, 4>::random(
        [batch_size, num_heads, seq_len, head_dim],
        Distribution::Normal(0.0, 1.0),
        &device,
    );

    println!("Running standard attention...");
    let output = FlashAttentionV3::forward(
        query.clone(),
        key.clone(),
        value.clone(),
        None,
        false,
    );
    println!("Output shape: {:?}\n", output.dims());

    println!("Running causal attention...");
    let output_causal = FlashAttentionV3::forward(
        query.clone(),
        key.clone(),
        value.clone(),
        None,
        true,
    );
    println!("Causal output shape: {:?}\n", output_causal.dims());

    println!("Example completed successfully!");
}