open_cl_low_level/
platform.rs

1/// Platform has 3 basic functions (other than holding a cl object handle).
2///
3/// Platform is the interface for listing platforms.
4///
5/// Platform is the interface for getting metadata about platforms.
6///
7/// Platform is the interface for listing Devices.
8///
9/// NOTE: ClPlatformID is tested!
10use std::default::Default;
11use std::sync::Mutex;
12
13use crate::cl_enums::PlatformInfo;
14use crate::cl_helpers::cl_get_info5;
15use crate::ffi::{clGetPlatformIDs, clGetPlatformInfo, cl_platform_id, cl_platform_info, cl_uint};
16use crate::{build_output, utils, ClPointer, Error, Output, ObjectWrapper};
17
18lazy_static! {
19    static ref PLATFORM_ACCESS: Mutex<()> = Mutex::new(());
20}
21
22/// Gets the cl_platform_ids of the host machine
23pub fn cl_get_platform_ids() -> Output<ClPointer<cl_platform_id>> {
24    let platform_lock = PLATFORM_ACCESS.lock();
25    // transactional access to the platform Mutex requires one lock.
26    let mut num_platforms: cl_uint = 0;
27    let e1 = unsafe { clGetPlatformIDs(0, std::ptr::null_mut(), &mut num_platforms) };
28    let mut ids: Vec<cl_platform_id> =
29        utils::vec_filled_with(0 as cl_platform_id, num_platforms as usize);
30    build_output((), e1)?;
31    let e2 = unsafe { clGetPlatformIDs(num_platforms, ids.as_mut_ptr(), &mut num_platforms) };
32    build_output((), e2)?;
33    std::mem::drop(platform_lock);
34    Ok(unsafe { ClPointer::from_vec(ids) })
35}
36
37/// Gets platform info for a given cl_platform_id and the given cl_platform_info flag via the
38/// OpenCL FFI call to clGetPlatformInfo.alloc
39///
40/// # Safety
41/// Use of an invalid cl_platform_id is undefined behavior. A mismatch between the
42/// types that the info flag is supposed to result in and the T of the Output<ClPointer<T>> is
43/// undefined behavior. Be careful. There be dragons.
44pub unsafe fn cl_get_platform_info<T: Copy>(
45    platform: cl_platform_id,
46    info_flag: cl_platform_info,
47) -> Output<ClPointer<T>> {
48    cl_get_info5(platform, info_flag, clGetPlatformInfo)
49}
50
51/// An error related to Platform.
52#[derive(Debug, Fail, PartialEq, Eq, Clone)]
53pub enum PlatformError {
54    #[fail(display = "No platforms found!")]
55    NoPlatforms,
56
57    #[fail(display = "The platform has no useable devices!")]
58    NoUsableDevices,
59
60    #[fail(display = "The given platform had no default device!")]
61    NoDefaultDevice,
62}
63
64pub type ClPlatformID = ObjectWrapper<cl_platform_id>;
65
66pub trait PlatformPtr {
67    fn platform_ptr(&self) -> cl_platform_id;
68}
69
70// pub struct ClPlatformID {
71//     object: cl_platform_id,
72// }
73
74impl PlatformPtr for cl_platform_id {
75    fn platform_ptr(&self) -> cl_platform_id {
76        *self
77    }
78}
79
80impl PlatformPtr for ClPlatformID {
81    fn platform_ptr(&self) -> cl_platform_id {
82        unsafe { self.cl_object() }
83    }
84}
85
86impl PlatformPtr for &ClPlatformID {
87    fn platform_ptr(&self) -> cl_platform_id {
88        unsafe { self.cl_object() }
89    }
90}
91
92pub fn list_platforms() -> Output<Vec<ClPlatformID>> {
93    let mut plats = Vec::new();
94    unsafe {
95        let cl_ptr = cl_get_platform_ids()?;
96        for object in cl_ptr.into_vec().into_iter() {
97            let plat = ClPlatformID::new(object)?;
98            plats.push(plat);
99        }
100    }
101    Ok(plats)
102}
103
104pub fn default_platform() -> Output<ClPlatformID> {
105    let mut platforms = list_platforms()?;
106
107    if platforms.is_empty() {
108        return Err(Error::from(PlatformError::NoPlatforms));
109    }
110    Ok(platforms.remove(0))
111}
112
113pub fn platform_info<P: PlatformPtr, I: Into<cl_platform_info>>(
114    platform: P,
115    info_code: I,
116) -> Output<String> {
117    unsafe {
118        cl_get_platform_info(platform.platform_ptr(), info_code.into()).map(|ret| ret.into_string())
119    }
120}
121
122pub fn platform_name<P: PlatformPtr>(platform: P) -> Output<String> {
123    platform_info(platform, PlatformInfo::Name)
124}
125
126pub fn platform_version<P: PlatformPtr>(platform: P) -> Output<String> {
127    platform_info(platform, PlatformInfo::Version)
128}
129
130pub fn platform_profile<P: PlatformPtr>(platform: P) -> Output<String> {
131    platform_info(platform, PlatformInfo::Profile)
132}
133
134pub fn platform_vendor<P: PlatformPtr>(platform: P) -> Output<String> {
135    platform_info(platform, PlatformInfo::Vendor)
136}
137
138pub fn platform_extensions<P: PlatformPtr>(platform: P) -> Output<Vec<String>> {
139    platform_info(platform, PlatformInfo::Extensions)
140        .map(|exts| exts.split(' ').map(|ext| ext.to_string()).collect())
141}
142
143// v2.1
144// pub fn host_timer_resolution(&self) -> Output<String> {
145//     self.get_info(PlatformInfo::HostTimerResolution)
146// }
147
148unsafe impl Send for ClPlatformID {}
149unsafe impl Sync for ClPlatformID {}
150
151impl Default for ClPlatformID {
152    fn default() -> ClPlatformID {
153        default_platform().unwrap()
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use crate::ClPointer;
161    // use crate::device::{Device, DeviceType, DevicePtr};
162
163    #[test]
164    fn test_cl_get_platforms() {
165        let cl_pointer: ClPointer<cl_platform_id> = cl_get_platform_ids()
166            .unwrap_or_else(|e| panic!("cl_get_platforms failed with {:?}", e));
167
168        let platforms: Vec<cl_platform_id> = unsafe { cl_pointer.into_vec() };
169        assert!(platforms.len() > 0);
170
171        for p in platforms {
172            assert!(p.is_null() == false);
173        }
174    }
175
176    #[test]
177    fn platform_func_default_works() {
178        let _platform: ClPlatformID = ClPlatformID::default();
179    }
180
181    #[test]
182    fn platform_func_all_works() {
183        let platforms: Vec<ClPlatformID> = list_platforms().expect("list_platforms() failed");
184        assert!(platforms.len() > 0);
185    }
186
187    #[test]
188    fn platform_has_functions_getting_for_info() {
189        let platform = ClPlatformID::default();
190        let empty_string = "".to_string();
191
192        let name = platform_name(&platform).expect("failed to get platform info for name");
193
194        assert!(name != empty_string);
195
196        let version = platform_version(&platform).expect("failed to get platform info for version");
197
198        assert!(version != empty_string);
199
200        let profile = platform_profile(&platform).expect("failed to get platform info for profile");
201
202        assert!(profile != empty_string);
203
204        let vendor = platform_vendor(&platform).expect("failed to get platform info for vendor");
205
206        assert!(vendor != empty_string);
207
208        let extensions =
209            platform_extensions(&platform).expect("failed to get platform info for extensions");
210
211        for ext in extensions.into_iter() {
212            assert!(ext != empty_string);
213        }
214        // v2.1
215        // let host_timer_resolution = platform.host_timer_resolution()
216        //     .expect("failed to get platform info for host_timer_resolution");
217
218        // assert_eq!(host_timer_resolution, "".to_string());
219    }
220}