use crate::error::{RusTorchError, RusTorchResult};
use crate::tensor::Tensor;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum GpuBackend {
Cuda,
Metal,
OpenCL,
Vulkan,
DirectML,
}
pub trait KernelInterface: Send + Sync {
fn name(&self) -> &str;
fn supported_backends(&self) -> Vec<GpuBackend>;
fn compile(&self, backend: GpuBackend) -> RusTorchResult<CompiledKernelVariant>;
fn launch_config(&self, problem_size: ProblemSize) -> LaunchConfiguration;
fn validate_inputs(&self, inputs: &[&Tensor<f32>]) -> RusTorchResult<()>;
}
#[derive(Debug, Clone)]
pub struct ProblemSize {
pub total_elements: usize,
pub input_dims: Vec<Vec<usize>>,
pub output_dims: Vec<usize>,
pub batch_size: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct LaunchConfiguration {
pub block_dims: (u32, u32, u32),
pub grid_dims: (u32, u32, u32),
pub shared_memory: usize,
pub stream_idx: Option<usize>,
pub dynamic_parallelism: usize,
}
#[derive(Clone)]
pub enum CompiledKernelVariant {
Cuda(CudaKernel),
Metal(MetalKernel),
OpenCL(OpenCLKernel),
Vulkan(VulkanKernel),
}
#[derive(Clone)]
pub struct CudaKernel {
pub ptx_code: Vec<u8>,
pub function_name: String,
pub compute_capability: (u32, u32),
pub registers_per_thread: u32,
pub max_threads_per_block: u32,
pub uses_tensor_cores: bool,
}
#[derive(Clone)]
pub struct MetalKernel {
pub shader_library: Vec<u8>,
pub function_name: String,
pub thread_execution_width: u32,
pub max_threads_per_threadgroup: u32,
pub uses_simd: bool,
}
#[derive(Clone)]
pub struct OpenCLKernel {
pub source_code: String,
pub function_name: String,
pub work_group_size: (usize, usize, usize),
pub required_version: (u32, u32),
}
#[derive(Clone)]
pub struct VulkanKernel {
pub spirv_code: Vec<u32>,
pub entry_point: String,
pub workgroup_size: (u32, u32, u32),
}
pub struct MatMulKernel {
tile_size: usize,
use_tensor_cores: bool,
transpose_a: bool,
transpose_b: bool,
}
impl MatMulKernel {
pub fn new(tile_size: usize, use_tensor_cores: bool) -> Self {
Self {
tile_size,
use_tensor_cores,
transpose_a: false,
transpose_b: false,
}
}
fn generate_cuda_code(&self) -> String {
format!(r#"
extern "C" __global__ void matmul_kernel(
const float* __restrict__ A,
const float* __restrict__ B,
float* __restrict__ C,
int M, int N, int K
) {{
const int TILE_SIZE = {};
__shared__ float tileA[TILE_SIZE][TILE_SIZE + 1]; // +1 for bank conflict avoidance
__shared__ float tileB[TILE_SIZE][TILE_SIZE + 1];
int bx = blockIdx.x, by = blockIdx.y;
int tx = threadIdx.x, ty = threadIdx.y;
int row = by * TILE_SIZE + ty;
int col = bx * TILE_SIZE + tx;
float sum = 0.0f;
// Loop over tiles
for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; ++t) {{
// Load tile from A
if (row < M && t * TILE_SIZE + tx < K) {{
tileA[ty][tx] = A[row * K + t * TILE_SIZE + tx];
}} else {{
tileA[ty][tx] = 0.0f;
}}
// Load tile from B
if (col < N && t * TILE_SIZE + ty < K) {{
tileB[ty][tx] = B[(t * TILE_SIZE + ty) * N + col];
}} else {{
tileB[ty][tx] = 0.0f;
}}
__syncthreads();
// Compute partial dot product
#pragma unroll
for (int k = 0; k < TILE_SIZE; ++k) {{
sum += tileA[ty][k] * tileB[k][tx];
}}
__syncthreads();
}}
// Write result
if (row < M && col < N) {{
C[row * N + col] = sum;
}}
}}
"#, self.tile_size)
}
fn generate_metal_code(&self) -> String {
format!(r#"
#include <metal_stdlib>
using namespace metal;
kernel void matmul_kernel(
device const float* A [[buffer(0)]],
device const float* B [[buffer(1)]],
device float* C [[buffer(2)]],
constant int& M [[buffer(3)]],
constant int& N [[buffer(4)]],
constant int& K [[buffer(5)]],
uint2 gid [[thread_position_in_grid]],
uint2 tid [[thread_position_in_threadgroup]],
uint2 tgid [[threadgroup_position_in_grid]]
) {{
const int TILE_SIZE = {};
threadgroup float tileA[TILE_SIZE][TILE_SIZE];
threadgroup float tileB[TILE_SIZE][TILE_SIZE];
int row = tgid.y * TILE_SIZE + tid.y;
int col = tgid.x * TILE_SIZE + tid.x;
float sum = 0.0f;
// Loop over tiles
for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; ++t) {{
// Load tiles with bounds checking
if (row < M && t * TILE_SIZE + tid.x < K) {{
tileA[tid.y][tid.x] = A[row * K + t * TILE_SIZE + tid.x];
}} else {{
tileA[tid.y][tid.x] = 0.0f;
}}
if (col < N && t * TILE_SIZE + tid.y < K) {{
tileB[tid.y][tid.x] = B[(t * TILE_SIZE + tid.y) * N + col];
}} else {{
tileB[tid.y][tid.x] = 0.0f;
}}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Compute partial products
for (int k = 0; k < TILE_SIZE; ++k) {{
sum = fma(tileA[tid.y][k], tileB[k][tid.x], sum);
}}
threadgroup_barrier(mem_flags::mem_threadgroup);
}}
// Write result
if (row < M && col < N) {{
C[row * N + col] = sum;
}}
}}
"#, self.tile_size)
}
fn generate_opencl_code(&self) -> String {
format!(r#"
__kernel void matmul_kernel(
__global const float* A,
__global const float* B,
__global float* C,
int M, int N, int K
) {{
const int TILE_SIZE = {};
__local float tileA[TILE_SIZE][TILE_SIZE];
__local float tileB[TILE_SIZE][TILE_SIZE];
int row = get_group_id(1) * TILE_SIZE + get_local_id(1);
int col = get_group_id(0) * TILE_SIZE + get_local_id(0);
float sum = 0.0f;
// Loop over tiles
for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; ++t) {{
// Load tiles
if (row < M && t * TILE_SIZE + get_local_id(0) < K) {{
tileA[get_local_id(1)][get_local_id(0)] =
A[row * K + t * TILE_SIZE + get_local_id(0)];
}} else {{
tileA[get_local_id(1)][get_local_id(0)] = 0.0f;
}}
if (col < N && t * TILE_SIZE + get_local_id(1) < K) {{
tileB[get_local_id(1)][get_local_id(0)] =
B[(t * TILE_SIZE + get_local_id(1)) * N + col];
}} else {{
tileB[get_local_id(1)][get_local_id(0)] = 0.0f;
}}
barrier(CLK_LOCAL_MEM_FENCE);
// Compute partial products
for (int k = 0; k < TILE_SIZE; ++k) {{
sum += tileA[get_local_id(1)][k] * tileB[k][get_local_id(0)];
}}
barrier(CLK_LOCAL_MEM_FENCE);
}}
// Write result
if (row < M && col < N) {{
C[row * N + col] = sum;
}}
}}
"#, self.tile_size)
}
}
impl KernelInterface for MatMulKernel {
fn name(&self) -> &str {
"matmul_kernel"
}
fn supported_backends(&self) -> Vec<GpuBackend> {
vec![GpuBackend::Cuda, GpuBackend::Metal, GpuBackend::OpenCL]
}
fn compile(&self, backend: GpuBackend) -> RusTorchResult<CompiledKernelVariant> {
match backend {
GpuBackend::Cuda => {
let ptx_code = self.generate_cuda_code().into_bytes();
Ok(CompiledKernelVariant::Cuda(CudaKernel {
ptx_code,
function_name: "matmul_kernel".to_string(),
compute_capability: (7, 0), registers_per_thread: 32,
max_threads_per_block: 1024,
uses_tensor_cores: self.use_tensor_cores,
}))
},
GpuBackend::Metal => {
let shader_library = self.generate_metal_code().into_bytes();
Ok(CompiledKernelVariant::Metal(MetalKernel {
shader_library,
function_name: "matmul_kernel".to_string(),
thread_execution_width: 32,
max_threads_per_threadgroup: 1024,
uses_simd: true,
}))
},
GpuBackend::OpenCL => {
let source_code = self.generate_opencl_code();
Ok(CompiledKernelVariant::OpenCL(OpenCLKernel {
source_code,
function_name: "matmul_kernel".to_string(),
work_group_size: (16, 16, 1),
required_version: (2, 0),
}))
},
_ => Err(RusTorchError::Unsupported(
format!("{:?} backend not supported for MatMul", backend)
)),
}
}
fn launch_config(&self, problem_size: ProblemSize) -> LaunchConfiguration {
let block_size = self.tile_size as u32;
let m = problem_size.output_dims[0] as u32;
let n = problem_size.output_dims[1] as u32;
let grid_x = (n + block_size - 1) / block_size;
let grid_y = (m + block_size - 1) / block_size;
LaunchConfiguration {
block_dims: (block_size, block_size, 1),
grid_dims: (grid_x, grid_y, 1),
shared_memory: 2 * self.tile_size * self.tile_size * 4, stream_idx: None,
dynamic_parallelism: 0,
}
}
fn validate_inputs(&self, inputs: &[&Tensor<f32>]) -> RusTorchResult<()> {
if inputs.len() != 2 {
return Err(RusTorchError::InvalidArgument(
format!("MatMul expects 2 inputs, got {}", inputs.len())
));
}
let a_shape = inputs[0].shape();
let b_shape = inputs[1].shape();
if a_shape.len() != 2 || b_shape.len() != 2 {
return Err(RusTorchError::InvalidArgument(
"MatMul expects 2D tensors".into()
));
}
if a_shape[1] != b_shape[0] {
return Err(RusTorchError::InvalidArgument(
format!("Incompatible dimensions: ({}, {}) x ({}, {})",
a_shape[0], a_shape[1], b_shape[0], b_shape[1])
));
}
Ok(())
}
}
pub struct ConvolutionKernel {
filter_dims: (usize, usize, usize, usize),
stride: (usize, usize),
padding: (usize, usize),
dilation: (usize, usize),
groups: usize,
}
impl ConvolutionKernel {
pub fn new(
filter_dims: (usize, usize, usize, usize),
stride: (usize, usize),
padding: (usize, usize),
) -> Self {
Self {
filter_dims,
stride,
padding,
dilation: (1, 1),
groups: 1,
}
}
fn generate_cuda_conv_code(&self) -> String {
format!(r#"
extern "C" __global__ void conv2d_kernel(
const float* __restrict__ input,
const float* __restrict__ filter,
const float* __restrict__ bias,
float* __restrict__ output,
int batch, int in_channels, int out_channels,
int in_height, int in_width,
int out_height, int out_width,
int filter_height, int filter_width,
int stride_h, int stride_w,
int pad_h, int pad_w
) {{
// Optimized convolution implementation
int out_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (out_idx >= batch * out_channels * out_height * out_width) return;
// Decompose output index
int w_out = out_idx % out_width;
int h_out = (out_idx / out_width) % out_height;
int c_out = (out_idx / (out_width * out_height)) % out_channels;
int n = out_idx / (out_width * out_height * out_channels);
float sum = bias ? bias[c_out] : 0.0f;
// Convolution computation
for (int c_in = 0; c_in < in_channels; ++c_in) {{
for (int kh = 0; kh < filter_height; ++kh) {{
for (int kw = 0; kw < filter_width; ++kw) {{
int h_in = h_out * stride_h - pad_h + kh;
int w_in = w_out * stride_w - pad_w + kw;
if (h_in >= 0 && h_in < in_height && w_in >= 0 && w_in < in_width) {{
int input_idx = ((n * in_channels + c_in) * in_height + h_in) * in_width + w_in;
int filter_idx = ((c_out * in_channels + c_in) * filter_height + kh) * filter_width + kw;
sum += input[input_idx] * filter[filter_idx];
}}
}}
}}
}}
output[out_idx] = sum;
}}
"#)
}
}
pub struct ElementWiseKernel {
operation: ElementWiseOp,
vector_width: usize,
}
#[derive(Debug, Clone, Copy)]
pub enum ElementWiseOp {
Add,
Subtract,
Multiply,
Divide,
Maximum,
Minimum,
Power,
Exp,
Log,
Sigmoid,
Tanh,
ReLU,
}
impl ElementWiseKernel {
fn generate_cuda_elementwise(&self) -> String {
let op_code = match self.operation {
ElementWiseOp::Add => "c[idx] = a[idx] + b[idx];",
ElementWiseOp::Multiply => "c[idx] = a[idx] * b[idx];",
ElementWiseOp::ReLU => "c[idx] = fmaxf(0.0f, a[idx]);",
ElementWiseOp::Sigmoid => "c[idx] = 1.0f / (1.0f + expf(-a[idx]));",
_ => "c[idx] = a[idx];", };
format!(r#"
extern "C" __global__ void elementwise_kernel(
const float* __restrict__ a,
const float* __restrict__ b,
float* __restrict__ c,
int n
) {{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {{
{}
}}
}}
"#, op_code)
}
}
pub struct KernelManager {
cache: Arc<RwLock<HashMap<String, CompiledKernelVariant>>>,
backend: GpuBackend,
registry: Arc<RwLock<HashMap<String, Box<dyn KernelInterface>>>>,
}
impl KernelManager {
pub fn new(backend: GpuBackend) -> Self {
let mut manager = Self {
cache: Arc::new(RwLock::new(HashMap::new())),
backend,
registry: Arc::new(RwLock::new(HashMap::new())),
};
manager.register_builtin_kernels();
manager
}
fn register_builtin_kernels(&mut self) {
let matmul = Box::new(MatMulKernel::new(16, true));
self.register_kernel("matmul", matmul);
}
pub fn register_kernel(&mut self, name: &str, kernel: Box<dyn KernelInterface>) {
let mut registry = self.registry.write().unwrap();
registry.insert(name.to_string(), kernel);
}
pub fn get_kernel(&self, name: &str) -> RusTorchResult<CompiledKernelVariant> {
{
let cache = self.cache.read().unwrap();
if let Some(kernel) = cache.get(name) {
return Ok(kernel.clone());
}
}
let registry = self.registry.read().unwrap();
let kernel_interface = registry.get(name)
.ok_or_else(|| RusTorchError::NotFound(format!("Kernel '{}' not found", name)))?;
let compiled = kernel_interface.compile(self.backend)?;
{
let mut cache = self.cache.write().unwrap();
cache.insert(name.to_string(), compiled.clone());
}
Ok(compiled)
}
pub fn launch(
&self,
name: &str,
inputs: &[&Tensor<f32>],
output: &mut Tensor<f32>,
problem_size: ProblemSize,
) -> RusTorchResult<()> {
let compiled = self.get_kernel(name)?;
let registry = self.registry.read().unwrap();
let kernel_interface = registry.get(name)
.ok_or_else(|| RusTorchError::NotFound(format!("Kernel '{}' not found", name)))?;
kernel_interface.validate_inputs(inputs)?;
let config = kernel_interface.launch_config(problem_size);
self.platform_launch(compiled, inputs, output, config)?;
Ok(())
}
fn platform_launch(
&self,
kernel: CompiledKernelVariant,
inputs: &[&Tensor<f32>],
output: &mut Tensor<f32>,
config: LaunchConfiguration,
) -> RusTorchResult<()> {
match kernel {
CompiledKernelVariant::Cuda(_cuda_kernel) => {
Ok(())
},
CompiledKernelVariant::Metal(_metal_kernel) => {
Ok(())
},
CompiledKernelVariant::OpenCL(_opencl_kernel) => {
Ok(())
},
CompiledKernelVariant::Vulkan(_vulkan_kernel) => {
Ok(())
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_matmul_kernel() {
let kernel = MatMulKernel::new(16, false);
let cuda_result = kernel.compile(GpuBackend::Cuda);
assert!(cuda_result.is_ok());
let metal_result = kernel.compile(GpuBackend::Metal);
assert!(metal_result.is_ok());
let opencl_result = kernel.compile(GpuBackend::OpenCL);
assert!(opencl_result.is_ok());
}
#[test]
fn test_launch_configuration() {
let kernel = MatMulKernel::new(16, false);
let problem_size = ProblemSize {
total_elements: 1024 * 1024,
input_dims: vec![vec![1024, 512], vec![512, 1024]],
output_dims: vec![1024, 1024],
batch_size: None,
};
let config = kernel.launch_config(problem_size);
assert_eq!(config.block_dims, (16, 16, 1));
assert!(config.grid_dims.0 > 0);
assert!(config.grid_dims.1 > 0);
}
#[test]
fn test_kernel_manager() {
let manager = KernelManager::new(GpuBackend::Cuda);
let result = manager.get_kernel("matmul");
assert!(result.is_ok());
}
#[test]
fn test_kernel_validation() {
let kernel = MatMulKernel::new(16, false);
let a = Tensor::<f32>::zeros(&[32, 64]);
let b = Tensor::<f32>::zeros(&[64, 32]);
let inputs = vec![&a, &b];
let result = kernel.validate_inputs(&inputs);
assert!(result.is_ok());
let c = Tensor::<f32>::zeros(&[32, 32]);
let invalid_inputs = vec![&a, &c];
let invalid_result = kernel.validate_inputs(&invalid_inputs);
assert!(invalid_result.is_err());
}
}