use cubecl::prelude::*;
use crate::accelerate::Accelerator;
pub fn sniff_best_accelerator(enable: &[Accelerator]) -> Option<Accelerator> {
for accelerator in enable {
let is_enabled = match accelerator {
#[cfg(feature = "cuda")]
Accelerator::Cuda => probe_runtime::<cubecl::cuda::CudaRuntime>("CUDA"),
#[cfg(feature = "rocm")]
Accelerator::Rocm => probe_runtime::<cubecl::hip::HipRuntime>("ROCM"),
#[cfg(feature = "vulkan")]
Accelerator::Vulkan => probe_runtime::<cubecl::wgpu::WgpuRuntime>("VULKAN"),
#[cfg(feature = "metal")]
Accelerator::Metal => probe_runtime::<cubecl::wgpu::WgpuRuntime>("METAL"),
#[cfg(feature = "cpu")]
Accelerator::Cpu => probe_runtime::<cubecl::cpu::CpuRuntime>("CPU"),
};
if is_enabled {
return Some(*accelerator);
}
}
None
}
fn probe_runtime<R: Runtime>(name: &'static str) -> bool {
let device = <R::Device as Default>::default();
let client = R::client(&device);
match cubecl::future::block_on(client.sync()) {
Ok(()) => true,
Err(err) => {
tracing::debug!(err = ?err, "could not use {name} runtime");
false
},
}
}