open_cl_low_level/
program.rs

1use std::convert::TryInto;
2
3use crate::cl_helpers::{cl_get_info5, cl_get_info6};
4use crate::ffi::{
5    clBuildProgram, clCreateProgramWithBinary, clCreateProgramWithSource, clGetProgramBuildInfo,
6    clGetProgramInfo, cl_context, cl_device_id, cl_program, cl_program_build_info, cl_program_info,
7};
8use crate::{
9    build_output, strings, ClContext, ClDeviceID, ClPointer, ContextPtr, DevicePtr, Error,
10    Output, ProgramBuildInfo, ProgramInfo, ObjectWrapper
11};
12
13pub const DEVICE_LIST_CANNOT_BE_EMPTY: Error =
14    Error::ProgramError(ProgramError::CannotBuildProgramWithEmptyDevicesList);
15
16/// An error related to Program.
17#[derive(Debug, Fail, PartialEq, Eq, Clone)]
18pub enum ProgramError {
19    #[fail(display = "The given source code was not a valid CString")]
20    CStringInvalidSourceCode,
21
22    #[fail(display = "The given program binary was not a valid CString")]
23    CStringInvalidProgramBinary,
24
25    #[fail(display = "Cannot build a program with an empty list of devices")]
26    CannotBuildProgramWithEmptyDevicesList,
27}
28
29/// A low-level helper function for calling the OpenCL FFI function clBuildProgram.
30///
31/// # Safety
32/// if the devices or the program are in an invalid state this function call results in
33/// undefined behavior.
34#[allow(clippy::transmuting_null)]
35#[allow(unused_mut)]
36pub unsafe fn cl_build_program(program: cl_program, device_ids: &[cl_device_id]) -> Output<()> {
37    let err_code = clBuildProgram(
38        program,
39        1u32,
40        device_ids.as_ptr() as *const cl_device_id,
41        std::ptr::null(),
42        std::mem::transmute(std::ptr::null::<fn()>()), // pfn_notify
43        std::ptr::null_mut(),                          // user_data
44    );
45    build_output((), err_code)
46}
47
48/// Low level helper function for clGetProgramBuildInfo.
49///
50/// # Safety
51/// If the program or device is in an invalid state this function call is undefined behavior.
52pub unsafe fn cl_get_program_build_log(
53    program: cl_program,
54    device: cl_device_id,
55    info_flag: cl_program_build_info,
56) -> Output<ClPointer<u8>> {
57    device.usability_check()?;
58    cl_get_info6(program, device, info_flag, clGetProgramBuildInfo)
59}
60
61/// Low level helper function for calling the OpenCL FFI function clCreateProgramWithSource.
62///
63/// # Safety
64/// If the context or device is in an invalid state this function will cause undefined
65/// behavior.
66pub unsafe fn cl_create_program_with_source(context: cl_context, src: &str) -> Output<cl_program> {
67    let src = strings::to_c_string(src).ok_or_else(|| ProgramError::CStringInvalidSourceCode)?;
68    let mut src_list = vec![src.as_ptr()];
69
70    let mut err_code = 0;
71    let program: cl_program = clCreateProgramWithSource(
72        context,
73        src_list.len().try_into().unwrap(),
74        // const char **strings
75        // mut pointer to const pointer of char. Great.
76        src_list.as_mut_ptr() as *mut *const libc::c_char,
77        // null pointer here indicates that all strings in the src
78        // are NULL-terminated.
79        std::ptr::null(),
80        &mut err_code,
81    );
82    build_output(program, err_code)
83}
84
85/// Low level helper function for calling the OpenCL FFI function clCreateProgramWithBinary.
86///
87/// # Safety
88/// If the context or device is in an invalid state this function will cause undefined
89/// behavior. WRT the clippy::cast_ptr_alignment below the dereferncing of the pointer
90/// happens on the _other_ _side_ of the C FFI. So it cannot be any more unsafe that
91/// in already is...
92#[allow(clippy::cast_ptr_alignment)]
93pub unsafe fn cl_create_program_with_binary(
94    context: cl_context,
95    device: cl_device_id,
96    binary: &[u8],
97) -> Output<cl_program> {
98    device.usability_check()?;
99    let mut err_code = 0;
100    let program = clCreateProgramWithBinary(
101        context,
102        1,
103        device as *const cl_device_id,
104        binary.len() as *const libc::size_t,
105        binary.as_ptr() as *mut *const u8,
106        std::ptr::null_mut(),
107        &mut err_code,
108    );
109    build_output(program, err_code)
110}
111
112/// Low level helper function for the FFI call to clGetProgramInfo
113///
114/// # Safety
115/// Calling this function with a cl_program that is not in a valid state is
116/// undefined behavior.
117pub unsafe fn cl_get_program_info<T: Copy>(
118    program: cl_program,
119    flag: cl_program_info,
120) -> Output<ClPointer<T>> {
121    cl_get_info5(program, flag, clGetProgramInfo)
122}
123
124pub type ClProgram = ObjectWrapper<cl_program>;
125
126impl ClProgram {
127    /// Creates a new ClProgram on the context and device with the given OpenCL source code.
128    ///
129    /// # Safety
130    /// The provided ClContext and ClDeviceID must be in valid state or else undefined behavior is
131    /// expected.
132    pub unsafe fn create_with_source(context: &ClContext, src: &str) -> Output<ClProgram> {
133        let prog = cl_create_program_with_source(context.context_ptr(), src)?;
134        Ok(ClProgram::unchecked_new(prog))
135    }
136
137    /// Creates a new ClProgram on the context and device with the given executable binary.
138    ///
139    /// # Safety
140    /// The provided ClContext and ClDeviceID must be in valid state or else undefined behavior is
141    /// expected.
142    pub unsafe fn create_with_binary(
143        context: &ClContext,
144        device: &ClDeviceID,
145        bin: &[u8],
146    ) -> Output<ClProgram> {
147        let prog = cl_create_program_with_binary(context.context_ptr(), device.device_ptr(), bin)?;
148        Ok(ClProgram::unchecked_new(prog))
149    }
150
151    pub fn build<D>(&mut self, devices: &[D]) -> Output<()>
152    where
153        D: DevicePtr,
154    {
155        if devices.is_empty() {
156            return Err(DEVICE_LIST_CANNOT_BE_EMPTY);
157        }
158        unsafe {
159            let device_ptrs: Vec<cl_device_id> = devices.iter().map(|d| d.device_ptr()).collect();
160            cl_build_program(self.program_ptr(), &device_ptrs[..])
161        }
162    }
163
164    pub fn get_log<D: DevicePtr>(&self, device: &D) -> Output<String> {
165        unsafe {
166            cl_get_program_build_log(
167                self.program_ptr(),
168                device.device_ptr(),
169                ProgramBuildInfo::Log.into(),
170            )
171            .map(|ret| ret.into_string())
172        }
173    }
174}
175
176unsafe impl ProgramPtr for ClProgram {
177    unsafe fn program_ptr(&self) -> cl_program {
178        self.cl_object()
179    }
180}
181
182fn get_info<T: Copy, P: ProgramPtr>(program: &P, flag: ProgramInfo) -> Output<ClPointer<T>> {
183    unsafe { cl_get_program_info(program.program_ptr(), flag.into()) }
184}
185
186/// ProgramPtr is the trait to access a cl_program for wrappers of that cl_program.
187///
188/// # Safety
189/// Direct interaction with any OpenCL pointer is unsafe so this trait is unsafe.
190pub unsafe trait ProgramPtr: Sized {
191    /// program_ptr is the trait to access a cl_program for wrappers of that cl_program.
192    ///
193    /// # Safety
194    /// Direct interaction with any OpenCL pointer is unsafe so this trait is unsafe.
195    unsafe fn program_ptr(&self) -> cl_program;
196
197    /// The OpenCL reference count of the cl_program.
198    fn reference_count(&self) -> Output<u32> {
199        get_info(self, ProgramInfo::ReferenceCount).map(|ret| unsafe { ret.into_one() })
200    }
201
202    /// The number of devices that this cl_program is built on.
203    fn num_devices(&self) -> Output<usize> {
204        get_info(self, ProgramInfo::NumDevices).map(|ret| unsafe {
205            let num32: u32 = ret.into_one();
206            num32 as usize
207        })
208    }
209
210    /// The source code String of this OpenCL program.
211    fn source(&self) -> Output<String> {
212        get_info(self, ProgramInfo::Source).map(|ret| unsafe { ret.into_string() })
213    }
214
215    /// The size of the binaries for this OpenCL program.
216    fn binary_sizes(&self) -> Output<Vec<usize>> {
217        get_info(self, ProgramInfo::BinarySizes).map(|ret| unsafe { ret.into_vec() })
218    }
219
220    /// The executable binaries for this OpenCL program.
221    fn binaries(&self) -> Output<Vec<u8>> {
222        get_info(self, ProgramInfo::Binaries).map(|ret| unsafe { ret.into_vec() })
223    }
224
225    /// The number of kernels (defined functions) in this OpenCL program.
226    fn num_kernels(&self) -> Output<usize> {
227        get_info(self, ProgramInfo::NumKernels).map(|ret| unsafe { ret.into_one() })
228    }
229
230    /// The names of the kernels (defined functions) in this OpenCL program.
231    fn kernel_names(&self) -> Output<Vec<String>> {
232        get_info(self, ProgramInfo::KernelNames).map(|ret| {
233            let kernels: String = unsafe { ret.into_string() };
234            kernels.split(';').map(|s| s.to_string()).collect()
235        })
236    }
237
238    fn devices(&self) -> Output<Vec<ClDeviceID>> {
239        get_info(self, ProgramInfo::Devices).map(|ret| unsafe {
240            ret.into_vec()
241                .into_iter()
242                .map(|d| ClDeviceID::retain_new(d).unwrap())
243                .collect()
244        })
245    }
246
247    fn context(&self) -> Output<ClContext> {
248        get_info(self, ProgramInfo::Context)
249            .and_then(|ret| unsafe { ClContext::retain_new(ret.into_one()) })
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use crate::*;
256
257    const SRC: &'static str = "
258    __kernel void test123(__global int *i) {
259        *i += 1;
260    }";
261
262    #[test]
263    fn program_ptr_reference_count() {
264        let (prog, _devices, _context) = ll_testing::get_program(SRC);
265        let ref_count = prog.reference_count().unwrap();
266        assert_eq!(ref_count, 1);
267    }
268
269    #[test]
270    fn cloning_increments_reference_count() {
271        let (prog, _devices, _context) = ll_testing::get_program(SRC);
272        let prog2 = prog.clone();
273        let prog3 = prog.clone();
274        let ref_count = prog.reference_count().unwrap();
275        assert_eq!(ref_count, 3);
276        assert_eq!(prog, prog2);
277        assert_eq!(prog, prog3);
278    }
279
280    #[test]
281    fn program_ptr_num_devices() {
282        let (prog, _devices, _context) = ll_testing::get_program(SRC);
283        let num_devices = prog.num_devices().unwrap();
284        assert!(num_devices > 0);
285    }
286
287    #[test]
288    fn program_ptr_devices() {
289        let (prog, devices, _context) = ll_testing::get_program(SRC);
290        let prog_devices = prog.devices().unwrap();
291        let num_devices = prog.num_devices().unwrap();
292        assert_eq!(num_devices, prog_devices.len());
293        assert_eq!(prog_devices.len(), devices.len());
294    }
295
296    #[test]
297    fn program_ptr_context() {
298        let (prog, _devices, context) = ll_testing::get_program(SRC);
299        let prog_context = prog.context().unwrap();
300        assert_eq!(prog_context, context);
301    }
302
303    #[test]
304    fn num_devices_matches_devices_len() {
305        let (prog, devices, _context) = ll_testing::get_program(SRC);
306        let num_devices = prog.num_devices().unwrap();
307        assert_eq!(num_devices, devices.len());
308    }
309
310    #[test]
311    fn program_ptr_source_matches_creates_src() {
312        let (prog, _devices, _context) = ll_testing::get_program(SRC);
313        let prog_src = prog.source().unwrap();
314        assert_eq!(prog_src, SRC.to_string());
315    }
316
317    #[test]
318    fn program_ptr_num_kernels() {
319        let (prog, _devices, _context) = ll_testing::get_program(SRC);
320        let num_kernels = prog.num_kernels().unwrap();
321        assert_eq!(num_kernels, 1);
322    }
323
324    #[test]
325    fn program_ptr_kernel_names() {
326        let (prog, _devices, _context) = ll_testing::get_program(SRC);
327        let kernel_names = prog.kernel_names().unwrap();
328        assert_eq!(kernel_names, vec!["test123"]);
329    }
330
331    #[test]
332    fn num_kernels_matches_kernel_names_len() {
333        let (prog, _devices, _context) = ll_testing::get_program(SRC);
334        let kernel_names = prog.kernel_names().unwrap();
335        let num_kernels = prog.num_kernels().unwrap();
336        assert_eq!(num_kernels, kernel_names.len());
337    }
338}