use crate::error::{RusTorchError, RusTorchResult};
#[cfg(feature = "opencl")]
use std::collections::HashMap;
#[cfg(feature = "opencl")]
use std::ffi::c_void;
use std::marker::PhantomData;
#[cfg(feature = "opencl")]
use opencl3::{
command_queue::CommandQueue,
context::Context,
device::Device,
kernel::Kernel,
memory::{Buffer, CL_MEM_READ_WRITE},
platform::get_platforms,
program::Program,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum OpenClKernelType {
ElementWise,
MatMul,
Reduction,
Convolution,
BatchNorm,
}
pub struct OpenClKernelParams {
pub global_work_size: [usize; 3],
pub local_work_size: [usize; 3],
pub queue_index: usize,
}
impl Default for OpenClKernelParams {
fn default() -> Self {
Self {
global_work_size: [1, 1, 1],
local_work_size: [1, 1, 1],
queue_index: 0,
}
}
}
pub struct OpenClBuffer<T> {
#[cfg(feature = "opencl")]
_buffer: Buffer<T>,
#[cfg(not(feature = "opencl"))]
_buffer: (),
size: usize,
_phantom: PhantomData<T>,
}
impl<T> OpenClBuffer<T> {
#[cfg(feature = "opencl")]
pub fn new(size: usize, context: &Context) -> RusTorchResult<Self> {
let buffer_size = size * std::mem::size_of::<T>();
let buffer = unsafe {
Buffer::<T>::create(
context,
CL_MEM_READ_WRITE,
buffer_size,
std::ptr::null_mut(),
)
}
.map_err(|e| RusTorchError::gpu(format!("OpenCL buffer creation failed: {:?}", e)))?;
Ok(Self {
_buffer: buffer,
size,
_phantom: PhantomData,
})
}
#[cfg(not(feature = "opencl"))]
pub fn new(_size: usize, _context: &()) -> RusTorchResult<Self> {
Err(RusTorchError::backend_unavailable("OpenCL"))
}
#[cfg(feature = "opencl")]
pub fn from_host_data(data: &[T]) -> RusTorchResult<Self> {
Err(RusTorchError::gpu("OpenCL context required"))
}
#[cfg(not(feature = "opencl"))]
pub fn from_host_data(_data: &[T]) -> RusTorchResult<Self> {
Err(RusTorchError::backend_unavailable("OpenCL"))
}
#[cfg(feature = "opencl")]
pub fn copy_to_host(&self, _host_data: &mut [T]) -> RusTorchResult<()> {
Err(RusTorchError::gpu("OpenCL queue required"))
}
#[cfg(not(feature = "opencl"))]
pub fn copy_to_host(&self, _host_data: &mut [T]) -> RusTorchResult<()> {
Err(RusTorchError::backend_unavailable("OpenCL"))
}
pub fn size(&self) -> usize {
self.size
}
}
#[cfg(feature = "opencl")]
pub struct OpenClKernelExecutor {
device: Device,
context: Context,
queue: CommandQueue,
program: Program,
kernels: HashMap<OpenClKernelType, Kernel>,
}
#[cfg(feature = "opencl")]
impl OpenClKernelExecutor {
pub fn new(device_id: usize) -> RusTorchResult<Self> {
let platforms = get_platforms()
.map_err(|e| RusTorchError::gpu(format!("Failed to get OpenCL platforms: {:?}", e)))?;
if platforms.is_empty() {
return Err(RusTorchError::gpu("No OpenCL platforms found".to_string()));
}
let devices = platforms[0]
.get_devices(opencl3::device::CL_DEVICE_TYPE_GPU)
.map_err(|e| RusTorchError::gpu(format!("Failed to get OpenCL devices: {:?}", e)))?;
if devices.is_empty() || device_id >= devices.len() {
return Err(RusTorchError::gpu(format!(
"OpenCL device {} not found",
device_id
)));
}
let device_id = devices[device_id];
let device = opencl3::device::Device::new(device_id);
let context = opencl3::context::Context::from_device(&device)
.map_err(|e| RusTorchError::gpu(format!("Failed to create OpenCL context: {:?}", e)))?;
let queue = opencl3::command_queue::CommandQueue::create_default_with_properties(
&context,
opencl3::command_queue::CL_QUEUE_PROFILING_ENABLE,
0,
)
.map_err(|e| {
RusTorchError::gpu(format!("Failed to create OpenCL command queue: {:?}", e))
})?;
let kernel_source = include_str!("opencl_kernels.cl");
let program =
opencl3::program::Program::create_and_build_from_source(&context, kernel_source, "")
.map_err(|e| {
RusTorchError::gpu(format!("Failed to compile OpenCL kernels: {:?}", e))
})?;
let mut kernels = HashMap::new();
let add_kernel = opencl3::kernel::Kernel::create(&program, "elementwise_add_f32")
.map_err(|e| RusTorchError::gpu(format!("Failed to create add kernel: {:?}", e)))?;
kernels.insert(OpenClKernelType::ElementWise, add_kernel);
let matmul_kernel = opencl3::kernel::Kernel::create(&program, "matrix_multiply_f32")
.map_err(|e| RusTorchError::gpu(format!("Failed to create matmul kernel: {:?}", e)))?;
kernels.insert(OpenClKernelType::MatMul, matmul_kernel);
let reduce_kernel = opencl3::kernel::Kernel::create(&program, "reduce_sum_f32")
.map_err(|e| RusTorchError::gpu(format!("Failed to create reduce kernel: {:?}", e)))?;
kernels.insert(OpenClKernelType::Reduction, reduce_kernel);
Ok(Self {
device,
context,
queue,
program,
kernels,
})
}
pub fn elementwise_add_f32(&self, a: &[f32], b: &[f32], c: &mut [f32]) -> RusTorchResult<()> {
let size = a.len();
if b.len() != size || c.len() != size {
return Err(RusTorchError::invalid_params(
"matmul",
"Array size mismatch in element-wise addition".to_string(),
));
}
let a_buffer = unsafe {
opencl3::memory::Buffer::<f32>::create(
&self.context,
opencl3::memory::CL_MEM_READ_ONLY | opencl3::memory::CL_MEM_COPY_HOST_PTR,
size,
a.as_ptr() as *mut std::ffi::c_void,
)
}
.map_err(|e| RusTorchError::gpu(format!("Failed to create buffer A: {:?}", e)))?;
let b_buffer = unsafe {
opencl3::memory::Buffer::<f32>::create(
&self.context,
opencl3::memory::CL_MEM_READ_ONLY | opencl3::memory::CL_MEM_COPY_HOST_PTR,
size,
b.as_ptr() as *mut std::ffi::c_void,
)
}
.map_err(|e| RusTorchError::gpu(format!("Failed to create buffer B: {:?}", e)))?;
let c_buffer = unsafe {
opencl3::memory::Buffer::<f32>::create(
&self.context,
opencl3::memory::CL_MEM_WRITE_ONLY,
size,
std::ptr::null_mut(),
)
}
.map_err(|e| RusTorchError::gpu(format!("Failed to create buffer C: {:?}", e)))?;
let kernel = self
.kernels
.get(&OpenClKernelType::ElementWise)
.ok_or_else(|| {
RusTorchError::KernelExecutionError("ElementWise kernel not found".to_string())
})?;
let global_work_size = [size];
let local_work_size = [256.min(size)];
unsafe {
opencl3::kernel::ExecuteKernel::new(kernel)
.set_arg(&a_buffer)
.set_arg(&b_buffer)
.set_arg(&c_buffer)
.set_arg(&(size as u32))
.set_global_work_sizes(&global_work_size)
.set_local_work_sizes(&local_work_size)
.enqueue_nd_range(&self.queue)
}
.map_err(|e| {
RusTorchError::KernelExecutionError(format!("Kernel execution failed: {:?}", e))
})?;
unsafe {
self.queue
.enqueue_read_buffer(&c_buffer, opencl3::types::CL_TRUE, 0, c, &[])
.map_err(|e| {
RusTorchError::invalid_params(
"matmul",
format!("Failed to read result: {:?}", e),
)
})?;
}
Ok(())
}
pub fn matmul_f32(
&self,
a: &[f32],
b: &[f32],
c: &mut [f32],
m: usize,
n: usize,
k: usize,
) -> RusTorchResult<()> {
let a_buffer = unsafe {
opencl3::memory::Buffer::<f32>::create(
&self.context,
opencl3::memory::CL_MEM_READ_ONLY | opencl3::memory::CL_MEM_COPY_HOST_PTR,
m * k,
a.as_ptr() as *mut std::ffi::c_void,
)
}
.map_err(|e| RusTorchError::gpu(format!("Failed to create buffer A: {:?}", e)))?;
let b_buffer = unsafe {
opencl3::memory::Buffer::<f32>::create(
&self.context,
opencl3::memory::CL_MEM_READ_ONLY | opencl3::memory::CL_MEM_COPY_HOST_PTR,
k * n,
b.as_ptr() as *mut std::ffi::c_void,
)
}
.map_err(|e| RusTorchError::gpu(format!("Failed to create buffer B: {:?}", e)))?;
let c_buffer = unsafe {
opencl3::memory::Buffer::<f32>::create(
&self.context,
opencl3::memory::CL_MEM_WRITE_ONLY,
m * n,
std::ptr::null_mut(),
)
}
.map_err(|e| RusTorchError::gpu(format!("Failed to create buffer C: {:?}", e)))?;
let kernel = self.kernels.get(&OpenClKernelType::MatMul).ok_or_else(|| {
RusTorchError::KernelExecutionError("MatMul kernel not found".to_string())
})?;
let global_work_size = [n, m];
let local_work_size = [16.min(n), 16.min(m)];
unsafe {
opencl3::kernel::ExecuteKernel::new(kernel)
.set_arg(&a_buffer)
.set_arg(&b_buffer)
.set_arg(&c_buffer)
.set_arg(&(m as u32))
.set_arg(&(n as u32))
.set_arg(&(k as u32))
.set_global_work_sizes(&global_work_size)
.set_local_work_sizes(&local_work_size)
.enqueue_nd_range(&self.queue)
}
.map_err(|e| {
RusTorchError::KernelExecutionError(format!("Kernel execution failed: {:?}", e))
})?;
unsafe {
self.queue
.enqueue_read_buffer(&c_buffer, opencl3::types::CL_TRUE, 0, c, &[])
.map_err(|e| {
RusTorchError::invalid_params(
"matmul",
format!("Failed to read result: {:?}", e),
)
})?;
}
Ok(())
}
pub fn matrix_multiply_f32(
&self,
a: &[f32],
b: &[f32],
m: usize,
n: usize,
k: usize,
) -> RusTorchResult<Vec<f32>> {
if a.len() != m * k || b.len() != k * n {
return Err(RusTorchError::InvalidOperation(
"Matrix dimension mismatch".to_string(),
));
}
let mut result = vec![0.0f32; m * n];
self.matmul_f32(a, b, &mut result, m, n, k)?;
Ok(result)
}
pub fn reduce_sum_f32(&self, input: &[f32]) -> RusTorchResult<f32> {
let size = input.len();
let local_size = 256;
let global_size = size.div_ceil(local_size) * local_size;
let num_groups = global_size / local_size;
let input_buffer = unsafe {
opencl3::memory::Buffer::<f32>::create(
&self.context,
opencl3::memory::CL_MEM_READ_ONLY | opencl3::memory::CL_MEM_COPY_HOST_PTR,
size,
input.as_ptr() as *mut std::ffi::c_void,
)
}
.map_err(|e| RusTorchError::gpu(format!("Failed to create input buffer: {:?}", e)))?;
let output_buffer = unsafe {
opencl3::memory::Buffer::<f32>::create(
&self.context,
opencl3::memory::CL_MEM_WRITE_ONLY,
num_groups,
std::ptr::null_mut(),
)
}
.map_err(|e| RusTorchError::gpu(format!("Failed to create output buffer: {:?}", e)))?;
let kernel = self
.kernels
.get(&OpenClKernelType::Reduction)
.ok_or_else(|| {
RusTorchError::KernelExecutionError("Reduction kernel not found".to_string())
})?;
let global_work_size = [global_size];
let local_work_size = [local_size];
unsafe {
opencl3::kernel::ExecuteKernel::new(kernel)
.set_arg(&input_buffer)
.set_arg(&output_buffer)
.set_arg(&(size as u32))
.set_global_work_sizes(&global_work_size)
.set_local_work_sizes(&local_work_size)
.enqueue_nd_range(&self.queue)
}
.map_err(|e| {
RusTorchError::KernelExecutionError(format!("Kernel execution failed: {:?}", e))
})?;
let mut partial_results = vec![0.0f32; num_groups];
unsafe {
self.queue
.enqueue_read_buffer(
&output_buffer,
opencl3::types::CL_TRUE,
0,
&mut partial_results,
&[],
)
.map_err(|e| {
RusTorchError::invalid_params(
"matmul",
format!("Failed to read partial results: {:?}", e),
)
})?;
}
Ok(partial_results.iter().sum())
}
}
#[cfg(not(feature = "opencl"))]
pub struct OpenClKernelExecutor;
#[cfg(not(feature = "opencl"))]
impl OpenClKernelExecutor {
pub fn new(_device_id: usize) -> RusTorchResult<Self> {
Err(RusTorchError::backend_unavailable("OpenCL"))
}
pub fn elementwise_add_f32(
&self,
_a: &[f32],
_b: &[f32],
_c: &mut [f32],
) -> RusTorchResult<()> {
Err(RusTorchError::backend_unavailable("OpenCL"))
}
pub fn matmul_f32(
&self,
_a: &[f32],
_b: &[f32],
_c: &mut [f32],
_m: usize,
_n: usize,
_k: usize,
) -> RusTorchResult<()> {
Err(RusTorchError::backend_unavailable("OpenCL"))
}
pub fn reduce_sum_f32(&self, _input: &[f32]) -> RusTorchResult<f32> {
Err(RusTorchError::backend_unavailable("OpenCL"))
}
}
pub fn opencl_matmul_f32(
_a: &[f32],
_b: &[f32],
_c: &mut [f32],
_m: usize,
_n: usize,
_k: usize,
) -> RusTorchResult<()> {
#[cfg(feature = "opencl")]
{
let executor = OpenClKernelExecutor::new(0)?;
executor.matmul_f32(_a, _b, _c, _m, _n, _k)
}
#[cfg(not(feature = "opencl"))]
{
Err(RusTorchError::backend_unavailable("OpenCL"))
}
}
pub fn opencl_elementwise_add_f32(_a: &[f32], _b: &[f32], _c: &mut [f32]) -> RusTorchResult<()> {
#[cfg(feature = "opencl")]
{
let executor = OpenClKernelExecutor::new(0)?;
executor.elementwise_add_f32(_a, _b, _c)
}
#[cfg(not(feature = "opencl"))]
{
Err(RusTorchError::backend_unavailable("OpenCL"))
}
}
pub fn opencl_reduce_sum_f32(_input: &[f32]) -> RusTorchResult<f32> {
#[cfg(feature = "opencl")]
{
let executor = OpenClKernelExecutor::new(0)?;
executor.reduce_sum_f32(_input)
}
#[cfg(not(feature = "opencl"))]
{
Err(RusTorchError::backend_unavailable("OpenCL"))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_opencl_kernel_params() {
let params = OpenClKernelParams::default();
assert_eq!(params.global_work_size, [1, 1, 1]);
assert_eq!(params.local_work_size, [1, 1, 1]);
assert_eq!(params.queue_index, 0);
}
#[test]
fn test_opencl_executor_creation() {
let result = OpenClKernelExecutor::new(0);
#[cfg(not(feature = "opencl"))]
assert!(result.is_err());
}
#[test]
fn test_opencl_kernel_types() {
assert_eq!(OpenClKernelType::ElementWise, OpenClKernelType::ElementWise);
assert_ne!(OpenClKernelType::ElementWise, OpenClKernelType::MatMul);
}
}