open_cl_low_level/
context.rs

1use crate::ffi::{
2    clCreateContext, clGetContextInfo, cl_context, cl_context_info, cl_context_properties,
3    cl_device_id,
4};
5use crate::{
6    build_output, ClDeviceID, ClPointer, ContextInfo, ContextProperties, DevicePtr, Output,
7    ObjectWrapper,
8};
9
10use crate::cl_helpers::cl_get_info5;
11
12#[allow(clippy::transmuting_null)]
13pub unsafe fn cl_create_context(device_ids: &[cl_device_id]) -> Output<cl_context> {
14    let mut err_code = 0;
15    let context = clCreateContext(
16        std::ptr::null(),
17        device_ids.len() as u32,
18        device_ids.as_ptr() as *const cl_device_id,
19        std::mem::transmute(std::ptr::null::<fn()>()),
20        std::ptr::null_mut(),
21        &mut err_code,
22    );
23    build_output(context, err_code)
24}
25
26pub fn cl_get_context_info<T>(context: cl_context, flag: cl_context_info) -> Output<ClPointer<T>>
27where
28    T: Copy,
29{
30    unsafe { cl_get_info5(context, flag as cl_context_info, clGetContextInfo) }
31}
32
33pub unsafe trait ContextPtr: Sized {
34    unsafe fn context_ptr(&self) -> cl_context;
35
36    unsafe fn info<T: Copy>(&self, flag: ContextInfo) -> Output<ClPointer<T>> {
37        cl_get_context_info::<T>(self.context_ptr(), flag.into())
38    }
39
40    unsafe fn reference_count(&self) -> Output<u32> {
41        self.info(ContextInfo::ReferenceCount)
42            .map(|ret| ret.into_one())
43    }
44
45    unsafe fn devices(&self) -> Output<Vec<ClDeviceID>> {
46        self.info(ContextInfo::Devices).map(|ret| {
47            let device_ids: Vec<cl_device_id> = ret.into_vec();
48            device_ids
49                .into_iter()
50                .map(|device_id| ClDeviceID::retain_new(device_id))
51                .filter_map(Result::ok)
52                .collect()
53        })
54    }
55
56    unsafe fn properties(&self) -> Output<Vec<ContextProperties>> {
57        self.info(ContextInfo::Properties)
58            .map(|ret: ClPointer<cl_context_properties>| {
59                ret.into_vec()
60                    .into_iter()
61                    .map(ContextProperties::from)
62                    .collect()
63            })
64    }
65
66    unsafe fn num_devices(&self) -> Output<u32> {
67        self.info(ContextInfo::NumDevices).map(|ret| ret.into_one())
68    }
69}
70
71pub type ClContext = ObjectWrapper<cl_context>;
72
73impl ClContext {
74    pub unsafe fn create<D>(devices: &[D]) -> Output<ClContext>
75    where
76        D: DevicePtr,
77    {
78        let device_ptrs: Vec<cl_device_id> = devices.iter().map(|d| d.device_ptr()).collect();
79        let object = cl_create_context(&device_ptrs[..])?;
80        ClContext::new(object)
81    }
82}
83
84unsafe impl ContextPtr for ClContext {
85    unsafe fn context_ptr(&self) -> cl_context {
86        self.cl_object()
87    }
88}
89
90#[cfg(test)]
91mod test_context_ptr {
92    use crate::*;
93
94    #[test]
95    fn reference_count_works() {
96        let (ctx, _devices) = ll_testing::get_context();
97        let ref_count = unsafe { ctx.reference_count() }.unwrap();
98        // this is the only place this context should be.
99        assert_eq!(ref_count, 1);
100    }
101
102    #[test]
103    fn devices_works() {
104        let (ctx, _devices) = ll_testing::get_context();
105        let devices = unsafe { ctx.devices() }.unwrap();
106        assert!(devices.len() > 0);
107    }
108
109    #[test]
110    fn properties_works() {
111        let (ctx, _devices) = ll_testing::get_context();
112        let _props = unsafe { ctx.properties() }.unwrap();
113    }
114
115    #[test]
116    fn num_devices_works() {
117        let (ctx, _devices) = ll_testing::get_context();
118        let n_devices = unsafe { ctx.num_devices() }.unwrap();
119        assert!(n_devices > 0);
120    }
121
122    #[test]
123    fn devices_len_matches_num_devices() {
124        let (ctx, _devices) = ll_testing::get_context();
125        let num_devices = unsafe { ctx.num_devices() }.unwrap();
126        let devices = unsafe { ctx.devices() }.unwrap();
127        assert_eq!(num_devices as usize, devices.len());
128    }
129}