use ariadnetor_core::backend::ExecPolicy;
#[derive(Clone, Debug)]
pub struct ThresholdTable {
pub svd: usize,
pub qr: usize,
pub lq: usize,
pub eigh: usize,
pub eig: usize,
pub gemm: usize,
pub solve: usize,
pub transpose: usize,
}
impl ThresholdTable {
pub fn laptop() -> Self {
Self {
svd: 384,
qr: 384,
lq: 512,
eigh: 256,
eig: 256,
gemm: 192,
solve: 768,
transpose: if cfg!(feature = "hptt") {
usize::MAX
} else {
65_536
},
}
}
pub fn workstation() -> Self {
Self {
svd: usize::MAX,
qr: usize::MAX,
lq: usize::MAX,
eigh: usize::MAX,
eig: usize::MAX,
gemm: 768,
solve: usize::MAX,
transpose: if cfg!(feature = "hptt") {
4_194_304
} else {
262_144
},
}
}
pub fn detect() -> Self {
let n = std::thread::available_parallelism()
.map(|v| v.get())
.unwrap_or(1);
Self::profile_for_parallelism(n)
}
fn profile_for_parallelism(n: usize) -> Self {
if n > 16 {
Self::workstation()
} else {
Self::laptop()
}
}
}
#[derive(Clone, Debug)]
pub struct PerformanceManager {
thresholds: ThresholdTable,
}
impl PerformanceManager {
pub fn new(thresholds: ThresholdTable) -> Self {
Self { thresholds }
}
pub fn thresholds(&self) -> &ThresholdTable {
&self.thresholds
}
pub(crate) fn policy_by_n(threshold: usize, n: usize) -> ExecPolicy {
if threshold != usize::MAX && n >= threshold {
ExecPolicy::Parallel(0)
} else {
ExecPolicy::Sequential
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn laptop_constants_pinned() {
let t = ThresholdTable::laptop();
assert_eq!(t.svd, 384);
assert_eq!(t.qr, 384);
assert_eq!(t.lq, 512);
assert_eq!(t.eigh, 256);
assert_eq!(t.eig, 256);
assert_eq!(t.gemm, 192);
assert_eq!(t.solve, 768);
#[cfg(feature = "hptt")]
assert_eq!(t.transpose, usize::MAX);
#[cfg(not(feature = "hptt"))]
assert_eq!(t.transpose, 65_536);
}
#[test]
fn workstation_constants_pinned() {
let t = ThresholdTable::workstation();
assert_eq!(t.svd, usize::MAX);
assert_eq!(t.qr, usize::MAX);
assert_eq!(t.lq, usize::MAX);
assert_eq!(t.eigh, usize::MAX);
assert_eq!(t.eig, usize::MAX);
assert_eq!(t.gemm, 768);
assert_eq!(t.solve, usize::MAX);
#[cfg(feature = "hptt")]
assert_eq!(t.transpose, 4_194_304);
#[cfg(not(feature = "hptt"))]
assert_eq!(t.transpose, 262_144);
}
#[test]
fn policy_by_n_below_threshold_is_sequential() {
assert_eq!(
PerformanceManager::policy_by_n(256, 255),
ExecPolicy::Sequential
);
}
#[test]
fn policy_by_n_at_threshold_is_parallel() {
assert_eq!(
PerformanceManager::policy_by_n(256, 256),
ExecPolicy::Parallel(0)
);
}
#[test]
fn policy_by_n_above_threshold_is_parallel() {
assert_eq!(
PerformanceManager::policy_by_n(256, 1024),
ExecPolicy::Parallel(0)
);
}
#[test]
fn profile_for_parallelism_pins_core_count_boundary() {
assert_eq!(
ThresholdTable::profile_for_parallelism(16).gemm,
ThresholdTable::laptop().gemm
);
assert_eq!(
ThresholdTable::profile_for_parallelism(17).gemm,
ThresholdTable::workstation().gemm
);
}
#[test]
fn policy_by_n_sentinel_is_always_sequential() {
assert_eq!(
PerformanceManager::policy_by_n(usize::MAX, 0),
ExecPolicy::Sequential
);
assert_eq!(
PerformanceManager::policy_by_n(usize::MAX, usize::MAX),
ExecPolicy::Sequential
);
}
}