use std::collections::HashMap;
use std::hint::black_box;
use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main};
use mathlex::{BinaryOp, Expression, UnaryOp};
use mathlex_eval::{EvalInput, NumericResult, compile, eval};
fn ast_linear() -> Expression {
Expression::Binary {
op: BinaryOp::Add,
left: Box::new(Expression::Binary {
op: BinaryOp::Mul,
left: Box::new(Expression::Integer(2)),
right: Box::new(Expression::Variable("x".into())),
}),
right: Box::new(Expression::Integer(3)),
}
}
fn ast_trig_complex() -> Expression {
let x_sq = Expression::Binary {
op: BinaryOp::Pow,
left: Box::new(Expression::Variable("x".into())),
right: Box::new(Expression::Integer(2)),
};
let cos_x = Expression::Function {
name: "cos".into(),
args: vec![Expression::Variable("x".into())],
};
let inner = Expression::Binary {
op: BinaryOp::Add,
left: Box::new(x_sq),
right: Box::new(cos_x),
};
Expression::Function {
name: "sin".into(),
args: vec![inner],
}
}
fn ast_quadratic() -> Expression {
Expression::Binary {
op: BinaryOp::Add,
left: Box::new(Expression::Binary {
op: BinaryOp::Pow,
left: Box::new(Expression::Variable("x".into())),
right: Box::new(Expression::Integer(2)),
}),
right: Box::new(Expression::Integer(1)),
}
}
fn ast_grid() -> Expression {
Expression::Binary {
op: BinaryOp::Add,
left: Box::new(Expression::Binary {
op: BinaryOp::Pow,
left: Box::new(Expression::Variable("x".into())),
right: Box::new(Expression::Integer(2)),
}),
right: Box::new(Expression::Variable("y".into())),
}
}
fn ast_negation() -> Expression {
Expression::Unary {
op: UnaryOp::Neg,
operand: Box::new(Expression::Variable("x".into())),
}
}
fn linspace(n: usize) -> Vec<f64> {
(0..n).map(|i| i as f64 / n as f64).collect()
}
fn bench_scalar_eval(c: &mut Criterion) {
let consts: HashMap<&str, NumericResult> = HashMap::new();
let compiled_linear = compile(&ast_linear(), &consts).expect("compile linear");
let compiled_trig = compile(&ast_trig_complex(), &consts).expect("compile trig");
let mut group = c.benchmark_group("scalar_eval");
group.bench_function("linear_2x_plus_3", |b| {
b.iter(|| {
let mut args = HashMap::new();
args.insert("x", EvalInput::Scalar(black_box(1.5)));
let handle = eval(black_box(&compiled_linear), args).expect("eval");
black_box(handle.scalar().expect("scalar"))
});
});
group.bench_function("complex_sin_x2_plus_cos_x", |b| {
b.iter(|| {
let mut args = HashMap::new();
args.insert("x", EvalInput::Scalar(black_box(0.7)));
let handle = eval(black_box(&compiled_trig), args).expect("eval");
black_box(handle.scalar().expect("scalar"))
});
});
group.finish();
}
fn bench_compile_vs_eval(c: &mut Criterion) {
let consts: HashMap<&str, NumericResult> = HashMap::new();
let compiled_trig = compile(&ast_trig_complex(), &consts).expect("compile");
let mut group = c.benchmark_group("compile_vs_eval");
group.bench_function("compile_sin_x2_plus_cos_x", |b| {
b.iter(|| {
let ast = black_box(ast_trig_complex());
black_box(compile(&ast, &consts).expect("compile"))
});
});
group.bench_function("compile_negation", |b| {
b.iter(|| {
let ast = black_box(ast_negation());
black_box(compile(&ast, &consts).expect("compile"))
});
});
group.bench_function("eval_only_sin_x2_plus_cos_x", |b| {
b.iter(|| {
let mut args = HashMap::new();
args.insert("x", EvalInput::Scalar(black_box(0.7)));
let handle = eval(black_box(&compiled_trig), args).expect("eval");
black_box(handle.scalar().expect("scalar"))
});
});
group.finish();
}
fn bench_broadcast_1d(c: &mut Criterion) {
let consts: HashMap<&str, NumericResult> = HashMap::new();
let compiled = compile(&ast_quadratic(), &consts).expect("compile quadratic");
let sizes: &[usize] = &[10, 100, 1_000, 10_000, 100_000];
let mut group = c.benchmark_group("broadcast_1d");
for &n in sizes {
let xs = linspace(n);
group.bench_with_input(BenchmarkId::from_parameter(n), &xs, |b, xs| {
b.iter(|| {
let mut args = HashMap::new();
args.insert("x", EvalInput::from(black_box(xs.clone())));
let handle = eval(black_box(&compiled), args).expect("eval");
black_box(handle.to_array().expect("to_array"))
});
});
}
group.finish();
}
fn bench_broadcast_grid(c: &mut Criterion) {
let consts: HashMap<&str, NumericResult> = HashMap::new();
let compiled = compile(&ast_grid(), &consts).expect("compile grid");
let grids: &[(usize, usize)] = &[(10, 10), (100, 100)];
let mut group = c.benchmark_group("broadcast_grid");
for &(nx, ny) in grids {
let xs = linspace(nx);
let ys = linspace(ny);
let label = format!("{nx}x{ny}");
group.bench_with_input(
BenchmarkId::new("grid", &label),
&(xs, ys),
|b, (xs, ys)| {
b.iter(|| {
let mut args = HashMap::new();
args.insert("x", EvalInput::from(black_box(xs.clone())));
args.insert("y", EvalInput::from(black_box(ys.clone())));
let handle = eval(black_box(&compiled), args).expect("eval");
black_box(handle.to_array().expect("to_array"))
});
},
);
}
group.finish();
}
criterion_group!(
benches,
bench_scalar_eval,
bench_compile_vs_eval,
bench_broadcast_1d,
bench_broadcast_grid,
);
criterion_main!(benches);