1use std::convert::TryInto;
2
3use crate::cl_helpers::{cl_get_info5, cl_get_info6};
4use crate::ffi::{
5 clBuildProgram, clCreateProgramWithBinary, clCreateProgramWithSource, clGetProgramBuildInfo,
6 clGetProgramInfo, cl_context, cl_device_id, cl_program, cl_program_build_info, cl_program_info,
7};
8use crate::{
9 build_output, strings, ClContext, ClDeviceID, ClPointer, ContextPtr, DevicePtr, Error,
10 Output, ProgramBuildInfo, ProgramInfo, ObjectWrapper
11};
12
13pub const DEVICE_LIST_CANNOT_BE_EMPTY: Error =
14 Error::ProgramError(ProgramError::CannotBuildProgramWithEmptyDevicesList);
15
16#[derive(Debug, Fail, PartialEq, Eq, Clone)]
18pub enum ProgramError {
19 #[fail(display = "The given source code was not a valid CString")]
20 CStringInvalidSourceCode,
21
22 #[fail(display = "The given program binary was not a valid CString")]
23 CStringInvalidProgramBinary,
24
25 #[fail(display = "Cannot build a program with an empty list of devices")]
26 CannotBuildProgramWithEmptyDevicesList,
27}
28
29#[allow(clippy::transmuting_null)]
35#[allow(unused_mut)]
36pub unsafe fn cl_build_program(program: cl_program, device_ids: &[cl_device_id]) -> Output<()> {
37 let err_code = clBuildProgram(
38 program,
39 1u32,
40 device_ids.as_ptr() as *const cl_device_id,
41 std::ptr::null(),
42 std::mem::transmute(std::ptr::null::<fn()>()), std::ptr::null_mut(), );
45 build_output((), err_code)
46}
47
48pub unsafe fn cl_get_program_build_log(
53 program: cl_program,
54 device: cl_device_id,
55 info_flag: cl_program_build_info,
56) -> Output<ClPointer<u8>> {
57 device.usability_check()?;
58 cl_get_info6(program, device, info_flag, clGetProgramBuildInfo)
59}
60
61pub unsafe fn cl_create_program_with_source(context: cl_context, src: &str) -> Output<cl_program> {
67 let src = strings::to_c_string(src).ok_or_else(|| ProgramError::CStringInvalidSourceCode)?;
68 let mut src_list = vec![src.as_ptr()];
69
70 let mut err_code = 0;
71 let program: cl_program = clCreateProgramWithSource(
72 context,
73 src_list.len().try_into().unwrap(),
74 src_list.as_mut_ptr() as *mut *const libc::c_char,
77 std::ptr::null(),
80 &mut err_code,
81 );
82 build_output(program, err_code)
83}
84
85#[allow(clippy::cast_ptr_alignment)]
93pub unsafe fn cl_create_program_with_binary(
94 context: cl_context,
95 device: cl_device_id,
96 binary: &[u8],
97) -> Output<cl_program> {
98 device.usability_check()?;
99 let mut err_code = 0;
100 let program = clCreateProgramWithBinary(
101 context,
102 1,
103 device as *const cl_device_id,
104 binary.len() as *const libc::size_t,
105 binary.as_ptr() as *mut *const u8,
106 std::ptr::null_mut(),
107 &mut err_code,
108 );
109 build_output(program, err_code)
110}
111
112pub unsafe fn cl_get_program_info<T: Copy>(
118 program: cl_program,
119 flag: cl_program_info,
120) -> Output<ClPointer<T>> {
121 cl_get_info5(program, flag, clGetProgramInfo)
122}
123
124pub type ClProgram = ObjectWrapper<cl_program>;
125
126impl ClProgram {
127 pub unsafe fn create_with_source(context: &ClContext, src: &str) -> Output<ClProgram> {
133 let prog = cl_create_program_with_source(context.context_ptr(), src)?;
134 Ok(ClProgram::unchecked_new(prog))
135 }
136
137 pub unsafe fn create_with_binary(
143 context: &ClContext,
144 device: &ClDeviceID,
145 bin: &[u8],
146 ) -> Output<ClProgram> {
147 let prog = cl_create_program_with_binary(context.context_ptr(), device.device_ptr(), bin)?;
148 Ok(ClProgram::unchecked_new(prog))
149 }
150
151 pub fn build<D>(&mut self, devices: &[D]) -> Output<()>
152 where
153 D: DevicePtr,
154 {
155 if devices.is_empty() {
156 return Err(DEVICE_LIST_CANNOT_BE_EMPTY);
157 }
158 unsafe {
159 let device_ptrs: Vec<cl_device_id> = devices.iter().map(|d| d.device_ptr()).collect();
160 cl_build_program(self.program_ptr(), &device_ptrs[..])
161 }
162 }
163
164 pub fn get_log<D: DevicePtr>(&self, device: &D) -> Output<String> {
165 unsafe {
166 cl_get_program_build_log(
167 self.program_ptr(),
168 device.device_ptr(),
169 ProgramBuildInfo::Log.into(),
170 )
171 .map(|ret| ret.into_string())
172 }
173 }
174}
175
176unsafe impl ProgramPtr for ClProgram {
177 unsafe fn program_ptr(&self) -> cl_program {
178 self.cl_object()
179 }
180}
181
182fn get_info<T: Copy, P: ProgramPtr>(program: &P, flag: ProgramInfo) -> Output<ClPointer<T>> {
183 unsafe { cl_get_program_info(program.program_ptr(), flag.into()) }
184}
185
186pub unsafe trait ProgramPtr: Sized {
191 unsafe fn program_ptr(&self) -> cl_program;
196
197 fn reference_count(&self) -> Output<u32> {
199 get_info(self, ProgramInfo::ReferenceCount).map(|ret| unsafe { ret.into_one() })
200 }
201
202 fn num_devices(&self) -> Output<usize> {
204 get_info(self, ProgramInfo::NumDevices).map(|ret| unsafe {
205 let num32: u32 = ret.into_one();
206 num32 as usize
207 })
208 }
209
210 fn source(&self) -> Output<String> {
212 get_info(self, ProgramInfo::Source).map(|ret| unsafe { ret.into_string() })
213 }
214
215 fn binary_sizes(&self) -> Output<Vec<usize>> {
217 get_info(self, ProgramInfo::BinarySizes).map(|ret| unsafe { ret.into_vec() })
218 }
219
220 fn binaries(&self) -> Output<Vec<u8>> {
222 get_info(self, ProgramInfo::Binaries).map(|ret| unsafe { ret.into_vec() })
223 }
224
225 fn num_kernels(&self) -> Output<usize> {
227 get_info(self, ProgramInfo::NumKernels).map(|ret| unsafe { ret.into_one() })
228 }
229
230 fn kernel_names(&self) -> Output<Vec<String>> {
232 get_info(self, ProgramInfo::KernelNames).map(|ret| {
233 let kernels: String = unsafe { ret.into_string() };
234 kernels.split(';').map(|s| s.to_string()).collect()
235 })
236 }
237
238 fn devices(&self) -> Output<Vec<ClDeviceID>> {
239 get_info(self, ProgramInfo::Devices).map(|ret| unsafe {
240 ret.into_vec()
241 .into_iter()
242 .map(|d| ClDeviceID::retain_new(d).unwrap())
243 .collect()
244 })
245 }
246
247 fn context(&self) -> Output<ClContext> {
248 get_info(self, ProgramInfo::Context)
249 .and_then(|ret| unsafe { ClContext::retain_new(ret.into_one()) })
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use crate::*;
256
257 const SRC: &'static str = "
258 __kernel void test123(__global int *i) {
259 *i += 1;
260 }";
261
262 #[test]
263 fn program_ptr_reference_count() {
264 let (prog, _devices, _context) = ll_testing::get_program(SRC);
265 let ref_count = prog.reference_count().unwrap();
266 assert_eq!(ref_count, 1);
267 }
268
269 #[test]
270 fn cloning_increments_reference_count() {
271 let (prog, _devices, _context) = ll_testing::get_program(SRC);
272 let prog2 = prog.clone();
273 let prog3 = prog.clone();
274 let ref_count = prog.reference_count().unwrap();
275 assert_eq!(ref_count, 3);
276 assert_eq!(prog, prog2);
277 assert_eq!(prog, prog3);
278 }
279
280 #[test]
281 fn program_ptr_num_devices() {
282 let (prog, _devices, _context) = ll_testing::get_program(SRC);
283 let num_devices = prog.num_devices().unwrap();
284 assert!(num_devices > 0);
285 }
286
287 #[test]
288 fn program_ptr_devices() {
289 let (prog, devices, _context) = ll_testing::get_program(SRC);
290 let prog_devices = prog.devices().unwrap();
291 let num_devices = prog.num_devices().unwrap();
292 assert_eq!(num_devices, prog_devices.len());
293 assert_eq!(prog_devices.len(), devices.len());
294 }
295
296 #[test]
297 fn program_ptr_context() {
298 let (prog, _devices, context) = ll_testing::get_program(SRC);
299 let prog_context = prog.context().unwrap();
300 assert_eq!(prog_context, context);
301 }
302
303 #[test]
304 fn num_devices_matches_devices_len() {
305 let (prog, devices, _context) = ll_testing::get_program(SRC);
306 let num_devices = prog.num_devices().unwrap();
307 assert_eq!(num_devices, devices.len());
308 }
309
310 #[test]
311 fn program_ptr_source_matches_creates_src() {
312 let (prog, _devices, _context) = ll_testing::get_program(SRC);
313 let prog_src = prog.source().unwrap();
314 assert_eq!(prog_src, SRC.to_string());
315 }
316
317 #[test]
318 fn program_ptr_num_kernels() {
319 let (prog, _devices, _context) = ll_testing::get_program(SRC);
320 let num_kernels = prog.num_kernels().unwrap();
321 assert_eq!(num_kernels, 1);
322 }
323
324 #[test]
325 fn program_ptr_kernel_names() {
326 let (prog, _devices, _context) = ll_testing::get_program(SRC);
327 let kernel_names = prog.kernel_names().unwrap();
328 assert_eq!(kernel_names, vec!["test123"]);
329 }
330
331 #[test]
332 fn num_kernels_matches_kernel_names_len() {
333 let (prog, _devices, _context) = ll_testing::get_program(SRC);
334 let kernel_names = prog.kernel_names().unwrap();
335 let num_kernels = prog.num_kernels().unwrap();
336 assert_eq!(num_kernels, kernel_names.len());
337 }
338}