opencl_api/api/
context.rs

1/*
2 * context.rs - Context API wrappers (Part of OpenCL Platform Layer).
3 *
4 * Copyright 2020-2021 Naman Bishnoi
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 *     http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18use crate::objects::bitfields::DeviceType;
19use crate::objects::enums::{ParamValue, Size};
20use crate::objects::functions::status_update;
21use crate::objects::structs::ContextInfo;
22use crate::objects::traits::GetSetGo;
23use crate::objects::types::{APIResult, ContextPtr, DeviceList, Properties};
24use crate::objects::wrappers::WrappedMutablePointer;
25use crate::{gen_param_value, size_getter};
26use libc::c_void;
27use opencl_heads::ffi;
28use opencl_heads::ffi::clGetContextInfo;
29use opencl_heads::types::*;
30use std::ptr;
31
32pub fn create_context(
33    properties: &Properties,
34    devices: DeviceList,
35    pfn_notify: Option<extern "C" fn(*const c_char, *const c_void, size_t, *mut c_void)>,
36    user_data: WrappedMutablePointer<c_void>,
37) -> APIResult<ContextPtr> {
38    let fn_name = "clCreateContext";
39    let mut status_code = 0;
40    let properties = match properties {
41        Some(x) => x.as_ptr(),
42        None => ptr::null(),
43    };
44    let context = unsafe {
45        ffi::clCreateContext(
46            properties,
47            devices.len() as cl_uint,
48            devices.as_ptr(),
49            pfn_notify,
50            user_data.unwrap(),
51            &mut status_code,
52        )
53    };
54    status_update(
55        status_code,
56        fn_name,
57        ContextPtr::from_ptr(context, fn_name)?,
58    )
59}
60
61pub fn create_context_from_type(
62    properties: &Properties,
63    device_type: DeviceType,
64    pfn_notify: Option<extern "C" fn(*const c_char, *const c_void, size_t, *mut c_void)>,
65    user_data: WrappedMutablePointer<c_void>,
66) -> APIResult<ContextPtr> {
67    let fn_name = "clCreateContextFromType";
68    let mut status_code = 0;
69    let properties = match properties {
70        Some(x) => x.as_ptr(),
71        None => ptr::null(),
72    };
73    let context = unsafe {
74        ffi::clCreateContextFromType(
75            properties,
76            device_type.get(),
77            pfn_notify,
78            user_data.unwrap(),
79            &mut status_code,
80        )
81    };
82    status_update(
83        status_code,
84        fn_name,
85        ContextPtr::from_ptr(context, fn_name)?,
86    )
87}
88
89pub fn retain_context(context: &ContextPtr) -> APIResult<()> {
90    let status_code = unsafe { ffi::clRetainContext(context.unwrap()) };
91    status_update(status_code, "clRetainContext", ())
92}
93
94pub fn release_context(context: ContextPtr) -> APIResult<()> {
95    let status_code = unsafe { ffi::clReleaseContext(context.unwrap()) };
96    status_update(status_code, "clReleaseContext", ())
97}
98
99pub fn get_context_info(
100    context: &ContextPtr,
101    param_name: cl_context_info,
102) -> APIResult<ParamValue> {
103    type C = ContextInfo;
104    let context = context.unwrap();
105    size_getter!(get_context_info_size, clGetContextInfo);
106    match param_name {
107        C::REFERENCE_COUNT | C::NUM_DEVICES => {
108            let param_value = gen_param_value!(clGetContextInfo, u32, context, param_name);
109            Ok(ParamValue::UInt(param_value))
110        }
111        C::DEVICES | C::PROPERTIES => {
112            let size = get_context_info_size(context, param_name)?;
113            let param_value = gen_param_value!(clGetContextInfo, isize, context, param_name, size);
114            Ok(ParamValue::ArrCPtr(param_value))
115        }
116        _ => status_update(40404, "clGetContextInfo", ParamValue::default()),
117    }
118}
119
120pub fn set_context_destructor_callback(
121    context: &ContextPtr,
122    pfn_notify: extern "C" fn(context: cl_context, user_data: *mut c_void),
123    user_data: WrappedMutablePointer<c_void>,
124) -> APIResult<()> {
125    let status_code = unsafe {
126        ffi::clSetContextDestructorCallback(context.unwrap(), pfn_notify, user_data.unwrap())
127    };
128    status_update(status_code, "clSetContextDestructorCallback", ())
129}
130
131/************************/
132/* /\ /\ /\ /\ /\ /\ /\ */
133/*|__|__|__|__|__|__|__|*/
134/*|  |  |  |  |  |  |  |*/
135/*|  |  Unit Tests  |  |*/
136/*|__|__|__|__|__|__|__|*/
137/*|__|__|__|__|__|__|__|*/
138/************************/
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    use crate::api::device::get_device_ids;
144    use crate::api::platform::get_platform_ids;
145    use crate::objects::bitfields::DeviceType;
146    use crate::objects::property::ContextProperties;
147    use crate::objects::types::{PlatformPtr, WrapMutPtr};
148
149    #[test]
150    fn test_create_context() {
151        let platform_ids = get_platform_ids().unwrap();
152        // Choose the first platform
153        // let platform_id = platform_ids[0];
154        let platform_id = PlatformPtr::from_ptr(platform_ids[0], "test_fn").unwrap();
155
156        let device_ids =
157            get_device_ids(&platform_id, DeviceType::new(DeviceType::DEFAULT).unwrap()).unwrap();
158        assert!(0 < device_ids.len());
159
160        let context = create_context(&None, device_ids, None, WrapMutPtr::null());
161        let context = context.unwrap();
162        eprintln!("CL_CONTEXT_PTR: {:?}", context);
163        assert_ne!(context.unwrap(), ptr::null_mut());
164        release_context(context).unwrap();
165    }
166    #[test]
167    fn test_create_context_from_type() {
168        let platform_ids = get_platform_ids().unwrap();
169        // Choose the first platform
170        // let platform_id = platform_ids[0];
171        let platform_id = PlatformPtr::from_ptr(platform_ids[0], "test_fn").unwrap();
172        // let properties = vec![ContextProperties::PLATFORM, platform_id as isize, 0];
173        let properties = ContextProperties.gen(Some(&platform_id), None);
174        let default_device = DeviceType::new(DeviceType::DEFAULT).unwrap();
175        let context =
176            create_context_from_type(&properties, default_device, None, WrapMutPtr::null());
177        let context = context.unwrap();
178        eprintln!("CL_CONTEXT_PTR: {:?}", context);
179        assert_ne!(context.unwrap(), ptr::null_mut());
180        release_context(context).unwrap();
181    }
182    #[test]
183    fn test_get_context_info_1() {
184        let platform_ids = get_platform_ids().unwrap();
185        // Choose the first platform
186        // let platform_id = platform_ids[0];
187        let platform_id = PlatformPtr::from_ptr(platform_ids[0], "test_fn").unwrap();
188        // let properties = vec![ContextProperties::PLATFORM, platform_id as isize, 0];
189        // let properties = ContextProperties.platform(&platform_id);
190        let properties = ContextProperties.gen(Some(&platform_id), None);
191        let default_device = DeviceType::new(DeviceType::DEFAULT).unwrap();
192        let context = create_context_from_type(
193            &properties,
194            default_device,
195            None,
196            WrappedMutablePointer::null(),
197        );
198        let context = context.unwrap();
199        eprintln!("CL_CONTEXT_PTR: {:?}", context);
200        assert_ne!(context.unwrap(), ptr::null_mut());
201        let device = get_context_info(&context, ContextInfo::DEVICES);
202        eprintln!("CL_CONTEXT_DEVICE: {:?}", device);
203        assert_ne!(device.unwrap().unwrap_arr_cptr().unwrap()[0], 0);
204
205        let properties = get_context_info(&context, ContextInfo::PROPERTIES);
206        eprintln!("CL_CONTEXT_PROPERTIES: {:?}", properties);
207        let re_platform_id = properties.unwrap().unwrap_arr_cptr().unwrap()[1];
208        assert_eq!(re_platform_id, platform_id.unwrap() as isize);
209        release_context(context).unwrap();
210    }
211    #[test]
212    fn test_get_context_info_2() {
213        let platform_ids = get_platform_ids().unwrap();
214        // Choose the first platform
215        // let platform_id = platform_ids[0];
216        let platform_id = PlatformPtr::from_ptr(platform_ids[0], "test_fn").unwrap();
217
218        // let properties = vec![ContextProperties::PLATFORM, platform_id as isize, 0];
219        let properties = ContextProperties.gen(Some(&platform_id), None);
220        // let properties = ContextProperties.platform(&platform_id);
221        let default_device = DeviceType::new(DeviceType::DEFAULT).unwrap();
222        let context = create_context_from_type(
223            &properties,
224            default_device,
225            None,
226            WrappedMutablePointer::null(),
227        );
228        let context = context.unwrap();
229        eprintln!("CL_CONTEXT_PTR: {:?}", context);
230        assert_ne!(context.unwrap(), ptr::null_mut());
231        let device_count = get_context_info(&context, ContextInfo::NUM_DEVICES);
232        eprintln!("CL_CONTEXT_DEVICE_COUNT: {:?}", device_count);
233        assert_ne!(device_count.unwrap().unwrap_uint().unwrap(), 0);
234        let ref_count = get_context_info(&context, ContextInfo::REFERENCE_COUNT);
235        eprintln!("CL_CONTEXT_REFERENCE_COUNT: {:?}", ref_count);
236        assert_ne!(ref_count.unwrap().unwrap_uint().unwrap(), 0);
237
238        release_context(context).unwrap();
239    }
240}