use ariadnetor_core::backend::{ComputeBackend, ExecPolicy};
use ariadnetor_native::{NativeBackend, PerformanceManager, ThresholdTable};
fn pinned_backend(t: ThresholdTable) -> NativeBackend {
NativeBackend::with_perf(PerformanceManager::new(t))
}
fn all_pinned() -> ThresholdTable {
ThresholdTable {
svd: 10,
qr: 10,
lq: 10,
eigh: 10,
eig: 10,
gemm: 10,
solve: 10,
transpose: 24,
}
}
fn all_sentinel() -> ThresholdTable {
ThresholdTable {
svd: usize::MAX,
qr: usize::MAX,
lq: usize::MAX,
eigh: usize::MAX,
eig: usize::MAX,
gemm: usize::MAX,
solve: usize::MAX,
transpose: usize::MAX,
}
}
#[test]
fn par_for_svd_below_threshold_is_sequential() {
let b = pinned_backend(all_pinned());
assert_eq!(b.par_for_svd(8, 8), ExecPolicy::Sequential);
assert_eq!(b.par_for_svd(20, 4), ExecPolicy::Sequential);
assert_eq!(b.par_for_svd(4, 20), ExecPolicy::Sequential);
}
#[test]
fn par_for_svd_at_threshold_is_parallel() {
let b = pinned_backend(all_pinned());
assert_eq!(b.par_for_svd(10, 10), ExecPolicy::Parallel(0));
assert_eq!(b.par_for_svd(10, 20), ExecPolicy::Parallel(0));
assert_eq!(b.par_for_svd(20, 10), ExecPolicy::Parallel(0));
}
#[test]
fn par_for_qr_uses_cbrt_proxy() {
let b = pinned_backend(all_pinned());
assert_eq!(b.par_for_qr(20, 4), ExecPolicy::Sequential);
assert_eq!(b.par_for_qr(10, 10), ExecPolicy::Parallel(0));
assert_eq!(b.par_for_qr(1000, 9), ExecPolicy::Parallel(0));
}
#[test]
fn par_for_lq_uses_cbrt_proxy() {
let b = pinned_backend(all_pinned());
assert_eq!(b.par_for_lq(20, 4), ExecPolicy::Sequential);
assert_eq!(b.par_for_lq(10, 10), ExecPolicy::Parallel(0));
assert_eq!(b.par_for_lq(9, 1000), ExecPolicy::Parallel(0));
}
#[test]
fn par_for_eigh_at_threshold_flips() {
let b = pinned_backend(all_pinned());
assert_eq!(b.par_for_eigh(9), ExecPolicy::Sequential);
assert_eq!(b.par_for_eigh(10), ExecPolicy::Parallel(0));
}
#[test]
fn par_for_eig_at_threshold_flips() {
let b = pinned_backend(all_pinned());
assert_eq!(b.par_for_eig(9), ExecPolicy::Sequential);
assert_eq!(b.par_for_eig(10), ExecPolicy::Parallel(0));
}
#[test]
fn par_for_gemm_below_threshold_is_sequential() {
let b = pinned_backend(all_pinned());
assert_eq!(b.par_for_gemm(9, 9, 9), ExecPolicy::Sequential);
}
#[test]
fn par_for_gemm_at_threshold_is_parallel() {
let b = pinned_backend(all_pinned());
assert_eq!(b.par_for_gemm(10, 10, 10), ExecPolicy::Parallel(0));
}
#[test]
fn par_for_gemm_non_cubic_uses_geometric_mean() {
let b = pinned_backend(all_pinned());
assert_eq!(b.par_for_gemm(4, 5, 50), ExecPolicy::Parallel(0));
assert_eq!(b.par_for_gemm(4, 5, 40), ExecPolicy::Sequential);
}
#[test]
fn par_for_solve_keys_on_n_not_nrhs() {
let b = pinned_backend(all_pinned());
assert_eq!(b.par_for_solve(9, 10_000), ExecPolicy::Sequential);
assert_eq!(b.par_for_solve(10, 1), ExecPolicy::Parallel(0));
}
#[test]
fn par_for_transpose_uses_total_elements() {
let b = pinned_backend(all_pinned());
assert_eq!(b.par_for_transpose(&[2, 3, 3]), ExecPolicy::Sequential);
assert_eq!(b.par_for_transpose(&[2, 3, 4]), ExecPolicy::Parallel(0));
assert_eq!(b.par_for_transpose(&[]), ExecPolicy::Sequential);
}
#[test]
fn sentinel_thresholds_never_dispatch_parallel() {
let b = pinned_backend(all_sentinel());
assert_eq!(b.par_for_svd(10_000, 10_000), ExecPolicy::Sequential);
assert_eq!(b.par_for_qr(10_000, 10_000), ExecPolicy::Sequential);
assert_eq!(b.par_for_lq(10_000, 10_000), ExecPolicy::Sequential);
assert_eq!(b.par_for_eigh(10_000), ExecPolicy::Sequential);
assert_eq!(b.par_for_eig(10_000), ExecPolicy::Sequential);
assert_eq!(
b.par_for_gemm(10_000, 10_000, 10_000),
ExecPolicy::Sequential
);
assert_eq!(b.par_for_solve(10_000, 10_000), ExecPolicy::Sequential);
assert_eq!(
b.par_for_transpose(&[10_000, 10_000]),
ExecPolicy::Sequential
);
}
#[cfg(feature = "hptt")]
#[test]
fn laptop_profile_transpose_stays_sequential() {
let b = NativeBackend::with_perf(PerformanceManager::new(ThresholdTable::laptop()));
assert_eq!(
b.par_for_transpose(&[10_000, 10_000]),
ExecPolicy::Sequential
);
}
#[cfg(not(feature = "hptt"))]
#[test]
fn laptop_profile_transpose_naive_threshold() {
let b = NativeBackend::with_perf(PerformanceManager::new(ThresholdTable::laptop()));
assert_eq!(b.par_for_transpose(&[128, 128]), ExecPolicy::Sequential);
assert_eq!(b.par_for_transpose(&[256, 256]), ExecPolicy::Parallel(0));
}
#[test]
fn laptop_profile_calibrated_ops_dispatch_parallel_above_threshold() {
let b = NativeBackend::with_perf(PerformanceManager::new(ThresholdTable::laptop()));
assert_eq!(
b.par_for_gemm(10_000, 10_000, 10_000),
ExecPolicy::Parallel(0)
);
assert_eq!(b.par_for_solve(10_000, 10_000), ExecPolicy::Parallel(0));
assert_eq!(b.par_for_svd(10_000, 10_000), ExecPolicy::Parallel(0));
assert_eq!(b.par_for_qr(10_000, 10_000), ExecPolicy::Parallel(0));
assert_eq!(b.par_for_lq(10_000, 10_000), ExecPolicy::Parallel(0));
assert_eq!(b.par_for_eigh(10_000), ExecPolicy::Parallel(0));
assert_eq!(b.par_for_eig(10_000), ExecPolicy::Parallel(0));
}
#[test]
fn workstation_profile_decomp_and_solve_stay_sequential() {
let b = NativeBackend::with_perf(PerformanceManager::new(ThresholdTable::workstation()));
assert_eq!(b.par_for_svd(10_000, 10_000), ExecPolicy::Sequential);
assert_eq!(b.par_for_qr(10_000, 10_000), ExecPolicy::Sequential);
assert_eq!(b.par_for_lq(10_000, 10_000), ExecPolicy::Sequential);
assert_eq!(b.par_for_eigh(10_000), ExecPolicy::Sequential);
assert_eq!(b.par_for_eig(10_000), ExecPolicy::Sequential);
assert_eq!(b.par_for_solve(10_000, 10_000), ExecPolicy::Sequential);
}
#[test]
fn workstation_profile_gemm_dispatches_parallel_above_threshold() {
let b = NativeBackend::with_perf(PerformanceManager::new(ThresholdTable::workstation()));
assert_eq!(b.par_for_gemm(64, 64, 64), ExecPolicy::Sequential);
assert_eq!(
b.par_for_gemm(10_000, 10_000, 10_000),
ExecPolicy::Parallel(0)
);
}
#[cfg(feature = "hptt")]
#[test]
fn workstation_profile_transpose_hptt_threshold() {
let b = NativeBackend::with_perf(PerformanceManager::new(ThresholdTable::workstation()));
assert_eq!(b.par_for_transpose(&[1024, 1024]), ExecPolicy::Sequential);
assert_eq!(b.par_for_transpose(&[2048, 2048]), ExecPolicy::Parallel(0));
}
#[cfg(not(feature = "hptt"))]
#[test]
fn workstation_profile_transpose_naive_threshold() {
let b = NativeBackend::with_perf(PerformanceManager::new(ThresholdTable::workstation()));
assert_eq!(b.par_for_transpose(&[256, 256]), ExecPolicy::Sequential);
assert_eq!(b.par_for_transpose(&[512, 512]), ExecPolicy::Parallel(0));
}