opencl_api/api/
memory.rs

1/*
2 * image.rs - Image API wrappers (Part of OpenCL Runtime Layer).
3 *
4 * Copyright 2020-2021 Naman Bishnoi
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 *     http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18//!
19//! Shared virtual memory (a.k.a. SVM) allows the host and kernels executing on devices to directly share complex, pointer-containing data structures such as trees and linked lists.
20//! It also eliminates the need to marshal data between the host and devices.
21//! As a result, SVM substantially simplifies OpenCL programming and may improve performance.
22//!
23use crate::objects::bitfields::{MapFlags, MemFlags, MemMigrationFlags};
24use crate::objects::enums::{ParamValue, Size};
25use crate::objects::functions::status_update;
26use crate::objects::structs::{MemInfo, StatusCode};
27use crate::objects::traits::GetSetGo;
28use crate::objects::types::{APIResult, ContextPtr, EventPtr, MemPtr, QueuePtr, SVMPtr};
29use crate::objects::wrappers::{WrappedMutablePointer, WrappedPointer};
30use crate::{gen_param_value, size_getter};
31use libc::c_void;
32use opencl_heads::ffi;
33use opencl_heads::ffi::clGetMemObjectInfo;
34use opencl_heads::types::*;
35use std::ptr;
36
37pub fn retain_mem_object(memobj: &MemPtr) -> APIResult<()> {
38    let status_code = unsafe { ffi::clRetainMemObject(memobj.unwrap()) };
39    status_update(status_code, "clRetainMemObject", ())
40}
41
42pub fn release_mem_object(memobj: MemPtr) -> APIResult<()> {
43    let status_code = unsafe { ffi::clReleaseMemObject(memobj.unwrap()) };
44    status_update(status_code, "clReleaseMemObject", ())
45}
46
47pub fn set_mem_object_destructor_callback(
48    memobj: &MemPtr,
49    pfn_notify: Option<extern "C" fn(memobj: cl_mem, user_data: *mut c_void)>,
50    user_data: WrappedMutablePointer<c_void>,
51) -> APIResult<()> {
52    let status_code = unsafe {
53        ffi::clSetMemObjectDestructorCallback(memobj.unwrap(), pfn_notify, user_data.unwrap())
54    };
55    status_update(status_code, "clSetMemObjectDestructorCallback", ())
56}
57
58pub fn enqueue_unmap_mem_object(
59    command_queue: &QueuePtr,
60    memobj: &MemPtr,
61    mapped_ptr: WrappedMutablePointer<c_void>,
62    num_events_in_wait_list: cl_uint,
63    event_wait_list: WrappedPointer<cl_event>,
64) -> APIResult<EventPtr> {
65    let fn_name = "clEnqueueUnmapMemObject";
66    let mut event_ptr = ptr::null_mut();
67    let status_code = unsafe {
68        ffi::clEnqueueUnmapMemObject(
69            command_queue.unwrap(),
70            memobj.unwrap(),
71            mapped_ptr.unwrap(),
72            num_events_in_wait_list,
73            event_wait_list.unwrap(),
74            &mut event_ptr,
75        )
76    };
77    status_update(
78        status_code,
79        fn_name,
80        EventPtr::from_ptr(event_ptr, fn_name)?,
81    )
82}
83
84pub fn enqueue_migrate_mem_objects(
85    command_queue: &QueuePtr,
86    num_mem_objects: cl_uint,
87    mem_objects: WrappedPointer<cl_mem>,
88    flags: MemMigrationFlags,
89    num_events_in_wait_list: cl_uint,
90    event_wait_list: WrappedPointer<cl_event>,
91) -> APIResult<EventPtr> {
92    let fn_name = "clEnqueueMigrateMemObjects";
93    let mut event_ptr = ptr::null_mut();
94    let status_code = unsafe {
95        ffi::clEnqueueMigrateMemObjects(
96            command_queue.unwrap(),
97            num_mem_objects,
98            mem_objects.unwrap(),
99            flags.get(),
100            num_events_in_wait_list,
101            event_wait_list.unwrap(),
102            &mut event_ptr,
103        )
104    };
105    status_update(
106        status_code,
107        fn_name,
108        EventPtr::from_ptr(event_ptr, fn_name)?,
109    )
110}
111
112pub fn get_mem_object_info(memobj: MemPtr, param_name: cl_mem_info) -> APIResult<ParamValue> {
113    let memobj = memobj.unwrap();
114    size_getter!(get_mem_object_info_size, clGetMemObjectInfo);
115    match param_name {
116        MemInfo::MAP_COUNT
117        | MemInfo::REFERENCE_COUNT
118        | MemInfo::USES_SVM_POINTER
119        | MemInfo::TYPE => {
120            let param_value = gen_param_value!(clGetMemObjectInfo, u32, memobj, param_name);
121            Ok(ParamValue::UInt(param_value))
122        }
123        MemInfo::HOST_PTR | MemInfo::CONTEXT | MemInfo::ASSOCIATED_MEMOBJECT => {
124            let param_value = gen_param_value!(clGetMemObjectInfo, isize, memobj, param_name);
125            Ok(ParamValue::CPtr(param_value))
126        }
127        MemInfo::FLAGS => {
128            let param_value = gen_param_value!(clGetMemObjectInfo, u64, memobj, param_name);
129            Ok(ParamValue::ULong(param_value))
130        }
131        MemInfo::SIZE | MemInfo::OFFSET => {
132            let param_value = gen_param_value!(clGetMemObjectInfo, usize, memobj, param_name);
133            Ok(ParamValue::CSize(param_value))
134        }
135        MemInfo::PROPERTIES => {
136            let size = get_mem_object_info_size(memobj, param_name)?;
137            let param_value = gen_param_value!(clGetMemObjectInfo, u64, memobj, param_name, size);
138            Ok(ParamValue::ArrULong(param_value))
139        }
140        _ => status_update(40404, "clGetMemObjectInfo", ParamValue::default()),
141    }
142}
143
144/***********************/
145/*Shared Virtual Memory*/
146/***********************/
147pub fn svm_alloc(
148    context: &ContextPtr,
149    flags: &MemFlags,
150    size: size_t,
151    alignment: cl_uint,
152) -> APIResult<SVMPtr> {
153    let fn_name = "clSVMAlloc";
154    let mem_ptr = unsafe { ffi::clSVMAlloc(context.unwrap(), flags.get(), size, alignment) };
155    status_update(
156        StatusCode::SUCCESS,
157        fn_name,
158        SVMPtr::from_ptr(mem_ptr, fn_name)?,
159    )
160}
161
162pub fn svm_free(context: &ContextPtr, svm_pointer: WrappedMutablePointer<c_void>) {
163    unsafe { ffi::clSVMFree(context.unwrap(), svm_pointer.unwrap()) };
164}
165
166pub fn enqueue_svm_free(
167    command_queue: &QueuePtr,
168    num_svm_pointers: cl_uint,
169    svm_pointers: WrappedPointer<*const c_void>,
170    pfn_free_func: Option<
171        extern "C" fn(
172            queue: cl_command_queue,
173            num_svm_pointers: cl_uint,
174            svm_pointes: *const *const c_void,
175            user_data: *mut c_void,
176        ),
177    >,
178    user_data: WrappedMutablePointer<c_void>,
179    num_events_in_wait_list: cl_uint,
180    event_wait_list: WrappedPointer<cl_event>,
181) -> APIResult<EventPtr> {
182    let fn_name = "clEnqueueSVMFree";
183    let mut event_ptr = ptr::null_mut();
184    let status_code = unsafe {
185        ffi::clEnqueueSVMFree(
186            command_queue.unwrap(),
187            num_svm_pointers,
188            svm_pointers.unwrap(),
189            pfn_free_func,
190            user_data.unwrap(),
191            num_events_in_wait_list,
192            event_wait_list.unwrap(),
193            &mut event_ptr,
194        )
195    };
196    status_update(
197        status_code,
198        fn_name,
199        EventPtr::from_ptr(event_ptr, fn_name)?,
200    )
201}
202
203pub fn enqueue_svm_memcpy(
204    command_queue: QueuePtr,
205    blocking_copy: cl_bool,
206    dst_ptr: WrappedMutablePointer<c_void>,
207    src_ptr: WrappedPointer<c_void>,
208    size: size_t,
209    num_events_in_wait_list: cl_uint,
210    event_wait_list: WrappedPointer<cl_event>,
211) -> APIResult<EventPtr> {
212    let fn_name = "clEnqueueSVMMemcpy";
213    let mut event_ptr = ptr::null_mut();
214    let status_code = unsafe {
215        ffi::clEnqueueSVMMemcpy(
216            command_queue.unwrap(),
217            blocking_copy,
218            dst_ptr.unwrap(),
219            src_ptr.unwrap(),
220            size,
221            num_events_in_wait_list,
222            event_wait_list.unwrap(),
223            &mut event_ptr,
224        )
225    };
226    status_update(
227        status_code,
228        fn_name,
229        EventPtr::from_ptr(event_ptr, fn_name)?,
230    )
231}
232
233pub fn enqueue_svm_memfill(
234    command_queue: &QueuePtr,
235    svm_ptr: WrappedMutablePointer<c_void>,
236    pattern: WrappedPointer<c_void>,
237    pattern_size: size_t,
238    size: size_t,
239    num_events_in_wait_list: cl_uint,
240    event_wait_list: WrappedPointer<cl_event>,
241) -> APIResult<EventPtr> {
242    let fn_name = "clEnqueueSVMMemfill";
243    let mut event_ptr = ptr::null_mut();
244    let status_code = unsafe {
245        ffi::clEnqueueSVMMemFill(
246            command_queue.unwrap(),
247            svm_ptr.unwrap(),
248            pattern.unwrap(),
249            pattern_size,
250            size,
251            num_events_in_wait_list,
252            event_wait_list.unwrap(),
253            &mut event_ptr,
254        )
255    };
256    status_update(
257        status_code,
258        fn_name,
259        EventPtr::from_ptr(event_ptr, fn_name)?,
260    )
261}
262
263pub fn enqueue_svm_map(
264    command_queue: &QueuePtr,
265    blocking_map: cl_bool,
266    flags: MapFlags,
267    svm_ptr: WrappedMutablePointer<c_void>,
268    size: size_t,
269    num_events_in_wait_list: cl_uint,
270    event_wait_list: WrappedPointer<cl_event>,
271) -> APIResult<EventPtr> {
272    let fn_name = "clEnqueueSVMMap";
273    let mut event_ptr = ptr::null_mut();
274    let status_code = unsafe {
275        ffi::clEnqueueSVMMap(
276            command_queue.unwrap(),
277            blocking_map,
278            flags.get(),
279            svm_ptr.unwrap(),
280            size,
281            num_events_in_wait_list,
282            event_wait_list.unwrap(),
283            &mut event_ptr,
284        )
285    };
286    status_update(
287        status_code,
288        fn_name,
289        EventPtr::from_ptr(event_ptr, fn_name)?,
290    )
291}
292
293pub fn enqueue_svm_unmap(
294    command_queue: &QueuePtr,
295    svm_ptr: WrappedMutablePointer<c_void>,
296    num_events_in_wait_list: cl_uint,
297    event_wait_list: WrappedPointer<cl_event>,
298) -> APIResult<EventPtr> {
299    let fn_name = "clEnqueueSVMUnmap";
300    let mut event_ptr = ptr::null_mut();
301    let status_code = unsafe {
302        ffi::clEnqueueSVMUnmap(
303            command_queue.unwrap(),
304            svm_ptr.unwrap(),
305            num_events_in_wait_list,
306            event_wait_list.unwrap(),
307            &mut event_ptr,
308        )
309    };
310    status_update(
311        status_code,
312        fn_name,
313        EventPtr::from_ptr(event_ptr, fn_name)?,
314    )
315}
316
317pub fn enqueue_svm_migrate_mem(
318    command_queue: QueuePtr,
319    num_svm_pointers: cl_uint,
320    svm_pointers: WrappedPointer<*const c_void>,
321    sizes: WrappedPointer<size_t>,
322    flags: MemMigrationFlags,
323    num_events_in_wait_list: cl_uint,
324    event_wait_list: WrappedPointer<cl_event>,
325) -> APIResult<EventPtr> {
326    let fn_name = "clEnqueueSVMMigrateMem";
327    let mut event_ptr = ptr::null_mut();
328    let status_code = unsafe {
329        ffi::clEnqueueSVMMigrateMem(
330            command_queue.unwrap(),
331            num_svm_pointers,
332            svm_pointers.unwrap(),
333            sizes.unwrap(),
334            flags.get(),
335            num_events_in_wait_list,
336            event_wait_list.unwrap(),
337            &mut event_ptr,
338        )
339    };
340    status_update(
341        status_code,
342        fn_name,
343        EventPtr::from_ptr(event_ptr, fn_name)?,
344    )
345}
346
347// TODO: Add unit tests for this file.