#[cfg(not(feature = "std"))]
extern crate alloc;
#[cfg(not(feature = "std"))]
use alloc::string::String;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum GpuBackend {
Auto,
Cuda,
Metal,
#[allow(dead_code)]
OpenCL,
#[allow(dead_code)]
Vulkan,
}
impl GpuBackend {
#[must_use]
pub fn is_available(self) -> bool {
match self {
Self::Auto => {
#[cfg(feature = "cuda")]
if super::cuda::is_available() {
return true;
}
#[cfg(feature = "metal")]
if super::metal::is_available() {
return true;
}
false
}
Self::Cuda => {
#[cfg(feature = "cuda")]
{
super::cuda::is_available()
}
#[cfg(not(feature = "cuda"))]
{
false
}
}
Self::Metal => {
#[cfg(feature = "metal")]
{
super::metal::is_available()
}
#[cfg(not(feature = "metal"))]
{
false
}
}
Self::OpenCL | Self::Vulkan => false, }
}
#[must_use]
pub const fn name(self) -> &'static str {
match self {
Self::Auto => "Auto",
Self::Cuda => "CUDA",
Self::Metal => "Metal",
Self::OpenCL => "OpenCL",
Self::Vulkan => "Vulkan",
}
}
}
#[derive(Debug, Clone)]
pub struct GpuCapabilities {
pub backend: GpuBackend,
pub device_name: String,
pub total_memory: u64,
pub available_memory: u64,
pub max_fft_size: usize,
pub supports_f64: bool,
pub supports_f16: bool,
pub compute_units: u32,
pub max_workgroup_size: u32,
}
impl Default for GpuCapabilities {
fn default() -> Self {
Self {
backend: GpuBackend::Auto,
device_name: String::new(),
total_memory: 0,
available_memory: 0,
max_fft_size: 0,
supports_f64: false,
supports_f16: false,
compute_units: 0,
max_workgroup_size: 0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_backend_name() {
assert_eq!(GpuBackend::Cuda.name(), "CUDA");
assert_eq!(GpuBackend::Metal.name(), "Metal");
assert_eq!(GpuBackend::Auto.name(), "Auto");
}
#[test]
fn test_backend_availability() {
let _ = GpuBackend::Cuda.is_available();
let _ = GpuBackend::Metal.is_available();
let _ = GpuBackend::Auto.is_available();
}
}