use cudarc::cutensor::sys as ct_sys;
pub const TOP_K_ALGOS: &[ct_sys::cutensorAlgo_t] = &[
ct_sys::cutensorAlgo_t::CUTENSOR_ALGO_DEFAULT,
ct_sys::cutensorAlgo_t::CUTENSOR_ALGO_GETT,
ct_sys::cutensorAlgo_t::CUTENSOR_ALGO_TGETT,
ct_sys::cutensorAlgo_t::CUTENSOR_ALGO_TTGT,
];
pub trait Measure {
fn measure(&mut self, algo: ct_sys::cutensorAlgo_t) -> Option<f64>;
}
pub fn autotune_pick<M: Measure>(measure: &mut M) -> Option<ct_sys::cutensorAlgo_t> {
let mut best: Option<(ct_sys::cutensorAlgo_t, f64)> = None;
for algo in TOP_K_ALGOS.iter().copied() {
if let Some(t) = measure.measure(algo) {
match best {
Some((_, bt)) if t < bt => best = Some((algo, t)),
None => best = Some((algo, t)),
_ => {}
}
}
}
best.map(|(a, _)| a)
}
#[cfg(test)]
mod tests {
use super::*;
struct Mock<F: FnMut(ct_sys::cutensorAlgo_t) -> Option<f64>>(F);
impl<F: FnMut(ct_sys::cutensorAlgo_t) -> Option<f64>> Measure for Mock<F> {
fn measure(&mut self, a: ct_sys::cutensorAlgo_t) -> Option<f64> {
(self.0)(a)
}
}
#[test]
fn autotune_picks_best_of_topk() {
let mut m = Mock(|a: ct_sys::cutensorAlgo_t| match a {
ct_sys::cutensorAlgo_t::CUTENSOR_ALGO_DEFAULT => Some(10.0),
ct_sys::cutensorAlgo_t::CUTENSOR_ALGO_GETT => Some(2.5),
ct_sys::cutensorAlgo_t::CUTENSOR_ALGO_TGETT => Some(5.0),
ct_sys::cutensorAlgo_t::CUTENSOR_ALGO_TTGT => Some(7.5),
_ => None,
});
let pick = autotune_pick(&mut m).expect("a winner");
assert_eq!(pick, ct_sys::cutensorAlgo_t::CUTENSOR_ALGO_GETT);
let mut m = Mock(|_| None);
assert!(autotune_pick(&mut m).is_none());
let mut m = Mock(|a: ct_sys::cutensorAlgo_t| match a {
ct_sys::cutensorAlgo_t::CUTENSOR_ALGO_DEFAULT => Some(3.0),
ct_sys::cutensorAlgo_t::CUTENSOR_ALGO_GETT => Some(3.0),
_ => None,
});
let pick = autotune_pick(&mut m).expect("a winner");
assert_eq!(pick, ct_sys::cutensorAlgo_t::CUTENSOR_ALGO_DEFAULT);
}
}