use crate::error::{RusTorchError, RusTorchResult};
use crate::tensor::Tensor;
use num_traits::Float;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
pub enum KernelOp {
Add,
Mul,
Sub,
Div,
MatMul,
Conv2D,
BatchNorm,
ReduceSum,
ReduceMean,
ReLU,
Softmax,
}
pub struct KernelParams {
pub input_shapes: Vec<Vec<usize>>,
pub output_shape: Vec<usize>,
pub extra_params: HashMap<String, f64>,
}
impl Default for KernelParams {
fn default() -> Self {
Self {
input_shapes: Vec::new(),
output_shape: Vec::new(),
extra_params: HashMap::new(),
}
}
}
pub struct KernelMetrics {
pub execution_time: Duration,
pub memory_bandwidth: f64,
pub occupancy: f64,
pub flops: f64,
}
impl Default for KernelMetrics {
fn default() -> Self {
Self {
execution_time: Duration::ZERO,
memory_bandwidth: 0.0,
occupancy: 0.0,
flops: 0.0,
}
}
}
pub trait UnifiedKernelExecutor: Send + Sync {
fn execute_f32(&self, op: KernelOp, inputs: &[&Tensor<f32>], params: &KernelParams) -> RusTorchResult<Tensor<f32>>;
fn execute_f64(&self, op: KernelOp, inputs: &[&Tensor<f64>], params: &KernelParams) -> RusTorchResult<Tensor<f64>>;
fn supports_operation(&self, op: KernelOp) -> bool;
fn device_type(&self) -> DeviceType;
fn get_metrics(&self) -> KernelMetrics;
fn optimize_params(&self, op: KernelOp, params: &mut KernelParams) -> RusTorchResult<()>;
}
trait ExecuteGeneric<T: Float + 'static + Send + Sync> {
fn execute(&self, op: KernelOp, inputs: &[&Tensor<T>], params: &KernelParams) -> RusTorchResult<Tensor<T>>;
}
pub struct CudaUnifiedExecutor {
device_id: usize,
metrics: Arc<Mutex<KernelMetrics>>,
}
#[cfg(feature = "cuda")]
impl CudaUnifiedExecutor {
pub fn new(device_id: usize) -> RusTorchResult<Self> {
use crate::gpu::cuda_kernels::CudaKernelExecutor;
let _executor = CudaKernelExecutor::new(device_id)
.map_err(|_| RusTorchError::DeviceNotFound(device_id))?;
Ok(Self {
device_id,
metrics: Arc::new(Mutex::new(KernelMetrics::default())),
})
}
}
#[cfg(feature = "cuda")]
impl ExecuteGeneric<f32> for CudaUnifiedExecutor {
fn execute(&self, op: KernelOp, inputs: &[&Tensor<f32>], params: &KernelParams) -> RusTorchResult<Tensor<f32>> {
let start_time = Instant::now();
let result = match op {
KernelOp::Add => self.execute_cuda_add(inputs, params),
KernelOp::Mul => self.execute_cuda_mul(inputs, params),
KernelOp::MatMul => self.execute_cuda_matmul(inputs, params),
KernelOp::Conv2D => self.execute_cuda_conv2d(inputs, params),
_ => Err(RusTorchError::UnsupportedOperation(format!("Operation {:?} not implemented for CUDA", op))),
};
let execution_time = start_time.elapsed();
if let Ok(mut metrics) = self.metrics.lock() {
metrics.execution_time = execution_time;
let (memory_bandwidth, occupancy, flops) = self.calculate_cuda_metrics(
operation,
&input_tensors,
execution_time,
)?;
metrics.memory_bandwidth = memory_bandwidth;
metrics.occupancy = occupancy;
metrics.flops = flops;
}
result
}
fn calculate_cuda_metrics(
&self,
operation: &str,
input_tensors: &[&Tensor<f32>],
execution_time: std::time::Duration,
) -> RusTorchResult<(f64, f64, f64)> {
let execution_time_ms = execution_time.as_secs_f64() * 1000.0;
let total_elements: usize = input_tensors.iter()
.map(|t| t.data.len())
.sum();
let total_bytes = total_elements * std::mem::size_of::<f32>();
let memory_bandwidth = if execution_time_ms > 0.0 {
(total_bytes as f64 * 2.0) / (execution_time_ms / 1000.0) / 1e9 } else {
0.0
};
let occupancy = match operation {
"matmul" => {
let matrix_size = (total_elements as f64).sqrt();
if matrix_size >= 1024.0 { 85.0 } else { 70.0 }
}
"conv2d" => {
let occupancy_base = 75.0;
if total_elements > 1_000_000 { occupancy_base + 10.0 } else { occupancy_base }
}
"elementwise" => {
90.0
}
_ => 65.0, };
let flops = match operation {
"matmul" => {
if input_tensors.len() >= 2 {
let a_shape = &input_tensors[0].shape;
let b_shape = &input_tensors[1].shape;
if a_shape.len() >= 2 && b_shape.len() >= 2 {
let m = a_shape[a_shape.len() - 2];
let k = a_shape[a_shape.len() - 1];
let n = b_shape[b_shape.len() - 1];
(2.0 * m as f64 * n as f64 * k as f64) / (execution_time_ms / 1000.0) / 1e9
} else {
0.0
}
} else {
0.0
}
}
"conv2d" => {
let operations_per_output = 9.0; (total_elements as f64 * operations_per_output) / (execution_time_ms / 1000.0) / 1e9
}
"elementwise" => {
(total_elements as f64) / (execution_time_ms / 1000.0) / 1e9
}
_ => 0.0,
};
Ok((memory_bandwidth, occupancy, flops))
}
}
#[cfg(feature = "cuda")]
impl ExecuteGeneric<f64> for CudaUnifiedExecutor {
fn execute(&self, op: KernelOp, inputs: &[&Tensor<f64>], params: &KernelParams) -> RusTorchResult<Tensor<f64>> {
let start_time = Instant::now();
let result = match op {
KernelOp::Add => self.execute_cuda_add(inputs, params),
KernelOp::Mul => self.execute_cuda_mul(inputs, params),
KernelOp::MatMul => self.execute_cuda_matmul(inputs, params),
KernelOp::Conv2D => self.execute_cuda_conv2d(inputs, params),
_ => Err(RusTorchError::UnsupportedOperation(format!("Operation {:?} not implemented for CUDA", op))),
};
let execution_time = start_time.elapsed();
if let Ok(mut metrics) = self.metrics.lock() {
metrics.execution_time = execution_time;
metrics.memory_bandwidth = 100.0;
metrics.occupancy = 80.0;
metrics.flops = 1000.0;
}
result
}
}
#[cfg(feature = "cuda")]
impl UnifiedKernelExecutor for CudaUnifiedExecutor {
fn execute_f32(&self, op: KernelOp, inputs: &[&Tensor<f32>], params: &KernelParams) -> RusTorchResult<Tensor<f32>> {
<Self as ExecuteGeneric<f32>>::execute(self, op, inputs, params)
}
fn execute_f64(&self, op: KernelOp, inputs: &[&Tensor<f64>], params: &KernelParams) -> RusTorchResult<Tensor<f64>> {
<Self as ExecuteGeneric<f64>>::execute(self, op, inputs, params)
}
fn supports_operation(&self, op: KernelOp) -> bool {
matches!(op, KernelOp::Add | KernelOp::Mul | KernelOp::MatMul | KernelOp::Conv2D | KernelOp::ReLU)
}
fn device_type(&self) -> DeviceType {
DeviceType::Cuda(self.device_id)
}
fn get_metrics(&self) -> KernelMetrics {
self.metrics.lock().unwrap().clone()
}
fn optimize_params(&self, op: KernelOp, params: &mut KernelParams) -> RusTorchResult<()> {
match op {
KernelOp::MatMul => {
params.extra_params.insert("cuda_tile_size".to_string(), 32.0);
params.extra_params.insert("cuda_use_cublas".to_string(), 1.0);
},
KernelOp::Conv2D => {
params.extra_params.insert("cuda_use_cudnn".to_string(), 1.0);
params.extra_params.insert("cuda_workspace_size".to_string(), 64.0 * 1024.0 * 1024.0);
},
_ => {}
}
Ok(())
}
}
#[cfg(feature = "cuda")]
impl CudaUnifiedExecutor {
fn execute_cuda_add<T>(&self, inputs: &[&Tensor<T>], _params: &KernelParams) -> RusTorchResult<Tensor<T>>
where
T: Float + 'static + Send + Sync,
{
if inputs.len() != 2 {
return Err(RusTorchError::InvalidOperation("Add operation requires exactly 2 inputs".to_string()));
}
inputs[0].add(inputs[1])
.map_err(|e| RusTorchError::KernelExecutionError(e))
}
fn execute_cuda_mul<T>(&self, inputs: &[&Tensor<T>], _params: &KernelParams) -> RusTorchResult<Tensor<T>>
where
T: Float + 'static + Send + Sync,
{
if inputs.len() != 2 {
return Err(RusTorchError::InvalidOperation("Mul operation requires exactly 2 inputs".to_string()));
}
inputs[0].mul(inputs[1])
.map_err(|e| RusTorchError::KernelExecutionError(e))
}
fn execute_cuda_matmul<T>(&self, inputs: &[&Tensor<T>], _params: &KernelParams) -> RusTorchResult<Tensor<T>>
where
T: Float + 'static + Send + Sync,
{
if inputs.len() != 2 {
return Err(RusTorchError::InvalidOperation("MatMul operation requires exactly 2 inputs".to_string()));
}
inputs[0].matmul(inputs[1])
.map_err(|e| RusTorchError::KernelExecutionError(e))
}
fn execute_cuda_conv2d<T>(&self, inputs: &[&Tensor<T>], _params: &KernelParams) -> RusTorchResult<Tensor<T>>
where
T: Float + 'static + Send + Sync,
{
Err(RusTorchError::UnsupportedOperation("Conv2D not yet implemented".to_string()))
}
}
pub struct MetalUnifiedExecutor {
device_id: usize,
metrics: Arc<Mutex<KernelMetrics>>,
}
#[cfg(feature = "metal")]
impl MetalUnifiedExecutor {
pub fn new(device_id: usize) -> RusTorchResult<Self> {
use crate::gpu::metal_kernels::MetalKernelExecutor;
let _executor = MetalKernelExecutor::new()
.map_err(|_| RusTorchError::DeviceNotFound(device_id))?;
Ok(Self {
device_id,
metrics: Arc::new(Mutex::new(KernelMetrics::default())),
})
}
}
#[cfg(feature = "metal")]
impl UnifiedKernelExecutor for MetalUnifiedExecutor {
fn execute<T>(&self, op: KernelOp, inputs: &[&Tensor<T>], params: &KernelParams) -> RusTorchResult<Tensor<T>>
where
T: Float + 'static + Send + Sync,
{
let start_time = Instant::now();
let result = match op {
KernelOp::Add => self.execute_metal_add(inputs, params),
KernelOp::Mul => self.execute_metal_mul(inputs, params),
KernelOp::MatMul => self.execute_metal_matmul(inputs, params),
_ => Err(RusTorchError::UnsupportedOperation(format!("Operation {:?} not implemented for Metal", op))),
};
let execution_time = start_time.elapsed();
if let Ok(mut metrics) = self.metrics.lock() {
metrics.execution_time = execution_time;
metrics.memory_bandwidth = 90.0; metrics.occupancy = 75.0; metrics.flops = 800.0; }
result
}
fn supports_operation(&self, op: KernelOp) -> bool {
matches!(op, KernelOp::Add | KernelOp::Mul | KernelOp::MatMul | KernelOp::ReLU)
}
fn device_type(&self) -> DeviceType {
DeviceType::Metal(self.device_id)
}
fn get_metrics(&self) -> KernelMetrics {
self.metrics.lock().unwrap().clone()
}
fn optimize_params(&self, op: KernelOp, params: &mut KernelParams) -> RusTorchResult<()> {
match op {
KernelOp::MatMul => {
params.extra_params.insert("metal_threads_per_group".to_string(), 64.0);
params.extra_params.insert("metal_use_mps".to_string(), 1.0);
},
_ => {}
}
Ok(())
}
}
#[cfg(feature = "metal")]
impl MetalUnifiedExecutor {
fn execute_metal_add<T>(&self, inputs: &[&Tensor<T>], _params: &KernelParams) -> RusTorchResult<Tensor<T>>
where
T: Float + 'static + Send + Sync,
{
if inputs.len() != 2 {
return Err(RusTorchError::InvalidOperation("Add operation requires exactly 2 inputs".to_string()));
}
inputs[0].add(inputs[1])
.map_err(|e| RusTorchError::KernelExecutionError(e))
}
fn execute_metal_mul<T>(&self, inputs: &[&Tensor<T>], _params: &KernelParams) -> RusTorchResult<Tensor<T>>
where
T: Float + 'static + Send + Sync,
{
if inputs.len() != 2 {
return Err(RusTorchError::InvalidOperation("Mul operation requires exactly 2 inputs".to_string()));
}
inputs[0].mul(inputs[1])
.map_err(|e| RusTorchError::KernelExecutionError(e))
}
fn execute_metal_matmul<T>(&self, inputs: &[&Tensor<T>], _params: &KernelParams) -> RusTorchResult<Tensor<T>>
where
T: Float + 'static + Send + Sync,
{
if inputs.len() != 2 {
return Err(RusTorchError::InvalidOperation("MatMul operation requires exactly 2 inputs".to_string()));
}
inputs[0].matmul(inputs[1])
.map_err(|e| RusTorchError::KernelExecutionError(e))
}
}
pub struct OpenClUnifiedExecutor {
device_id: usize,
metrics: Arc<Mutex<KernelMetrics>>,
}
#[cfg(feature = "opencl")]
impl OpenClUnifiedExecutor {
pub fn new(device_id: usize) -> RusTorchResult<Self> {
use crate::gpu::opencl_kernels::OpenClKernelExecutor;
let _executor = OpenClKernelExecutor::new(device_id)
.map_err(|_| RusTorchError::DeviceNotFound(device_id))?;
Ok(Self {
device_id,
metrics: Arc::new(Mutex::new(KernelMetrics::default())),
})
}
}
#[cfg(feature = "opencl")]
impl UnifiedKernelExecutor for OpenClUnifiedExecutor {
fn execute<T>(&self, op: KernelOp, inputs: &[&Tensor<T>], params: &KernelParams) -> RusTorchResult<Tensor<T>>
where
T: Float + 'static + Send + Sync,
{
let start_time = Instant::now();
let result = match op {
KernelOp::Add => self.execute_opencl_add(inputs, params),
KernelOp::Mul => self.execute_opencl_mul(inputs, params),
_ => Err(RusTorchError::UnsupportedOperation(format!("Operation {:?} not implemented for OpenCL", op))),
};
let execution_time = start_time.elapsed();
if let Ok(mut metrics) = self.metrics.lock() {
metrics.execution_time = execution_time;
metrics.memory_bandwidth = 70.0; metrics.occupancy = 60.0; metrics.flops = 600.0; }
result
}
fn supports_operation(&self, op: KernelOp) -> bool {
matches!(op, KernelOp::Add | KernelOp::Mul)
}
fn device_type(&self) -> DeviceType {
DeviceType::OpenCL(self.device_id)
}
fn get_metrics(&self) -> KernelMetrics {
self.metrics.lock().unwrap().clone()
}
fn optimize_params(&self, op: KernelOp, params: &mut KernelParams) -> RusTorchResult<()> {
match op {
KernelOp::Add | KernelOp::Mul => {
params.extra_params.insert("opencl_local_size".to_string(), 256.0);
},
_ => {}
}
Ok(())
}
}
#[cfg(feature = "opencl")]
impl OpenClUnifiedExecutor {
fn execute_opencl_add<T>(&self, inputs: &[&Tensor<T>], _params: &KernelParams) -> RusTorchResult<Tensor<T>>
where
T: Float + 'static + Send + Sync,
{
if inputs.len() != 2 {
return Err(RusTorchError::InvalidOperation("Add operation requires exactly 2 inputs".to_string()));
}
inputs[0].add(inputs[1])
.map_err(|e| RusTorchError::KernelExecutionError(e))
}
fn execute_opencl_mul<T>(&self, inputs: &[&Tensor<T>], _params: &KernelParams) -> RusTorchResult<Tensor<T>>
where
T: Float + 'static + Send + Sync,
{
if inputs.len() != 2 {
return Err(RusTorchError::InvalidOperation("Mul operation requires exactly 2 inputs".to_string()));
}
inputs[0].mul(inputs[1])
.map_err(|e| RusTorchError::KernelExecutionError(e))
}
}
pub struct CpuFallbackExecutor {
metrics: Arc<Mutex<KernelMetrics>>,
}
impl CpuFallbackExecutor {
pub fn new() -> Self {
Self {
metrics: Arc::new(Mutex::new(KernelMetrics::default())),
}
}
}
impl Default for CpuFallbackExecutor {
fn default() -> Self {
Self::new()
}
}
impl UnifiedKernelExecutor for CpuFallbackExecutor {
fn execute<T>(&self, op: KernelOp, inputs: &[&Tensor<T>], _params: &KernelParams) -> RusTorchResult<Tensor<T>>
where
T: Float + 'static + Send + Sync,
{
let start_time = Instant::now();
let result = match op {
KernelOp::Add => {
if inputs.len() != 2 {
return Err(RusTorchError::InvalidOperation("Add operation requires exactly 2 inputs".to_string()));
}
inputs[0].add(inputs[1])
.map_err(|e| RusTorchError::KernelExecutionError(e))
},
KernelOp::Mul => {
if inputs.len() != 2 {
return Err(RusTorchError::InvalidOperation("Mul operation requires exactly 2 inputs".to_string()));
}
inputs[0].mul(inputs[1])
.map_err(|e| RusTorchError::KernelExecutionError(e))
},
KernelOp::Sub => {
if inputs.len() != 2 {
return Err(RusTorchError::InvalidOperation("Sub operation requires exactly 2 inputs".to_string()));
}
inputs[0].sub(inputs[1])
.map_err(|e| RusTorchError::KernelExecutionError(e))
},
KernelOp::MatMul => {
if inputs.len() != 2 {
return Err(RusTorchError::InvalidOperation("MatMul operation requires exactly 2 inputs".to_string()));
}
inputs[0].matmul(inputs[1])
.map_err(|e| RusTorchError::KernelExecutionError(e))
},
_ => Err(RusTorchError::UnsupportedOperation(format!("Operation {:?} not implemented for CPU fallback", op))),
};
let execution_time = start_time.elapsed();
if let Ok(mut metrics) = self.metrics.lock() {
metrics.execution_time = execution_time;
metrics.memory_bandwidth = 50.0; metrics.occupancy = 100.0; metrics.flops = 100.0; }
result
}
fn supports_operation(&self, op: KernelOp) -> bool {
matches!(op, KernelOp::Add | KernelOp::Mul | KernelOp::Sub | KernelOp::MatMul)
}
fn device_type(&self) -> DeviceType {
DeviceType::Cpu
}
fn get_metrics(&self) -> KernelMetrics {
self.metrics.lock().unwrap().clone()
}
fn optimize_params(&self, _op: KernelOp, _params: &mut KernelParams) -> RusTorchResult<()> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::Tensor;
#[test]
fn test_cpu_fallback_executor() {
let executor = CpuFallbackExecutor::new();
assert_eq!(executor.device_type(), DeviceType::Cpu);
assert!(executor.supports_operation(KernelOp::Add));
assert!(executor.supports_operation(KernelOp::Mul));
assert!(!executor.supports_operation(KernelOp::Conv2D));
}
#[test]
fn test_cpu_add_execution() {
let executor = CpuFallbackExecutor::new();
let a = Tensor::from_vec(vec![1.0f32, 2.0, 3.0], vec![3]);
let b = Tensor::from_vec(vec![4.0f32, 5.0, 6.0], vec![3]);
let params = KernelParams::default();
let result = executor.execute(KernelOp::Add, &[&a, &b], ¶ms).unwrap();
let expected = vec![5.0f32, 7.0, 9.0];
assert_eq!(result.as_slice().unwrap(), &expected);
let metrics = executor.get_metrics();
assert!(metrics.execution_time > Duration::ZERO);
}
#[test]
fn test_kernel_params() {
let mut params = KernelParams::default();
params.input_shapes = vec![vec![3, 3], vec![3, 3]];
params.output_shape = vec![3, 3];
params.extra_params.insert("test_param".to_string(), 42.0);
assert_eq!(params.input_shapes.len(), 2);
assert_eq!(params.extra_params.get("test_param"), Some(&42.0));
}
#[test]
fn test_kernel_metrics() {
let metrics = KernelMetrics {
execution_time: Duration::from_millis(10),
memory_bandwidth: 100.0,
occupancy: 80.0,
flops: 1000.0,
};
assert_eq!(metrics.execution_time, Duration::from_millis(10));
assert_eq!(metrics.memory_bandwidth, 100.0);
assert_eq!(metrics.occupancy, 80.0);
assert_eq!(metrics.flops, 1000.0);
}
}