use scirs2_core::ndarray::{Array2, Array3};
use scirs2_linalg::attention::{AttentionConfig, AttentionMask};
use scirs2_linalg::batch::attention::{
batch_flash_attention, batch_multi_head_attention, batch_multi_query_attention,
};
#[allow(dead_code)]
fn main() {
println!("Batch Attention Mechanism Examples");
println!("==================================\n");
let batchsize = 4;
let seq_len = 8;
let d_model = 16;
let scale = 1.0 / (d_model as f32).sqrt();
println!("Example 1: Batch Multi-Query Attention");
let batch_query = Array3::<f32>::from_elem((batchsize, seq_len, d_model), 0.1);
let key = Array2::<f32>::from_elem((seq_len, d_model), 0.1);
let value = Array2::<f32>::from_elem((seq_len, d_model), 0.1);
let output =
batch_multi_query_attention(&batch_query.view(), &key.view(), &value.view(), None, scale)
.expect("Operation failed");
println!("Input batch query shape: {:?}", batch_query.shape());
println!("Shared key shape: {:?}", key.shape());
println!("Shared value shape: {:?}", value.shape());
println!("Output shape: {:?}", output.shape());
println!("First value: {:.6}", output[[0, 0, 0]]);
println!();
println!("Example 2: Batch Multi-Head Attention with Causal Masking");
let batch_query = Array3::<f32>::from_elem((batchsize, seq_len, d_model), 0.1);
let batch_key = Array3::<f32>::from_elem((batchsize, seq_len, d_model), 0.1);
let batch_value = Array3::<f32>::from_elem((batchsize, seq_len, d_model), 0.1);
let num_heads = 4;
let head_dim = d_model / num_heads;
let wq = Array2::<f32>::from_elem((d_model, d_model), 0.1);
let wk = Array2::<f32>::from_elem((d_model, d_model), 0.1);
let wv = Array2::<f32>::from_elem((d_model, d_model), 0.1);
let wo = Array2::<f32>::from_elem((d_model, d_model), 0.1);
let mask = AttentionMask::Causal;
let config = AttentionConfig {
num_heads,
head_dim,
dropout_prob: 0.0,
causal: true,
scale: Some(1.0 / (head_dim as f32).sqrt()),
};
let output = batch_multi_head_attention(
&batch_query.view(),
&batch_key.view(),
&batch_value.view(),
&wq.view(),
&wk.view(),
&wv.view(),
&wo.view(),
Some(&mask),
&config,
)
.expect("Operation failed");
println!(
"Input shapes - batchsize: {}, seq_len: {}, d_model: {}",
batchsize, seq_len, d_model
);
println!(
"Multi-head config - heads: {}, head_dim: {}",
num_heads, head_dim
);
println!("Output shape: {:?}", output.shape());
println!("First value: {:.6}", output[[0, 0, 0]]);
println!();
println!("Example 3: Batch Flash Attention");
let batch_query = Array3::<f32>::from_elem((batchsize, seq_len, d_model), 0.1);
let batch_key = Array3::<f32>::from_elem((batchsize, seq_len, d_model), 0.1);
let batch_value = Array3::<f32>::from_elem((batchsize, seq_len, d_model), 0.1);
let blocksize = 4;
let output = batch_flash_attention(
&batch_query.view(),
&batch_key.view(),
&batch_value.view(),
None,
scale,
blocksize,
)
.expect("Operation failed");
println!(
"Input shapes - batchsize: {}, seq_len: {}, d_model: {}",
batchsize, seq_len, d_model
);
println!("Block size: {}", blocksize);
println!("Output shape: {:?}", output.shape());
println!("First value: {:.6}", output[[0, 0, 0]]);
println!();
println!("Performance Comparison and Use Cases");
println!("-----------------------------------");
println!("1. Batch Multi-Query Attention:");
println!(" - Best for: Decoder-only architectures (e.g., GPT models)");
println!(" - Memory efficiency: Good (shared key/value matrices)");
println!(" - Use when: Different sequences attend to the same context");
println!();
println!("2. Batch Multi-Head Attention:");
println!(" - Best for: Standard transformer architectures");
println!(" - Flexibility: Highest (supports all attention patterns)");
println!(" - Use when: Full transformer functionality is needed");
println!();
println!("3. Batch Flash Attention:");
println!(" - Best for: Very long sequences");
println!(" - Memory efficiency: Excellent (O(N) memory usage)");
println!(" - Use when: Memory is constrained or sequences are long");
println!();
println!("All batch operations support parallel processing across multiple sequences,");
println!("making them ideal for high-throughput machine learning workloads.");
}