pub(crate) mod utils;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::ptr;
use opencl3::command_queue::CommandQueue;
use opencl3::context::Context;
use opencl3::error_codes::ClError;
use opencl3::kernel::ExecuteKernel;
use opencl3::memory::CL_MEM_READ_WRITE;
use opencl3::types::CL_BLOCKING;
use log::debug;
use crate::device::{DeviceUuid, PciId, Vendor};
use crate::error::{GPUError, GPUResult};
use crate::LocalBuffer;
#[allow(non_camel_case_types)]
pub type cl_device_id = opencl3::types::cl_device_id;
#[derive(Debug)]
pub struct Buffer<T> {
buffer: opencl3::memory::Buffer<u8>,
length: usize,
_phantom: std::marker::PhantomData<T>,
}
#[derive(Debug, Clone)]
pub struct Device {
vendor: Vendor,
name: String,
memory: u64,
compute_units: u32,
compute_capability: Option<(u32, u32)>,
pci_id: PciId,
uuid: Option<DeviceUuid>,
device: opencl3::device::Device,
}
impl Hash for Device {
fn hash<H: Hasher>(&self, state: &mut H) {
self.vendor.hash(state);
self.name.hash(state);
self.memory.hash(state);
self.pci_id.hash(state);
self.uuid.hash(state);
}
}
impl PartialEq for Device {
fn eq(&self, other: &Self) -> bool {
self.vendor == other.vendor
&& self.name == other.name
&& self.memory == other.memory
&& self.pci_id == other.pci_id
&& self.uuid == other.uuid
}
}
impl Eq for Device {}
impl Device {
pub fn vendor(&self) -> Vendor {
self.vendor
}
pub fn name(&self) -> String {
self.name.clone()
}
pub fn memory(&self) -> u64 {
self.memory
}
pub fn compute_units(&self) -> u32 {
self.compute_units
}
pub fn compute_capability(&self) -> Option<(u32, u32)> {
self.compute_capability
}
pub fn pci_id(&self) -> PciId {
self.pci_id
}
pub fn uuid(&self) -> Option<DeviceUuid> {
self.uuid
}
pub fn cl_device_id(&self) -> cl_device_id {
self.device.id()
}
}
#[allow(rustdoc::broken_intra_doc_links)]
pub struct Program {
device_name: String,
queue: CommandQueue,
context: Context,
kernels_by_name: HashMap<String, opencl3::kernel::Kernel>,
}
impl Program {
pub fn device_name(&self) -> &str {
&self.device_name
}
pub fn from_opencl(device: &Device, src: &str) -> GPUResult<Program> {
debug!("Creating OpenCL program from source.");
let cached = utils::cache_path(device, src)?;
if std::path::Path::exists(&cached) {
let bin = std::fs::read(cached)?;
Program::from_binary(device, bin)
} else {
let context = Context::from_device(&device.device)?;
debug!(
"Building kernel ({}) from source…",
cached.to_string_lossy()
);
let mut program = opencl3::program::Program::create_from_source(&context, src)?;
if let Err(build_error) = program.build(context.devices(), "") {
let log = program.get_build_log(context.devices()[0])?;
return Err(GPUError::Opencl3(build_error, Some(log)));
}
debug!(
"Building kernel ({}) from source: done.",
cached.to_string_lossy()
);
let queue = CommandQueue::create_default(&context, 0)?;
let kernels = opencl3::kernel::create_program_kernels(&program)?;
let kernels_by_name = kernels
.into_iter()
.map(|kernel| {
let name = kernel.function_name()?;
Ok((name, kernel))
})
.collect::<Result<_, ClError>>()?;
let prog = Program {
device_name: device.name(),
queue,
context,
kernels_by_name,
};
let binaries = program
.get_binaries()
.map_err(GPUError::ProgramInfoNotAvailable)?;
std::fs::write(cached, binaries[0].clone())?;
Ok(prog)
}
}
pub fn from_binary(device: &Device, bin: Vec<u8>) -> GPUResult<Program> {
debug!("Creating OpenCL program from binary.");
let context = Context::from_device(&device.device)?;
let bins = vec![&bin[..]];
let mut program = unsafe {
opencl3::program::Program::create_from_binary(&context, context.devices(), &bins)
}?;
if let Err(build_error) = program.build(context.devices(), "") {
let log = program.get_build_log(context.devices()[0])?;
return Err(GPUError::Opencl3(build_error, Some(log)));
}
let queue = CommandQueue::create_default(&context, 0)?;
let kernels = opencl3::kernel::create_program_kernels(&program)?;
let kernels_by_name = kernels
.into_iter()
.map(|kernel| {
let name = kernel.function_name()?;
Ok((name, kernel))
})
.collect::<Result<_, ClError>>()?;
Ok(Program {
device_name: device.name(),
queue,
context,
kernels_by_name,
})
}
pub unsafe fn create_buffer<T>(&self, length: usize) -> GPUResult<Buffer<T>> {
assert!(length > 0);
let mut buff = opencl3::memory::Buffer::create(
&self.context,
CL_MEM_READ_WRITE,
length * std::mem::size_of::<T>(),
ptr::null_mut(),
)?;
self.queue
.enqueue_write_buffer(&mut buff, opencl3::types::CL_BLOCKING, 0, &[0u8], &[])?;
Ok(Buffer::<T> {
buffer: buff,
length,
_phantom: std::marker::PhantomData,
})
}
pub fn create_buffer_from_slice<T>(&self, slice: &[T]) -> GPUResult<Buffer<T>> {
let length = slice.len();
let bytes_len = length * std::mem::size_of::<T>();
let mut buffer = unsafe {
opencl3::memory::Buffer::create(
&self.context,
CL_MEM_READ_WRITE,
bytes_len,
ptr::null_mut(),
)?
};
let bytes = unsafe {
std::slice::from_raw_parts(slice.as_ptr() as *const T as *const u8, bytes_len)
};
unsafe {
self.queue
.enqueue_write_buffer(&mut buffer, CL_BLOCKING, 0, &[0u8], &[])?;
self.queue
.enqueue_write_buffer(&mut buffer, CL_BLOCKING, 0, bytes, &[])?;
};
Ok(Buffer::<T> {
buffer,
length,
_phantom: std::marker::PhantomData,
})
}
pub fn create_kernel(
&self,
name: &str,
global_work_size: usize,
local_work_size: usize,
) -> GPUResult<Kernel> {
let kernel = self
.kernels_by_name
.get(name)
.ok_or_else(|| GPUError::KernelNotFound(name.to_string()))?;
let mut builder = ExecuteKernel::new(kernel);
builder.set_global_work_size(global_work_size * local_work_size);
builder.set_local_work_size(local_work_size);
Ok(Kernel {
builder,
queue: &self.queue,
num_local_buffers: 0,
})
}
pub fn write_from_buffer<T>(
&self,
buffer: &mut Buffer<T>,
data: &[T],
) -> GPUResult<()> {
assert!(data.len() <= buffer.length, "Buffer is too small");
let bytes = unsafe {
std::slice::from_raw_parts(
data.as_ptr() as *const T as *const u8,
data.len() * std::mem::size_of::<T>(),
)
};
unsafe {
self.queue
.enqueue_write_buffer(&mut buffer.buffer, CL_BLOCKING, 0, bytes, &[])?;
}
Ok(())
}
pub fn read_into_buffer<T>(&self, buffer: &Buffer<T>, data: &mut [T]) -> GPUResult<()> {
assert!(data.len() <= buffer.length, "Buffer is too small");
let bytes = unsafe {
std::slice::from_raw_parts_mut(
data.as_mut_ptr() as *mut T as *mut u8,
data.len() * std::mem::size_of::<T>(),
)
};
unsafe {
self.queue
.enqueue_read_buffer(&buffer.buffer, CL_BLOCKING, 0, bytes, &[])?;
};
Ok(())
}
pub fn run<F, R, E, A>(&self, fun: F, arg: A) -> Result<R, E>
where
F: FnOnce(&Self, A) -> Result<R, E>,
E: From<GPUError>,
{
fun(self, arg)
}
}
pub trait KernelArgument {
fn push(&self, kernel: &mut Kernel);
}
impl<T> KernelArgument for Buffer<T> {
fn push(&self, kernel: &mut Kernel) {
unsafe {
kernel.builder.set_arg(&self.buffer);
}
}
}
impl KernelArgument for i32 {
fn push(&self, kernel: &mut Kernel) {
unsafe {
kernel.builder.set_arg(self);
}
}
}
impl KernelArgument for u32 {
fn push(&self, kernel: &mut Kernel) {
unsafe {
kernel.builder.set_arg(self);
}
}
}
impl<T> KernelArgument for LocalBuffer<T> {
fn push(&self, kernel: &mut Kernel) {
unsafe {
kernel
.builder
.set_arg_local_buffer(self.length * std::mem::size_of::<T>());
}
kernel.num_local_buffers += 1;
}
}
#[derive(Debug)]
pub struct Kernel<'a> {
pub builder: ExecuteKernel<'a>,
queue: &'a CommandQueue,
num_local_buffers: u8,
}
impl<'a> Kernel<'a> {
pub fn arg<T: KernelArgument>(mut self, t: &'a T) -> Self {
t.push(&mut self);
self
}
pub fn run(mut self) -> GPUResult<()> {
if self.num_local_buffers > 1 {
return Err(GPUError::Generic(
"There cannot be more than one `LocalBuffer`.".to_string(),
));
}
unsafe {
self.builder.enqueue_nd_range(self.queue)?;
}
Ok(())
}
}