use std::array;
use std::cell::Cell;
use std::fmt::{self, Display};
use std::hint::black_box;
use std::num::NonZeroUsize;
use std::time::Instant;
use criterion::{BatchSize, BenchmarkGroup, Criterion, measurement::WallTime};
use pastey::paste;
use la_stack::{Matrix, Vector};
const RANDOM_INPUTS_PER_DIM: SampleCount = SampleCount::new_unchecked(50);
const RANDOM_INPUT_ARRAY_LEN: usize = RANDOM_INPUTS_PER_DIM.get();
const RANDOM_TIMING_PASSES: SampleCount = SampleCount::new_unchecked(5);
const RANDOM_SEED: [u8; 32] = [0; 32];
const RANDOM_PERCENTILES: [RandomPercentile; 3] = [
RandomPercentile::P50,
RandomPercentile::P95,
RandomPercentile::P99,
];
fn require_ok<T, E: Display>(result: Result<T, E>, operation: &str) -> T {
match result {
Ok(value) => value,
Err(err) => panic!("{operation} failed: {err}"),
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum ExactBenchConfigError {
EmptyCorpus,
UnorderedRange { min: i16, max: i16 },
}
impl Display for ExactBenchConfigError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
Self::EmptyCorpus => f.write_str("random input corpus must be nonempty"),
Self::UnorderedRange { min, max } => {
write!(f, "random integer range must be ordered: {min}..={max}")
}
}
}
}
#[derive(Clone, Copy)]
struct SampleCount {
len: NonZeroUsize,
}
impl SampleCount {
const fn new_unchecked(len: usize) -> Self {
match NonZeroUsize::new(len) {
Some(len) => Self { len },
None => panic!("random input corpus must be nonempty"),
}
}
const fn new(len: usize) -> Result<Self, ExactBenchConfigError> {
if let Some(len) = NonZeroUsize::new(len) {
Ok(Self { len })
} else {
Err(ExactBenchConfigError::EmptyCorpus)
}
}
const fn get(self) -> usize {
self.len.get()
}
}
#[derive(Clone, Copy)]
struct I16Range {
min: i16,
width: u64,
}
impl I16Range {
fn new(min: i16, max: i16) -> Result<Self, ExactBenchConfigError> {
if min > max {
return Err(ExactBenchConfigError::UnorderedRange { min, max });
}
let width = i32::from(max) - i32::from(min) + 1;
Ok(Self {
min,
width: u64::try_from(width)
.map_err(|_| ExactBenchConfigError::UnorderedRange { min, max })?,
})
}
}
#[derive(Clone, Copy)]
enum RandomPercentile {
P50,
P95,
P99,
}
impl RandomPercentile {
const fn value(self) -> usize {
match self {
Self::P50 => 50,
Self::P95 => 95,
Self::P99 => 99,
}
}
const fn name(self) -> &'static str {
match self {
Self::P50 => "p50",
Self::P95 => "p95",
Self::P99 => "p99",
}
}
}
#[inline]
#[allow(clippy::cast_precision_loss)]
const fn matrix_entry<const D: usize>(r: usize, c: usize) -> f64 {
if r == c {
(r as f64).mul_add(1.0e-3, (D as f64) + 1.0)
} else {
0.1 / ((r + c + 1) as f64)
}
}
#[inline]
const fn make_matrix_rows<const D: usize>() -> [[f64; D]; D] {
let mut rows = [[0.0; D]; D];
let mut r = 0;
while r < D {
let mut c = 0;
while c < D {
rows[r][c] = matrix_entry::<D>(r, c);
c += 1;
}
r += 1;
}
rows
}
#[inline]
#[allow(clippy::cast_precision_loss)]
fn make_vector_array<const D: usize>() -> [f64; D] {
let mut data = [0.0; D];
let mut i = 0;
while i < D {
data[i] = (i as f64) + 1.0;
i += 1;
}
data
}
#[derive(Clone, Copy)]
struct ExactRandomInput<const D: usize> {
matrix: Matrix<D>,
rhs: Vector<D>,
}
#[derive(Clone, Copy)]
enum ExactRandomOperation {
DetSignExact,
DetExact,
SolveExact,
SolveExactF64,
}
impl ExactRandomOperation {
const fn name(self) -> &'static str {
match self {
Self::DetSignExact => "det_sign_exact",
Self::DetExact => "det_exact",
Self::SolveExact => "solve_exact",
Self::SolveExactF64 => "solve_exact_f64",
}
}
}
struct SplitMix64 {
state: u64,
}
impl SplitMix64 {
const fn new(state: u64) -> Self {
Self { state }
}
const fn next_u64(&mut self) -> u64 {
self.state = self.state.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut z = self.state;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z ^ (z >> 31)
}
#[allow(clippy::cast_possible_truncation)]
fn next_i16(&mut self, range: I16Range) -> i16 {
let offset = (self.next_u64() % range.width) as i32;
let value = i32::from(range.min) + offset;
value as i16
}
}
#[allow(clippy::cast_possible_truncation)]
fn random_seed_for_dim<const D: usize>() -> u64 {
let mut seed =
0xC0DE_CAFE_D15C_A11Au64 ^ require_ok(u64::try_from(D), "dimension seed conversion");
for (i, byte) in RANDOM_SEED.iter().copied().enumerate() {
let shift = require_ok(u32::try_from((i % 8) * 8), "seed shift conversion");
seed ^= u64::from(byte) << shift;
seed = seed.rotate_left(7) ^ require_ok(u64::try_from(i), "seed index conversion");
}
seed
}
fn make_random_input_corpus<const D: usize>() -> [ExactRandomInput<D>; RANDOM_INPUT_ARRAY_LEN] {
let mut rng = SplitMix64::new(random_seed_for_dim::<D>());
let entry_range = require_ok(I16Range::new(-10, 10), "random integer range");
array::from_fn(|_| {
let mut rows = [[0.0; D]; D];
let mut diag = [0_i16; D];
for (r, row) in rows.iter_mut().enumerate() {
for (c, entry) in row.iter_mut().enumerate() {
if r == c {
diag[r] = rng.next_i16(entry_range);
} else {
*entry = f64::from(rng.next_i16(entry_range));
}
}
}
let shift =
f64::from(require_ok(u8::try_from(D), "dimension shift conversion")).mul_add(10.0, 1.0);
for (i, row) in rows.iter_mut().enumerate() {
row[i] = if diag[i] >= 0 {
f64::from(diag[i]) + shift
} else {
f64::from(diag[i]) - shift
};
}
let rhs = array::from_fn(|_| f64::from(rng.next_i16(entry_range)));
ExactRandomInput {
matrix: require_ok(
Matrix::<D>::try_from_rows(rows),
"random matrix construction",
),
rhs: require_ok(Vector::<D>::try_new(rhs), "random RHS vector construction"),
}
})
}
fn run_random_operation<const D: usize>(
operation: ExactRandomOperation,
input: ExactRandomInput<D>,
) {
match operation {
ExactRandomOperation::DetSignExact => {
let sign = require_ok(
black_box(input.matrix).det_sign_exact(),
"exact determinant sign",
);
black_box(sign);
}
ExactRandomOperation::DetExact => {
let det = require_ok(black_box(input.matrix).det_exact(), "exact determinant");
black_box(det);
}
ExactRandomOperation::SolveExact => {
let x = require_ok(
black_box(input.matrix).solve_exact(black_box(input.rhs)),
"exact linear solve",
);
let _ = black_box(x);
}
ExactRandomOperation::SolveExactF64 => {
let x = require_ok(
black_box(input.matrix).solve_exact_f64(black_box(input.rhs)),
"exact linear solve converted to f64",
);
let _ = black_box(x);
}
}
}
fn time_random_operation<const D: usize>(
operation: ExactRandomOperation,
input: ExactRandomInput<D>,
) -> u128 {
let start = Instant::now();
run_random_operation(operation, input);
start.elapsed().as_nanos()
}
fn time_random_operation_repeated<const D: usize>(
operation: ExactRandomOperation,
input: ExactRandomInput<D>,
) -> u128 {
let mut elapsed = 0;
for _ in 0..RANDOM_TIMING_PASSES.get() {
elapsed += time_random_operation(operation, input);
}
elapsed
}
const fn percentile_index(count: SampleCount, percentile: RandomPercentile) -> usize {
((count.get() - 1) * percentile.value() + 50) / 100
}
fn percentile_input_indices<const D: usize>(
corpus: &[ExactRandomInput<D>; RANDOM_INPUT_ARRAY_LEN],
operation: ExactRandomOperation,
) -> [Vec<usize>; RANDOM_PERCENTILES.len()] {
let input_count = require_ok(SampleCount::new(corpus.len()), "random input corpus size");
let mut timings = [(0_u128, 0_usize); RANDOM_INPUT_ARRAY_LEN];
for (i, input) in corpus.iter().copied().enumerate() {
timings[i] = (time_random_operation_repeated(operation, input), i);
}
timings.sort_unstable();
RANDOM_PERCENTILES.map(|percentile| {
let timing_idx = percentile_index(input_count, percentile);
let threshold = timings[timing_idx].0;
let mut indices = Vec::new();
for &(elapsed, input_idx) in &timings {
if elapsed <= threshold {
indices.push(input_idx);
}
}
indices
})
}
fn bench_random_percentile_operation<const D: usize>(
group: &mut BenchmarkGroup<'_, WallTime>,
corpus: &[ExactRandomInput<D>; RANDOM_INPUT_ARRAY_LEN],
operation: ExactRandomOperation,
) {
let index_sets = percentile_input_indices(corpus, operation);
for (percentile, input_indices) in RANDOM_PERCENTILES.into_iter().zip(index_sets) {
let input_count = require_ok(
SampleCount::new(input_indices.len()),
"percentile input set size",
);
let cursor = Cell::new(0);
group.bench_function(
format!("{}_{}", operation.name(), percentile.name()),
move |bencher| {
bencher.iter_batched(
|| {
let cursor_pos = cursor.get();
cursor.set((cursor_pos + 1) % input_count.get());
corpus[input_indices[cursor_pos]]
},
|sample| run_random_operation(operation, sample),
BatchSize::SmallInput,
);
},
);
}
}
#[inline]
fn near_singular_3x3() -> Matrix<3> {
let perturbation = f64::from_bits(0x3CD0_0000_0000_0000); require_ok(
Matrix::<3>::try_from_rows([
[1.0 + perturbation, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0],
]),
"near-singular matrix construction",
)
}
#[inline]
fn large_entries_3x3() -> Matrix<3> {
let big = f64::MAX / 2.0;
require_ok(
Matrix::<3>::try_from_rows([[big, 1.0, 1.0], [1.0, big, 1.0], [1.0, 1.0, big]]),
"large-entry matrix construction",
)
}
#[inline]
#[allow(clippy::cast_precision_loss)]
fn hilbert<const D: usize>() -> Matrix<D> {
let mut rows = [[0.0; D]; D];
let mut r = 0;
while r < D {
let mut c = 0;
while c < D {
rows[r][c] = 1.0 / ((r + c + 1) as f64);
c += 1;
}
r += 1;
}
require_ok(
Matrix::<D>::try_from_rows(rows),
"Hilbert matrix construction",
)
}
fn bench_extreme_group<const D: usize>(
group: &mut BenchmarkGroup<'_, WallTime>,
m: Matrix<D>,
rhs: Vector<D>,
) {
group.bench_function("det_sign_exact", |bencher| {
bencher.iter(|| {
let sign = require_ok(black_box(m).det_sign_exact(), "exact determinant sign");
black_box(sign);
});
});
group.bench_function("det_exact", |bencher| {
bencher.iter(|| {
let det = require_ok(black_box(m).det_exact(), "exact determinant");
black_box(det);
});
});
group.bench_function("solve_exact", |bencher| {
bencher.iter(|| {
let x = require_ok(
black_box(m).solve_exact(black_box(rhs)),
"exact linear solve",
);
let _ = black_box(x);
});
});
group.bench_function("solve_exact_f64", |bencher| {
bencher.iter(|| {
let x = require_ok(
black_box(m).solve_exact_f64(black_box(rhs)),
"exact linear solve converted to f64",
);
let _ = black_box(x);
});
});
}
macro_rules! gen_exact_benches_for_dim {
($c:expr, $d:literal) => {
paste! {{
let a = require_ok(
Matrix::<$d>::try_from_rows(make_matrix_rows::<$d>()),
"benchmark matrix construction",
);
let rhs = require_ok(
Vector::<$d>::try_new(make_vector_array::<$d>()),
"benchmark RHS vector construction",
);
let mut [<group_d $d>] = ($c).benchmark_group(concat!("exact_d", stringify!($d)));
[<group_d $d>].bench_function("det", |bencher| {
bencher.iter(|| {
let det = require_ok(black_box(a).det(), "f64 determinant");
black_box(det);
});
});
[<group_d $d>].bench_function("det_direct", |bencher| {
bencher.iter(|| {
let det = black_box(a).det_direct();
black_box(det);
});
});
[<group_d $d>].bench_function("det_exact", |bencher| {
bencher.iter(|| {
let det = require_ok(black_box(a).det_exact(), "exact determinant");
black_box(det);
});
});
[<group_d $d>].bench_function("det_exact_f64", |bencher| {
bencher.iter(|| {
let det = require_ok(
black_box(a).det_exact_f64(),
"exact determinant converted to f64",
);
black_box(det);
});
});
[<group_d $d>].bench_function("det_sign_exact", |bencher| {
bencher.iter(|| {
let sign = require_ok(black_box(a).det_sign_exact(), "exact determinant sign");
black_box(sign);
});
});
[<group_d $d>].bench_function("solve_exact", |bencher| {
bencher.iter(|| {
let x = require_ok(
black_box(a).solve_exact(black_box(rhs)),
"exact linear solve",
);
black_box(x);
});
});
[<group_d $d>].bench_function("solve_exact_f64", |bencher| {
bencher.iter(|| {
let x = require_ok(
black_box(a).solve_exact_f64(black_box(rhs)),
"exact linear solve converted to f64",
);
black_box(x);
});
});
[<group_d $d>].finish();
}};
};
}
macro_rules! gen_random_percentile_benches_for_dim {
($c:expr, $d:literal) => {
paste! {{
let corpus = make_random_input_corpus::<$d>();
let mut [<group_random_percentile_d $d>] =
($c).benchmark_group(concat!("exact_random_percentile_d", stringify!($d)));
bench_random_percentile_operation(
&mut [<group_random_percentile_d $d>],
&corpus,
ExactRandomOperation::DetSignExact,
);
bench_random_percentile_operation(
&mut [<group_random_percentile_d $d>],
&corpus,
ExactRandomOperation::DetExact,
);
bench_random_percentile_operation(
&mut [<group_random_percentile_d $d>],
&corpus,
ExactRandomOperation::SolveExact,
);
bench_random_percentile_operation(
&mut [<group_random_percentile_d $d>],
&corpus,
ExactRandomOperation::SolveExactF64,
);
[<group_random_percentile_d $d>].finish();
}};
};
}
fn main() {
let mut c = Criterion::default().configure_from_args();
#[allow(unused_must_use)]
{
gen_exact_benches_for_dim!(&mut c, 2);
gen_exact_benches_for_dim!(&mut c, 3);
gen_exact_benches_for_dim!(&mut c, 4);
gen_exact_benches_for_dim!(&mut c, 5);
}
#[allow(unused_must_use)]
{
gen_random_percentile_benches_for_dim!(&mut c, 2);
gen_random_percentile_benches_for_dim!(&mut c, 3);
gen_random_percentile_benches_for_dim!(&mut c, 4);
gen_random_percentile_benches_for_dim!(&mut c, 5);
}
{
let mut group = c.benchmark_group("exact_near_singular_3x3");
bench_extreme_group(
&mut group,
near_singular_3x3(),
require_ok(
Vector::<3>::try_new([1.0, 2.0, 3.0]),
"near-singular RHS vector construction",
),
);
group.finish();
}
{
let mut group = c.benchmark_group("exact_large_entries_3x3");
bench_extreme_group(
&mut group,
large_entries_3x3(),
require_ok(
Vector::<3>::try_new([1.0, 1.0, 1.0]),
"large-entry RHS vector construction",
),
);
group.finish();
}
{
let mut group = c.benchmark_group("exact_hilbert_4x4");
bench_extreme_group(
&mut group,
hilbert::<4>(),
require_ok(
Vector::<4>::try_new([1.0; 4]),
"Hilbert RHS vector construction",
),
);
group.finish();
}
{
let mut group = c.benchmark_group("exact_hilbert_5x5");
bench_extreme_group(
&mut group,
hilbert::<5>(),
require_ok(
Vector::<5>::try_new([1.0; 5]),
"Hilbert RHS vector construction",
),
);
group.finish();
}
c.final_summary();
}