min_cl/
cl_device.rs

1use std::{cell::RefCell, fmt::Debug, time::Duration};
2
3use crate::{
4    api::{
5        clEnqueueNDRangeKernel, cl_event, create_command_queue, create_context, get_device_ids,
6        get_platforms, wait_for_events, CLIntDevice, CommandQueue, Context, DeviceType, Event,
7        Kernel, OCLErrorKind, Platform,
8    },
9    init_devices,
10    measure_perf::measure_perf,
11    Error, DEVICES,
12};
13
14pub fn all_devices() -> Result<Vec<Vec<CLIntDevice>>, Error> {
15    Ok(get_platforms()?
16        .into_iter()
17        .map(all_devices_of_platform)
18        .collect())
19}
20
21pub fn all_devices_of_platform(platform: Platform) -> Vec<CLIntDevice> {
22    [
23        DeviceType::GPU as u64 | DeviceType::ACCELERATOR as u64,
24        DeviceType::CPU as u64,
25    ]
26    .into_iter()
27    .flat_map(|device_type| get_device_ids(platform, &device_type))
28    .flatten()
29    .collect()
30}
31
32/// Internal representation of an OpenCL Device.
33pub struct CLDevice {
34    pub device: CLIntDevice,
35    pub ctx: Context,
36    pub queue: CommandQueue,
37    pub unified_mem: bool,
38    pub event_wait_list: RefCell<Vec<Event>>,
39}
40
41impl Debug for CLDevice {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        f.debug_struct("CLDevice")
44            .field("name", &self.device.get_name().unwrap())
45            .field("unified_mem", &self.unified_mem)
46            .field("event_wait_list", &self.event_wait_list.borrow())
47            .finish()
48    }
49}
50
51unsafe impl Sync for CLDevice {}
52unsafe impl Send for CLDevice {}
53
54impl TryFrom<CLIntDevice> for CLDevice {
55    type Error = Error;
56
57    fn try_from(device: CLIntDevice) -> Result<Self, Self::Error> {
58        let ctx = create_context(&[device])?;
59        let queue = create_command_queue(&ctx, device)?;
60        let unified_mem = device.unified_mem()?;
61
62        Ok(CLDevice {
63            device,
64            ctx,
65            queue,
66            unified_mem,
67            event_wait_list: Default::default(),
68        })
69    }
70}
71
72pub fn measured_devices() -> Result<Vec<(Duration, usize, usize, CLDevice)>, Error> {
73    Ok(all_devices()?
74        .into_iter()
75        .enumerate()
76        .map(|(platform_idx, devices)| {
77            devices
78                .into_iter()
79                .map(TryInto::try_into)
80                .enumerate()
81                .filter_map(move |(device_idx, device)| {
82                    Some((
83                        measure_perf(device.as_ref().ok()?).ok()?,
84                        platform_idx,
85                        device_idx,
86                        device.ok()?,
87                    ))
88                })
89        })
90        .flatten()
91        .collect::<Vec<_>>())
92}
93
94pub fn extract_indices_from_device_idx(device_idx: usize) -> Result<(usize, usize), Error> {
95    let rwlock_guard = DEVICES.read().map_err(|_| OCLErrorKind::InvalidDevice)?;
96    let devices = rwlock_guard.as_ref().unwrap();
97
98    let (_, platform_idx, device_idx, _) = &devices
99        .get(device_idx)
100        .ok_or(OCLErrorKind::InvalidDeviceIdx)?;
101    Ok((*platform_idx, *device_idx))
102}
103
104impl CLDevice {
105    pub fn from_indices(platform_idx: usize, device_idx: usize) -> Result<CLDevice, Error> {
106        let platform = get_platforms()?[platform_idx];
107        let devices = get_device_ids(platform, &(DeviceType::GPU as u64))?;
108
109        devices[device_idx].try_into()
110    }
111    pub fn new(device_idx: usize) -> Result<CLDevice, Error> {
112        init_devices();
113
114        let (platform_idx, device_idx) = extract_indices_from_device_idx(device_idx)?;
115        CLDevice::from_indices(platform_idx, device_idx)
116    }
117
118    pub fn fastest() -> Result<CLDevice, Error> {
119        Ok(measured_devices()?
120            .into_iter()
121            .min_by_key(|(dur, _, _, _)| *dur)
122            .ok_or(OCLErrorKind::InvalidDevice)
123            .map(|(_, _, _, device)| device)?)
124    }
125
126    pub fn enqueue_nd_range_kernel(
127        &self,
128        kernel: &Kernel,
129        wd: usize,
130        gws: &[usize; 3],
131        lws: Option<&[usize; 3]>,
132        offset: Option<[usize; 3]>,
133    ) -> Result<(), Error> {
134        let mut event = [std::ptr::null_mut(); 1];
135        let lws = match lws {
136            Some(lws) => lws.as_ptr(),
137            None => std::ptr::null(),
138        };
139        let offset = match offset {
140            Some(offset) => offset.as_ptr(),
141            None => std::ptr::null(),
142        };
143
144        let value = unsafe {
145            clEnqueueNDRangeKernel(
146                self.queue.0,
147                kernel.0,
148                wd as u32,
149                offset,
150                gws.as_ptr(),
151                lws,
152                0,
153                std::ptr::null(),
154                event.as_mut_ptr() as *mut cl_event,
155            )
156        };
157
158        if value != 0 {
159            return Err(Error::from(OCLErrorKind::from_value(value)));
160        }
161
162        self.event_wait_list.borrow_mut().push(Event(event[0]));
163
164        Ok(())
165    }
166
167    #[inline]
168    pub fn wait_for_events(&self) -> Result<(), Error> {
169        wait_for_events(&self.event_wait_list.borrow())?;
170        self.event_wait_list.borrow_mut().clear();
171        Ok(())
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use crate::{
178        api::{create_buffer, MemFlags},
179        CLDevice,
180    };
181
182    #[test]
183    fn test_get_fastest() {
184        let device = CLDevice::fastest().unwrap();
185        create_buffer::<f32>(&device.ctx, MemFlags::MemReadWrite as u64, 10000, None).unwrap();
186        println!("device name: {}", device.device.get_name().unwrap());
187        create_buffer::<f32>(&device.ctx, MemFlags::MemReadWrite as u64, 9423 * 123, None).unwrap();
188
189        println!(
190            "{}",
191            device.device.get_global_mem().unwrap() as f32 * 10f32.powf(-9.)
192        )
193    }
194}