opencl_core/
session.rs

1use std::mem::ManuallyDrop;
2use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
3
4use crate::{
5    Buffer, BufferCreator, CommandQueueOptions, Context, Device, DeviceType, Kernel, KernelOpArg,
6    KernelOperation, MemConfig, MutVecOrSlice, Output, Program, VecOrSlice, Waitlist,
7    Work, NumberTyped,
8};
9
10use crate::ll::{
11    list_devices_by_type, list_platforms, BufferReadEvent, ClCommandQueue, ClContext, ClDeviceID,
12    ClEvent, ClKernel, ClMem, ClNumber, ClProgram, CommandQueueProperties, CommandQueuePtr,
13    DevicePtr, KernelArg,
14};
15
16#[derive(Debug)]
17pub struct Session {
18    _device: ManuallyDrop<ClDeviceID>,
19    _program: ManuallyDrop<ClProgram>,
20    _context: ManuallyDrop<ClContext>,
21    _queue: ManuallyDrop<Arc<RwLock<ClCommandQueue>>>,
22    _unconstructable: (),
23}
24
25unsafe impl Send for Session {}
26unsafe impl Sync for Session {}
27
28impl Session {
29    pub fn create_with_devices<'a, D>(
30        devices: D,
31        src: &str,
32        cq_props: Option<CommandQueueProperties>,
33    ) -> Output<Vec<Session>>
34    where
35        D: Into<VecOrSlice<'a, Device>>,
36    {
37        let devices: Vec<Device> = devices.into().to_vec();
38        unsafe {
39            let context = ClContext::create(devices.as_slice())?;
40            let mut sessions: Vec<Session> = Vec::with_capacity(devices.len());
41            for device in devices.iter() {
42                let device = ClDeviceID::unchecked_new(device.device_ptr());
43                let mut program = ClProgram::create_with_source(&context, src)?;
44                program.build(devices.as_slice())?;
45
46                let queue = ClCommandQueue::create(&context, &device, cq_props)?;
47                let session = Session {
48                    _device: ManuallyDrop::new(device),
49                    _context: ManuallyDrop::new(context.clone()),
50                    _program: ManuallyDrop::new(program.clone()),
51                    _queue: ManuallyDrop::new(Arc::new(RwLock::new(queue))),
52                    _unconstructable: (),
53                };
54                sessions.push(session);
55            }
56            Ok(sessions)
57        }
58    }
59
60    pub fn create(src: &str, cq_props: Option<CommandQueueProperties>) -> Output<Vec<Session>> {
61        let platforms = list_platforms()?;
62        let mut devices: Vec<Device> = Vec::new();
63        for platform in platforms.iter() {
64            let platform_devices: Vec<Device> = list_devices_by_type(platform, DeviceType::ALL)
65                .map(|ll_devices| ll_devices.into_iter().map(|d| Device::new(d)).collect())?;
66            devices.extend(platform_devices);
67        }
68        Session::create_with_devices(devices, src, cq_props)
69    }
70
71    pub fn context(&self) -> Context {
72        Context::from_low_level_context(self.low_level_context()).unwrap()
73    }
74
75    pub fn device(&self) -> Device {
76        Device::new(self.low_level_device().clone())
77    }
78
79    pub fn program(&self) -> Program {
80        unsafe { Program::from_low_level_program(self.low_level_program()).unwrap() }
81    }
82
83    pub fn read_queue(&self) -> RwLockReadGuard<ClCommandQueue> {
84        self._queue.read().unwrap()
85    }
86
87    pub fn write_queue(&self) -> RwLockWriteGuard<ClCommandQueue> {
88        self._queue.write().unwrap()
89    }
90
91    pub fn low_level_device(&self) -> &ClDeviceID {
92        &*self._device
93    }
94
95    pub fn low_level_context(&self) -> &ClContext {
96        &self._context
97    }
98
99    pub fn low_level_program(&self) -> &ClProgram {
100        &self._program
101    }
102
103    pub fn create_copy(&self) -> Output<Session> {
104        let cloned_device = self._device.clone();
105        let cloned_context = self._context.clone();
106        let cloned_program = self._program.clone();
107        let ll_queue = self._queue.read().unwrap();
108        let copied_queue = unsafe { ll_queue.create_copy()? };
109
110        Ok(Session {
111            _device: cloned_device,
112            _context: cloned_context,
113            _program: cloned_program,
114            _queue: ManuallyDrop::new(Arc::new(RwLock::new(copied_queue))),
115            _unconstructable: (),
116        })
117    }
118
119    /// Creates a ClKernel from the session's program.
120    pub fn create_kernel(&self, kernel_name: &str) -> Output<Kernel> {
121        unsafe {
122            let ll_kernel = ClKernel::create(self.low_level_program(), kernel_name)?;
123            Ok(Kernel::new(ll_kernel, self.program()))
124        }
125    }
126
127    /// Creates a ClMem object in the given context, with the given buffer creator
128    /// (either a length or some data). This function uses the BufferCreator's implementation
129    /// to retrieve the appropriate MemConfig.
130    pub fn create_buffer<T: ClNumber, B: BufferCreator<T>>(
131        &self,
132        buffer_creator: B,
133    ) -> Output<Buffer> {
134        let cfg = buffer_creator.mem_config();
135        Buffer::create_from_low_level_context(
136            self.low_level_context(),
137            buffer_creator,
138            cfg.host_access,
139            cfg.kernel_access,
140            cfg.mem_location,
141        )
142    }
143
144    /// Creates a ClMem object in the given context, with the given buffer creator
145    /// (either a length or some data) and a given MemConfig.
146    pub fn create_buffer_with_config<T: ClNumber, B: BufferCreator<T>>(
147        &self,
148        buffer_creator: B,
149        mem_config: MemConfig,
150    ) -> Output<Buffer> {
151        Buffer::create_from_low_level_context(
152            self.low_level_context(),
153            buffer_creator,
154            mem_config.host_access,
155            mem_config.kernel_access,
156            mem_config.mem_location,
157        )
158    }
159
160    /// This function copies data from the host buffer into the device mem buffer. The host
161    /// buffer must be a mutable slice or a vector to ensure the safety of the read_Buffer
162    /// operation.
163    pub fn sync_write_buffer<'a, T: ClNumber, H: Into<VecOrSlice<'a, T>>>(
164        &self,
165        buffer: &Buffer,
166        host_buffer: H,
167        opts: Option<CommandQueueOptions>,
168    ) -> Output<()> {
169        buffer.number_type().type_check(T::number_type())?;
170        let mut queue = self.write_queue();
171        let mut buffer_lock = buffer.write_lock();
172        unsafe {
173            let event: ClEvent = queue.write_buffer(&mut (*buffer_lock), host_buffer, opts)?;
174            event.wait()
175        }
176    }
177
178    /// This function copies data from a device mem buffer into a host buffer. The host
179    /// buffer must be a mutable slice or a vector. For the moment the device mem must also
180    /// be passed as mutable; I don't trust OpenCL.
181    pub fn sync_read_buffer<'a, T: ClNumber, H: Into<MutVecOrSlice<'a, T>>>(
182        &self,
183        buffer: &Buffer,
184        host_buffer: H,
185        opts: Option<CommandQueueOptions>,
186    ) -> Output<Option<Vec<T>>> {
187        buffer.number_type().type_check(T::number_type())?;
188        let mut queue = self.write_queue();
189
190        let buffer_lock = buffer.read_lock();
191        unsafe {
192            let mut event: BufferReadEvent<T> =
193                queue.read_buffer(&(*buffer_lock), host_buffer, opts)?;
194            event.wait()
195        }
196    }
197
198    /// This function enqueues a CLKernel into a command queue
199    ///
200    /// # Safety
201    /// If the ClKernel is not in a usable state or any of the Kernel's dependent object
202    /// has been release, or the kernel belongs to a different session, or the ClKernel's
203    /// pointer is a null pointer, then calling this function will cause undefined behavior.
204    pub fn sync_enqueue_kernel(
205        &self,
206        kernel: &Kernel,
207        work: &Work,
208        opts: Option<CommandQueueOptions>,
209    ) -> Output<()> {
210        let mut queue = self.write_queue();
211        let mut kernel_lock = kernel.write_lock();
212        unsafe {
213            let event = queue.enqueue_kernel(&mut (*kernel_lock), work, opts)?;
214            event.wait()
215        }
216    }
217
218    pub fn execute_sync_kernel_operation<'a, T>(
219        &self,
220        mut kernel_op: KernelOperation<'a, T>,
221    ) -> Output<()>
222    where
223        T: ClNumber + KernelArg,
224    {
225        unsafe {
226            let kernel = self.create_kernel(kernel_op.name())?;
227            let work = kernel_op.work()?;
228            let command_queue_opts = kernel_op.command_queue_opts();
229            let mut mem_locks: Vec<RwLockWriteGuard<ClMem>> = Vec::new();
230            for (arg_index, arg) in kernel_op.mut_args().iter_mut().enumerate() {
231                match arg {
232                    KernelOpArg::Num(ref mut num) => kernel.set_arg(arg_index, num)?,
233                    KernelOpArg::Buffer(ref buffer) => {
234                        let mut mem = buffer.write_lock();
235                        kernel.set_arg(arg_index, &mut *mem)?;
236                        mem_locks.push(mem);
237                    }
238                }
239            }
240
241            let mut queue = self.write_queue();
242            let mut ll_kernel = kernel.write_lock();
243            let event = queue.enqueue_kernel(&mut ll_kernel, &work, command_queue_opts)?;
244            // Wait until queued mems finish being accessed.
245            event.wait()?;
246            // then drop locks.
247            std::mem::drop(mem_locks);
248            Ok(())
249        }
250    }
251}
252
253impl Clone for Session {
254    fn clone(&self) -> Session {
255        Session {
256            _device: self._device.clone(),
257            _context: self._context.clone(),
258            _program: self._program.clone(),
259            _queue: self._queue.clone(),
260            _unconstructable: (),
261        }
262    }
263}
264
265impl PartialEq for Session {
266    fn eq(&self, other: &Self) -> bool {
267        unsafe {
268            let self_queue_ptr = self.read_queue().command_queue_ptr();
269            let other_queue_ptr = other.read_queue().command_queue_ptr();
270            std::ptr::eq(self_queue_ptr, other_queue_ptr)
271        }
272    }
273}
274
275impl Eq for Session {}
276
277impl Drop for Session {
278    fn drop(&mut self) {
279        unsafe {
280            ManuallyDrop::drop(&mut self._queue);
281            ManuallyDrop::drop(&mut self._program);
282            ManuallyDrop::drop(&mut self._context);
283            ManuallyDrop::drop(&mut self._device);
284        }
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use crate::{testing, Buffer, Kernel, Session, Work};
291
292    const SRC: &'static str = "__kernel void test(__global int *data) {
293        data[get_global_id(0)] += 1;
294    }";
295
296    fn new_session() -> Session {
297        testing::get_session(SRC)
298    }
299
300    #[test]
301    fn session_can_be_created_with_src() {
302        let _session = Session::create(SRC, None).unwrap_or_else(|e| {
303            panic!("Failed to create session: {:?}", e);
304        });
305    }
306
307    #[test]
308    fn session_can_be_created_with_src_and_slice_of_devices() {
309        let devices = testing::get_all_devices();
310        assert_ne!(devices.len(), 0);
311        let _session = Session::create_with_devices(&devices[..], SRC, None).unwrap_or_else(|e| {
312            panic!("Failed to create session with slice of devices: {:?}", e);
313        });
314    }
315
316    #[test]
317    fn session_can_be_created_with_src_and_vec_of_devices() {
318        let devices = testing::get_all_devices();
319        assert_ne!(devices.len(), 0);
320        let _session = Session::create_with_devices(devices, SRC, None).unwrap_or_else(|e| {
321            panic!("Failed to create session with vec of devices: {:?}", e);
322        });
323    }
324
325    #[test]
326    fn session_implements_clone() {
327        let _other: Session = new_session().clone();
328    }
329
330    #[test]
331    fn session_implementation_of_fmt_debug_works() {
332        let session = new_session();
333        let formatted = format!("{:?}", session);
334        assert!(
335            formatted.starts_with("Session"),
336            "Formatted did not start with the correct value. Got: {:?}",
337            formatted
338        );
339    }
340
341    #[test]
342    fn session_create_copy_copies_command_queue_and_clones_the_rest() {
343        let session = new_session();
344
345        let session_copy = session.create_copy().unwrap_or_else(|e| {
346            panic!("Failed to create_copy of session: {:?}", e);
347        });
348        let s1_queue = session.read_queue();
349        let s2_queue = session_copy.read_queue();
350        assert_ne!(*s1_queue, *s2_queue);
351        assert_eq!(
352            session.low_level_context(),
353            session_copy.low_level_context()
354        );
355        assert_eq!(
356            session.low_level_program(),
357            session_copy.low_level_program()
358        );
359        assert_ne!(session, session_copy);
360    }
361
362    #[test]
363    fn session_can_create_kernel() {
364        let src = "__kernel void add_one_i32(__global int *i) { *i += 1; }";
365        let session = testing::get_session(src);
366        let _kernel: Kernel = session.create_kernel("add_one_i32").unwrap_or_else(|e| {
367            panic!("Failed to create kernel for session: {:?}", e);
368        });
369    }
370
371    #[test]
372    fn session_can_create_buffer_from_data() {
373        let data: Vec<i32> = vec![0, 1, 2, 3, 4, 5, 6, 7];
374        let session = new_session();
375        let _buffer: Buffer = session
376            .create_buffer(&data[..])
377            .unwrap_or_else(|e| panic!("Session failed to create buffer: {:?}", e));
378    }
379
380    #[test]
381    fn session_can_create_buffer_of_a_given_length() {
382        let session = new_session();
383        let buffer: Buffer = session
384            .create_buffer::<i32, usize>(100)
385            .unwrap_or_else(|e| panic!("Session failed to create buffer: {:?}", e));
386        assert_eq!(buffer.len(), 100);
387    }
388
389    #[test]
390    fn session_can_write_and_read_buffer() {
391        let data: Vec<i32> = vec![0, 1, 2, 3, 4, 5, 6, 7];
392        let session = new_session();
393        let buffer: Buffer = session
394            .create_buffer(&data[..])
395            .unwrap_or_else(|e| panic!("Session failed to create buffer: {:?}", e));
396        assert_eq!(buffer.len(), 8);
397        let () = session
398            .sync_write_buffer(&buffer, &data[..], None)
399            .unwrap_or_else(|e| {
400                panic!("Failed to write buffer: {:?}", e);
401            });
402        let data2 = vec![0i32; 8];
403        let data3 = session
404            .sync_read_buffer(&buffer, data2, None)
405            .unwrap_or_else(|e| {
406                panic!("Failed to write buffer: {:?}", e);
407            })
408            .unwrap();
409
410        assert_eq!(data3.len(), 8);
411        assert_eq!(data3, data);
412    }
413
414    #[test]
415    fn session_sync_enqueue_kernel_and_read_buffer() {
416        let data: Vec<i32> = vec![0, 1, 2, 3, 4, 5, 6, 7];
417        let session = new_session();
418        let buffer: Buffer = session
419            .create_buffer(&data[..])
420            .unwrap_or_else(|e| panic!("Session failed to create buffer: {:?}", e));
421        assert_eq!(buffer.len(), 8);
422        let () = session
423            .sync_write_buffer(&buffer, &data[..], None)
424            .unwrap_or_else(|e| {
425                panic!("Failed to write buffer: {:?}", e);
426            });
427        let kernel: Kernel = session.create_kernel("test").unwrap();
428        let mut buffer_lock = buffer.write_lock();
429        unsafe { kernel.set_arg(0, &mut (*buffer_lock)).unwrap() };
430        let work = Work::new(data.len());
431        session.sync_enqueue_kernel(&kernel, &work, None).unwrap();
432        std::mem::drop(buffer_lock);
433
434        let data2 = vec![0i32; 8];
435        let data3 = session
436            .sync_read_buffer(&buffer, data2, None)
437            .unwrap_or_else(|e| {
438                panic!("Failed to write buffer: {:?}", e);
439            })
440            .unwrap();
441
442        assert_eq!(data3.len(), 8);
443        let expected_data: Vec<i32> = vec![1, 2, 3, 4, 5, 6, 7, 8];
444        assert_eq!(data3, expected_data);
445    }
446}