const KEY: &str = "TREE";
use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main};
use ndarray::{Array, Dim};
use std::env;
use svod_schedule::{HeuristicsConfig, OptStrategy, OptimizerConfig};
use svod_tensor::{PrepareConfig, Tensor};
fn create_matrix(rows: usize, cols: usize) -> Tensor {
let data: Vec<f32> = (0..rows * cols).map(|i| i as f32 * 0.01).collect();
Tensor::from_slice(&data).try_reshape([rows as isize, cols as isize]).expect("reshape should succeed")
}
fn create_ndarray(rows: usize, cols: usize) -> ndarray::Array<f32, Dim<[usize; 2]>> {
let data: Vec<f32> = (0..rows * cols).map(|i| i as f32 * 0.01).collect();
Array::from_shape_vec((rows, cols), data).expect("array from vec should succeed")
}
fn matmul_flops(m: usize, k: usize, n: usize) -> u64 {
2 * (m as u64) * (k as u64) * (n as u64)
}
fn print_tree(config: &str, size: usize, plan: &svod_runtime::ExecutionPlan, result: &Tensor) {
if env::var(KEY).is_ok() {
eprintln!("\n=== {config} (size={size}) ===");
eprintln!("Kernel count: {}", plan.kernels().count());
eprintln!("UOp tree:\n{}", result.uop().tree());
for (i, kernel) in plan.prepared_kernels().iter().enumerate() {
eprintln!("UOp tree:\n{}", kernel.ast.tree());
eprintln!(" Kernel {}: {}", i, kernel.kernel.entry_point);
eprintln!("{}", kernel.kernel.code);
}
}
}
fn bench_matmul(c: &mut Criterion) {
let mut group = c.benchmark_group("matmul_optimization");
let heuristic_config: PrepareConfig = OptimizerConfig::builder()
.strategy(OptStrategy::Heuristic)
.heuristics(HeuristicsConfig::builder().build())
.build()
.into();
const BEAM_WIDTH: usize = 4;
let beam_config: PrepareConfig =
OptimizerConfig::builder().strategy(OptStrategy::Beam { width: BEAM_WIDTH }).build().into();
for size in [256, 512, 1024] {
let flops = matmul_flops(size, size, size);
group.throughput(Throughput::Elements(flops));
{
let a = create_matrix(size, size);
let b = create_matrix(size, size);
let mut result_h = a.matmul(&b).expect("matmul should succeed");
let plan_h = result_h.prepare_with(&heuristic_config).expect("prepare should succeed");
print_tree("HEURISTIC", size, &plan_h, &result_h);
group.bench_with_input(BenchmarkId::new("heuristic", size), &plan_h, |bencher, plan_h| {
bencher.iter(|| plan_h.execute().expect("execute should succeed"));
});
let mut result_b = a.matmul(&b).expect("matmul should succeed");
let plan_b = result_b.prepare_with(&beam_config).expect("prepare should succeed");
print_tree("BEAM", size, &plan_b, &result_b);
group.bench_with_input(
BenchmarkId::new(format!("beam_w{BEAM_WIDTH}"), size),
&plan_b,
|bencher, plan_b| {
bencher.iter(|| plan_b.execute().expect("execute should succeed"));
},
);
let a = create_ndarray(size, size);
let b = create_ndarray(size, size);
group.bench_with_input(
BenchmarkId::new("ndarray multiplication".to_string(), size),
&(a, b),
|bencher, (a, b)| {
bencher.iter(|| a.dot(b));
},
);
}
}
group.finish();
}
criterion_group!(benches, bench_matmul);
criterion_main!(benches);