opencl_api/api/
buffer.rs

1/*
2 * buffer.rs - Buffer API wrappers (Part of OpenCL Runtime Layer).
3 *
4 * Copyright 2020-2021 Naman Bishnoi
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 *     http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18//!
19//! A buffer object stores a one-dimensional collection of elements.
20//! Elements of a buffer object can be a scalar data type (such as an int, float), vector data type, or a user-defined structure.
21//!
22use crate::objects::bitfields::{MapFlags, MemFlags};
23use crate::objects::enums::ParamValue;
24use crate::objects::functions::status_update;
25use crate::objects::structs::{BufferCreateType, StatusCode};
26use crate::objects::traits::GetSetGo;
27use crate::objects::types::{APIResult, ContextPtr, EventPtr, LongProperties, MemPtr, QueuePtr};
28use crate::objects::wrappers::*;
29use libc::c_void;
30use opencl_heads::ffi;
31use opencl_heads::types::*;
32use std::ptr;
33
34pub fn create_buffer(
35    context: &ContextPtr,
36    flags: MemFlags,
37    size: size_t,
38    host_ptr: WrappedMutablePointer<c_void>,
39) -> APIResult<MemPtr> {
40    let fn_name = "clCreateBuffer";
41    let mut status_code = StatusCode::INVALID_VALUE;
42    let mem = unsafe {
43        ffi::clCreateBuffer(
44            context.unwrap(),
45            flags.get(),
46            size,
47            host_ptr.unwrap(),
48            &mut status_code,
49        )
50    };
51    eprintln!("POINTER: {:?}", mem);
52    status_update(status_code, fn_name, MemPtr::from_ptr(mem, fn_name)?)
53}
54// #[cfg(feature = "cl_3_0")]
55pub fn create_buffer_with_properties(
56    context: ContextPtr,
57    properties: &LongProperties,
58    flags: MemFlags,
59    size: size_t,
60    host_ptr: WrappedMutablePointer<c_void>,
61) -> APIResult<MemPtr> {
62    let mut status_code = StatusCode::INVALID_VALUE;
63    let fn_name = "clCreateBufferWithProperties";
64    let properties = match properties {
65        Some(x) => x.as_ptr(),
66        None => ptr::null(),
67    };
68    let mem = unsafe {
69        ffi::clCreateBufferWithProperties(
70            context.unwrap(),
71            properties,
72            flags.get(),
73            size,
74            host_ptr.unwrap(),
75            &mut status_code,
76        )
77    };
78    status_update(status_code, fn_name, MemPtr::from_ptr(mem, fn_name)?)
79}
80
81// TODO: buffer_create_info takes cl_buffer_region as input
82/***
83typedef struct cl_buffer_region {
84    size_t    origin;
85    size_t    size;
86} cl_buffer_region;
87 */
88pub fn create_sub_buffer(
89    buffer: MemPtr,
90    flags: MemFlags,
91    buffer_create_type: cl_buffer_create_type,
92    buffer_create_info: WrappedPointer<c_void>,
93) -> APIResult<ParamValue> {
94    let fn_name = "clCreateSubBuffer";
95    match buffer_create_type {
96        BufferCreateType::REGION => {
97            let mut status_code = StatusCode::INVALID_VALUE;
98            let mem = unsafe {
99                ffi::clCreateSubBuffer(
100                    buffer.unwrap(),
101                    flags.get(),
102                    buffer_create_type,
103                    buffer_create_info.unwrap(),
104                    &mut status_code,
105                )
106            };
107            let value = status_update(status_code, fn_name, mem)?;
108            Ok(ParamValue::CPtr(value as isize))
109        }
110        _ => status_update(40404, fn_name, ParamValue::default()),
111    }
112}
113
114/******************************/
115/*Reading, Writing and Copying*/
116/*       Buffer Objects       */
117/******************************/
118pub fn enqueue_read_buffer(
119    command_queue: &QueuePtr,
120    buffer: &MemPtr,
121    blocking_read: cl_bool,
122    offset: size_t,
123    size: size_t,
124    ptr: WrappedMutablePointer<c_void>,
125    num_events_in_wait_list: cl_uint,
126    event_wait_list: WrappedPointer<cl_event>,
127) -> APIResult<EventPtr> {
128    let fn_name = "clEnqueueReadBuffer";
129    let mut event = ptr::null_mut();
130    let status_code = unsafe {
131        ffi::clEnqueueReadBuffer(
132            command_queue.unwrap(),
133            buffer.unwrap(),
134            blocking_read,
135            offset,
136            size,
137            ptr.unwrap(),
138            num_events_in_wait_list,
139            event_wait_list.unwrap(),
140            &mut event,
141        )
142    };
143    status_update(status_code, fn_name, EventPtr::from_ptr(event, fn_name)?)
144}
145pub fn enqueue_write_buffer(
146    command_queue: &QueuePtr,
147    buffer: &MemPtr,
148    blocking_write: cl_bool,
149    offset: size_t,
150    size: size_t,
151    ptr: WrappedPointer<c_void>,
152    num_events_in_wait_list: cl_uint,
153    event_wait_list: WrappedPointer<cl_event>,
154) -> APIResult<EventPtr> {
155    let mut event = ptr::null_mut();
156    let fn_name = "clEnqueueWriteBuffer";
157    let status_code = unsafe {
158        ffi::clEnqueueWriteBuffer(
159            command_queue.unwrap(),
160            buffer.unwrap(),
161            blocking_write,
162            offset,
163            size,
164            ptr.unwrap(),
165            num_events_in_wait_list,
166            event_wait_list.unwrap(),
167            &mut event,
168        )
169    };
170    status_update(status_code, fn_name, EventPtr::from_ptr(event, fn_name)?)
171}
172
173pub fn enqueue_read_buffer_rect(
174    command_queue: &QueuePtr,
175    buffer: &MemPtr,
176    blocking_read: cl_bool,
177    buffer_origin: WrappedPointer<size_t>,
178    host_origin: WrappedPointer<size_t>,
179    region: WrappedPointer<size_t>,
180    buffer_row_pitch: size_t,
181    buffer_slice_pitch: size_t,
182    host_row_pitch: size_t,
183    host_slice_pitch: size_t,
184    ptr: WrappedMutablePointer<c_void>,
185    num_events_in_wait_list: cl_uint,
186    event_wait_list: WrappedPointer<cl_event>,
187) -> APIResult<EventPtr> {
188    let mut event = ptr::null_mut();
189    let fn_name = "clEnqueueReadBufferRect";
190    let status_code = unsafe {
191        ffi::clEnqueueReadBufferRect(
192            command_queue.unwrap(),
193            buffer.unwrap(),
194            blocking_read,
195            buffer_origin.unwrap(),
196            host_origin.unwrap(),
197            region.unwrap(),
198            buffer_row_pitch,
199            buffer_slice_pitch,
200            host_row_pitch,
201            host_slice_pitch,
202            ptr.unwrap(),
203            num_events_in_wait_list,
204            event_wait_list.unwrap(),
205            &mut event,
206        )
207    };
208    status_update(status_code, fn_name, EventPtr::from_ptr(event, fn_name)?)
209}
210
211pub fn enqueue_write_buffer_rect(
212    command_queue: &QueuePtr,
213    buffer: &MemPtr,
214    blocking_write: cl_bool,
215    buffer_origin: WrappedPointer<size_t>,
216    host_origin: WrappedPointer<size_t>,
217    region: WrappedPointer<size_t>,
218    buffer_row_pitch: size_t,
219    buffer_slice_pitch: size_t,
220    host_row_pitch: size_t,
221    host_slice_pitch: size_t,
222    ptr: WrappedPointer<c_void>,
223    num_events_in_wait_list: cl_uint,
224    event_wait_list: WrappedPointer<cl_event>,
225) -> APIResult<EventPtr> {
226    let mut event = ptr::null_mut();
227    let fn_name = "clEnqueueWriteBufferRect";
228    let status_code = unsafe {
229        ffi::clEnqueueWriteBufferRect(
230            command_queue.unwrap(),
231            buffer.unwrap(),
232            blocking_write,
233            buffer_origin.unwrap(),
234            host_origin.unwrap(),
235            region.unwrap(),
236            buffer_row_pitch,
237            buffer_slice_pitch,
238            host_row_pitch,
239            host_slice_pitch,
240            ptr.unwrap(),
241            num_events_in_wait_list,
242            event_wait_list.unwrap(),
243            &mut event,
244        )
245    };
246    status_update(status_code, fn_name, EventPtr::from_ptr(event, fn_name)?)
247}
248
249pub fn enqueue_copy_buffer(
250    command_queue: &QueuePtr,
251    src_buffer: &MemPtr,
252    dst_buffer: &MemPtr,
253    src_offset: size_t,
254    dst_offset: size_t,
255    size: size_t,
256    num_events_in_wait_list: cl_uint,
257    event_wait_list: WrappedPointer<cl_event>,
258) -> APIResult<EventPtr> {
259    let mut event = ptr::null_mut();
260    let fn_name = "clEnqueueCopyBuffer";
261    let status_code = unsafe {
262        ffi::clEnqueueCopyBuffer(
263            command_queue.unwrap(),
264            src_buffer.unwrap(),
265            dst_buffer.unwrap(),
266            src_offset,
267            dst_offset,
268            size,
269            num_events_in_wait_list,
270            event_wait_list.unwrap(),
271            &mut event,
272        )
273    };
274    status_update(status_code, fn_name, EventPtr::from_ptr(event, fn_name)?)
275}
276
277pub fn enqueue_copy_buffer_rect(
278    command_queue: &QueuePtr,
279    src_buffer: &MemPtr,
280    dst_buffer: &MemPtr,
281    src_origin: WrappedPointer<size_t>,
282    dst_origin: WrappedPointer<size_t>,
283    region: WrappedPointer<size_t>,
284    src_row_pitch: size_t,
285    src_slice_pitch: size_t,
286    dst_row_pitch: size_t,
287    dst_slice_pitch: size_t,
288    num_events_in_wait_list: cl_uint,
289    event_wait_list: WrappedPointer<cl_event>,
290) -> APIResult<EventPtr> {
291    let mut event = ptr::null_mut();
292    let fn_name = "clEnqueueCopyBufferRect";
293    let status_code = unsafe {
294        ffi::clEnqueueCopyBufferRect(
295            command_queue.unwrap(),
296            src_buffer.unwrap(),
297            dst_buffer.unwrap(),
298            src_origin.unwrap(),
299            dst_origin.unwrap(),
300            region.unwrap(),
301            src_row_pitch,
302            src_slice_pitch,
303            dst_row_pitch,
304            dst_slice_pitch,
305            num_events_in_wait_list,
306            event_wait_list.unwrap(),
307            &mut event,
308        )
309    };
310    status_update(status_code, fn_name, EventPtr::from_ptr(event, fn_name)?)
311}
312
313pub fn enqueue_fill_buffer(
314    command_queue: &QueuePtr,
315    buffer: &MemPtr,
316    pattern: WrappedPointer<c_void>,
317    pattern_size: size_t,
318    offset: size_t,
319    size: size_t,
320    num_events_in_wait_list: cl_uint,
321    event_wait_list: WrappedPointer<cl_event>,
322) -> APIResult<EventPtr> {
323    let mut event = ptr::null_mut();
324    let fn_name = "clEnqueueFillBuffer";
325    let status_code = unsafe {
326        ffi::clEnqueueFillBuffer(
327            command_queue.unwrap(),
328            buffer.unwrap(),
329            pattern.unwrap(),
330            pattern_size,
331            offset,
332            size,
333            num_events_in_wait_list,
334            event_wait_list.unwrap(),
335            &mut event,
336        )
337    };
338    status_update(status_code, fn_name, EventPtr::from_ptr(event, fn_name)?)
339}
340
341pub fn enqueue_map_buffer(
342    command_queue: &QueuePtr,
343    buffer: &MemPtr,
344    blocking_map: cl_bool,
345    map_flags: MapFlags,
346    offset: size_t,
347    size: size_t,
348    num_events_in_wait_list: cl_uint,
349    event_wait_list: WrappedPointer<cl_event>,
350    region_ptr: &mut cl_mem,
351) -> APIResult<EventPtr> {
352    let mut event = ptr::null_mut();
353    let fn_name = "clEnqueueMapBuffer";
354    let mut status_code = StatusCode::INVALID_VALUE;
355    *region_ptr = unsafe {
356        ffi::clEnqueueMapBuffer(
357            command_queue.unwrap(),
358            buffer.unwrap(),
359            blocking_map,
360            map_flags.get(),
361            offset,
362            size,
363            num_events_in_wait_list,
364            event_wait_list.unwrap(),
365            &mut event,
366            &mut status_code,
367        )
368    };
369    status_update(status_code, fn_name, EventPtr::from_ptr(event, fn_name)?)
370}
371
372/************************/
373/* /\ /\ /\ /\ /\ /\ /\ */
374/*|__|__|__|__|__|__|__|*/
375/*|  |  |  |  |  |  |  |*/
376/*|  |  Unit Tests  |  |*/
377/*|__|__|__|__|__|__|__|*/
378/*|__|__|__|__|__|__|__|*/
379/************************/
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384    use crate::api::context::{create_context, release_context};
385    use crate::api::device::get_device_ids;
386    use crate::api::platform::get_platform_ids;
387    use crate::api::queue::{create_command_queue_with_properties, release_command_queue};
388    use crate::objects::bitfields::{CommandQueueProperties, DeviceType};
389    use crate::objects::property::ContextProperties;
390    // use crate::objects::structs::CommandQueueInfo;
391    use crate::objects::property::QueueProperties;
392    use crate::objects::types::{DevicePtr, PlatformPtr, WrapMutPtr, WrapPtr};
393    use std::ptr;
394
395    #[test]
396    fn test_create_buffer() {
397        let platform_ids = get_platform_ids().unwrap();
398        // let platform_id = platform_ids[0];
399        let platform_id = PlatformPtr::from_ptr(platform_ids[0], "test_fn").unwrap();
400        let device_ids =
401            get_device_ids(&platform_id, DeviceType::new(DeviceType::DEFAULT).unwrap()).unwrap();
402        assert!(0 < device_ids.len());
403        // let device_id = device_ids[0];
404        let device_id = DevicePtr::from_ptr(device_ids[0], "test_fn").unwrap();
405
406        let properties = ContextProperties.gen(Some(&platform_id), None);
407        // let properties = ContextProperties.platform(&platform_id);
408        let context = create_context(&properties, device_ids, None, WrapMutPtr::null());
409        let context = context.unwrap();
410
411        // Queue v2
412        // let properties = CommandQueueInfo.properties(
413        //     CommandQueueProperties::new(CommandQueueProperties::PROFILING_ENABLE).unwrap(),
414        // );
415        let properties = QueueProperties.gen(
416            Some(CommandQueueProperties::new(CommandQueueProperties::PROFILING_ENABLE).unwrap()),
417            None,
418        );
419        let queue =
420            create_command_queue_with_properties(&context, &device_id, &properties).unwrap();
421
422        // Start buffer test
423        let flags = MemFlags::new(MemFlags::READ_WRITE).unwrap();
424        let size = 1048576 * 4; // 4MB memory
425        let buffer_mem = create_buffer(&context, flags, size, WrapMutPtr::null()).unwrap();
426        // eprintln!("{:?}", buffer_mem);
427        assert_ne!(buffer_mem.unwrap(), ptr::null_mut());
428        // Start buffer test 2
429        //NOTE: This OpenCL API flag is not stable as there is 21.3% chance that this will return a null pointer;
430        let flags = MemFlags::new(MemFlags::READ_WRITE).unwrap()
431            + MemFlags::new(MemFlags::USE_HOST_PTR).unwrap();
432        let size = 1048576; // 2MB memory
433        let buffer_mem = create_buffer(&context, flags, size, buffer_mem).unwrap();
434        // eprintln!("{:?}", buffer_mem);
435        assert_ne!(buffer_mem.unwrap(), ptr::null_mut());
436
437        release_command_queue(queue).unwrap();
438        release_context(context).unwrap();
439    }
440
441    #[test]
442    fn test_create_sub_buffer() {
443        let platform_ids = get_platform_ids().unwrap();
444        // let platform_id = platform_ids[0];
445        let platform_id = PlatformPtr::from_ptr(platform_ids[0], "test_fn").unwrap();
446        let device_ids =
447            get_device_ids(&platform_id, DeviceType::new(DeviceType::DEFAULT).unwrap()).unwrap();
448        assert!(0 < device_ids.len());
449        // let device_id = device_ids[0];
450        let device_id = DevicePtr::from_ptr(device_ids[0], "test_fn").unwrap();
451
452        let properties = ContextProperties.gen(Some(&platform_id), None);
453        // let properties = ContextProperties.platform(&platform_id);
454        let context = create_context(&properties, device_ids, None, WrapMutPtr::null());
455        let context = context.unwrap();
456
457        // Queue v2
458        // let _ = CommandQueueInfo.properties(
459        //     CommandQueueProperties::new(
460        //         CommandQueueProperties::PROFILING_ENABLE
461        //             + CommandQueueProperties::ON_DEVICE_DEFAULT,
462        //     )
463        //     .unwrap(),
464        // );
465        // let properties = QueueProperties.gen(
466        //     Some(CommandQueueProperties::new(CommandQueueProperties::PROFILING_ENABLE).unwrap()),
467        //     None,
468        // );
469        let queue = create_command_queue_with_properties(&context, &device_id, &None).unwrap();
470
471        // Start buffer test
472        let flags = MemFlags::new(MemFlags::READ_WRITE).unwrap();
473        let size = 1048576 * 4; // 4MB memory
474        let buffer_mem = create_buffer(&context, flags.clone(), size, WrapMutPtr::null()).unwrap();
475        // eprintln!("{:?}", buffer_mem);
476        assert_ne!(buffer_mem.unwrap(), ptr::null_mut());
477        // Create sub buffer test
478        let creation_info = cl_buffer_region {
479            origin: 0,
480            size: size / 2,
481        };
482        let creation_info_ptr = WrapPtr::from(&creation_info);
483        let sub_buffer_mem = create_sub_buffer(
484            buffer_mem,
485            flags,
486            BufferCreateType::REGION,
487            creation_info_ptr,
488        )
489        .unwrap();
490        let sub_buffer_mem_ptr = sub_buffer_mem.unwrap_mut_cptr().unwrap();
491
492        const ARRAY_SIZE: usize = 1000;
493        let ones: [cl_float; ARRAY_SIZE] = [1.0; ARRAY_SIZE];
494        let write_event = enqueue_write_buffer(
495            &queue,
496            &sub_buffer_mem_ptr,
497            1,
498            0,
499            ones.len() * std::mem::size_of::<cl_float>(),
500            WrappedPointer::from_owned(ones.as_ptr() as cl_mem),
501            0,
502            WrappedPointer::null(),
503        )
504        .unwrap();
505        assert_ne!(write_event.unwrap(), ptr::null_mut());
506
507        let read_event = enqueue_read_buffer(
508            &queue,
509            &sub_buffer_mem_ptr,
510            1,
511            0,
512            ones.len() * std::mem::size_of::<cl_float>(),
513            WrapMutPtr::from_owned(ones),
514            0,
515            WrappedPointer::null(),
516        )
517        .unwrap();
518        assert_ne!(read_event.unwrap(), ptr::null_mut());
519
520        release_command_queue(queue).unwrap();
521        release_context(context).unwrap();
522    }
523}