1use crate::{contexted_call, contexted_new, device::*, error::*, *};
4use cuda::*;
5use num_traits::ToPrimitive;
6use std::{ffi::*, path::*, ptr::null_mut, sync::Arc};
7
8#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
63pub struct Block {
64 pub x: u32,
65 pub y: u32,
66 pub z: u32,
67}
68
69impl Block {
70 pub fn x<I: ToPrimitive>(x: I) -> Self {
76 Block {
77 x: x.to_u32().expect("Cannot convert to u32"),
78 y: 1,
79 z: 1,
80 }
81 }
82
83 pub fn xy<I1: ToPrimitive, I2: ToPrimitive>(x: I1, y: I2) -> Self {
89 Block {
90 x: x.to_u32().expect("Cannot convert to u32"),
91 y: y.to_u32().expect("Cannot convert to u32"),
92 z: 1,
93 }
94 }
95
96 pub fn xyz<I1: ToPrimitive, I2: ToPrimitive, I3: ToPrimitive>(x: I1, y: I2, z: I3) -> Self {
102 Block {
103 x: x.to_u32().expect("Cannot convert to u32"),
104 y: y.to_u32().expect("Cannot convert to u32"),
105 z: z.to_u32().expect("Cannot convert to u32"),
106 }
107 }
108}
109
110impl<I: ToPrimitive> Into<Block> for (I,) {
111 fn into(self) -> Block {
112 Block::x(self.0)
113 }
114}
115
116impl<I1: ToPrimitive, I2: ToPrimitive> Into<Block> for (I1, I2) {
117 fn into(self) -> Block {
118 Block::xy(self.0, self.1)
119 }
120}
121
122impl<I1: ToPrimitive, I2: ToPrimitive, I3: ToPrimitive> Into<Block> for (I1, I2, I3) {
123 fn into(self) -> Block {
124 Block::xyz(self.0, self.1, self.2)
125 }
126}
127
128macro_rules! impl_into_block {
129 ($integer:ty) => {
130 impl Into<Block> for $integer {
131 fn into(self) -> Block {
132 Block::x(self)
133 }
134 }
135 };
136}
137
138impl_into_block!(u8);
139impl_into_block!(u16);
140impl_into_block!(u32);
141impl_into_block!(u64);
142impl_into_block!(u128);
143impl_into_block!(usize);
144impl_into_block!(i8);
145impl_into_block!(i16);
146impl_into_block!(i32);
147impl_into_block!(i64);
148impl_into_block!(i128);
149impl_into_block!(isize);
150
151#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
206pub struct Grid {
207 pub x: u32,
208 pub y: u32,
209 pub z: u32,
210}
211
212impl Grid {
213 pub fn x<I: ToPrimitive>(x: I) -> Self {
219 Grid {
220 x: x.to_u32().expect("Cannot convert to u32"),
221 y: 1,
222 z: 1,
223 }
224 }
225
226 pub fn xy<I1: ToPrimitive, I2: ToPrimitive>(x: I1, y: I2) -> Self {
232 Grid {
233 x: x.to_u32().expect("Cannot convert to u32"),
234 y: y.to_u32().expect("Cannot convert to u32"),
235 z: 1,
236 }
237 }
238
239 pub fn xyz<I1: ToPrimitive, I2: ToPrimitive, I3: ToPrimitive>(x: I1, y: I2, z: I3) -> Self {
245 Grid {
246 x: x.to_u32().expect("Cannot convert to u32"),
247 y: y.to_u32().expect("Cannot convert to u32"),
248 z: z.to_u32().expect("Cannot convert to u32"),
249 }
250 }
251}
252
253impl<I: ToPrimitive> Into<Grid> for (I,) {
254 fn into(self) -> Grid {
255 Grid::x(self.0)
256 }
257}
258
259impl<I1: ToPrimitive, I2: ToPrimitive> Into<Grid> for (I1, I2) {
260 fn into(self) -> Grid {
261 Grid::xy(self.0, self.1)
262 }
263}
264
265impl<I1: ToPrimitive, I2: ToPrimitive, I3: ToPrimitive> Into<Grid> for (I1, I2, I3) {
266 fn into(self) -> Grid {
267 Grid::xyz(self.0, self.1, self.2)
268 }
269}
270
271macro_rules! impl_into_grid {
272 ($integer:ty) => {
273 impl Into<Grid> for $integer {
274 fn into(self) -> Grid {
275 Grid::x(self)
276 }
277 }
278 };
279}
280
281impl_into_grid!(u8);
282impl_into_grid!(u16);
283impl_into_grid!(u32);
284impl_into_grid!(u64);
285impl_into_grid!(u128);
286impl_into_grid!(usize);
287impl_into_grid!(i8);
288impl_into_grid!(i16);
289impl_into_grid!(i32);
290impl_into_grid!(i64);
291impl_into_grid!(i128);
292impl_into_grid!(isize);
293
294#[derive(Debug)]
296pub enum Instruction {
297 PTX(CString),
298 PTXFile(PathBuf),
299 Cubin(Vec<u8>),
300 CubinFile(PathBuf),
301}
302
303impl Instruction {
304 pub fn ptx(s: &str) -> Instruction {
306 let ptx = CString::new(s).expect("Invalid PTX string");
307 Instruction::PTX(ptx)
308 }
309
310 pub fn cubin(sl: &[u8]) -> Instruction {
312 Instruction::Cubin(sl.to_vec())
313 }
314
315 pub fn ptx_file(path: &Path) -> Result<Self> {
317 if !path.exists() {
318 return Err(AccelError::FileNotFound {
319 path: path.to_owned(),
320 });
321 }
322 Ok(Instruction::PTXFile(path.to_owned()))
323 }
324
325 pub fn cubin_file(path: &Path) -> Result<Self> {
327 if !path.exists() {
328 return Err(AccelError::FileNotFound {
329 path: path.to_owned(),
330 });
331 }
332 Ok(Instruction::CubinFile(path.to_owned()))
333 }
334}
335
336impl Instruction {
337 pub fn input_type(&self) -> CUjitInputType {
339 match *self {
340 Instruction::PTX(_) | Instruction::PTXFile(_) => CUjitInputType_enum::CU_JIT_INPUT_PTX,
341 Instruction::Cubin(_) | Instruction::CubinFile(_) => {
342 CUjitInputType_enum::CU_JIT_INPUT_CUBIN
343 }
344 }
345 }
346}
347
348#[derive(Debug)]
350pub struct Kernel<'module> {
351 func: CUfunction,
352 module: &'module Module,
353}
354
355impl Contexted for Kernel<'_> {
356 fn get_context(&self) -> Arc<Context> {
357 self.module.get_context()
358 }
359}
360
361pub trait DeviceSend: Sized {
378 fn as_ptr(&self) -> *const u8 {
380 self as *const Self as *const u8
381 }
382}
383
384impl<T> DeviceSend for *mut T {}
386impl<T> DeviceSend for *const T {}
387impl DeviceSend for bool {}
388impl DeviceSend for i8 {}
389impl DeviceSend for i16 {}
390impl DeviceSend for i32 {}
391impl DeviceSend for i64 {}
392impl DeviceSend for isize {}
393impl DeviceSend for u8 {}
394impl DeviceSend for u16 {}
395impl DeviceSend for u32 {}
396impl DeviceSend for u64 {}
397impl DeviceSend for usize {}
398impl DeviceSend for f32 {}
399impl DeviceSend for f64 {}
400
401pub trait Arguments<'arg> {
414 fn kernel_params(&self) -> Vec<*mut c_void>;
418}
419
420macro_rules! impl_kernel_parameters {
421 ($($name:ident),*; $($num:tt),*) => {
422 impl<'arg, $($name : DeviceSend),*> Arguments<'arg> for ($( &'arg $name, )*) {
423 fn kernel_params(&self) -> Vec<*mut c_void> {
424 vec![$( self.$num.as_ptr() as *mut c_void ),*]
425 }
426 }
427 }
428}
429
430impl_kernel_parameters!(;);
431impl_kernel_parameters!(D0; 0);
432impl_kernel_parameters!(D0, D1; 0, 1);
433impl_kernel_parameters!(D0, D1, D2; 0, 1, 2);
434impl_kernel_parameters!(D0, D1, D2, D3; 0, 1, 2, 3);
435impl_kernel_parameters!(D0, D1, D2, D3, D4; 0, 1, 2, 3, 4);
436impl_kernel_parameters!(D0, D1, D2, D3, D4, D5; 0, 1, 2, 3, 4, 5);
437impl_kernel_parameters!(D0, D1, D2, D3, D4, D5, D6; 0, 1, 2, 3, 4, 5, 6);
438impl_kernel_parameters!(D0, D1, D2, D3, D4, D5, D6, D7; 0, 1, 2, 3, 4, 5, 6, 7);
439impl_kernel_parameters!(D0, D1, D2, D3, D4, D5, D6, D7, D8; 0, 1, 2, 3, 4, 5, 6, 7, 8);
440impl_kernel_parameters!(D0, D1, D2, D3, D4, D5, D6, D7, D8, D9; 0, 1, 2, 3, 4, 5, 6, 7, 8, 9);
441impl_kernel_parameters!(D0, D1, D2, D3, D4, D5, D6, D7, D8, D9, D10; 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
442impl_kernel_parameters!(D0, D1, D2, D3, D4, D5, D6, D7, D8, D9, D10, D11; 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);
443
444pub trait Launchable<'arg> {
471 type Args: Arguments<'arg>;
476
477 fn get_kernel(&self) -> Result<Kernel>;
478
479 fn launch<G: Into<Grid>, B: Into<Block>>(
495 &self,
496 grid: G,
497 block: B,
498 args: &Self::Args,
499 ) -> Result<()> {
500 let grid = grid.into();
501 let block = block.into();
502 let kernel = self.get_kernel()?;
503 let mut params = args.kernel_params();
504 unsafe {
505 contexted_call!(
506 &kernel.get_context(),
507 cuLaunchKernel,
508 kernel.func,
509 grid.x,
510 grid.y,
511 grid.z,
512 block.x,
513 block.y,
514 block.z,
515 0, null_mut(), params.as_mut_ptr(),
518 null_mut() )?;
520 }
521 kernel.sync_context()?;
522 Ok(())
523 }
524}
525
526#[derive(Debug)]
528pub struct Module {
529 module: CUmodule,
530 context: Arc<Context>,
531}
532
533impl Drop for Module {
534 fn drop(&mut self) {
535 if let Err(e) = unsafe { contexted_call!(&self.get_context(), cuModuleUnload, self.module) }
536 {
537 log::error!("Failed to unload module: {:?}", e);
538 }
539 }
540}
541
542impl Contexted for Module {
543 fn get_context(&self) -> Arc<Context> {
544 self.context.clone()
545 }
546}
547
548impl Module {
549 pub fn load(context: Arc<Context>, data: &Instruction) -> Result<Self> {
551 match *data {
552 Instruction::PTX(ref ptx) => {
553 let module = unsafe {
554 contexted_new!(&context, cuModuleLoadData, ptx.as_ptr() as *const _)?
555 };
556 Ok(Module { module, context })
557 }
558 Instruction::Cubin(ref bin) => {
559 let module = unsafe {
560 contexted_new!(&context, cuModuleLoadData, bin.as_ptr() as *const _)?
561 };
562 Ok(Module { module, context })
563 }
564 Instruction::PTXFile(ref path) | Instruction::CubinFile(ref path) => {
565 let filename = path_to_cstring(path);
566 let module = unsafe { contexted_new!(&context, cuModuleLoad, filename.as_ptr())? };
567 Ok(Module { module, context })
568 }
569 }
570 }
571
572 pub fn from_str(context: Arc<Context>, ptx: &str) -> Result<Self> {
573 let data = Instruction::ptx(ptx);
574 Self::load(context, &data)
575 }
576
577 pub fn get_kernel(&self, name: &str) -> Result<Kernel> {
579 let name = CString::new(name).expect("Invalid Kernel name");
580 let func = unsafe {
581 contexted_new!(
582 &self.get_context(),
583 cuModuleGetFunction,
584 self.module,
585 name.as_ptr()
586 )
587 }?;
588 Ok(Kernel { func, module: self })
589 }
590}
591
592fn path_to_cstring(path: &Path) -> CString {
593 CString::new(path.to_str().unwrap()).expect("Invalid Path")
594}
595
596#[cfg(test)]
597mod tests {
598 use super::*;
599
600 #[test]
601 fn load_do_nothing() -> Result<()> {
602 let ptx = r#"
604 .version 3.2
605 .target sm_30
606 .address_size 64
607 .visible .entry do_nothing()
608 {
609 ret;
610 }
611 "#;
612 let device = Device::nth(0)?;
613 let ctx = device.create_context();
614 let _mod = Module::from_str(ctx, ptx)?;
615 Ok(())
616 }
617}