use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock};
#[cfg(feature = "gpu")]
use crate::gpu::{GpuBackend, GpuDevice, GpuError, GpuKernel};
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct KernelId {
pub module: String,
pub operation: String,
pub dtype: String,
pub variant: Option<String>,
}
impl KernelId {
pub fn new(module: &str, operation: &str, dtype: &str) -> Self {
Self {
module: module.to_string(),
operation: operation.to_string(),
dtype: dtype.to_string(),
variant: None,
}
}
pub fn with_variant(module: &str, operation: &str, dtype: &str, variant: &str) -> Self {
Self {
module: module.to_string(),
operation: operation.to_string(),
dtype: dtype.to_string(),
variant: Some(variant.to_string()),
}
}
pub fn as_kernel_name(&self) -> String {
match &self.variant {
Some(variant) => format!(
"{}_{}_{}__{}",
self.module, self.operation, self.dtype, variant
),
None => format!(
"{module}_{operation}__{dtype}",
module = self.module,
operation = self.operation,
dtype = self.dtype
),
}
}
}
#[derive(Debug, Clone)]
pub struct KernelSource {
pub source: String,
pub backend: GpuBackend,
pub entry_point: String,
pub workgroup_size: (u32, u32, u32),
pub shared_memory: usize,
pub uses_tensor_cores: bool,
}
#[cfg(feature = "gpu")]
struct CompiledKernel {
kernel: Arc<GpuKernel>,
device_id: usize,
}
static KERNEL_REGISTRY: OnceLock<Mutex<KernelRegistry>> = OnceLock::new();
pub struct KernelRegistry {
sources: HashMap<KernelId, Vec<KernelSource>>,
#[cfg(feature = "gpu")]
compiled_cache: HashMap<(KernelId, usize), CompiledKernel>,
}
impl KernelRegistry {
fn new() -> Self {
Self {
sources: HashMap::new(),
#[cfg(feature = "gpu")]
compiled_cache: HashMap::new(),
}
}
pub fn global() -> &'static Mutex<KernelRegistry> {
KERNEL_REGISTRY.get_or_init(|| {
let mut registry = KernelRegistry::new();
registry.register_builtin_kernels();
Mutex::new(registry)
})
}
fn register_builtin_kernels(&mut self) {
self.register_blas_kernels();
self.register_reduction_kernels();
self.register_utility_kernels();
}
pub fn register_kernel(&mut self, id: KernelId, source: KernelSource) {
self.sources.entry(id).or_default().push(source);
}
pub fn get_sources(&self, id: &KernelId) -> Option<&Vec<KernelSource>> {
self.sources.get(id)
}
#[cfg(feature = "gpu")]
pub fn get_kernel(
&mut self,
id: &KernelId,
device: &GpuDevice,
) -> Result<Arc<GpuKernel>, GpuError> {
let device_id = device.device_id();
let cache_key = (id.clone(), device_id);
if let Some(cached) = self.compiled_cache.get(&cache_key) {
if cached.device_id == device_id {
return Ok(cached.kernel.clone());
}
}
let sources = self
.sources
.get(id)
.ok_or_else(|| GpuError::KernelNotFound(id.as_kernel_name()))?;
let source = sources
.iter()
.find(|s| s.backend == device.backend())
.ok_or_else(|| GpuError::BackendNotSupported(device.backend()))?;
let kernel = device.compile_kernel(&source.source, &source.entry_point)?;
let kernel = Arc::new(kernel);
self.compiled_cache.insert(
cache_key,
CompiledKernel {
kernel: kernel.clone(),
device_id,
},
);
Ok(kernel)
}
#[cfg(feature = "gpu")]
pub fn clear_cache(&mut self) {
self.compiled_cache.clear();
}
pub fn list_kernels(&self) -> Vec<KernelId> {
self.sources.keys().cloned().collect()
}
pub fn has_kernel(&self, id: &KernelId) -> bool {
self.sources.contains_key(id)
}
}
impl KernelRegistry {
fn register_blas_kernels(&mut self) {
self.register_kernel(
KernelId::new("core", "gemm", "f32"),
KernelSource {
source: include_str!("gpu/kernels/gemm_f32.cu").to_string(),
backend: GpuBackend::Cuda,
entry_point: "gemm_f32".to_string(),
workgroup_size: (16, 16, 1),
shared_memory: 4096,
uses_tensor_cores: false,
},
);
self.register_kernel(
KernelId::new("core", "gemm", "f64"),
KernelSource {
source: include_str!("gpu/kernels/gemm_f64.cu").to_string(),
backend: GpuBackend::Cuda,
entry_point: "gemm_f64".to_string(),
workgroup_size: (16, 16, 1),
shared_memory: 8192,
uses_tensor_cores: false,
},
);
self.register_kernel(
KernelId::new("core", "axpy", "f32"),
KernelSource {
source: include_str!("gpu/kernels/axpy.cu").to_string(),
backend: GpuBackend::Cuda,
entry_point: "axpy_f32".to_string(),
workgroup_size: (256, 1, 1),
shared_memory: 0,
uses_tensor_cores: false,
},
);
}
fn register_reduction_kernels(&mut self) {
self.register_kernel(
KernelId::new("core", "reduce_sum", "f32"),
KernelSource {
source: include_str!("gpu/kernels/reduce_sum.cu").to_string(),
backend: GpuBackend::Cuda,
entry_point: "reduce_sum_f32".to_string(),
workgroup_size: (256, 1, 1),
shared_memory: 1024,
uses_tensor_cores: false,
},
);
self.register_kernel(
KernelId::new("core", "reduce_max", "f32"),
KernelSource {
source: include_str!("gpu/kernels/reduce_max.cu").to_string(),
backend: GpuBackend::Cuda,
entry_point: "reduce_max_f32".to_string(),
workgroup_size: (256, 1, 1),
shared_memory: 1024,
uses_tensor_cores: false,
},
);
}
fn register_utility_kernels(&mut self) {
self.register_kernel(
KernelId::new("core", "memcpy", "f32"),
KernelSource {
source: include_str!("gpu/kernels/memcpy.cu").to_string(),
backend: GpuBackend::Cuda,
entry_point: "memcpy_f32".to_string(),
workgroup_size: (256, 1, 1),
shared_memory: 0,
uses_tensor_cores: false,
},
);
self.register_kernel(
KernelId::new("core", "fill", "f32"),
KernelSource {
source: include_str!("gpu/kernels/fill.cu").to_string(),
backend: GpuBackend::Cuda,
entry_point: "fill_f32".to_string(),
workgroup_size: (256, 1, 1),
shared_memory: 0,
uses_tensor_cores: false,
},
);
}
}
#[allow(dead_code)]
pub fn register_module_kernel(id: KernelId, source: KernelSource) {
let registry = KernelRegistry::global();
let mut registry = registry.lock().expect("Operation failed");
registry.register_kernel(id, source);
}
#[cfg(feature = "gpu")]
#[allow(dead_code)]
pub fn get_kernel(id: &KernelId, device: &GpuDevice) -> Result<Arc<GpuKernel>, GpuError> {
let registry = KernelRegistry::global();
let mut registry = registry.lock().expect("Operation failed");
registry.get_kernel(id, device)
}
#[allow(dead_code)]
pub fn has_kernel(id: &KernelId) -> bool {
let registry = KernelRegistry::global();
let registry = registry.lock().expect("Operation failed");
registry.has_kernel(id)
}
#[allow(dead_code)]
pub fn list_kernels() -> Vec<KernelId> {
let registry = KernelRegistry::global();
let registry = registry.lock().expect("Operation failed");
registry.list_kernels()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_kernel_id() {
let id = KernelId::new("linalg", "gemm", "f32");
assert_eq!(id.as_kernel_name(), "linalg_gemm__f32");
let id_with_variant = KernelId::with_variant("fft", "fft2d", "c64", "batched");
assert_eq!(id_with_variant.as_kernel_name(), "fft_fft2d_c64__batched");
}
#[test]
fn test_kernel_registration() {
let id = KernelId::new("test", "dummy", "f32");
let source = KernelSource {
source: "dummy kernel".to_string(),
backend: GpuBackend::Cuda,
entry_point: "dummy".to_string(),
workgroup_size: (1, 1, 1),
shared_memory: 0,
uses_tensor_cores: false,
};
register_module_kernel(id.clone(), source);
assert!(has_kernel(&id));
}
#[test]
fn test_builtin_kernels() {
assert!(has_kernel(&KernelId::new("core", "gemm", "f32")));
assert!(has_kernel(&KernelId::new("core", "reduce_sum", "f32")));
assert!(has_kernel(&KernelId::new("core", "fill", "f32")));
}
}