pub(crate) mod utils;
use std::convert::TryFrom;
use std::ffi::{c_void, CStr, CString};
use std::fmt;
use std::hash::{Hash, Hasher};
use log::debug;
use rustacuda::memory::{AsyncCopyDestination, DeviceBuffer};
use rustacuda::stream::{Stream, StreamFlags};
use crate::device::{DeviceUuid, PciId, Vendor};
use crate::error::{GPUError, GPUResult};
use crate::LocalBuffer;
#[derive(Debug)]
pub struct Buffer<T> {
buffer: DeviceBuffer<u8>,
length: usize,
_phantom: std::marker::PhantomData<T>,
}
#[derive(Debug, Clone)]
pub struct Device {
vendor: Vendor,
name: String,
memory: u64,
compute_units: u32,
compute_capability: (u32, u32),
pci_id: PciId,
uuid: Option<DeviceUuid>,
context: rustacuda::context::UnownedContext,
}
impl Hash for Device {
fn hash<H: Hasher>(&self, state: &mut H) {
self.vendor.hash(state);
self.name.hash(state);
self.memory.hash(state);
self.pci_id.hash(state);
self.uuid.hash(state);
}
}
impl PartialEq for Device {
fn eq(&self, other: &Self) -> bool {
self.vendor == other.vendor
&& self.name == other.name
&& self.memory == other.memory
&& self.pci_id == other.pci_id
&& self.uuid == other.uuid
}
}
impl Eq for Device {}
impl Device {
pub fn vendor(&self) -> Vendor {
self.vendor
}
pub fn name(&self) -> String {
self.name.clone()
}
pub fn memory(&self) -> u64 {
self.memory
}
pub fn compute_units(&self) -> u32 {
self.compute_units
}
pub fn compute_capability(&self) -> (u32, u32) {
self.compute_capability
}
pub fn pci_id(&self) -> PciId {
self.pci_id
}
pub fn uuid(&self) -> Option<DeviceUuid> {
self.uuid
}
}
#[allow(rustdoc::broken_intra_doc_links)]
#[derive(Debug)]
pub struct Program {
context: rustacuda::context::UnownedContext,
module: rustacuda::module::Module,
stream: Stream,
device_name: String,
}
impl Program {
pub fn device_name(&self) -> &str {
&self.device_name
}
pub fn from_binary(device: &Device, filename: &CStr) -> GPUResult<Program> {
debug!("Creating CUDA program from binary file.");
rustacuda::context::CurrentContext::set_current(&device.context)?;
let module = rustacuda::module::Module::load_from_file(filename).map_err(|err| {
Self::pop_context();
err
})?;
let stream = Stream::new(StreamFlags::NON_BLOCKING, None).map_err(|err| {
Self::pop_context();
err
})?;
let prog = Program {
module,
stream,
device_name: device.name(),
context: device.context.clone(),
};
Self::pop_context();
Ok(prog)
}
pub fn from_bytes(device: &Device, bytes: &[u8]) -> GPUResult<Program> {
debug!("Creating CUDA program from bytes.");
rustacuda::context::CurrentContext::set_current(&device.context)?;
let module = rustacuda::module::Module::load_from_bytes(bytes).map_err(|err| {
Self::pop_context();
err
})?;
let stream = Stream::new(StreamFlags::NON_BLOCKING, None).map_err(|err| {
Self::pop_context();
err
})?;
let prog = Program {
module,
stream,
device_name: device.name(),
context: device.context.clone(),
};
Self::pop_context();
Ok(prog)
}
pub unsafe fn create_buffer<T>(&self, length: usize) -> GPUResult<Buffer<T>> {
assert!(length > 0);
let buffer = DeviceBuffer::<u8>::uninitialized(length * std::mem::size_of::<T>())?;
Ok(Buffer::<T> {
buffer,
length,
_phantom: std::marker::PhantomData,
})
}
pub fn create_buffer_from_slice<T>(&self, slice: &[T]) -> GPUResult<Buffer<T>> {
let bytes_len = slice.len() * std::mem::size_of::<T>();
let bytes = unsafe {
std::slice::from_raw_parts(slice.as_ptr() as *const T as *const u8, bytes_len)
};
let mut buffer = unsafe { DeviceBuffer::<u8>::uninitialized(bytes_len)? };
unsafe { buffer.async_copy_from(bytes, &self.stream)? };
self.stream.synchronize()?;
Ok(Buffer::<T> {
buffer,
length: slice.len(),
_phantom: std::marker::PhantomData,
})
}
pub fn create_kernel(&self, name: &str, gws: usize, lws: usize) -> GPUResult<Kernel> {
let function_name = CString::new(name).expect("Kernel name must not contain nul bytes");
let function = self.module.get_function(&function_name)?;
Ok(Kernel {
function,
global_work_size: gws,
local_work_size: lws,
stream: &self.stream,
args: Vec::new(),
})
}
pub fn write_from_buffer<T>(&self, buffer: &mut Buffer<T>, data: &[T]) -> GPUResult<()> {
assert!(data.len() <= buffer.length, "Buffer is too small");
let bytes = unsafe {
std::slice::from_raw_parts(
data.as_ptr() as *const T as *const u8,
data.len() * std::mem::size_of::<T>(),
)
};
unsafe { buffer.buffer.async_copy_from(bytes, &self.stream)? };
self.stream.synchronize()?;
Ok(())
}
pub fn read_into_buffer<T>(&self, buffer: &Buffer<T>, data: &mut [T]) -> GPUResult<()> {
assert!(data.len() <= buffer.length, "Buffer is too small");
let bytes = unsafe {
std::slice::from_raw_parts_mut(
data.as_mut_ptr() as *mut T as *mut u8,
data.len() * std::mem::size_of::<T>(),
)
};
unsafe { buffer.buffer.async_copy_to(bytes, &self.stream)? };
self.stream.synchronize()?;
Ok(())
}
pub fn run<F, R, E, A>(&self, fun: F, arg: A) -> Result<R, E>
where
F: FnOnce(&Self, A) -> Result<R, E>,
E: From<GPUError>,
{
rustacuda::context::CurrentContext::set_current(&self.context).map_err(Into::into)?;
let result = fun(self, arg);
Self::pop_context();
result
}
fn pop_context() {
rustacuda::context::ContextStack::pop().expect("Cannot remove context.");
}
}
unsafe impl Send for Program {}
pub trait KernelArgument {
fn as_c_void(&self) -> *mut c_void;
fn shared_mem(&self) -> u32 {
0
}
}
impl<T> KernelArgument for Buffer<T> {
fn as_c_void(&self) -> *mut c_void {
&self.buffer as *const _ as _
}
}
impl KernelArgument for i32 {
fn as_c_void(&self) -> *mut c_void {
self as *const _ as _
}
}
impl KernelArgument for u32 {
fn as_c_void(&self) -> *mut c_void {
self as *const _ as _
}
}
impl<T> KernelArgument for LocalBuffer<T> {
fn as_c_void(&self) -> *mut c_void {
self as *const _ as _
}
fn shared_mem(&self) -> u32 {
u32::try_from(self.length * std::mem::size_of::<T>())
.expect("__shared__ memory allocation is too big.")
}
}
pub struct Kernel<'a> {
function: rustacuda::function::Function<'a>,
global_work_size: usize,
local_work_size: usize,
stream: &'a Stream,
args: Vec<&'a dyn KernelArgument>,
}
impl fmt::Debug for Kernel<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let args = self
.args
.iter()
.map(|arg| (arg.as_c_void(), arg.shared_mem()))
.collect::<Vec<_>>();
f.debug_struct("Kernel")
.field("function", &self.function)
.field("global_work_size", &self.global_work_size)
.field("local_work_size", &self.local_work_size)
.field("stream", &self.stream)
.field("args", &args)
.finish()
}
}
impl<'a> Kernel<'a> {
pub fn arg<T: KernelArgument>(mut self, t: &'a T) -> Self {
self.args.push(t);
self
}
pub fn run(self) -> GPUResult<()> {
let shared_mem = self
.args
.iter()
.try_fold(0, |acc, &arg| -> GPUResult<u32> {
let mem = arg.shared_mem();
match (mem, acc) {
(0, _) => Ok(acc),
(_, 0) => Ok(mem),
(_, _) => Err(GPUError::Generic(
"There cannot be more than one `LocalBuffer`.".to_string(),
)),
}
})?;
let args = self
.args
.iter()
.map(|arg| arg.as_c_void())
.collect::<Vec<_>>();
unsafe {
self.stream.launch(
&self.function,
self.global_work_size as u32,
self.local_work_size as u32,
shared_mem,
&args,
)?;
};
self.stream.synchronize()?;
Ok(())
}
}