use crate::{contexted_call, contexted_new, device::*, error::*, *};
use cuda::*;
use num_traits::ToPrimitive;
use std::{ffi::*, path::*, ptr::null_mut, sync::Arc};
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
pub struct Block {
pub x: u32,
pub y: u32,
pub z: u32,
}
impl Block {
pub fn x<I: ToPrimitive>(x: I) -> Self {
Block {
x: x.to_u32().expect("Cannot convert to u32"),
y: 1,
z: 1,
}
}
pub fn xy<I1: ToPrimitive, I2: ToPrimitive>(x: I1, y: I2) -> Self {
Block {
x: x.to_u32().expect("Cannot convert to u32"),
y: y.to_u32().expect("Cannot convert to u32"),
z: 1,
}
}
pub fn xyz<I1: ToPrimitive, I2: ToPrimitive, I3: ToPrimitive>(x: I1, y: I2, z: I3) -> Self {
Block {
x: x.to_u32().expect("Cannot convert to u32"),
y: y.to_u32().expect("Cannot convert to u32"),
z: z.to_u32().expect("Cannot convert to u32"),
}
}
}
impl<I: ToPrimitive> Into<Block> for (I,) {
fn into(self) -> Block {
Block::x(self.0)
}
}
impl<I1: ToPrimitive, I2: ToPrimitive> Into<Block> for (I1, I2) {
fn into(self) -> Block {
Block::xy(self.0, self.1)
}
}
impl<I1: ToPrimitive, I2: ToPrimitive, I3: ToPrimitive> Into<Block> for (I1, I2, I3) {
fn into(self) -> Block {
Block::xyz(self.0, self.1, self.2)
}
}
macro_rules! impl_into_block {
($integer:ty) => {
impl Into<Block> for $integer {
fn into(self) -> Block {
Block::x(self)
}
}
};
}
impl_into_block!(u8);
impl_into_block!(u16);
impl_into_block!(u32);
impl_into_block!(u64);
impl_into_block!(u128);
impl_into_block!(usize);
impl_into_block!(i8);
impl_into_block!(i16);
impl_into_block!(i32);
impl_into_block!(i64);
impl_into_block!(i128);
impl_into_block!(isize);
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
pub struct Grid {
pub x: u32,
pub y: u32,
pub z: u32,
}
impl Grid {
pub fn x<I: ToPrimitive>(x: I) -> Self {
Grid {
x: x.to_u32().expect("Cannot convert to u32"),
y: 1,
z: 1,
}
}
pub fn xy<I1: ToPrimitive, I2: ToPrimitive>(x: I1, y: I2) -> Self {
Grid {
x: x.to_u32().expect("Cannot convert to u32"),
y: y.to_u32().expect("Cannot convert to u32"),
z: 1,
}
}
pub fn xyz<I1: ToPrimitive, I2: ToPrimitive, I3: ToPrimitive>(x: I1, y: I2, z: I3) -> Self {
Grid {
x: x.to_u32().expect("Cannot convert to u32"),
y: y.to_u32().expect("Cannot convert to u32"),
z: z.to_u32().expect("Cannot convert to u32"),
}
}
}
impl<I: ToPrimitive> Into<Grid> for (I,) {
fn into(self) -> Grid {
Grid::x(self.0)
}
}
impl<I1: ToPrimitive, I2: ToPrimitive> Into<Grid> for (I1, I2) {
fn into(self) -> Grid {
Grid::xy(self.0, self.1)
}
}
impl<I1: ToPrimitive, I2: ToPrimitive, I3: ToPrimitive> Into<Grid> for (I1, I2, I3) {
fn into(self) -> Grid {
Grid::xyz(self.0, self.1, self.2)
}
}
macro_rules! impl_into_grid {
($integer:ty) => {
impl Into<Grid> for $integer {
fn into(self) -> Grid {
Grid::x(self)
}
}
};
}
impl_into_grid!(u8);
impl_into_grid!(u16);
impl_into_grid!(u32);
impl_into_grid!(u64);
impl_into_grid!(u128);
impl_into_grid!(usize);
impl_into_grid!(i8);
impl_into_grid!(i16);
impl_into_grid!(i32);
impl_into_grid!(i64);
impl_into_grid!(i128);
impl_into_grid!(isize);
#[derive(Debug)]
pub enum Instruction {
PTX(CString),
PTXFile(PathBuf),
Cubin(Vec<u8>),
CubinFile(PathBuf),
}
impl Instruction {
pub fn ptx(s: &str) -> Instruction {
let ptx = CString::new(s).expect("Invalid PTX string");
Instruction::PTX(ptx)
}
pub fn cubin(sl: &[u8]) -> Instruction {
Instruction::Cubin(sl.to_vec())
}
pub fn ptx_file(path: &Path) -> Result<Self> {
if !path.exists() {
return Err(AccelError::FileNotFound {
path: path.to_owned(),
});
}
Ok(Instruction::PTXFile(path.to_owned()))
}
pub fn cubin_file(path: &Path) -> Result<Self> {
if !path.exists() {
return Err(AccelError::FileNotFound {
path: path.to_owned(),
});
}
Ok(Instruction::CubinFile(path.to_owned()))
}
}
impl Instruction {
pub fn input_type(&self) -> CUjitInputType {
match *self {
Instruction::PTX(_) | Instruction::PTXFile(_) => CUjitInputType_enum::CU_JIT_INPUT_PTX,
Instruction::Cubin(_) | Instruction::CubinFile(_) => {
CUjitInputType_enum::CU_JIT_INPUT_CUBIN
}
}
}
}
#[derive(Debug)]
pub struct Kernel<'module> {
func: CUfunction,
module: &'module Module,
}
impl Contexted for Kernel<'_> {
fn get_context(&self) -> Arc<Context> {
self.module.get_context()
}
}
pub trait DeviceSend: Sized {
fn as_ptr(&self) -> *const u8 {
self as *const Self as *const u8
}
}
impl<T> DeviceSend for *mut T {}
impl<T> DeviceSend for *const T {}
impl DeviceSend for bool {}
impl DeviceSend for i8 {}
impl DeviceSend for i16 {}
impl DeviceSend for i32 {}
impl DeviceSend for i64 {}
impl DeviceSend for isize {}
impl DeviceSend for u8 {}
impl DeviceSend for u16 {}
impl DeviceSend for u32 {}
impl DeviceSend for u64 {}
impl DeviceSend for usize {}
impl DeviceSend for f32 {}
impl DeviceSend for f64 {}
pub trait Arguments<'arg> {
fn kernel_params(&self) -> Vec<*mut c_void>;
}
macro_rules! impl_kernel_parameters {
($($name:ident),*; $($num:tt),*) => {
impl<'arg, $($name : DeviceSend),*> Arguments<'arg> for ($( &'arg $name, )*) {
fn kernel_params(&self) -> Vec<*mut c_void> {
vec![$( self.$num.as_ptr() as *mut c_void ),*]
}
}
}
}
impl_kernel_parameters!(;);
impl_kernel_parameters!(D0; 0);
impl_kernel_parameters!(D0, D1; 0, 1);
impl_kernel_parameters!(D0, D1, D2; 0, 1, 2);
impl_kernel_parameters!(D0, D1, D2, D3; 0, 1, 2, 3);
impl_kernel_parameters!(D0, D1, D2, D3, D4; 0, 1, 2, 3, 4);
impl_kernel_parameters!(D0, D1, D2, D3, D4, D5; 0, 1, 2, 3, 4, 5);
impl_kernel_parameters!(D0, D1, D2, D3, D4, D5, D6; 0, 1, 2, 3, 4, 5, 6);
impl_kernel_parameters!(D0, D1, D2, D3, D4, D5, D6, D7; 0, 1, 2, 3, 4, 5, 6, 7);
impl_kernel_parameters!(D0, D1, D2, D3, D4, D5, D6, D7, D8; 0, 1, 2, 3, 4, 5, 6, 7, 8);
impl_kernel_parameters!(D0, D1, D2, D3, D4, D5, D6, D7, D8, D9; 0, 1, 2, 3, 4, 5, 6, 7, 8, 9);
impl_kernel_parameters!(D0, D1, D2, D3, D4, D5, D6, D7, D8, D9, D10; 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
impl_kernel_parameters!(D0, D1, D2, D3, D4, D5, D6, D7, D8, D9, D10, D11; 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);
pub trait Launchable<'arg> {
type Args: Arguments<'arg>;
fn get_kernel(&self) -> Result<Kernel>;
fn launch<G: Into<Grid>, B: Into<Block>>(
&self,
grid: G,
block: B,
args: &Self::Args,
) -> Result<()> {
let grid = grid.into();
let block = block.into();
let kernel = self.get_kernel()?;
let mut params = args.kernel_params();
unsafe {
contexted_call!(
&kernel.get_context(),
cuLaunchKernel,
kernel.func,
grid.x,
grid.y,
grid.z,
block.x,
block.y,
block.z,
0,
null_mut(),
params.as_mut_ptr(),
null_mut()
)?;
}
kernel.sync_context()?;
Ok(())
}
}
#[derive(Debug)]
pub struct Module {
module: CUmodule,
context: Arc<Context>,
}
impl Drop for Module {
fn drop(&mut self) {
if let Err(e) = unsafe { contexted_call!(&self.get_context(), cuModuleUnload, self.module) }
{
log::error!("Failed to unload module: {:?}", e);
}
}
}
impl Contexted for Module {
fn get_context(&self) -> Arc<Context> {
self.context.clone()
}
}
impl Module {
pub fn load(context: Arc<Context>, data: &Instruction) -> Result<Self> {
match *data {
Instruction::PTX(ref ptx) => {
let module = unsafe {
contexted_new!(&context, cuModuleLoadData, ptx.as_ptr() as *const _)?
};
Ok(Module { module, context })
}
Instruction::Cubin(ref bin) => {
let module = unsafe {
contexted_new!(&context, cuModuleLoadData, bin.as_ptr() as *const _)?
};
Ok(Module { module, context })
}
Instruction::PTXFile(ref path) | Instruction::CubinFile(ref path) => {
let filename = path_to_cstring(path);
let module = unsafe { contexted_new!(&context, cuModuleLoad, filename.as_ptr())? };
Ok(Module { module, context })
}
}
}
pub fn from_str(context: Arc<Context>, ptx: &str) -> Result<Self> {
let data = Instruction::ptx(ptx);
Self::load(context, &data)
}
pub fn get_kernel(&self, name: &str) -> Result<Kernel> {
let name = CString::new(name).expect("Invalid Kernel name");
let func = unsafe {
contexted_new!(
&self.get_context(),
cuModuleGetFunction,
self.module,
name.as_ptr()
)
}?;
Ok(Kernel { func, module: self })
}
}
fn path_to_cstring(path: &Path) -> CString {
CString::new(path.to_str().unwrap()).expect("Invalid Path")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn load_do_nothing() -> Result<()> {
let ptx = r#"
.version 3.2
.target sm_30
.address_size 64
.visible .entry do_nothing()
{
ret;
}
"#;
let device = Device::nth(0)?;
let ctx = device.create_context();
let _mod = Module::from_str(ctx, ptx)?;
Ok(())
}
}