use crate::types::KernelMode;
#[derive(Debug, Clone)]
pub struct KernelRegistration {
pub id: &'static str,
pub mode: KernelMode,
pub grid_size: u32,
pub block_size: u32,
pub publishes_to: &'static [&'static str],
}
inventory::collect!(KernelRegistration);
pub fn registered_kernels() -> impl Iterator<Item = &'static KernelRegistration> {
inventory::iter::<KernelRegistration>()
}
#[derive(Debug, Clone)]
pub struct StencilKernelRegistration {
pub id: &'static str,
pub grid: &'static str,
pub tile_width: u32,
pub tile_height: u32,
pub halo: u32,
pub cuda_source: &'static str,
}
inventory::collect!(StencilKernelRegistration);
pub fn registered_stencil_kernels() -> impl Iterator<Item = &'static StencilKernelRegistration> {
inventory::iter::<StencilKernelRegistration>()
}
pub fn find_stencil_kernel(id: &str) -> Option<&'static StencilKernelRegistration> {
registered_stencil_kernels().find(|k| k.id == id)
}
#[derive(Debug, Clone)]
pub struct GpuKernelRegistration {
pub id: &'static str,
pub block_size: u32,
pub capabilities: &'static [&'static str],
pub backends: &'static [&'static str],
pub fallback_order: &'static [&'static str],
}
inventory::collect!(GpuKernelRegistration);
pub fn registered_gpu_kernels() -> impl Iterator<Item = &'static GpuKernelRegistration> {
inventory::iter::<GpuKernelRegistration>()
}
pub fn find_gpu_kernel(id: &str) -> Option<&'static GpuKernelRegistration> {
registered_gpu_kernels().find(|k| k.id == id)
}
pub fn backend_supports_capability(backend: &str, capability: &str) -> bool {
match (backend, capability) {
("cuda", _) => true,
("metal", "f64") => false,
("metal", "cooperative_groups") => false,
("metal", "dynamic_parallelism") => false,
("metal", _) => true,
("wgpu", "f64") => false,
("wgpu", "i64") => false,
("wgpu", "atomic64") => false, ("wgpu", "cooperative_groups") => false,
("wgpu", "dynamic_parallelism") => false,
("wgpu", _) => true,
("cpu", _) => true,
_ => false,
}
}
pub fn select_backend(
fallback_order: &[&str],
required_capabilities: &[&str],
available_backends: &[&str],
) -> Option<&'static str> {
for backend in fallback_order {
if !available_backends.contains(backend) {
continue;
}
let supports_all = required_capabilities
.iter()
.all(|cap| backend_supports_capability(backend, cap));
if supports_all {
return match *backend {
"cuda" => Some("cuda"),
"metal" => Some("metal"),
"wgpu" => Some("wgpu"),
"cpu" => Some("cpu"),
_ => None,
};
}
}
None
}