accel/
module.rs

1//! CUDA Module (i.e. loaded PTX or cubin)
2
3use crate::{contexted_call, contexted_new, device::*, error::*, *};
4use cuda::*;
5use num_traits::ToPrimitive;
6use std::{ffi::*, path::*, ptr::null_mut, sync::Arc};
7
8/// Size of Block (thread block) in [CUDA thread hierarchy]( http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programming-model )
9///
10/// Every input integer and float convert into `u32` using [ToPrimitive].
11/// If the conversion is impossible, e.g. negative or too large integers, the conversion will panics.
12///
13/// [ToPrimitive]: https://docs.rs/num-traits/0.2.11/num_traits/cast/trait.ToPrimitive.html
14///
15/// Examples
16/// --------
17///
18/// - Explicit creation
19///
20/// ```
21/// # use accel::*;
22/// let block1d = Block::x(64);
23/// assert_eq!(block1d.x, 64);
24///
25/// let block2d = Block::xy(64, 128);
26/// assert_eq!(block2d.x, 64);
27/// assert_eq!(block2d.y, 128);
28///
29/// let block3d = Block::xyz(64, 128, 256);
30/// assert_eq!(block3d.x, 64);
31/// assert_eq!(block3d.y, 128);
32/// assert_eq!(block3d.z, 256);
33/// ```
34///
35/// - From single integer (unsigned and signed)
36///
37/// ```
38/// # use accel::*;
39/// let block1d: Block = 64_usize.into();
40/// assert_eq!(block1d.x, 64);
41///
42/// let block1d: Block = 64_i32.into();
43/// assert_eq!(block1d.x, 64);
44/// ```
45///
46/// - From tuple
47///
48/// ```
49/// # use accel::*;
50/// let block1d: Block = (64,).into();
51/// assert_eq!(block1d.x, 64);
52///
53/// let block2d: Block = (64, 128).into();
54/// assert_eq!(block2d.x, 64);
55/// assert_eq!(block2d.y, 128);
56///
57/// let block3d: Block = (64, 128, 256).into();
58/// assert_eq!(block3d.x, 64);
59/// assert_eq!(block3d.y, 128);
60/// assert_eq!(block3d.z, 256);
61/// ```
62#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
63pub struct Block {
64    pub x: u32,
65    pub y: u32,
66    pub z: u32,
67}
68
69impl Block {
70    /// 1D Block
71    ///
72    /// Panic
73    /// -----
74    /// - If input values cannot convert to u32
75    pub fn x<I: ToPrimitive>(x: I) -> Self {
76        Block {
77            x: x.to_u32().expect("Cannot convert to u32"),
78            y: 1,
79            z: 1,
80        }
81    }
82
83    /// 2D Block
84    ///
85    /// Panic
86    /// -----
87    /// - If input values cannot convert to u32
88    pub fn xy<I1: ToPrimitive, I2: ToPrimitive>(x: I1, y: I2) -> Self {
89        Block {
90            x: x.to_u32().expect("Cannot convert to u32"),
91            y: y.to_u32().expect("Cannot convert to u32"),
92            z: 1,
93        }
94    }
95
96    /// 3D Block
97    ///
98    /// Panic
99    /// -----
100    /// - If input values cannot convert to u32
101    pub fn xyz<I1: ToPrimitive, I2: ToPrimitive, I3: ToPrimitive>(x: I1, y: I2, z: I3) -> Self {
102        Block {
103            x: x.to_u32().expect("Cannot convert to u32"),
104            y: y.to_u32().expect("Cannot convert to u32"),
105            z: z.to_u32().expect("Cannot convert to u32"),
106        }
107    }
108}
109
110impl<I: ToPrimitive> Into<Block> for (I,) {
111    fn into(self) -> Block {
112        Block::x(self.0)
113    }
114}
115
116impl<I1: ToPrimitive, I2: ToPrimitive> Into<Block> for (I1, I2) {
117    fn into(self) -> Block {
118        Block::xy(self.0, self.1)
119    }
120}
121
122impl<I1: ToPrimitive, I2: ToPrimitive, I3: ToPrimitive> Into<Block> for (I1, I2, I3) {
123    fn into(self) -> Block {
124        Block::xyz(self.0, self.1, self.2)
125    }
126}
127
128macro_rules! impl_into_block {
129    ($integer:ty) => {
130        impl Into<Block> for $integer {
131            fn into(self) -> Block {
132                Block::x(self)
133            }
134        }
135    };
136}
137
138impl_into_block!(u8);
139impl_into_block!(u16);
140impl_into_block!(u32);
141impl_into_block!(u64);
142impl_into_block!(u128);
143impl_into_block!(usize);
144impl_into_block!(i8);
145impl_into_block!(i16);
146impl_into_block!(i32);
147impl_into_block!(i64);
148impl_into_block!(i128);
149impl_into_block!(isize);
150
151/// Size of Grid (grid of blocks) in [CUDA thread hierarchy]( http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programming-model )
152///
153/// Every input integer and float convert into `u32` using [ToPrimitive].
154/// If the conversion is impossible, e.g. negative or too large integers, the conversion will panics.
155///
156/// [ToPrimitive]: https://docs.rs/num-traits/0.2.11/num_traits/cast/trait.ToPrimitive.html
157///
158/// Examples
159/// --------
160///
161/// - Explicit creation
162///
163/// ```
164/// # use accel::*;
165/// let grid1d = Grid::x(64);
166/// assert_eq!(grid1d.x, 64);
167///
168/// let grid2d = Grid::xy(64, 128);
169/// assert_eq!(grid2d.x, 64);
170/// assert_eq!(grid2d.y, 128);
171///
172/// let grid3d = Grid::xyz(64, 128, 256);
173/// assert_eq!(grid3d.x, 64);
174/// assert_eq!(grid3d.y, 128);
175/// assert_eq!(grid3d.z, 256);
176/// ```
177///
178/// - From single integer (unsigned and signed)
179///
180/// ```
181/// # use accel::*;
182/// let grid1d: Grid = 64_usize.into();
183/// assert_eq!(grid1d.x, 64);
184///
185/// let grid1d: Grid = 64_i32.into();
186/// assert_eq!(grid1d.x, 64);
187/// ```
188///
189/// - From tuple
190///
191/// ```
192/// # use accel::*;
193/// let grid1d: Grid = (64,).into();
194/// assert_eq!(grid1d.x, 64);
195///
196/// let grid2d: Grid = (64, 128).into();
197/// assert_eq!(grid2d.x, 64);
198/// assert_eq!(grid2d.y, 128);
199///
200/// let grid3d: Grid = (64, 128, 256).into();
201/// assert_eq!(grid3d.x, 64);
202/// assert_eq!(grid3d.y, 128);
203/// assert_eq!(grid3d.z, 256);
204/// ```
205#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
206pub struct Grid {
207    pub x: u32,
208    pub y: u32,
209    pub z: u32,
210}
211
212impl Grid {
213    /// 1D Grid
214    ///
215    /// Panic
216    /// -----
217    /// - If input values cannot convert to u32
218    pub fn x<I: ToPrimitive>(x: I) -> Self {
219        Grid {
220            x: x.to_u32().expect("Cannot convert to u32"),
221            y: 1,
222            z: 1,
223        }
224    }
225
226    /// 2D Grid
227    ///
228    /// Panic
229    /// -----
230    /// - If input values cannot convert to u32
231    pub fn xy<I1: ToPrimitive, I2: ToPrimitive>(x: I1, y: I2) -> Self {
232        Grid {
233            x: x.to_u32().expect("Cannot convert to u32"),
234            y: y.to_u32().expect("Cannot convert to u32"),
235            z: 1,
236        }
237    }
238
239    /// 3D Grid
240    ///
241    /// Panic
242    /// -----
243    /// - If input values cannot convert to u32
244    pub fn xyz<I1: ToPrimitive, I2: ToPrimitive, I3: ToPrimitive>(x: I1, y: I2, z: I3) -> Self {
245        Grid {
246            x: x.to_u32().expect("Cannot convert to u32"),
247            y: y.to_u32().expect("Cannot convert to u32"),
248            z: z.to_u32().expect("Cannot convert to u32"),
249        }
250    }
251}
252
253impl<I: ToPrimitive> Into<Grid> for (I,) {
254    fn into(self) -> Grid {
255        Grid::x(self.0)
256    }
257}
258
259impl<I1: ToPrimitive, I2: ToPrimitive> Into<Grid> for (I1, I2) {
260    fn into(self) -> Grid {
261        Grid::xy(self.0, self.1)
262    }
263}
264
265impl<I1: ToPrimitive, I2: ToPrimitive, I3: ToPrimitive> Into<Grid> for (I1, I2, I3) {
266    fn into(self) -> Grid {
267        Grid::xyz(self.0, self.1, self.2)
268    }
269}
270
271macro_rules! impl_into_grid {
272    ($integer:ty) => {
273        impl Into<Grid> for $integer {
274            fn into(self) -> Grid {
275                Grid::x(self)
276            }
277        }
278    };
279}
280
281impl_into_grid!(u8);
282impl_into_grid!(u16);
283impl_into_grid!(u32);
284impl_into_grid!(u64);
285impl_into_grid!(u128);
286impl_into_grid!(usize);
287impl_into_grid!(i8);
288impl_into_grid!(i16);
289impl_into_grid!(i32);
290impl_into_grid!(i64);
291impl_into_grid!(i128);
292impl_into_grid!(isize);
293
294/// Represent the resource of CUDA middle-IR (PTX/cubin)
295#[derive(Debug)]
296pub enum Instruction {
297    PTX(CString),
298    PTXFile(PathBuf),
299    Cubin(Vec<u8>),
300    CubinFile(PathBuf),
301}
302
303impl Instruction {
304    /// Constructor for `Instruction::PTX`
305    pub fn ptx(s: &str) -> Instruction {
306        let ptx = CString::new(s).expect("Invalid PTX string");
307        Instruction::PTX(ptx)
308    }
309
310    /// Constructor for `Instruction::Cubin`
311    pub fn cubin(sl: &[u8]) -> Instruction {
312        Instruction::Cubin(sl.to_vec())
313    }
314
315    /// Constructor for `Instruction::PTXFile`
316    pub fn ptx_file(path: &Path) -> Result<Self> {
317        if !path.exists() {
318            return Err(AccelError::FileNotFound {
319                path: path.to_owned(),
320            });
321        }
322        Ok(Instruction::PTXFile(path.to_owned()))
323    }
324
325    /// Constructor for `Instruction::CubinFile`
326    pub fn cubin_file(path: &Path) -> Result<Self> {
327        if !path.exists() {
328            return Err(AccelError::FileNotFound {
329                path: path.to_owned(),
330            });
331        }
332        Ok(Instruction::CubinFile(path.to_owned()))
333    }
334}
335
336impl Instruction {
337    /// Get type of PTX/cubin
338    pub fn input_type(&self) -> CUjitInputType {
339        match *self {
340            Instruction::PTX(_) | Instruction::PTXFile(_) => CUjitInputType_enum::CU_JIT_INPUT_PTX,
341            Instruction::Cubin(_) | Instruction::CubinFile(_) => {
342                CUjitInputType_enum::CU_JIT_INPUT_CUBIN
343            }
344        }
345    }
346}
347
348/// CUDA Kernel function
349#[derive(Debug)]
350pub struct Kernel<'module> {
351    func: CUfunction,
352    module: &'module Module,
353}
354
355impl Contexted for Kernel<'_> {
356    fn get_context(&self) -> Arc<Context> {
357        self.module.get_context()
358    }
359}
360
361/// Type which can be sent to the device as kernel argument
362///
363/// ```
364/// # use accel::*;
365/// # use std::ffi::*;
366/// let a: i32 = 10;
367/// let p = &a as *const i32;
368/// assert_eq!(
369///     DeviceSend::as_ptr(&p),
370///     &p as *const *const i32 as *const u8
371/// );
372/// assert!(std::ptr::eq(
373///     unsafe { *(DeviceSend::as_ptr(&p) as *mut *const i32) },
374///     p
375/// ));
376/// ```
377pub trait DeviceSend: Sized {
378    /// Get the address of this value
379    fn as_ptr(&self) -> *const u8 {
380        self as *const Self as *const u8
381    }
382}
383
384// Use default impl
385impl<T> DeviceSend for *mut T {}
386impl<T> DeviceSend for *const T {}
387impl DeviceSend for bool {}
388impl DeviceSend for i8 {}
389impl DeviceSend for i16 {}
390impl DeviceSend for i32 {}
391impl DeviceSend for i64 {}
392impl DeviceSend for isize {}
393impl DeviceSend for u8 {}
394impl DeviceSend for u16 {}
395impl DeviceSend for u32 {}
396impl DeviceSend for u64 {}
397impl DeviceSend for usize {}
398impl DeviceSend for f32 {}
399impl DeviceSend for f64 {}
400
401/// Arbitary number of tuple of kernel arguments
402///
403/// ```
404/// # use accel::*;
405/// # use std::ffi::*;
406/// let a: i32 = 10;
407/// let b: f32 = 1.0;
408/// assert_eq!(
409///   Arguments::kernel_params(&(&a, &b)),
410///   vec![&a as *const i32 as *mut _, &b as *const f32 as *mut _, ]
411/// );
412/// ```
413pub trait Arguments<'arg> {
414    /// Get a list of kernel parameters to be passed into [cuLaunchKernel]
415    ///
416    /// [cuLaunchKernel]: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1gb8f3dc3031b40da29d5f9a7139e52e15
417    fn kernel_params(&self) -> Vec<*mut c_void>;
418}
419
420macro_rules! impl_kernel_parameters {
421    ($($name:ident),*; $($num:tt),*) => {
422        impl<'arg, $($name : DeviceSend),*> Arguments<'arg> for ($( &'arg $name, )*) {
423            fn kernel_params(&self) -> Vec<*mut c_void> {
424                vec![$( self.$num.as_ptr() as *mut c_void ),*]
425            }
426        }
427    }
428}
429
430impl_kernel_parameters!(;);
431impl_kernel_parameters!(D0; 0);
432impl_kernel_parameters!(D0, D1; 0, 1);
433impl_kernel_parameters!(D0, D1, D2; 0, 1, 2);
434impl_kernel_parameters!(D0, D1, D2, D3; 0, 1, 2, 3);
435impl_kernel_parameters!(D0, D1, D2, D3, D4; 0, 1, 2, 3, 4);
436impl_kernel_parameters!(D0, D1, D2, D3, D4, D5; 0, 1, 2, 3, 4, 5);
437impl_kernel_parameters!(D0, D1, D2, D3, D4, D5, D6; 0, 1, 2, 3, 4, 5, 6);
438impl_kernel_parameters!(D0, D1, D2, D3, D4, D5, D6, D7; 0, 1, 2, 3, 4, 5, 6, 7);
439impl_kernel_parameters!(D0, D1, D2, D3, D4, D5, D6, D7, D8; 0, 1, 2, 3, 4, 5, 6, 7, 8);
440impl_kernel_parameters!(D0, D1, D2, D3, D4, D5, D6, D7, D8, D9; 0, 1, 2, 3, 4, 5, 6, 7, 8, 9);
441impl_kernel_parameters!(D0, D1, D2, D3, D4, D5, D6, D7, D8, D9, D10; 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
442impl_kernel_parameters!(D0, D1, D2, D3, D4, D5, D6, D7, D8, D9, D10, D11; 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);
443
444/// Typed CUDA Kernel launcher
445///
446/// This will be automatically implemented in [accel_derive::kernel] for autogenerated wrapper
447/// module of [Module].
448///
449/// ```
450/// #[accel_derive::kernel]
451/// fn f(a: i32) {}
452/// ```
453///
454/// will create a submodule `f`:
455///
456/// ```
457/// mod f {
458///     pub const PTX_STR: &str = "PTX string generated by rustc/nvptx64-nvidia-cuda";
459///     pub struct Module(::accel::Module);
460///     /* impl Module { ... } */
461///     /* impl Launchable for Module { ... } */
462/// }
463/// ```
464///
465/// Implementation of `Launchable` for `f::Module` is also generated by [accel_derive::kernel]
466/// proc-macro.
467///
468/// [accel_derive::kernel]: https://docs.rs/accel-derive/0.3.0-alpha.1/accel_derive/attr.kernel.html
469/// [Module]: struct.Module.html
470pub trait Launchable<'arg> {
471    /// Arguments for the kernel to be launched.
472    /// This must be a tuple of [DeviceSend] types.
473    ///
474    /// [DeviceSend]: trait.DeviceSend.html
475    type Args: Arguments<'arg>;
476
477    fn get_kernel(&self) -> Result<Kernel>;
478
479    /// Launch CUDA Kernel synchronously
480    ///
481    /// ```
482    /// use accel::*;
483    ///
484    /// #[accel_derive::kernel]
485    /// fn f(a: i32) {}
486    ///
487    /// let device = Device::nth(0)?;
488    /// let ctx = device.create_context();
489    /// let module = f::Module::new(ctx)?;
490    /// let a = 12;
491    /// module.launch((1,) /* grid */, (4,) /* block */, &(&a,))?; // wait until kernel execution ends
492    /// # Ok::<(), ::accel::error::AccelError>(())
493    /// ```
494    fn launch<G: Into<Grid>, B: Into<Block>>(
495        &self,
496        grid: G,
497        block: B,
498        args: &Self::Args,
499    ) -> Result<()> {
500        let grid = grid.into();
501        let block = block.into();
502        let kernel = self.get_kernel()?;
503        let mut params = args.kernel_params();
504        unsafe {
505            contexted_call!(
506                &kernel.get_context(),
507                cuLaunchKernel,
508                kernel.func,
509                grid.x,
510                grid.y,
511                grid.z,
512                block.x,
513                block.y,
514                block.z,
515                0,          /* FIXME: no shared memory */
516                null_mut(), /* use default stream */
517                params.as_mut_ptr(),
518                null_mut() /* no extra */
519            )?;
520        }
521        kernel.sync_context()?;
522        Ok(())
523    }
524}
525
526/// OOP-like wrapper of `cuModule*` APIs
527#[derive(Debug)]
528pub struct Module {
529    module: CUmodule,
530    context: Arc<Context>,
531}
532
533impl Drop for Module {
534    fn drop(&mut self) {
535        if let Err(e) = unsafe { contexted_call!(&self.get_context(), cuModuleUnload, self.module) }
536        {
537            log::error!("Failed to unload module: {:?}", e);
538        }
539    }
540}
541
542impl Contexted for Module {
543    fn get_context(&self) -> Arc<Context> {
544        self.context.clone()
545    }
546}
547
548impl Module {
549    /// integrated loader of Instruction
550    pub fn load(context: Arc<Context>, data: &Instruction) -> Result<Self> {
551        match *data {
552            Instruction::PTX(ref ptx) => {
553                let module = unsafe {
554                    contexted_new!(&context, cuModuleLoadData, ptx.as_ptr() as *const _)?
555                };
556                Ok(Module { module, context })
557            }
558            Instruction::Cubin(ref bin) => {
559                let module = unsafe {
560                    contexted_new!(&context, cuModuleLoadData, bin.as_ptr() as *const _)?
561                };
562                Ok(Module { module, context })
563            }
564            Instruction::PTXFile(ref path) | Instruction::CubinFile(ref path) => {
565                let filename = path_to_cstring(path);
566                let module = unsafe { contexted_new!(&context, cuModuleLoad, filename.as_ptr())? };
567                Ok(Module { module, context })
568            }
569        }
570    }
571
572    pub fn from_str(context: Arc<Context>, ptx: &str) -> Result<Self> {
573        let data = Instruction::ptx(ptx);
574        Self::load(context, &data)
575    }
576
577    /// Wrapper of `cuModuleGetFunction`
578    pub fn get_kernel(&self, name: &str) -> Result<Kernel> {
579        let name = CString::new(name).expect("Invalid Kernel name");
580        let func = unsafe {
581            contexted_new!(
582                &self.get_context(),
583                cuModuleGetFunction,
584                self.module,
585                name.as_ptr()
586            )
587        }?;
588        Ok(Kernel { func, module: self })
589    }
590}
591
592fn path_to_cstring(path: &Path) -> CString {
593    CString::new(path.to_str().unwrap()).expect("Invalid Path")
594}
595
596#[cfg(test)]
597mod tests {
598    use super::*;
599
600    #[test]
601    fn load_do_nothing() -> Result<()> {
602        // generated by do_nothing example in accel-derive
603        let ptx = r#"
604        .version 3.2
605        .target sm_30
606        .address_size 64
607        .visible .entry do_nothing()
608        {
609          ret;
610        }
611        "#;
612        let device = Device::nth(0)?;
613        let ctx = device.create_context();
614        let _mod = Module::from_str(ctx, ptx)?;
615        Ok(())
616    }
617}