kn_cuda_sys/wrapper/rtc/
args.rs

1use std::cmp::max;
2use std::slice;
3
4pub unsafe trait KernelArg {}
5
6/// A cuda kernel argument builder.
7///
8/// ```
9/// # use kn_cuda_sys::wrapper::rtc::args::KernelArgs;
10/// let mut args = KernelArgs::new();
11/// args.push(std::ptr::null::<f32>());
12/// args.push(16);
13/// args.push(32);
14/// args.push(1.0f32);
15/// let arg_bytes = args.finish();
16#[derive(Debug)]
17pub struct KernelArgs {
18    buffer: Vec<u8>,
19    max_alignment: usize,
20}
21
22impl KernelArgs {
23    pub fn new() -> Self {
24        Self {
25            buffer: vec![],
26            max_alignment: 1,
27        }
28    }
29
30    pub fn push<T: KernelArg>(&mut self, value: T) {
31        // handle alignment
32        let alignment = std::mem::align_of::<T>();
33        self.max_alignment = max(self.max_alignment, alignment);
34        self.pad_to(alignment);
35
36        // append bytes
37        unsafe {
38            let bytes = slice::from_raw_parts(&value as *const T as *const u8, std::mem::size_of::<T>());
39            self.buffer.extend_from_slice(bytes);
40        }
41    }
42
43    pub fn push_int(&mut self, value: i32) {
44        self.push(value)
45    }
46
47    pub fn finish(self) -> Vec<u8> {
48        // we're not supposed to pad until alignment here
49        self.buffer
50    }
51
52    fn pad_to(&mut self, alignment: usize) {
53        while self.buffer.len() % alignment != 0 {
54            self.buffer.push(0);
55        }
56    }
57}
58
59unsafe impl KernelArg for u8 {}
60
61unsafe impl KernelArg for u16 {}
62
63unsafe impl KernelArg for u32 {}
64
65unsafe impl KernelArg for u64 {}
66
67unsafe impl KernelArg for i8 {}
68
69unsafe impl KernelArg for i16 {}
70
71unsafe impl KernelArg for i32 {}
72
73unsafe impl KernelArg for i64 {}
74
75unsafe impl KernelArg for f32 {}
76
77unsafe impl KernelArg for f64 {}
78
79unsafe impl<T> KernelArg for *const T {}
80
81unsafe impl<T> KernelArg for *mut T {}