open_cl_low_level/
command_queue.rs

1use crate::ffi::{
2    clCreateCommandQueue, clEnqueueNDRangeKernel, clEnqueueReadBuffer, clEnqueueWriteBuffer,
3    clFinish, clGetCommandQueueInfo, cl_bool, cl_command_queue, cl_command_queue_info,
4    cl_command_queue_properties, cl_context, cl_device_id, cl_event, cl_kernel, cl_mem
5};
6
7use crate::cl_helpers::cl_get_info5;
8use crate::CommandQueueInfo as CQInfo;
9use crate::{
10    build_output, BufferReadEvent, ClContext, ClDeviceID, ClEvent, ClKernel, ClMem,
11    ClNumber, ClPointer, CommandQueueInfo, CommandQueueProperties, ContextPtr, DevicePtr,
12    EventPtr, GlobalWorkSize, KernelPtr, LocalWorkSize, MemPtr, MutVecOrSlice, Output,
13    VecOrSlice, Waitlist, WaitlistSizeAndPtr, Work, BufferCreator, ObjectWrapper,
14};
15
16#[derive(Debug, Clone)]
17pub struct CommandQueueOptions {
18    pub is_blocking: bool,
19    pub offset: usize,
20    pub waitlist: Vec<ClEvent>,
21}
22
23impl Default for CommandQueueOptions {
24    /// Default constructor for CommandQueueOptions.
25    fn default() -> CommandQueueOptions {
26        CommandQueueOptions {
27            is_blocking: true,
28            offset: 0,
29            waitlist: vec![],
30        }
31    }
32}
33
34impl From<Option<CommandQueueOptions>> for CommandQueueOptions {
35    fn from(maybe_cq_opts: Option<CommandQueueOptions>) -> CommandQueueOptions {
36        maybe_cq_opts.unwrap_or(CommandQueueOptions::default())
37    }
38}
39
40unsafe impl Waitlist for CommandQueueOptions {
41    /// Fill waitlist extends the waitlist from the CommandQueueOptions' waitlist.
42    unsafe fn fill_waitlist(&self, waitlist: &mut Vec<cl_event>) {
43        waitlist.extend(self.new_waitlist());
44    }
45
46    /// Creates a waitlist Vec<cl_event> for using in OpenCL FFI.
47    unsafe fn new_waitlist(&self) -> Vec<cl_event> {
48        self.waitlist.iter().map(|evt| evt.event_ptr()).collect()
49    }
50}
51
52/// Creates a new cl_command_queue.
53///
54/// # Safety
55/// Usage of an invalid ClObject is undefined behavior.
56pub unsafe fn cl_create_command_queue(
57    context: cl_context,
58    device: cl_device_id,
59    flags: cl_command_queue_properties,
60) -> Output<cl_command_queue> {
61    device.usability_check()?;
62    let mut err_code = 0;
63    let command_queue = clCreateCommandQueue(context, device, flags, &mut err_code);
64    build_output(command_queue, err_code)
65}
66
67/// Blocks until all previously queued tasks are finished.
68///
69/// # Safety
70/// Usage of an invalid ClObject is undefined behavior.
71pub unsafe fn cl_finish(command_queue: cl_command_queue) -> Output<()> {
72    build_output((), clFinish(command_queue))
73}
74
75/// Queues an n-dimensionally ranged kernel to be executed.
76///
77/// Blocks until the kernel is finished.
78///
79/// # Safety
80/// Usage of an invalid ClObject is undefined behavior.
81pub unsafe fn cl_enqueue_nd_range_kernel<W: Waitlist>(
82    queue: cl_command_queue,
83    kernel: cl_kernel,
84    work: &Work,
85    waitlist: W,
86) -> Output<cl_event> {
87    let mut tracking_event: cl_event = new_tracking_event();
88    let event_waitlist = waitlist.new_waitlist();
89    let wl = event_waitlist.as_slice();
90
91    let gws: GlobalWorkSize = work.global_work_size()?;
92    let lws: LocalWorkSize = work.local_work_size()?;
93    let err_code = clEnqueueNDRangeKernel(
94        queue,
95        kernel,
96        work.work_dims(),
97        work.global_work_offset().as_ptr(),
98        gws.as_ptr(),
99        lws.as_ptr(),
100        wl.waitlist_len(),
101        wl.waitlist_ptr(),
102        &mut tracking_event,
103    );
104
105    build_output((), err_code)?;
106    cl_finish(queue)?;
107
108    // TODO: Remove this check when Event checks for null pointer
109    debug_assert!(!tracking_event.is_null());
110    Ok(tracking_event)
111}
112
113fn new_tracking_event() -> cl_event {
114    std::ptr::null_mut() as cl_event
115}
116
117pub unsafe fn cl_enqueue_read_buffer<T>(
118    queue: cl_command_queue,
119    mem: cl_mem,
120    buffer: &mut [T],
121    command_queue_opts: CommandQueueOptions,
122) -> Output<cl_event>
123where
124    T: ClNumber,
125{
126    let mut tracking_event = new_tracking_event();
127    let waitlist = command_queue_opts.new_waitlist();
128    let wl = waitlist.as_slice();
129
130    // TODO: Make this a Error returning check
131    // debug_assert!(buffer.len() == device_mem.len());
132
133    let err_code = clEnqueueReadBuffer(
134        queue,
135        mem,
136        command_queue_opts.is_blocking as cl_bool,
137        command_queue_opts.offset,
138        buffer.buffer_byte_size(),
139        buffer.buffer_ptr(),
140        wl.waitlist_len(),
141        wl.waitlist_ptr(),
142        &mut tracking_event,
143    );
144    build_output(tracking_event, err_code)
145}
146
147pub unsafe fn cl_enqueue_write_buffer<T: ClNumber>(
148    queue: cl_command_queue,
149    mem: cl_mem,
150    buffer: &[T],
151    command_queue_opts: CommandQueueOptions,
152) -> Output<cl_event> {
153    let mut tracking_event = new_tracking_event();
154
155    let waitlist = command_queue_opts.new_waitlist();
156    let wl = waitlist.as_slice();
157    
158    let err_code = clEnqueueWriteBuffer(
159        queue,
160        mem,
161        command_queue_opts.is_blocking as cl_bool,
162        command_queue_opts.offset,
163        buffer.buffer_byte_size(),
164        buffer.buffer_ptr(),
165        wl.waitlist_len(),
166        wl.waitlist_ptr(),
167        &mut tracking_event,
168    );
169
170    build_output(tracking_event, err_code)
171}
172
173pub unsafe fn cl_get_command_queue_info<T: Copy>(
174    command_queue: cl_command_queue,
175    flag: CommandQueueInfo,
176) -> Output<ClPointer<T>> {
177    cl_get_info5(
178        command_queue,
179        flag as cl_command_queue_info,
180        clGetCommandQueueInfo,
181    )
182}
183
184pub unsafe trait CommandQueuePtr: Sized {
185    unsafe fn command_queue_ptr(&self) -> cl_command_queue;
186
187    fn address(&self) -> String {
188        format!("{:?}", unsafe { self.command_queue_ptr() })
189    }
190
191    unsafe fn info<T: Copy>(&self, flag: CQInfo) -> Output<ClPointer<T>> {
192        cl_get_command_queue_info(self.command_queue_ptr(), flag.into())
193    }
194
195    unsafe fn cl_context(&self) -> Output<cl_context> {
196        self.info(CQInfo::Context).map(|cl_ptr| cl_ptr.into_one())
197    }
198
199    unsafe fn context(&self) -> Output<ClContext> {
200        self.cl_context().and_then(|obj| ClContext::retain_new(obj))
201    }
202
203    unsafe fn cl_device_id(&self) -> Output<cl_device_id> {
204        self.info(CQInfo::Device).map(|cl_ptr| cl_ptr.into_one())
205    }
206
207    unsafe fn device(&self) -> Output<ClDeviceID> {
208        self.cl_device_id()
209            .and_then(|obj| ClDeviceID::retain_new(obj))
210    }
211
212    unsafe fn reference_count(&self) -> Output<u32> {
213        self.info(CQInfo::ReferenceCount).map(|ret| ret.into_one())
214    }
215
216    unsafe fn cl_command_queue_properties(&self) -> Output<cl_command_queue_properties> {
217        self.info::<cl_command_queue_properties>(CQInfo::Properties)
218            .map(|ret| ret.into_one())
219    }
220
221    unsafe fn properties(&self) -> Output<CommandQueueProperties> {
222        self.cl_command_queue_properties().map(|props| {
223            CommandQueueProperties::from_bits(props).unwrap_or_else(|| {
224                panic!("Failed to convert cl_command_queue_properties");
225            })
226        })
227    }
228}
229
230unsafe impl CommandQueuePtr for ObjectWrapper<cl_command_queue> {
231    unsafe fn command_queue_ptr(&self) -> cl_command_queue {
232        self.cl_object()
233    }
234}
235
236pub type ClCommandQueue = ObjectWrapper<cl_command_queue>;
237
238impl ObjectWrapper<cl_command_queue> {
239    /// Create a new ClCommandQueue in the given ClContext on the given
240    /// ClDeviceID with the given CommandQueueProperties (optional).
241    ///
242    /// # Safety
243    /// Calling this function with an invalid ClContext or ClDeviceID
244    /// is undefined behavior.
245    pub unsafe fn create(
246        context: &ClContext,
247        device: &ClDeviceID,
248        opt_props: Option<CommandQueueProperties>,
249    ) -> Output<ClCommandQueue> {
250        let properties = match opt_props {
251            None => CommandQueueProperties::PROFILING_ENABLE,
252            Some(prop) => prop,
253        };
254        ClCommandQueue::create_from_raw_pointers(
255            context.context_ptr(),
256            device.device_ptr(),
257            properties.bits() as cl_command_queue_properties,
258        )
259    }
260
261    /// Creates a ClCommandQueue from raw ClObject pointers.
262    ///
263    /// # Safety
264    /// Passing an invalid ClObject is undefined behavior.
265    pub unsafe fn create_from_raw_pointers(
266        context: cl_context,
267        device: cl_device_id,
268        props: cl_command_queue_properties,
269    ) -> Output<ClCommandQueue> {
270        let cq_object = cl_create_command_queue(context, device, props)?;
271        ClCommandQueue::new(cq_object)
272    }
273
274    /// Creates a copy of a ClCommandQueue. The copy is, in fact, a completely differnt
275    /// ClCommandQueue that has the same cl_context and cl_device_id as the original.
276    ///
277    /// # Safety
278    /// Calling this function on an invalid ClCommandQueue is undefined behavior.
279    pub unsafe fn create_copy(&self) -> Output<ClCommandQueue> {
280        let context = self.cl_context()?;
281        let device = self.cl_device_id()?;
282        let props = self.cl_command_queue_properties()?;
283        ClCommandQueue::create_from_raw_pointers(context, device, props)
284    }
285
286    /// write_buffer is used to move data from the host buffer (buffer: &[T]) to
287    /// the mutable OpenCL cl_mem pointer.
288    pub unsafe fn write_buffer<'a, T: ClNumber, H: Into<VecOrSlice<'a, T>>>(
289        &mut self,
290        mem: &mut ClMem,
291        host_buffer: H,
292        opts: Option<CommandQueueOptions>,
293    ) -> Output<ClEvent> {
294        match host_buffer.into() {
295            VecOrSlice::Slice(hb) => self.write_buffer_from_slice(mem, hb, opts),
296            VecOrSlice::Vec(hb) => self.write_buffer_from_slice(mem, &hb[..], opts),
297        }
298    }
299
300    /// Copies data to a ClMem buffer from a host slice of T.
301    unsafe fn write_buffer_from_slice<'a, T: ClNumber>(
302        &mut self,
303        mem: &mut ClMem,
304        host_buffer: &[T],
305        opts: Option<CommandQueueOptions>,
306    ) -> Output<ClEvent> {
307        let event = cl_enqueue_write_buffer(
308            self.command_queue_ptr(),
309            mem.mem_ptr(),
310            host_buffer,
311            opts.into(),
312        )?;
313        ClEvent::new(event)
314    }
315
316    /// Copies data from a ClMem<T> buffer to a &mut [T] or mut Vec<T>.
317    pub unsafe fn read_buffer<'a, T: ClNumber, H: Into<MutVecOrSlice<'a, T>>>(
318        &mut self,
319        mem: &ClMem,
320        host_buffer: H,
321        opts: Option<CommandQueueOptions>,
322    ) -> Output<BufferReadEvent<T>> {
323        match host_buffer.into() {
324            MutVecOrSlice::Slice(slc) => {
325                let event = self.read_buffer_into_slice(mem, slc, opts)?;
326                Ok(BufferReadEvent::new(event, None))
327            }
328            MutVecOrSlice::Vec(mut hb) => {
329                let event = self.read_buffer_into_slice(mem, &mut hb[..], opts)?;
330                Ok(BufferReadEvent::new(event, Some(hb)))
331            }
332        }
333    }
334
335    /// Copies data from a ClMem<T> buffer to a &mut [T].
336    unsafe fn read_buffer_into_slice<T: ClNumber>(
337        &mut self,
338        mem: &ClMem,
339        host_buffer: &mut [T],
340        opts: Option<CommandQueueOptions>,
341    ) -> Output<ClEvent> {
342        assert_eq!(mem.len().unwrap(), host_buffer.len());
343        let raw_event = cl_enqueue_read_buffer(
344            self.command_queue_ptr(),
345            mem.mem_ptr(),
346            host_buffer,
347            opts.into(),
348        )?;
349        ClEvent::new(raw_event)
350    }
351
352    /// Enqueues a ClKernel onto a the ClCommandQueue.
353    ///
354    /// # Safety
355    /// Usage of invalid ClObjects is undefined behavior.
356    pub unsafe fn enqueue_kernel(
357        &mut self,
358        kernel: &mut ClKernel,
359        work: &Work,
360        opts: Option<CommandQueueOptions>,
361    ) -> Output<ClEvent> {
362        let cq_opts: CommandQueueOptions = opts.into();
363        let event = cl_enqueue_nd_range_kernel(
364            self.command_queue_ptr(),
365            kernel.kernel_ptr(),
366            work,
367            &cq_opts.waitlist[..],
368        )?;
369        ClEvent::new(event)
370    }
371
372    pub unsafe fn finish(&mut self) -> Output<()> {
373        cl_finish(self.cl_object())
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use crate::*;
380
381    
382    #[test]
383    fn command_queue_can_be_created() {
384        let (context, devices) = ll_testing::get_context();
385        for d in devices.iter() {
386            let cq = unsafe { ClCommandQueue::create(&context, d, None).unwrap() };
387            // ensure the compiler does not optimize away.
388            let addr = cq.address();
389            assert!(addr.len() > 0);
390        }
391    }
392
393    #[test]
394    fn address_works() {
395        let (cqs, _context, _devices) = ll_testing::get_command_queues();
396        for cq in cqs.iter() {
397            let addr: String = cq.address();
398            let expected: String = format!("{:?}", unsafe { cq.command_queue_ptr() });
399            assert_eq!(addr, expected);
400        }
401    }
402
403    #[test]
404    fn context_works() {
405        let (cqs, context, _devices) = ll_testing::get_command_queues();
406        for cq in cqs.iter() {
407            let queue_ctx = unsafe { cq.context() }.unwrap();
408            assert_eq!(queue_ctx, context);
409        }
410    }
411
412    #[test]
413    fn device_works() {
414        let (cqs, _context, devices) = ll_testing::get_command_queues();
415        for (cq, device) in cqs.iter().zip(devices.iter()) {
416            let queue_device = unsafe { cq.device() }.unwrap();
417            assert_eq!(&queue_device, device);
418        }
419    }
420
421    #[test]
422    fn reference_count_works() {
423        let (cqs, _context, _devices) = ll_testing::get_command_queues();
424        for cq in cqs.iter() {
425            let ref_count = unsafe { cq.reference_count() }.unwrap();
426            assert_eq!(ref_count, 1);
427        }
428    }
429
430    #[test]
431    fn properties_works() {
432        let (cqs, _context, _devices) = ll_testing::get_command_queues();
433        for cq in cqs.iter() {
434            let props = unsafe { cq.properties() }.unwrap();
435            assert_eq!(props, CommandQueueProperties::PROFILING_ENABLE);
436        }
437    }
438
439    #[test]
440    fn create_copy_works() {
441        let (cqs, _context, _devices) = ll_testing::get_command_queues();
442        for cq in cqs.iter() {
443            unsafe {
444                let copied_cq = cq.create_copy().unwrap();
445                assert_eq!(copied_cq.context().unwrap(), cq.context().unwrap());
446                assert_eq!(copied_cq.device().unwrap(), cq.device().unwrap());
447                assert_ne!(copied_cq.command_queue_ptr(), cq.command_queue_ptr());
448            }
449        }
450    }
451
452    #[test]
453    fn buffer_can_be_written_and_waited() {
454        let (mut cqs, context, _devices) = ll_testing::get_command_queues();
455        let mut data = vec![0u8, 1, 2, 3, 4, 5, 6, 7];
456        let mut buffer = ll_testing::mem_from_data_and_context::<u8>(&mut data, &context);
457        for cq in cqs.iter_mut() {
458            unsafe {
459                let event = cq.write_buffer(&mut buffer, &data[..], None).unwrap();
460                event.wait().unwrap();
461            }
462        }
463    }
464
465    #[test]
466    fn buffer_vec_can_be_read_and_waited() {
467        let (mut cqs, context, _devices) = ll_testing::get_command_queues();
468        let mut data = vec![0u8, 1, 2, 3, 4, 5, 6, 7];
469        let buffer = ll_testing::mem_from_data_and_context(&mut data, &context);
470        let data_ref = &data;
471        for cq in cqs.iter_mut() {
472            unsafe {
473                let mut event = cq.read_buffer(&buffer, data_ref.clone(), None).unwrap();
474                let data2: Option<Vec<u8>> = event.wait().unwrap();
475                assert_eq!(data2, Some(data_ref.clone()));
476            }
477        }
478    }
479
480    #[test]
481    fn buffer_slice_can_be_read_and_waited() {
482        let (mut cqs, context, _devices) = ll_testing::get_command_queues();
483        let mut data = vec![0u8, 1, 2, 3, 4, 5, 6, 7];
484        let buffer = ll_testing::mem_from_data_and_context(&mut data, &context);
485
486        for cq in cqs.iter_mut() {
487            unsafe {
488                let mut data2 = vec![0u8, 0, 0, 0, 0, 0, 0, 0];
489                let mut event = cq.read_buffer(&buffer, &mut data2[..], None).unwrap();
490                let data3 = event.wait();
491                assert_eq!(data3, Ok(None));
492            }
493        }
494    }
495}