#[cfg(feature = "cuda")]
pub mod cuda;
#[cfg(feature = "vulkan")]
pub mod vulkan;
#[cfg(feature = "metal")]
pub mod metal;
#[cfg(feature = "directml")]
pub mod directml;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct BackendCapabilities {
pub tensor_cores: bool,
pub ray_tracing: bool,
pub mesh_shaders: bool,
pub variable_rate_shading: bool,
pub async_compute: bool,
pub p2p_transfer: bool,
pub max_workgroup_size: (u32, u32, u32),
pub max_compute_invocations: u32,
}
impl Default for BackendCapabilities {
fn default() -> Self {
Self {
tensor_cores: false,
ray_tracing: false,
mesh_shaders: false,
variable_rate_shading: false,
async_compute: false,
p2p_transfer: false,
max_workgroup_size: (256, 256, 64),
max_compute_invocations: 256,
}
}
}
#[derive(Debug, Clone)]
pub enum OptimizationHint {
UseSharedMemory,
UseWarpPrimitives,
UseSubgroupOps,
UseThreadgroupMemory,
PreferWaveOps,
EnableAsyncExecution,
}
pub fn query_capabilities(backend: wgpu::Backend) -> BackendCapabilities {
match backend {
wgpu::Backend::Vulkan => BackendCapabilities {
async_compute: true,
max_workgroup_size: (1024, 1024, 64),
max_compute_invocations: 1024,
..Default::default()
},
wgpu::Backend::Metal => BackendCapabilities {
async_compute: true,
max_workgroup_size: (1024, 1024, 64),
max_compute_invocations: 1024,
..Default::default()
},
wgpu::Backend::Dx12 => BackendCapabilities {
async_compute: true,
max_workgroup_size: (1024, 1024, 64),
max_compute_invocations: 1024,
..Default::default()
},
_ => BackendCapabilities::default(),
}
}