use std::collections::HashMap;
use std::fmt;
pub mod blas;
pub mod complex;
pub mod elementwise;
pub mod ml;
pub mod reduction;
pub mod transform;
use crate::gpu::{GpuBackend, GpuError};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DataType {
Float32,
Float64,
Int32,
UInt32,
Float16,
BFloat16,
}
impl fmt::Display for DataType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
DataType::Float32 => write!(f, "f32"),
DataType::Float64 => write!(f, "f64"),
DataType::Int32 => write!(f, "i32"),
DataType::UInt32 => write!(f, "u32"),
DataType::Float16 => write!(f, "f16"),
DataType::BFloat16 => write!(f, "bf16"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OperationType {
ComputeIntensive,
MemoryIntensive,
Balanced,
}
#[derive(Debug, Clone)]
pub struct KernelMetadata {
pub workgroup_size: [u32; 3],
pub local_memory_usage: usize,
pub supports_tensor_cores: bool,
pub operationtype: OperationType,
pub backend_metadata: HashMap<String, String>,
}
impl Default for KernelMetadata {
fn default() -> Self {
Self {
workgroup_size: [16, 16, 1],
local_memory_usage: 0,
supports_tensor_cores: false,
operationtype: OperationType::Balanced,
backend_metadata: HashMap::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct KernelParams {
pub datatype: DataType,
pub input_dims: Vec<usize>,
pub output_dims: Vec<usize>,
pub numeric_params: HashMap<String, f64>,
pub string_params: HashMap<String, String>,
}
impl KernelParams {
pub fn new(datatype: DataType) -> Self {
Self {
datatype,
input_dims: Vec::new(),
output_dims: Vec::new(),
numeric_params: HashMap::new(),
string_params: HashMap::new(),
}
}
pub fn with_input_dims(mut self, dims: Vec<usize>) -> Self {
self.input_dims = dims;
self
}
pub fn with_output_dims(mut self, dims: Vec<usize>) -> Self {
self.output_dims = dims;
self
}
pub fn with_numeric_param(mut self, name: &str, value: f64) -> Self {
self.numeric_params.insert(name.to_string(), value);
self
}
pub fn with_string_param(mut self, name: &str, value: &str) -> Self {
self.string_params
.insert(name.to_string(), value.to_string());
self
}
}
pub trait GpuKernel: Send + Sync {
fn name(&self) -> &str;
fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError>;
fn metadata(&self) -> KernelMetadata;
fn can_specialize(&self, params: &KernelParams) -> bool;
fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError>;
}
pub struct BaseKernel {
name: String,
cuda_source: String,
rocm_source: String,
wgpu_source: String,
metal_source: String,
opencl_source: String,
metadata: KernelMetadata,
}
impl BaseKernel {
pub fn new(
name: &str,
cuda_source: &str,
rocm_source: &str,
wgpu_source: &str,
metal_source: &str,
opencl_source: &str,
metadata: KernelMetadata,
) -> Self {
Self {
name: name.to_string(),
cuda_source: cuda_source.to_string(),
rocm_source: rocm_source.to_string(),
wgpu_source: wgpu_source.to_string(),
metal_source: metal_source.to_string(),
opencl_source: opencl_source.to_string(),
metadata,
}
}
}
impl GpuKernel for BaseKernel {
fn name(&self) -> &str {
&self.name
}
fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
match backend {
GpuBackend::Cuda => Ok(self.cuda_source.clone()),
GpuBackend::Rocm => Ok(self.rocm_source.clone()),
GpuBackend::Wgpu => Ok(self.wgpu_source.clone()),
GpuBackend::Metal => Ok(self.metal_source.clone()),
GpuBackend::OpenCL => Ok(self.opencl_source.clone()),
_ => Err(GpuError::UnsupportedBackend(backend)),
}
}
fn metadata(&self) -> KernelMetadata {
self.metadata.clone()
}
fn can_specialize(&self, params: &KernelParams) -> bool {
false }
fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
Err(GpuError::SpecializationNotSupported)
}
}
pub struct KernelRegistry {
kernels: HashMap<String, Box<dyn GpuKernel>>,
}
impl KernelRegistry {
pub fn new() -> Self {
Self {
kernels: HashMap::new(),
}
}
pub fn with_default_kernels() -> Self {
let mut registry = Self::new();
registry.register(Box::new(blas::gemm::GemmKernel::new()));
registry.register(Box::new(blas::axpy::AxpyKernel::new()));
registry.register(Box::new(blas::gemv::GemvKernel::new()));
registry.register(Box::new(elementwise::ElementwiseAddKernel::new()));
registry.register(Box::new(elementwise::ElementwiseSubKernel::new()));
registry.register(Box::new(elementwise::ElementwiseMulKernel::new()));
registry.register(Box::new(elementwise::ElementwiseDivKernel::new()));
registry.register(Box::new(elementwise::ElementwisePowKernel::new()));
registry.register(Box::new(elementwise::ElementwiseSqrtKernel::new()));
registry.register(Box::new(elementwise::ElementwiseExpKernel::new()));
registry.register(Box::new(elementwise::ElementwiseLogKernel::new()));
registry.register(Box::new(create_adam_optimizer_kernel()));
registry.register(Box::new(create_sgd_optimizer_kernel()));
registry.register(Box::new(create_rmsprop_optimizer_kernel()));
registry.register(Box::new(create_adagrad_optimizer_kernel()));
registry.register(Box::new(create_lamb_optimizer_kernel()));
registry.register(Box::new(create_memcpy_kernel()));
registry.register(Box::new(create_fill_kernel()));
registry.register(Box::new(create_reduce_sum_kernel()));
registry.register(Box::new(create_reduce_max_kernel()));
registry.register(Box::new(transform::fft::FftKernel::new()));
registry.register(Box::new(transform::convolution::Conv1dKernel::new()));
registry.register(Box::new(transform::convolution::Conv2dKernel::new()));
registry.register(Box::new(reduction::sum::SumKernel::new()));
registry.register(Box::new(reduction::norm::NormKernel::new()));
registry.register(Box::new(reduction::min_max::MinKernel::new()));
registry.register(Box::new(reduction::min_max::MaxKernel::new()));
registry.register(Box::new(reduction::mean::MeanKernel::new()));
registry.register(Box::new(reduction::std_dev::StdDevKernel::new()));
registry.register(Box::new(ml::activation::ReluKernel::new()));
registry.register(Box::new(ml::activation::SigmoidKernel::new()));
registry.register(Box::new(ml::activation::TanhKernel::new()));
registry.register(Box::new(ml::softmax::SoftmaxKernel::new()));
registry.register(Box::new(ml::pooling::MaxPoolKernel::new()));
registry.register(Box::new(ml::pooling::AvgPoolKernel::new()));
registry.register(Box::new(complex::ComplexMultiplyKernel::new()));
registry.register(Box::new(complex::ComplexConjugateKernel::new()));
registry.register(Box::new(complex::ComplexMatMulKernel::new()));
registry.register(Box::new(create_rk4_stage1_kernel()));
registry.register(Box::new(create_rk4_stage2_kernel()));
registry.register(Box::new(create_rk4_stage3_kernel()));
registry.register(Box::new(create_rk4_stage4_kernel()));
registry.register(Box::new(create_rk4_combine_kernel()));
registry.register(Box::new(createerror_estimate_kernel()));
registry
}
pub fn register(&mut self, kernel: Box<dyn GpuKernel>) {
self.kernels.insert(kernel.name().to_string(), kernel);
}
pub fn get(&self, name: &str) -> Option<&dyn GpuKernel> {
self.kernels.get(name).map(|k| k.as_ref())
}
pub fn get_specialized(
&self,
name: &str,
params: &KernelParams,
) -> Result<Box<dyn GpuKernel>, GpuError> {
let kernel = self
.get(name)
.ok_or_else(|| GpuError::KernelNotFound(name.to_string()))?;
if kernel.can_specialize(params) {
kernel.specialize(params)
} else {
Err(GpuError::SpecializationNotSupported)
}
}
}
impl Default for KernelRegistry {
fn default() -> Self {
Self::with_default_kernels()
}
}
#[allow(dead_code)]
fn create_rk4_stage1_kernel() -> BaseKernel {
let cuda_source = include_str!("rk4_stage1.cu");
let metadata = KernelMetadata {
workgroup_size: [256, 1, 1],
local_memory_usage: 0,
supports_tensor_cores: false,
operationtype: OperationType::ComputeIntensive,
backend_metadata: HashMap::new(),
};
BaseKernel::new(
"rk4_stage1",
cuda_source,
cuda_source, "", "", cuda_source, metadata,
)
}
#[allow(dead_code)]
fn create_rk4_stage2_kernel() -> BaseKernel {
let cuda_source = include_str!("rk4_stage2.cu");
let metadata = KernelMetadata {
workgroup_size: [256, 1, 1],
local_memory_usage: 0,
supports_tensor_cores: false,
operationtype: OperationType::ComputeIntensive,
backend_metadata: HashMap::new(),
};
BaseKernel::new(
"rk4_stage2",
cuda_source,
cuda_source,
"",
"",
cuda_source,
metadata,
)
}
#[allow(dead_code)]
fn create_rk4_stage3_kernel() -> BaseKernel {
let cuda_source = include_str!("rk4_stage3.cu");
let metadata = KernelMetadata {
workgroup_size: [256, 1, 1],
local_memory_usage: 0,
supports_tensor_cores: false,
operationtype: OperationType::ComputeIntensive,
backend_metadata: HashMap::new(),
};
BaseKernel::new(
"rk4_stage3",
cuda_source,
cuda_source,
"",
"",
cuda_source,
metadata,
)
}
#[allow(dead_code)]
fn create_rk4_stage4_kernel() -> BaseKernel {
let cuda_source = include_str!("rk4_stage4.cu");
let metadata = KernelMetadata {
workgroup_size: [256, 1, 1],
local_memory_usage: 0,
supports_tensor_cores: false,
operationtype: OperationType::ComputeIntensive,
backend_metadata: HashMap::new(),
};
BaseKernel::new(
"rk4_stage4",
cuda_source,
cuda_source,
"",
"",
cuda_source,
metadata,
)
}
#[allow(dead_code)]
fn create_rk4_combine_kernel() -> BaseKernel {
let cuda_source = include_str!("rk4_combine.cu");
let metadata = KernelMetadata {
workgroup_size: [256, 1, 1],
local_memory_usage: 0,
supports_tensor_cores: false,
operationtype: OperationType::MemoryIntensive,
backend_metadata: HashMap::new(),
};
BaseKernel::new(
"rk4_combine",
cuda_source,
cuda_source,
"",
"",
cuda_source,
metadata,
)
}
#[allow(dead_code)]
fn createerror_estimate_kernel() -> BaseKernel {
let cuda_source = include_str!("error_estimate.cu");
let metadata = KernelMetadata {
workgroup_size: [256, 1, 1],
local_memory_usage: 1024, supports_tensor_cores: false,
operationtype: OperationType::ComputeIntensive,
backend_metadata: HashMap::new(),
};
BaseKernel::new(
"error_estimate",
cuda_source,
cuda_source,
"",
"",
cuda_source,
metadata,
)
}
#[allow(dead_code)]
fn create_adam_optimizer_kernel() -> BaseKernel {
let cuda_source = include_str!("adam_optimizer.cu");
let metadata = KernelMetadata {
workgroup_size: [256, 1, 1],
local_memory_usage: 0,
supports_tensor_cores: false,
operationtype: OperationType::ComputeIntensive,
backend_metadata: HashMap::new(),
};
BaseKernel::new(
"adam_optimizer",
cuda_source,
cuda_source,
"",
"",
cuda_source,
metadata,
)
}
#[allow(dead_code)]
fn create_sgd_optimizer_kernel() -> BaseKernel {
let cuda_source = include_str!("sgd_optimizer.cu");
let metadata = KernelMetadata {
workgroup_size: [256, 1, 1],
local_memory_usage: 0,
supports_tensor_cores: false,
operationtype: OperationType::MemoryIntensive,
backend_metadata: HashMap::new(),
};
BaseKernel::new(
"sgd_optimizer",
cuda_source,
cuda_source,
"",
"",
cuda_source,
metadata,
)
}
#[allow(dead_code)]
fn create_rmsprop_optimizer_kernel() -> BaseKernel {
let cuda_source = include_str!("rmsprop_optimizer.cu");
let metadata = KernelMetadata {
workgroup_size: [256, 1, 1],
local_memory_usage: 0,
supports_tensor_cores: false,
operationtype: OperationType::ComputeIntensive,
backend_metadata: HashMap::new(),
};
BaseKernel::new(
"rmsprop_optimizer",
cuda_source,
cuda_source,
"",
"",
cuda_source,
metadata,
)
}
#[allow(dead_code)]
fn create_adagrad_optimizer_kernel() -> BaseKernel {
let cuda_source = include_str!("adagrad_optimizer.cu");
let metadata = KernelMetadata {
workgroup_size: [256, 1, 1],
local_memory_usage: 0,
supports_tensor_cores: false,
operationtype: OperationType::ComputeIntensive,
backend_metadata: HashMap::new(),
};
BaseKernel::new(
"adagrad_optimizer",
cuda_source,
cuda_source,
"",
"",
cuda_source,
metadata,
)
}
#[allow(dead_code)]
fn create_lamb_optimizer_kernel() -> BaseKernel {
let cuda_source = include_str!("lamb_optimizer.cu");
let metadata = KernelMetadata {
workgroup_size: [256, 1, 1],
local_memory_usage: 0,
supports_tensor_cores: false,
operationtype: OperationType::ComputeIntensive,
backend_metadata: HashMap::new(),
};
BaseKernel::new(
"lamb_optimizer",
cuda_source,
cuda_source,
"",
"",
cuda_source,
metadata,
)
}
#[allow(dead_code)]
fn create_memcpy_kernel() -> BaseKernel {
let cuda_source = include_str!("memcpy.cu");
let metadata = KernelMetadata {
workgroup_size: [256, 1, 1],
local_memory_usage: 0,
supports_tensor_cores: false,
operationtype: OperationType::MemoryIntensive,
backend_metadata: HashMap::new(),
};
BaseKernel::new(
"memcpy",
cuda_source,
cuda_source,
"",
"",
cuda_source,
metadata,
)
}
#[allow(dead_code)]
fn create_fill_kernel() -> BaseKernel {
let cuda_source = include_str!("fill.cu");
let metadata = KernelMetadata {
workgroup_size: [256, 1, 1],
local_memory_usage: 0,
supports_tensor_cores: false,
operationtype: OperationType::MemoryIntensive,
backend_metadata: HashMap::new(),
};
BaseKernel::new(
"fill",
cuda_source,
cuda_source,
"",
"",
cuda_source,
metadata,
)
}
#[allow(dead_code)]
fn create_reduce_sum_kernel() -> BaseKernel {
let cuda_source = include_str!("reduce_sum.cu");
let metadata = KernelMetadata {
workgroup_size: [256, 1, 1],
local_memory_usage: 1024, supports_tensor_cores: false,
operationtype: OperationType::ComputeIntensive,
backend_metadata: HashMap::new(),
};
BaseKernel::new(
"reduce_sum",
cuda_source,
cuda_source,
"",
"",
cuda_source,
metadata,
)
}
#[allow(dead_code)]
fn create_reduce_max_kernel() -> BaseKernel {
let cuda_source = include_str!("reduce_max.cu");
let metadata = KernelMetadata {
workgroup_size: [256, 1, 1],
local_memory_usage: 1024, supports_tensor_cores: false,
operationtype: OperationType::ComputeIntensive,
backend_metadata: HashMap::new(),
};
BaseKernel::new(
"reduce_max",
cuda_source,
cuda_source,
"",
"",
cuda_source,
metadata,
)
}