use xn::{Backend, Result, Tensor, cuda_backend::Device};
fn main() -> Result<()> {
println!("Initializing CUDA device...");
let device = Device::new(0)?;
println!("CUDA device initialized successfully!");
let a: Tensor<f32, Device> =
Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3], &device)?;
let b: Tensor<f32, Device> =
Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2], &device)?;
println!("A shape: {:?}", a.dims());
println!("B shape: {:?}", b.dims());
let c = a.matmul(&b)?;
println!("C = A @ B");
println!("C shape: {:?}", c.dims());
println!("C data: {:?}", c.to_vec()?);
run_mm_benchmark(32, 1, 11264, 2048, &device)?;
run_mm_benchmark(1, 32, 11264, 2048, &device)?;
run_attn_benchmarks(32, 32, 128, 128, 64, &device)?;
Ok(())
}
fn run_mm_benchmark(bs: usize, m: usize, n: usize, k: usize, device: &Device) -> Result<()> {
println!("\nBenchmarking matmul ({bs}x{m}x{k}) @ (1x{k}x{n})...");
let a_data: Vec<half::bf16> =
(0..bs * m * k).map(|i| half::bf16::from_f32((i % 127) as f32 * 0.01)).collect();
let b_data: Vec<half::bf16> =
(0..k * n).map(|i| half::bf16::from_f32((i % 113) as f32 * 0.01)).collect();
let a: Tensor<half::bf16, Device> = Tensor::from_vec(a_data, (bs, m, k), device)?;
let b: Tensor<half::bf16, Device> = Tensor::from_vec(b_data, (k, n), device)?;
let _warmup = a.matmul(&b)?;
device.synchronize()?;
let num_iters = 1000;
let start = std::time::Instant::now();
for _ in 0..num_iters {
let _c = a.matmul(&b)?;
}
device.synchronize()?;
let elapsed = start.elapsed();
let avg_us = elapsed.as_micros() as f64 / num_iters as f64;
let flops = 2.0 * bs as f64 * m as f64 * n as f64 * k as f64;
let tflops = flops * num_iters as f64 / elapsed.as_secs_f64() / 1e12;
println!("{num_iters} iters in {elapsed:.2?} ({avg_us:.1} us/iter, {tflops:.2} TFLOP/s)");
Ok(())
}
fn run_attn_benchmarks(
batch_size: usize,
num_heads: usize,
len_q: usize,
len_kv: usize,
head_dim: usize,
device: &Device,
) -> Result<()> {
use xn::cuda_backend;
let bs = batch_size * num_heads;
let q_numel = bs * len_q * head_dim;
let kv_numel = bs * len_kv * head_dim;
let q_data: Vec<half::bf16> =
(0..q_numel).map(|i| half::bf16::from_f32((i % 127) as f32 * 0.01)).collect();
let k_data: Vec<half::bf16> =
(0..kv_numel).map(|i| half::bf16::from_f32((i % 113) as f32 * 0.01)).collect();
let v_data: Vec<half::bf16> =
(0..kv_numel).map(|i| half::bf16::from_f32((i % 97) as f32 * 0.01)).collect();
let q: Tensor<half::bf16, Device> =
Tensor::from_vec(q_data.clone(), (batch_size, num_heads, len_q, head_dim), device)?;
let k: Tensor<half::bf16, Device> =
Tensor::from_vec(k_data.clone(), (batch_size, num_heads, len_kv, head_dim), device)?;
let v: Tensor<half::bf16, Device> =
Tensor::from_vec(v_data.clone(), (batch_size, num_heads, len_kv, head_dim), device)?;
let num_iters = 100;
let flops =
4.0 * batch_size as f64 * num_heads as f64 * len_q as f64 * len_kv as f64 * head_dim as f64;
const BLOCK_Q: usize = 64;
const BLOCK_KV: usize = 64;
let flash_attn_compatible = len_q.is_multiple_of(BLOCK_Q) && len_kv.is_multiple_of(BLOCK_KV);
let flash_dst = if flash_attn_compatible {
println!(
"\nBenchmarking flash-attn (bs={batch_size}, heads={num_heads}, \
q_len={len_q}, kv_len={len_kv}, head_dim={head_dim})..."
);
let q_3d: Tensor<half::bf16, Device> =
Tensor::from_vec(q_data, (bs, len_q, head_dim), device)?;
let k_3d: Tensor<half::bf16, Device> =
Tensor::from_vec(k_data, (bs, len_kv, head_dim), device)?;
let v_3d: Tensor<half::bf16, Device> =
Tensor::from_vec(v_data, (bs, len_kv, head_dim), device)?;
let flash_dst: Tensor<half::bf16, Device> =
Tensor::from_vec(vec![half::bf16::ZERO; q_numel], (bs, len_q, head_dim), device)?;
{
let q_s = q_3d.storage()?;
let k_s = k_3d.storage()?;
let v_s = v_3d.storage()?;
let mut dst_s = flash_dst.storage_mut()?;
cuda_backend::flash_attn(
&mut dst_s, &q_s, &k_s, &v_s, batch_size, num_heads, len_q, len_kv, head_dim,
)?;
}
device.synchronize()?;
let start = std::time::Instant::now();
for _ in 0..num_iters {
let q_s = q_3d.storage()?;
let k_s = k_3d.storage()?;
let v_s = v_3d.storage()?;
let mut dst_s = flash_dst.storage_mut()?;
cuda_backend::flash_attn(
&mut dst_s, &q_s, &k_s, &v_s, batch_size, num_heads, len_q, len_kv, head_dim,
)?;
}
device.synchronize()?;
let elapsed = start.elapsed();
let avg_us = elapsed.as_micros() as f64 / num_iters as f64;
let tflops = flops * num_iters as f64 / elapsed.as_secs_f64() / 1e12;
println!("{num_iters} iters in {elapsed:.2?} ({avg_us:.1} us/iter, {tflops:.2} TFLOP/s)");
Some(flash_dst)
} else {
println!(
"\nSkipping flash-attn: q_len={len_q} and kv_len={len_kv} must be \
multiples of {BLOCK_Q} and {BLOCK_KV}"
);
None
};
println!(
"\nBenchmarking matmul-attn (bs={batch_size}, heads={num_heads}, \
q_len={len_q}, kv_len={len_kv}, head_dim={head_dim})..."
);
let scale = half::bf16::from_f32(1.0 / (head_dim as f32).sqrt());
let matmul_dst = {
let scores = q.matmul_t(&k)?.scale(scale)?;
let weights = scores.softmax()?;
weights.matmul(&v)?
};
device.synchronize()?;
let start = std::time::Instant::now();
for _ in 0..num_iters {
let scores = q.matmul_t(&k)?.scale(scale)?;
let weights = scores.softmax()?;
let _out = weights.matmul(&v)?;
}
device.synchronize()?;
let elapsed = start.elapsed();
let avg_us = elapsed.as_micros() as f64 / num_iters as f64;
let tflops = flops * num_iters as f64 / elapsed.as_secs_f64() / 1e12;
println!("{num_iters} iters in {elapsed:.2?} ({avg_us:.1} us/iter, {tflops:.2} TFLOP/s)");
if let Some(flash_dst) = flash_dst {
let flash_vec = flash_dst.to_vec()?;
let matmul_vec = matmul_dst.to_vec()?;
let max_diff = flash_vec
.iter()
.zip(matmul_vec.iter())
.map(|(a, b): (&half::bf16, &half::bf16)| (a.to_f32() - b.to_f32()).abs())
.fold(0.0f32, f32::max);
println!("\nMax abs difference (flash vs matmul): {max_diff}");
}
Ok(())
}