use crate::error::{RusTorchError, RusTorchResult};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
#[cfg(feature = "opencl")]
use opencl3::{
command_queue::{CommandQueue, CL_NON_BLOCKING, CL_QUEUE_PROFILING_ENABLE},
context::Context,
device::{
get_device_info, Device, CL_DEVICE_GLOBAL_MEM_SIZE, CL_DEVICE_LOCAL_MEM_SIZE,
CL_DEVICE_MAX_CLOCK_FREQUENCY, CL_DEVICE_MAX_COMPUTE_UNITS, CL_DEVICE_MAX_WORK_GROUP_SIZE,
CL_DEVICE_NAME, CL_DEVICE_TYPE_ALL, CL_DEVICE_TYPE_GPU, CL_DEVICE_VENDOR,
},
kernel::{ExecuteKernel, Kernel},
memory::{Buffer, ClMem, CL_MEM_READ_ONLY, CL_MEM_READ_WRITE, CL_MEM_WRITE_ONLY},
platform::{get_platforms, Platform},
program::Program,
types::{cl_device_id, cl_device_type, cl_event, cl_platform_id},
};
#[derive(Debug, Clone)]
pub struct OpenClDeviceInfo {
pub name: String,
pub vendor: String,
pub compute_units: u32,
pub max_work_group_size: usize,
pub global_mem_size: u64,
pub local_mem_size: u64,
pub max_clock_frequency: u32,
pub device_type: String,
}
#[cfg(feature = "opencl")]
pub struct OpenClMatrixExecutor {
context: Context,
command_queue: CommandQueue,
device: Device,
device_info: OpenClDeviceInfo,
}
#[cfg(feature = "opencl")]
impl OpenClMatrixExecutor {
pub fn new() -> RusTorchResult<Self> {
let platforms = get_platforms().map_err(|e| {
RusTorchError::tensor_op(format!("Failed to get OpenCL platforms: {:?}", e))
})?;
if platforms.is_empty() {
return Err(RusTorchError::UnsupportedDevice(
"No OpenCL platforms available".to_string(),
));
}
let (device, platform) = Self::select_best_device(&platforms)?;
let context = Context::from_device(&device).map_err(|e| {
RusTorchError::tensor_op(format!("Failed to create OpenCL context: {:?}", e))
})?;
let command_queue =
CommandQueue::create_default_with_properties(&context, CL_QUEUE_PROFILING_ENABLE, 0)
.map_err(|e| {
RusTorchError::tensor_op(format!("Failed to create command queue: {:?}", e))
})?;
let device_info = OpenClDeviceInfo {
name: "OpenCL Device".to_string(),
vendor: "Unknown".to_string(),
compute_units: 1,
max_work_group_size: 256,
global_mem_size: 1024 * 1024 * 1024, local_mem_size: 64 * 1024, max_clock_frequency: 1000, device_type: "GPU".to_string(),
};
println!(
"Selected OpenCL device: {} by {}",
device_info.name, device_info.vendor
);
println!(
"Compute units: {}, Max work group size: {}",
device_info.compute_units, device_info.max_work_group_size
);
Ok(Self {
context,
command_queue,
device,
device_info,
})
}
fn select_best_device(platforms: &[Platform]) -> RusTorchResult<(Device, Platform)> {
let mut best_device = None;
let mut best_platform = None;
let mut best_score = 0f64;
for platform in platforms {
let devices = platform
.get_devices(CL_DEVICE_TYPE_GPU)
.or_else(|_| platform.get_devices(CL_DEVICE_TYPE_ALL))
.map_err(|e| RusTorchError::tensor_op(format!("Failed to get devices: {:?}", e)))?;
for device_id in devices {
let device = Device::new(device_id);
let info = Self::get_opencl_device_info_from_id(device_id)?;
let score = Self::score_device(&info);
if score > best_score {
best_score = score;
best_device = Some(device);
best_platform = Some(platform.clone());
}
}
}
match (best_device, best_platform) {
(Some(device), Some(platform)) => Ok((device, platform)),
_ => Err(RusTorchError::UnsupportedDevice(
"No suitable OpenCL device found".to_string(),
)),
}
}
fn get_opencl_device_info_from_id(device_id: cl_device_id) -> RusTorchResult<OpenClDeviceInfo> {
Ok(OpenClDeviceInfo {
name: get_device_info(device_id, CL_DEVICE_NAME)
.map_err(|e| {
RusTorchError::tensor_op(format!("Failed to get device name: {:?}", e))
})?
.to_string(),
vendor: get_device_info(device_id, CL_DEVICE_VENDOR)
.map_err(|e| {
RusTorchError::tensor_op(format!("Failed to get device vendor: {:?}", e))
})?
.to_string(),
compute_units: get_device_info(device_id, CL_DEVICE_MAX_COMPUTE_UNITS)
.map_err(|e| {
RusTorchError::tensor_op(format!("Failed to get compute units: {:?}", e))
})?
.to_uint(),
max_work_group_size: get_device_info(device_id, CL_DEVICE_MAX_WORK_GROUP_SIZE)
.map_err(|e| {
RusTorchError::tensor_op(format!("Failed to get max work group size: {:?}", e))
})?
.to_size(),
global_mem_size: get_device_info(device_id, CL_DEVICE_GLOBAL_MEM_SIZE)
.map_err(|e| {
RusTorchError::tensor_op(format!("Failed to get global memory size: {:?}", e))
})?
.to_ulong(),
local_mem_size: get_device_info(device_id, CL_DEVICE_LOCAL_MEM_SIZE)
.map_err(|e| {
RusTorchError::tensor_op(format!("Failed to get local memory size: {:?}", e))
})?
.to_ulong(),
max_clock_frequency: get_device_info(device_id, CL_DEVICE_MAX_CLOCK_FREQUENCY)
.map_err(|e| {
RusTorchError::tensor_op(format!("Failed to get max clock frequency: {:?}", e))
})?
.to_uint(),
device_type: "GPU".to_string(), })
}
fn score_device(info: &OpenClDeviceInfo) -> f64 {
let mut score = 0.0;
score += info.compute_units as f64 * info.max_clock_frequency as f64 / 1000.0;
score += (info.global_mem_size as f64 / 1024.0 / 1024.0 / 1024.0) * 10.0;
if info.vendor.contains("NVIDIA") {
score *= 1.2; } else if info.vendor.contains("AMD") {
score *= 1.1; }
score
}
pub fn matmul_f32(
&mut self,
a: &[f32],
b: &[f32],
c: &mut [f32],
m: usize,
n: usize,
k: usize,
) -> RusTorchResult<()> {
if a.len() != m * k || b.len() != k * n || c.len() != m * n {
return Err(RusTorchError::shape_mismatch(
&[m, k, k, n, m, n],
&[a.len(), 0, b.len(), 0, c.len(), 0],
));
}
let kernel_source = self.generate_matmul_kernel(m, n, k)?;
let kernel = self.compile_kernel("matmul_f32", &kernel_source)?;
let mut buffer_a = unsafe {
Buffer::<f32>::create(&self.context, CL_MEM_READ_ONLY, m * k, std::ptr::null_mut())
}
.map_err(|e| RusTorchError::tensor_op(format!("Failed to create buffer A: {:?}", e)))?;
let mut buffer_b = unsafe {
Buffer::<f32>::create(&self.context, CL_MEM_READ_ONLY, k * n, std::ptr::null_mut())
}
.map_err(|e| RusTorchError::tensor_op(format!("Failed to create buffer B: {:?}", e)))?;
let buffer_c = unsafe {
Buffer::<f32>::create(
&self.context,
CL_MEM_WRITE_ONLY,
m * n,
std::ptr::null_mut(),
)
}
.map_err(|e| RusTorchError::tensor_op(format!("Failed to create buffer C: {:?}", e)))?;
unsafe {
self.command_queue
.enqueue_write_buffer(&mut buffer_a, CL_NON_BLOCKING, 0, a, &[])
.map_err(|e| {
RusTorchError::tensor_op(format!("Failed to write buffer A: {:?}", e))
})?;
self.command_queue
.enqueue_write_buffer(&mut buffer_b, CL_NON_BLOCKING, 0, b, &[])
.map_err(|e| {
RusTorchError::tensor_op(format!("Failed to write buffer B: {:?}", e))
})?;
}
let (global_work_size, local_work_size) = self.calculate_work_sizes(m, n)?;
unsafe {
ExecuteKernel::new(&kernel)
.set_arg(&buffer_a)
.set_arg(&buffer_b)
.set_arg(&buffer_c)
.set_arg(&(m as u32))
.set_arg(&(n as u32))
.set_arg(&(k as u32))
.set_global_work_sizes(&global_work_size)
.set_local_work_sizes(&local_work_size)
.enqueue_nd_range(&self.command_queue)
.map_err(|e| {
RusTorchError::tensor_op(format!("Failed to execute kernel: {:?}", e))
})?;
}
unsafe {
self.command_queue
.enqueue_read_buffer(&buffer_c, CL_NON_BLOCKING, 0, c, &[])
.map_err(|e| RusTorchError::tensor_op(format!("Failed to read result: {:?}", e)))?;
}
Ok(())
}
fn generate_matmul_kernel(&self, m: usize, n: usize, k: usize) -> RusTorchResult<String> {
let tile_size = self.calculate_optimal_tile_size(m, n, k);
let kernel_source = if self.device_info.vendor.contains("AMD") {
self.generate_amd_optimized_kernel(tile_size)
} else if self.device_info.vendor.contains("NVIDIA") {
self.generate_nvidia_optimized_kernel(tile_size)
} else if self.device_info.vendor.contains("Intel") {
self.generate_intel_optimized_kernel(tile_size)
} else {
self.generate_generic_kernel(tile_size)
};
Ok(kernel_source)
}
fn calculate_optimal_tile_size(&self, m: usize, n: usize, k: usize) -> usize {
let max_work_group = self.device_info.max_work_group_size;
let local_mem_size = self.device_info.local_mem_size as usize;
let max_tile_from_memory =
(((local_mem_size / 2) / (2 * std::mem::size_of::<f32>())) as f64).sqrt() as usize;
let max_tile_from_workgroup = (max_work_group as f64).sqrt() as usize;
let optimal_tile = max_tile_from_memory.min(max_tile_from_workgroup).min(32);
let tile_size = (optimal_tile.next_power_of_two().max(8)).min(32);
tile_size
}
fn generate_amd_optimized_kernel(&self, tile_size: usize) -> String {
format!(
r#"
__kernel void matmul_f32(
__global const float* A,
__global const float* B,
__global float* C,
const unsigned int M,
const unsigned int N,
const unsigned int K
) {{
// AMD GCN architecture optimizations
const int TILE_SIZE = {tile_size};
// Get work group and local IDs
const int group_x = get_group_id(0);
const int group_y = get_group_id(1);
const int local_x = get_local_id(0);
const int local_y = get_local_id(1);
// Shared memory for tiles
__local float As[{tile_size}][{tile_size}];
__local float Bs[{tile_size}][{tile_size}];
// Calculate global position
const int global_row = group_y * TILE_SIZE + local_y;
const int global_col = group_x * TILE_SIZE + local_x;
float sum = 0.0f;
// Loop over tiles
for (int tile = 0; tile < (K + TILE_SIZE - 1) / TILE_SIZE; tile++) {{
// Load tiles into shared memory with bounds checking
int a_row = global_row;
int a_col = tile * TILE_SIZE + local_x;
int b_row = tile * TILE_SIZE + local_y;
int b_col = global_col;
As[local_y][local_x] = (a_row < M && a_col < K) ? A[a_row * K + a_col] : 0.0f;
Bs[local_y][local_x] = (b_row < K && b_col < N) ? B[b_row * N + b_col] : 0.0f;
barrier(CLK_LOCAL_MEM_FENCE);
// Compute partial sum with loop unrolling for AMD GCN
#pragma unroll 4
for (int k = 0; k < TILE_SIZE; k++) {{
sum = fma(As[local_y][k], Bs[k][local_x], sum);
}}
barrier(CLK_LOCAL_MEM_FENCE);
}}
// Write result
if (global_row < M && global_col < N) {{
C[global_row * N + global_col] = sum;
}}
}}
"#,
tile_size = tile_size
)
}
fn generate_nvidia_optimized_kernel(&self, tile_size: usize) -> String {
format!(
r#"
__kernel void matmul_f32(
__global const float* A,
__global const float* B,
__global float* C,
const unsigned int M,
const unsigned int N,
const unsigned int K
) {{
// NVIDIA CUDA core optimizations
const int TILE_SIZE = {tile_size};
const int group_x = get_group_id(0);
const int group_y = get_group_id(1);
const int local_x = get_local_id(0);
const int local_y = get_local_id(1);
__local float As[{tile_size}][{tile_size} + 1]; // Bank conflict avoidance
__local float Bs[{tile_size}][{tile_size} + 1];
const int global_row = group_y * TILE_SIZE + local_y;
const int global_col = group_x * TILE_SIZE + local_x;
float sum = 0.0f;
for (int tile = 0; tile < (K + TILE_SIZE - 1) / TILE_SIZE; tile++) {{
// Coalesced memory access for NVIDIA
int a_row = global_row;
int a_col = tile * TILE_SIZE + local_x;
int b_row = tile * TILE_SIZE + local_y;
int b_col = global_col;
As[local_y][local_x] = (a_row < M && a_col < K) ? A[a_row * K + a_col] : 0.0f;
Bs[local_y][local_x] = (b_row < K && b_col < N) ? B[b_row * N + b_col] : 0.0f;
barrier(CLK_LOCAL_MEM_FENCE);
// Optimized for NVIDIA warp execution
#pragma unroll 8
for (int k = 0; k < TILE_SIZE; k++) {{
sum += As[local_y][k] * Bs[k][local_x];
}}
barrier(CLK_LOCAL_MEM_FENCE);
}}
if (global_row < M && global_col < N) {{
C[global_row * N + global_col] = sum;
}}
}}
"#,
tile_size = tile_size
)
}
fn generate_intel_optimized_kernel(&self, tile_size: usize) -> String {
format!(
r#"
__kernel void matmul_f32(
__global const float* A,
__global const float* B,
__global float* C,
const unsigned int M,
const unsigned int N,
const unsigned int K
) {{
// Intel GPU optimizations
const int TILE_SIZE = {tile_size};
const int group_x = get_group_id(0);
const int group_y = get_group_id(1);
const int local_x = get_local_id(0);
const int local_y = get_local_id(1);
__local float As[{tile_size}][{tile_size}];
__local float Bs[{tile_size}][{tile_size}];
const int global_row = group_y * TILE_SIZE + local_y;
const int global_col = group_x * TILE_SIZE + local_x;
float sum = 0.0f;
for (int tile = 0; tile < (K + TILE_SIZE - 1) / TILE_SIZE; tile++) {{
int a_row = global_row;
int a_col = tile * TILE_SIZE + local_x;
int b_row = tile * TILE_SIZE + local_y;
int b_col = global_col;
// Intel-specific prefetch hints
As[local_y][local_x] = (a_row < M && a_col < K) ? A[a_row * K + a_col] : 0.0f;
Bs[local_y][local_x] = (b_row < K && b_col < N) ? B[b_row * N + b_col] : 0.0f;
barrier(CLK_LOCAL_MEM_FENCE);
// Conservative unrolling for Intel
#pragma unroll 2
for (int k = 0; k < TILE_SIZE; k++) {{
sum += As[local_y][k] * Bs[k][local_x];
}}
barrier(CLK_LOCAL_MEM_FENCE);
}}
if (global_row < M && global_col < N) {{
C[global_row * N + global_col] = sum;
}}
}}
"#,
tile_size = tile_size
)
}
fn generate_generic_kernel(&self, tile_size: usize) -> String {
format!(
r#"
__kernel void matmul_f32(
__global const float* A,
__global const float* B,
__global float* C,
const unsigned int M,
const unsigned int N,
const unsigned int K
) {{
const int TILE_SIZE = {tile_size};
const int group_x = get_group_id(0);
const int group_y = get_group_id(1);
const int local_x = get_local_id(0);
const int local_y = get_local_id(1);
__local float As[{tile_size}][{tile_size}];
__local float Bs[{tile_size}][{tile_size}];
const int global_row = group_y * TILE_SIZE + local_y;
const int global_col = group_x * TILE_SIZE + local_x;
float sum = 0.0f;
for (int tile = 0; tile < (K + TILE_SIZE - 1) / TILE_SIZE; tile++) {{
int a_row = global_row;
int a_col = tile * TILE_SIZE + local_x;
int b_row = tile * TILE_SIZE + local_y;
int b_col = global_col;
As[local_y][local_x] = (a_row < M && a_col < K) ? A[a_row * K + a_col] : 0.0f;
Bs[local_y][local_x] = (b_row < K && b_col < N) ? B[b_row * N + b_col] : 0.0f;
barrier(CLK_LOCAL_MEM_FENCE);
for (int k = 0; k < TILE_SIZE; k++) {{
sum += As[local_y][k] * Bs[k][local_x];
}}
barrier(CLK_LOCAL_MEM_FENCE);
}}
if (global_row < M && global_col < N) {{
C[global_row * N + global_col] = sum;
}}
}}
"#,
tile_size = tile_size
)
}
fn compile_kernel(&self, name: &str, source: &str) -> RusTorchResult<Kernel> {
let program = Program::create_and_build_from_source(&self.context, source, "")
.map_err(|e| RusTorchError::tensor_op(format!("Failed to compile kernel: {:?}", e)))?;
let kernel = Kernel::create(&program, name)
.map_err(|e| RusTorchError::tensor_op(format!("Failed to create kernel: {:?}", e)))?;
Ok(kernel)
}
fn calculate_work_sizes(&self, m: usize, n: usize) -> RusTorchResult<([usize; 2], [usize; 2])> {
let tile_size = self.calculate_optimal_tile_size(m, n, m.max(n));
let global_work_size = [
n.div_ceil(tile_size) * tile_size,
m.div_ceil(tile_size) * tile_size,
];
let local_work_size = [tile_size, tile_size];
Ok((global_work_size, local_work_size))
}
pub fn get_device_info_ref(&self) -> &OpenClDeviceInfo {
&self.device_info
}
}
#[cfg(not(feature = "opencl"))]
pub struct OpenClMatrixExecutor;
#[cfg(not(feature = "opencl"))]
impl OpenClMatrixExecutor {
pub fn new() -> RusTorchResult<Self> {
Err(RusTorchError::UnsupportedDevice(
"OpenCL not available".to_string(),
))
}
pub fn matmul_f32(
&mut self,
_a: &[f32],
_b: &[f32],
_c: &mut [f32],
_m: usize,
_n: usize,
_k: usize,
) -> RusTorchResult<()> {
Err(RusTorchError::UnsupportedDevice(
"OpenCL not available".to_string(),
))
}
}
#[allow(clippy::many_single_char_names)] pub fn opencl_matmul_f32(
a: &[f32],
b: &[f32],
c: &mut [f32],
m: usize,
n: usize,
k: usize,
) -> RusTorchResult<()> {
#[cfg(feature = "opencl")]
{
let mut executor = OpenClMatrixExecutor::new()?;
executor.matmul_f32(a, b, c, m, n, k)
}
#[cfg(not(feature = "opencl"))]
{
let _ = (a, b, c, m, n, k);
Err(RusTorchError::UnsupportedDevice(
"OpenCL not available".to_string(),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_opencl_executor_creation() {
let result = OpenClMatrixExecutor::new();
#[cfg(not(feature = "opencl"))]
assert!(result.is_err());
}
#[test]
fn test_opencl_matmul_interface() {
let a = vec![1.0f32; 64];
let b = vec![2.0f32; 64];
let mut c = vec![0.0f32; 64];
let result = opencl_matmul_f32(&a, &b, &mut c, 8, 8, 8);
#[cfg(not(feature = "opencl"))]
assert!(result.is_err());
}
}