use std::iter::Sum;
use arrow_arith::{aggregate::sum, arithmetic::multiply};
use arrow_array::{
types::{Float16Type, Float32Type, Float64Type},
ArrowNumericType, Float16Array, Float32Array, NativeAdapter, PrimitiveArray,
};
use criterion::{criterion_group, criterion_main, Criterion};
use num_traits::{real::Real, FromPrimitive};
#[cfg(target_os = "linux")]
use pprof::criterion::{Output, PProfProfiler};
use lance::linalg::dot::{dot, Dot};
use lance::utils::testing::generate_random_array_with_seed;
#[inline]
fn dot_arrow_artiy<T: ArrowNumericType>(x: &PrimitiveArray<T>, y: &PrimitiveArray<T>) -> T::Native {
let m = multiply(x, y).unwrap();
sum(&m).unwrap()
}
fn run_bench<T: ArrowNumericType>(c: &mut Criterion)
where
T::Native: Real + FromPrimitive + Sum,
NativeAdapter<T>: From<T::Native>,
{
const DIMENSION: usize = 1024;
const TOTAL: usize = 1024 * 1024;
let key: PrimitiveArray<T> = generate_random_array_with_seed(DIMENSION, [0; 32]);
let target: PrimitiveArray<T> = generate_random_array_with_seed(TOTAL * DIMENSION, [42; 32]);
let type_name = std::any::type_name::<T::Native>();
c.bench_function(format!("Dot({type_name}, arrow_artiy)").as_str(), |b| {
b.iter(|| unsafe {
PrimitiveArray::<T>::from_trusted_len_iter((0..target.len() / DIMENSION).map(|idx| {
let arr = target.slice(idx * DIMENSION, DIMENSION);
Some(dot_arrow_artiy(&key, &arr))
}))
});
});
c.bench_function(format!("Dot({type_name})").as_str(), |b| {
let x = key.values();
b.iter(|| unsafe {
PrimitiveArray::<T>::from_trusted_len_iter((0..target.len() / DIMENSION).map(|idx| {
let y = target.values()[idx * DIMENSION..(idx + 1) * DIMENSION].as_ref();
Some(dot(x, y))
}))
});
});
}
fn bench_distance(c: &mut Criterion) {
const DIMENSION: usize = 1024;
const TOTAL: usize = 1024 * 1024;
run_bench::<Float16Type>(c);
c.bench_function("Dot(f16, SIMD)", |b| {
let key: Float16Array = generate_random_array_with_seed(DIMENSION, [0; 32]);
let target: Float16Array = generate_random_array_with_seed(TOTAL * DIMENSION, [42; 32]);
b.iter(|| unsafe {
let x = key.values().as_ref();
Float16Array::from_trusted_len_iter((0..target.len() / DIMENSION).map(|idx| {
let y = target.values()[idx * DIMENSION..(idx + 1) * DIMENSION].as_ref();
Some(x.dot(y))
}))
});
});
run_bench::<Float32Type>(c);
c.bench_function("Dot(f32, SIMD)", |b| {
let key: Float32Array = generate_random_array_with_seed(DIMENSION, [0; 32]);
let target: Float32Array = generate_random_array_with_seed(TOTAL * DIMENSION, [42; 32]);
b.iter(|| unsafe {
let x = key.values().as_ref();
Float32Array::from_trusted_len_iter((0..target.len() / DIMENSION).map(|idx| {
let y = target.values()[idx * DIMENSION..(idx + 1) * DIMENSION].as_ref();
Some(x.dot(y))
}))
});
});
run_bench::<Float64Type>(c);
}
#[cfg(target_os = "linux")]
criterion_group!(
name=benches;
config = Criterion::default().significance_level(0.1).sample_size(10)
.with_profiler(PProfProfiler::new(100, Output::Flamegraph(None)));
targets = bench_distance);
#[cfg(not(target_os = "linux"))]
criterion_group!(
name=benches;
config = Criterion::default().significance_level(0.1).sample_size(10);
targets = bench_distance);
criterion_main!(benches);