1use std::mem::ManuallyDrop;
2use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
3
4use crate::{
5 Buffer, BufferCreator, CommandQueueOptions, Context, Device, DeviceType, Kernel, KernelOpArg,
6 KernelOperation, MemConfig, MutVecOrSlice, Output, Program, VecOrSlice, Waitlist,
7 Work, NumberTyped,
8};
9
10use crate::ll::{
11 list_devices_by_type, list_platforms, BufferReadEvent, ClCommandQueue, ClContext, ClDeviceID,
12 ClEvent, ClKernel, ClMem, ClNumber, ClProgram, CommandQueueProperties, CommandQueuePtr,
13 DevicePtr, KernelArg,
14};
15
16#[derive(Debug)]
17pub struct Session {
18 _device: ManuallyDrop<ClDeviceID>,
19 _program: ManuallyDrop<ClProgram>,
20 _context: ManuallyDrop<ClContext>,
21 _queue: ManuallyDrop<Arc<RwLock<ClCommandQueue>>>,
22 _unconstructable: (),
23}
24
25unsafe impl Send for Session {}
26unsafe impl Sync for Session {}
27
28impl Session {
29 pub fn create_with_devices<'a, D>(
30 devices: D,
31 src: &str,
32 cq_props: Option<CommandQueueProperties>,
33 ) -> Output<Vec<Session>>
34 where
35 D: Into<VecOrSlice<'a, Device>>,
36 {
37 let devices: Vec<Device> = devices.into().to_vec();
38 unsafe {
39 let context = ClContext::create(devices.as_slice())?;
40 let mut sessions: Vec<Session> = Vec::with_capacity(devices.len());
41 for device in devices.iter() {
42 let device = ClDeviceID::unchecked_new(device.device_ptr());
43 let mut program = ClProgram::create_with_source(&context, src)?;
44 program.build(devices.as_slice())?;
45
46 let queue = ClCommandQueue::create(&context, &device, cq_props)?;
47 let session = Session {
48 _device: ManuallyDrop::new(device),
49 _context: ManuallyDrop::new(context.clone()),
50 _program: ManuallyDrop::new(program.clone()),
51 _queue: ManuallyDrop::new(Arc::new(RwLock::new(queue))),
52 _unconstructable: (),
53 };
54 sessions.push(session);
55 }
56 Ok(sessions)
57 }
58 }
59
60 pub fn create(src: &str, cq_props: Option<CommandQueueProperties>) -> Output<Vec<Session>> {
61 let platforms = list_platforms()?;
62 let mut devices: Vec<Device> = Vec::new();
63 for platform in platforms.iter() {
64 let platform_devices: Vec<Device> = list_devices_by_type(platform, DeviceType::ALL)
65 .map(|ll_devices| ll_devices.into_iter().map(|d| Device::new(d)).collect())?;
66 devices.extend(platform_devices);
67 }
68 Session::create_with_devices(devices, src, cq_props)
69 }
70
71 pub fn context(&self) -> Context {
72 Context::from_low_level_context(self.low_level_context()).unwrap()
73 }
74
75 pub fn device(&self) -> Device {
76 Device::new(self.low_level_device().clone())
77 }
78
79 pub fn program(&self) -> Program {
80 unsafe { Program::from_low_level_program(self.low_level_program()).unwrap() }
81 }
82
83 pub fn read_queue(&self) -> RwLockReadGuard<ClCommandQueue> {
84 self._queue.read().unwrap()
85 }
86
87 pub fn write_queue(&self) -> RwLockWriteGuard<ClCommandQueue> {
88 self._queue.write().unwrap()
89 }
90
91 pub fn low_level_device(&self) -> &ClDeviceID {
92 &*self._device
93 }
94
95 pub fn low_level_context(&self) -> &ClContext {
96 &self._context
97 }
98
99 pub fn low_level_program(&self) -> &ClProgram {
100 &self._program
101 }
102
103 pub fn create_copy(&self) -> Output<Session> {
104 let cloned_device = self._device.clone();
105 let cloned_context = self._context.clone();
106 let cloned_program = self._program.clone();
107 let ll_queue = self._queue.read().unwrap();
108 let copied_queue = unsafe { ll_queue.create_copy()? };
109
110 Ok(Session {
111 _device: cloned_device,
112 _context: cloned_context,
113 _program: cloned_program,
114 _queue: ManuallyDrop::new(Arc::new(RwLock::new(copied_queue))),
115 _unconstructable: (),
116 })
117 }
118
119 pub fn create_kernel(&self, kernel_name: &str) -> Output<Kernel> {
121 unsafe {
122 let ll_kernel = ClKernel::create(self.low_level_program(), kernel_name)?;
123 Ok(Kernel::new(ll_kernel, self.program()))
124 }
125 }
126
127 pub fn create_buffer<T: ClNumber, B: BufferCreator<T>>(
131 &self,
132 buffer_creator: B,
133 ) -> Output<Buffer> {
134 let cfg = buffer_creator.mem_config();
135 Buffer::create_from_low_level_context(
136 self.low_level_context(),
137 buffer_creator,
138 cfg.host_access,
139 cfg.kernel_access,
140 cfg.mem_location,
141 )
142 }
143
144 pub fn create_buffer_with_config<T: ClNumber, B: BufferCreator<T>>(
147 &self,
148 buffer_creator: B,
149 mem_config: MemConfig,
150 ) -> Output<Buffer> {
151 Buffer::create_from_low_level_context(
152 self.low_level_context(),
153 buffer_creator,
154 mem_config.host_access,
155 mem_config.kernel_access,
156 mem_config.mem_location,
157 )
158 }
159
160 pub fn sync_write_buffer<'a, T: ClNumber, H: Into<VecOrSlice<'a, T>>>(
164 &self,
165 buffer: &Buffer,
166 host_buffer: H,
167 opts: Option<CommandQueueOptions>,
168 ) -> Output<()> {
169 buffer.number_type().type_check(T::number_type())?;
170 let mut queue = self.write_queue();
171 let mut buffer_lock = buffer.write_lock();
172 unsafe {
173 let event: ClEvent = queue.write_buffer(&mut (*buffer_lock), host_buffer, opts)?;
174 event.wait()
175 }
176 }
177
178 pub fn sync_read_buffer<'a, T: ClNumber, H: Into<MutVecOrSlice<'a, T>>>(
182 &self,
183 buffer: &Buffer,
184 host_buffer: H,
185 opts: Option<CommandQueueOptions>,
186 ) -> Output<Option<Vec<T>>> {
187 buffer.number_type().type_check(T::number_type())?;
188 let mut queue = self.write_queue();
189
190 let buffer_lock = buffer.read_lock();
191 unsafe {
192 let mut event: BufferReadEvent<T> =
193 queue.read_buffer(&(*buffer_lock), host_buffer, opts)?;
194 event.wait()
195 }
196 }
197
198 pub fn sync_enqueue_kernel(
205 &self,
206 kernel: &Kernel,
207 work: &Work,
208 opts: Option<CommandQueueOptions>,
209 ) -> Output<()> {
210 let mut queue = self.write_queue();
211 let mut kernel_lock = kernel.write_lock();
212 unsafe {
213 let event = queue.enqueue_kernel(&mut (*kernel_lock), work, opts)?;
214 event.wait()
215 }
216 }
217
218 pub fn execute_sync_kernel_operation<'a, T>(
219 &self,
220 mut kernel_op: KernelOperation<'a, T>,
221 ) -> Output<()>
222 where
223 T: ClNumber + KernelArg,
224 {
225 unsafe {
226 let kernel = self.create_kernel(kernel_op.name())?;
227 let work = kernel_op.work()?;
228 let command_queue_opts = kernel_op.command_queue_opts();
229 let mut mem_locks: Vec<RwLockWriteGuard<ClMem>> = Vec::new();
230 for (arg_index, arg) in kernel_op.mut_args().iter_mut().enumerate() {
231 match arg {
232 KernelOpArg::Num(ref mut num) => kernel.set_arg(arg_index, num)?,
233 KernelOpArg::Buffer(ref buffer) => {
234 let mut mem = buffer.write_lock();
235 kernel.set_arg(arg_index, &mut *mem)?;
236 mem_locks.push(mem);
237 }
238 }
239 }
240
241 let mut queue = self.write_queue();
242 let mut ll_kernel = kernel.write_lock();
243 let event = queue.enqueue_kernel(&mut ll_kernel, &work, command_queue_opts)?;
244 event.wait()?;
246 std::mem::drop(mem_locks);
248 Ok(())
249 }
250 }
251}
252
253impl Clone for Session {
254 fn clone(&self) -> Session {
255 Session {
256 _device: self._device.clone(),
257 _context: self._context.clone(),
258 _program: self._program.clone(),
259 _queue: self._queue.clone(),
260 _unconstructable: (),
261 }
262 }
263}
264
265impl PartialEq for Session {
266 fn eq(&self, other: &Self) -> bool {
267 unsafe {
268 let self_queue_ptr = self.read_queue().command_queue_ptr();
269 let other_queue_ptr = other.read_queue().command_queue_ptr();
270 std::ptr::eq(self_queue_ptr, other_queue_ptr)
271 }
272 }
273}
274
275impl Eq for Session {}
276
277impl Drop for Session {
278 fn drop(&mut self) {
279 unsafe {
280 ManuallyDrop::drop(&mut self._queue);
281 ManuallyDrop::drop(&mut self._program);
282 ManuallyDrop::drop(&mut self._context);
283 ManuallyDrop::drop(&mut self._device);
284 }
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 use crate::{testing, Buffer, Kernel, Session, Work};
291
292 const SRC: &'static str = "__kernel void test(__global int *data) {
293 data[get_global_id(0)] += 1;
294 }";
295
296 fn new_session() -> Session {
297 testing::get_session(SRC)
298 }
299
300 #[test]
301 fn session_can_be_created_with_src() {
302 let _session = Session::create(SRC, None).unwrap_or_else(|e| {
303 panic!("Failed to create session: {:?}", e);
304 });
305 }
306
307 #[test]
308 fn session_can_be_created_with_src_and_slice_of_devices() {
309 let devices = testing::get_all_devices();
310 assert_ne!(devices.len(), 0);
311 let _session = Session::create_with_devices(&devices[..], SRC, None).unwrap_or_else(|e| {
312 panic!("Failed to create session with slice of devices: {:?}", e);
313 });
314 }
315
316 #[test]
317 fn session_can_be_created_with_src_and_vec_of_devices() {
318 let devices = testing::get_all_devices();
319 assert_ne!(devices.len(), 0);
320 let _session = Session::create_with_devices(devices, SRC, None).unwrap_or_else(|e| {
321 panic!("Failed to create session with vec of devices: {:?}", e);
322 });
323 }
324
325 #[test]
326 fn session_implements_clone() {
327 let _other: Session = new_session().clone();
328 }
329
330 #[test]
331 fn session_implementation_of_fmt_debug_works() {
332 let session = new_session();
333 let formatted = format!("{:?}", session);
334 assert!(
335 formatted.starts_with("Session"),
336 "Formatted did not start with the correct value. Got: {:?}",
337 formatted
338 );
339 }
340
341 #[test]
342 fn session_create_copy_copies_command_queue_and_clones_the_rest() {
343 let session = new_session();
344
345 let session_copy = session.create_copy().unwrap_or_else(|e| {
346 panic!("Failed to create_copy of session: {:?}", e);
347 });
348 let s1_queue = session.read_queue();
349 let s2_queue = session_copy.read_queue();
350 assert_ne!(*s1_queue, *s2_queue);
351 assert_eq!(
352 session.low_level_context(),
353 session_copy.low_level_context()
354 );
355 assert_eq!(
356 session.low_level_program(),
357 session_copy.low_level_program()
358 );
359 assert_ne!(session, session_copy);
360 }
361
362 #[test]
363 fn session_can_create_kernel() {
364 let src = "__kernel void add_one_i32(__global int *i) { *i += 1; }";
365 let session = testing::get_session(src);
366 let _kernel: Kernel = session.create_kernel("add_one_i32").unwrap_or_else(|e| {
367 panic!("Failed to create kernel for session: {:?}", e);
368 });
369 }
370
371 #[test]
372 fn session_can_create_buffer_from_data() {
373 let data: Vec<i32> = vec![0, 1, 2, 3, 4, 5, 6, 7];
374 let session = new_session();
375 let _buffer: Buffer = session
376 .create_buffer(&data[..])
377 .unwrap_or_else(|e| panic!("Session failed to create buffer: {:?}", e));
378 }
379
380 #[test]
381 fn session_can_create_buffer_of_a_given_length() {
382 let session = new_session();
383 let buffer: Buffer = session
384 .create_buffer::<i32, usize>(100)
385 .unwrap_or_else(|e| panic!("Session failed to create buffer: {:?}", e));
386 assert_eq!(buffer.len(), 100);
387 }
388
389 #[test]
390 fn session_can_write_and_read_buffer() {
391 let data: Vec<i32> = vec![0, 1, 2, 3, 4, 5, 6, 7];
392 let session = new_session();
393 let buffer: Buffer = session
394 .create_buffer(&data[..])
395 .unwrap_or_else(|e| panic!("Session failed to create buffer: {:?}", e));
396 assert_eq!(buffer.len(), 8);
397 let () = session
398 .sync_write_buffer(&buffer, &data[..], None)
399 .unwrap_or_else(|e| {
400 panic!("Failed to write buffer: {:?}", e);
401 });
402 let data2 = vec![0i32; 8];
403 let data3 = session
404 .sync_read_buffer(&buffer, data2, None)
405 .unwrap_or_else(|e| {
406 panic!("Failed to write buffer: {:?}", e);
407 })
408 .unwrap();
409
410 assert_eq!(data3.len(), 8);
411 assert_eq!(data3, data);
412 }
413
414 #[test]
415 fn session_sync_enqueue_kernel_and_read_buffer() {
416 let data: Vec<i32> = vec![0, 1, 2, 3, 4, 5, 6, 7];
417 let session = new_session();
418 let buffer: Buffer = session
419 .create_buffer(&data[..])
420 .unwrap_or_else(|e| panic!("Session failed to create buffer: {:?}", e));
421 assert_eq!(buffer.len(), 8);
422 let () = session
423 .sync_write_buffer(&buffer, &data[..], None)
424 .unwrap_or_else(|e| {
425 panic!("Failed to write buffer: {:?}", e);
426 });
427 let kernel: Kernel = session.create_kernel("test").unwrap();
428 let mut buffer_lock = buffer.write_lock();
429 unsafe { kernel.set_arg(0, &mut (*buffer_lock)).unwrap() };
430 let work = Work::new(data.len());
431 session.sync_enqueue_kernel(&kernel, &work, None).unwrap();
432 std::mem::drop(buffer_lock);
433
434 let data2 = vec![0i32; 8];
435 let data3 = session
436 .sync_read_buffer(&buffer, data2, None)
437 .unwrap_or_else(|e| {
438 panic!("Failed to write buffer: {:?}", e);
439 })
440 .unwrap();
441
442 assert_eq!(data3.len(), 8);
443 let expected_data: Vec<i32> = vec![1, 2, 3, 4, 5, 6, 7, 8];
444 assert_eq!(data3, expected_data);
445 }
446}