open_cl_low_level/
context.rs1use crate::ffi::{
2 clCreateContext, clGetContextInfo, cl_context, cl_context_info, cl_context_properties,
3 cl_device_id,
4};
5use crate::{
6 build_output, ClDeviceID, ClPointer, ContextInfo, ContextProperties, DevicePtr, Output,
7 ObjectWrapper,
8};
9
10use crate::cl_helpers::cl_get_info5;
11
12#[allow(clippy::transmuting_null)]
13pub unsafe fn cl_create_context(device_ids: &[cl_device_id]) -> Output<cl_context> {
14 let mut err_code = 0;
15 let context = clCreateContext(
16 std::ptr::null(),
17 device_ids.len() as u32,
18 device_ids.as_ptr() as *const cl_device_id,
19 std::mem::transmute(std::ptr::null::<fn()>()),
20 std::ptr::null_mut(),
21 &mut err_code,
22 );
23 build_output(context, err_code)
24}
25
26pub fn cl_get_context_info<T>(context: cl_context, flag: cl_context_info) -> Output<ClPointer<T>>
27where
28 T: Copy,
29{
30 unsafe { cl_get_info5(context, flag as cl_context_info, clGetContextInfo) }
31}
32
33pub unsafe trait ContextPtr: Sized {
34 unsafe fn context_ptr(&self) -> cl_context;
35
36 unsafe fn info<T: Copy>(&self, flag: ContextInfo) -> Output<ClPointer<T>> {
37 cl_get_context_info::<T>(self.context_ptr(), flag.into())
38 }
39
40 unsafe fn reference_count(&self) -> Output<u32> {
41 self.info(ContextInfo::ReferenceCount)
42 .map(|ret| ret.into_one())
43 }
44
45 unsafe fn devices(&self) -> Output<Vec<ClDeviceID>> {
46 self.info(ContextInfo::Devices).map(|ret| {
47 let device_ids: Vec<cl_device_id> = ret.into_vec();
48 device_ids
49 .into_iter()
50 .map(|device_id| ClDeviceID::retain_new(device_id))
51 .filter_map(Result::ok)
52 .collect()
53 })
54 }
55
56 unsafe fn properties(&self) -> Output<Vec<ContextProperties>> {
57 self.info(ContextInfo::Properties)
58 .map(|ret: ClPointer<cl_context_properties>| {
59 ret.into_vec()
60 .into_iter()
61 .map(ContextProperties::from)
62 .collect()
63 })
64 }
65
66 unsafe fn num_devices(&self) -> Output<u32> {
67 self.info(ContextInfo::NumDevices).map(|ret| ret.into_one())
68 }
69}
70
71pub type ClContext = ObjectWrapper<cl_context>;
72
73impl ClContext {
74 pub unsafe fn create<D>(devices: &[D]) -> Output<ClContext>
75 where
76 D: DevicePtr,
77 {
78 let device_ptrs: Vec<cl_device_id> = devices.iter().map(|d| d.device_ptr()).collect();
79 let object = cl_create_context(&device_ptrs[..])?;
80 ClContext::new(object)
81 }
82}
83
84unsafe impl ContextPtr for ClContext {
85 unsafe fn context_ptr(&self) -> cl_context {
86 self.cl_object()
87 }
88}
89
90#[cfg(test)]
91mod test_context_ptr {
92 use crate::*;
93
94 #[test]
95 fn reference_count_works() {
96 let (ctx, _devices) = ll_testing::get_context();
97 let ref_count = unsafe { ctx.reference_count() }.unwrap();
98 assert_eq!(ref_count, 1);
100 }
101
102 #[test]
103 fn devices_works() {
104 let (ctx, _devices) = ll_testing::get_context();
105 let devices = unsafe { ctx.devices() }.unwrap();
106 assert!(devices.len() > 0);
107 }
108
109 #[test]
110 fn properties_works() {
111 let (ctx, _devices) = ll_testing::get_context();
112 let _props = unsafe { ctx.properties() }.unwrap();
113 }
114
115 #[test]
116 fn num_devices_works() {
117 let (ctx, _devices) = ll_testing::get_context();
118 let n_devices = unsafe { ctx.num_devices() }.unwrap();
119 assert!(n_devices > 0);
120 }
121
122 #[test]
123 fn devices_len_matches_num_devices() {
124 let (ctx, _devices) = ll_testing::get_context();
125 let num_devices = unsafe { ctx.num_devices() }.unwrap();
126 let devices = unsafe { ctx.devices() }.unwrap();
127 assert_eq!(num_devices as usize, devices.len());
128 }
129}