1use std::{cell::RefCell, fmt::Debug, time::Duration};
2
3use crate::{
4 api::{
5 clEnqueueNDRangeKernel, cl_event, create_command_queue, create_context, get_device_ids,
6 get_platforms, wait_for_events, CLIntDevice, CommandQueue, Context, DeviceType, Event,
7 Kernel, OCLErrorKind, Platform,
8 },
9 init_devices,
10 measure_perf::measure_perf,
11 Error, DEVICES,
12};
13
14pub fn all_devices() -> Result<Vec<Vec<CLIntDevice>>, Error> {
15 Ok(get_platforms()?
16 .into_iter()
17 .map(all_devices_of_platform)
18 .collect())
19}
20
21pub fn all_devices_of_platform(platform: Platform) -> Vec<CLIntDevice> {
22 [
23 DeviceType::GPU as u64 | DeviceType::ACCELERATOR as u64,
24 DeviceType::CPU as u64,
25 ]
26 .into_iter()
27 .flat_map(|device_type| get_device_ids(platform, &device_type))
28 .flatten()
29 .collect()
30}
31
32pub struct CLDevice {
34 pub device: CLIntDevice,
35 pub ctx: Context,
36 pub queue: CommandQueue,
37 pub unified_mem: bool,
38 pub event_wait_list: RefCell<Vec<Event>>,
39}
40
41impl Debug for CLDevice {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 f.debug_struct("CLDevice")
44 .field("name", &self.device.get_name().unwrap())
45 .field("unified_mem", &self.unified_mem)
46 .field("event_wait_list", &self.event_wait_list.borrow())
47 .finish()
48 }
49}
50
51unsafe impl Sync for CLDevice {}
52unsafe impl Send for CLDevice {}
53
54impl TryFrom<CLIntDevice> for CLDevice {
55 type Error = Error;
56
57 fn try_from(device: CLIntDevice) -> Result<Self, Self::Error> {
58 let ctx = create_context(&[device])?;
59 let queue = create_command_queue(&ctx, device)?;
60 let unified_mem = device.unified_mem()?;
61
62 Ok(CLDevice {
63 device,
64 ctx,
65 queue,
66 unified_mem,
67 event_wait_list: Default::default(),
68 })
69 }
70}
71
72pub fn measured_devices() -> Result<Vec<(Duration, usize, usize, CLDevice)>, Error> {
73 Ok(all_devices()?
74 .into_iter()
75 .enumerate()
76 .map(|(platform_idx, devices)| {
77 devices
78 .into_iter()
79 .map(TryInto::try_into)
80 .enumerate()
81 .filter_map(move |(device_idx, device)| {
82 Some((
83 measure_perf(device.as_ref().ok()?).ok()?,
84 platform_idx,
85 device_idx,
86 device.ok()?,
87 ))
88 })
89 })
90 .flatten()
91 .collect::<Vec<_>>())
92}
93
94pub fn extract_indices_from_device_idx(device_idx: usize) -> Result<(usize, usize), Error> {
95 let rwlock_guard = DEVICES.read().map_err(|_| OCLErrorKind::InvalidDevice)?;
96 let devices = rwlock_guard.as_ref().unwrap();
97
98 let (_, platform_idx, device_idx, _) = &devices
99 .get(device_idx)
100 .ok_or(OCLErrorKind::InvalidDeviceIdx)?;
101 Ok((*platform_idx, *device_idx))
102}
103
104impl CLDevice {
105 pub fn from_indices(platform_idx: usize, device_idx: usize) -> Result<CLDevice, Error> {
106 let platform = get_platforms()?[platform_idx];
107 let devices = get_device_ids(platform, &(DeviceType::GPU as u64))?;
108
109 devices[device_idx].try_into()
110 }
111 pub fn new(device_idx: usize) -> Result<CLDevice, Error> {
112 init_devices();
113
114 let (platform_idx, device_idx) = extract_indices_from_device_idx(device_idx)?;
115 CLDevice::from_indices(platform_idx, device_idx)
116 }
117
118 pub fn fastest() -> Result<CLDevice, Error> {
119 Ok(measured_devices()?
120 .into_iter()
121 .min_by_key(|(dur, _, _, _)| *dur)
122 .ok_or(OCLErrorKind::InvalidDevice)
123 .map(|(_, _, _, device)| device)?)
124 }
125
126 pub fn enqueue_nd_range_kernel(
127 &self,
128 kernel: &Kernel,
129 wd: usize,
130 gws: &[usize; 3],
131 lws: Option<&[usize; 3]>,
132 offset: Option<[usize; 3]>,
133 ) -> Result<(), Error> {
134 let mut event = [std::ptr::null_mut(); 1];
135 let lws = match lws {
136 Some(lws) => lws.as_ptr(),
137 None => std::ptr::null(),
138 };
139 let offset = match offset {
140 Some(offset) => offset.as_ptr(),
141 None => std::ptr::null(),
142 };
143
144 let value = unsafe {
145 clEnqueueNDRangeKernel(
146 self.queue.0,
147 kernel.0,
148 wd as u32,
149 offset,
150 gws.as_ptr(),
151 lws,
152 0,
153 std::ptr::null(),
154 event.as_mut_ptr() as *mut cl_event,
155 )
156 };
157
158 if value != 0 {
159 return Err(Error::from(OCLErrorKind::from_value(value)));
160 }
161
162 self.event_wait_list.borrow_mut().push(Event(event[0]));
163
164 Ok(())
165 }
166
167 #[inline]
168 pub fn wait_for_events(&self) -> Result<(), Error> {
169 wait_for_events(&self.event_wait_list.borrow())?;
170 self.event_wait_list.borrow_mut().clear();
171 Ok(())
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use crate::{
178 api::{create_buffer, MemFlags},
179 CLDevice,
180 };
181
182 #[test]
183 fn test_get_fastest() {
184 let device = CLDevice::fastest().unwrap();
185 create_buffer::<f32>(&device.ctx, MemFlags::MemReadWrite as u64, 10000, None).unwrap();
186 println!("device name: {}", device.device.get_name().unwrap());
187 create_buffer::<f32>(&device.ctx, MemFlags::MemReadWrite as u64, 9423 * 123, None).unwrap();
188
189 println!(
190 "{}",
191 device.device.get_global_mem().unwrap() as f32 * 10f32.powf(-9.)
192 )
193 }
194}