1use std::marker::PhantomData;
2use std::mem::ManuallyDrop;
3use std::convert::TryInto;
4
5use crate::vec_or_slice::VecOrSlice;
6use crate::*;
7
8#[derive(Debug, Fail, PartialEq, Eq, Clone)]
10pub enum SessionError {
11 #[fail(display = "The given queue index {} was out of range", _0)]
12 QueueIndexOutOfRange(usize),
13}
14
15#[derive(Debug)]
19pub struct Session {
20 devices: ManuallyDrop<Vec<ClDeviceID>>,
21 context: ManuallyDrop<ClContext>,
22 program: ManuallyDrop<ClProgram>,
23 queues: ManuallyDrop<Vec<ClCommandQueue>>,
24}
25
26impl Session {
27 pub fn create_with_devices<'a, D>(devices: D, src: &str) -> Output<Session>
28 where
29 D: Into<VecOrSlice<'a, ClDeviceID>>,
30 {
31 unsafe {
32 let devices = devices.into();
33 let context = ClContext::create(devices.as_slice())?;
34 let mut program = ClProgram::create_with_source(&context, src)?;
35 program.build(devices.as_slice())?;
36 let props = CommandQueueProperties::default();
37 let maybe_queues: Result<Vec<ClCommandQueue>, Error> = devices
38 .iter()
39 .map(|dev| ClCommandQueue::create(&context, dev, Some(props)))
40 .collect();
41
42 let queues = maybe_queues?;
43
44 let sess = Session {
45 devices: ManuallyDrop::new(devices.to_vec()),
46 context: ManuallyDrop::new(context),
47 program: ManuallyDrop::new(program),
48 queues: ManuallyDrop::new(queues),
49 };
50 Ok(sess)
51 }
52 }
53
54 pub fn create(src: &str) -> Output<Session> {
62 let platforms = list_platforms()?;
63 let mut devices = Vec::new();
64 for platform in platforms.iter() {
65 let platform_devices = list_devices_by_type(platform, DeviceType::ALL)?;
66 devices.extend(platform_devices);
67 }
68 Session::create_with_devices(devices, src)
69 }
70
71 pub unsafe fn decompose(
79 mut self,
80 ) -> (Vec<ClDeviceID>, ClContext, ClProgram, Vec<ClCommandQueue>) {
81 let devices: Vec<ClDeviceID> = utils::take_manually_drop(&mut self.devices);
82 let context: ClContext = utils::take_manually_drop(&mut self.context);
83 let program: ClProgram = utils::take_manually_drop(&mut self.program);
84 let queues: Vec<ClCommandQueue> = utils::take_manually_drop(&mut self.queues);
85 std::mem::forget(self);
86 (devices, context, program, queues)
87 }
88
89 pub fn devices(&self) -> &[ClDeviceID] {
91 &(*self.devices)[..]
92 }
93
94 pub fn context(&self) -> &ClContext {
96 &(*self.context)
97 }
98
99 pub fn program(&self) -> &ClProgram {
101 &(*self.program)
102 }
103
104 pub fn queues(&self) -> &[ClCommandQueue] {
106 &(*self.queues)[..]
107 }
108
109 pub unsafe fn create_kernel(&self, kernel_name: &str) -> Output<ClKernel> {
117 ClKernel::create(self.program(), kernel_name)
118 }
119
120 pub unsafe fn create_mem<T: ClNumber, B: BufferCreator<T>>(
128 &self,
129 buffer_creator: B,
130 ) -> Output<ClMem> {
131 let cfg = buffer_creator.mem_config();
132 ClMem::create_with_config(self.context(), buffer_creator, cfg)
133 }
134
135 pub unsafe fn create_mem_with_config<T: ClNumber, B: BufferCreator<T>>(
142 &self,
143 buffer_creator: B,
144 mem_config: MemConfig,
145 ) -> Output<ClMem> {
146 ClMem::create_with_config(self.context(), buffer_creator, mem_config)
147 }
148
149 #[inline]
150 fn get_queue_by_index(&mut self, index: usize) -> Output<&mut ClCommandQueue> {
151 self.queues
152 .get_mut(index)
153 .ok_or_else(|| SessionError::QueueIndexOutOfRange(index).into())
154 }
155
156 pub unsafe fn write_buffer<'a, T: ClNumber, H: Into<VecOrSlice<'a, T>>>(
166 &mut self,
167 queue_index: usize,
168 mem: &mut ClMem,
169 host_buffer: H,
170 opts: Option<CommandQueueOptions>,
171 ) -> Output<ClEvent> {
172 mem.number_type().match_or_panic(T::number_type());
173 let queue: &mut ClCommandQueue = self.get_queue_by_index(queue_index)?;
174 queue.write_buffer(mem, host_buffer, opts)
175 }
176
177 pub unsafe fn read_buffer<'a, T: ClNumber, H: Into<MutVecOrSlice<'a, T>>>(
187 &mut self,
188 queue_index: usize,
189 mem: &mut ClMem,
190 host_buffer: H,
191 opts: Option<CommandQueueOptions>,
192 ) -> Output<BufferReadEvent<T>> {
193 let queue: &mut ClCommandQueue = self.get_queue_by_index(queue_index)?;
194 queue.read_buffer(mem, host_buffer, opts)
195 }
196
197 pub unsafe fn enqueue_kernel(
204 &mut self,
205 queue_index: usize,
206 kernel: &mut ClKernel,
207 work: &Work,
208 opts: Option<CommandQueueOptions>,
209 ) -> Output<ClEvent> {
210 let queue: &mut ClCommandQueue = self.get_queue_by_index(queue_index)?;
211 let cq_opts: CommandQueueOptions = opts.into();
212 let event = cl_enqueue_nd_range_kernel(
213 queue.command_queue_ptr(),
214 kernel.kernel_ptr(),
215 work,
216 &cq_opts.waitlist[..],
217 )?;
218 ClEvent::new(event)
219 }
220
221 pub fn execute_sync_kernel_operation(
222 &mut self,
223 queue_index: usize,
224 mut kernel_op: KernelOperation,
225 ) -> Output<()> {
226 unsafe {
227 let mut kernel = self.create_kernel(kernel_op.name())?;
228 let queue: &mut ClCommandQueue = self.get_queue_by_index(queue_index)?;
229 for (arg_index, (arg_size, arg_ptr)) in kernel_op.mut_args().iter_mut().enumerate() {
230 kernel.set_arg_raw(
231 arg_index.try_into().unwrap(),
232 *arg_size,
233 *arg_ptr
234 )?;
235 }
236 let work = kernel_op.work()?;
237 let event = queue.enqueue_kernel(&mut kernel, &work, kernel_op.command_queue_opts())?;
238 event.wait()?;
239 Ok(())
240 }
241 }
242}
243
244unsafe impl Send for Session {}
253impl Drop for Session {
269 fn drop(&mut self) {
270 unsafe {
271 ManuallyDrop::drop(&mut self.queues);
272 ManuallyDrop::drop(&mut self.program);
273 ManuallyDrop::drop(&mut self.context);
274 ManuallyDrop::drop(&mut self.devices);
275 }
276 }
277}
278
279#[derive(Debug, Clone, Copy, PartialEq)]
280pub struct SessionQueue<'a> {
281 phantom: PhantomData<&'a ClCommandQueue>,
282 index: usize,
283}
284
285impl<'a> SessionQueue<'a> {
286 pub fn new(index: usize) -> SessionQueue<'a> {
287 SessionQueue {
288 index,
289 phantom: PhantomData,
290 }
291 }
292}
293
294#[derive(Debug, Fail, PartialEq, Eq, Clone)]
296pub enum SessionBuilderError {
297 #[fail(display = "Given ClMem has no associated cl_mem object")]
298 NoAssociatedMemObject,
299
300 #[fail(
301 display = "For session building platforms AND devices cannot be specifed together; they are mutually exclusive."
302 )]
303 CannotSpecifyPlatformsAndDevices,
304
305 #[fail(
306 display = "For session building program src AND binaries cannot be specifed together; they are mutually exclusive."
307 )]
308 CannotSpecifyProgramSrcAndProgramBinaries,
309
310 #[fail(
311 display = "For session building either program src or program binaries must be specified."
312 )]
313 MustSpecifyProgramSrcOrProgramBinaries,
314
315 #[fail(
316 display = "Building a session with program binaries requires exactly 1 device: Got {:?} devices",
317 _0
318 )]
319 BinaryProgramRequiresExactlyOneDevice(usize),
320}
321
322const CANNOT_SPECIFY_SRC_AND_BINARIES: Error =
323 Error::SessionBuilderError(SessionBuilderError::CannotSpecifyProgramSrcAndProgramBinaries);
324const MUST_SPECIFY_SRC_OR_BINARIES: Error =
325 Error::SessionBuilderError(SessionBuilderError::MustSpecifyProgramSrcOrProgramBinaries);
326
327#[derive(Default)]
328pub struct SessionBuilder<'a> {
329 pub program_src: Option<&'a str>,
330 pub program_binaries: Option<&'a [u8]>,
331 pub device_type: Option<DeviceType>,
332 pub platforms: Option<&'a [ClPlatformID]>,
333 pub devices: Option<&'a [ClDeviceID]>,
334 pub command_queue_properties: Option<CommandQueueProperties>,
335}
336
337impl<'a> SessionBuilder<'a> {
338 pub fn new() -> SessionBuilder<'a> {
339 SessionBuilder {
340 program_src: None,
341 program_binaries: None,
342 device_type: None,
343 platforms: None,
344 devices: None,
345 command_queue_properties: None,
346 }
347 }
348
349 pub fn with_program_src(mut self, src: &'a str) -> SessionBuilder<'a> {
350 self.program_src = Some(src);
351 self
352 }
353
354 pub fn with_program_binaries(mut self, bins: &'a [u8]) -> SessionBuilder<'a> {
355 self.program_binaries = Some(bins);
356 self
357 }
358
359 pub fn with_platforms(mut self, platforms: &'a [ClPlatformID]) -> SessionBuilder<'a> {
360 self.platforms = Some(platforms);
361 self
362 }
363
364 pub fn with_devices(mut self, devices: &'a [ClDeviceID]) -> SessionBuilder<'a> {
365 self.devices = Some(devices);
366 self
367 }
368
369 pub fn with_device_type(mut self, device_type: DeviceType) -> SessionBuilder<'a> {
370 self.device_type = Some(device_type);
371 self
372 }
373
374 pub fn with_command_queue_properties(
375 mut self,
376 props: CommandQueueProperties,
377 ) -> SessionBuilder<'a> {
378 self.command_queue_properties = Some(props);
379 self
380 }
381 fn check_for_error_state(&self) -> Output<()> {
382 match self {
383 Self {
384 program_src: Some(_),
385 program_binaries: Some(_),
386 ..
387 } => return Err(CANNOT_SPECIFY_SRC_AND_BINARIES),
388 Self {
389 program_src: None,
390 program_binaries: None,
391 ..
392 } => return Err(MUST_SPECIFY_SRC_OR_BINARIES),
393 _ => Ok(()),
394 }
395 }
396
397 pub unsafe fn build(self) -> Output<Session> {
405 self.check_for_error_state()?;
406 let context_builder = ClContextBuilder {
407 devices: self.devices,
408 device_type: self.device_type,
409 platforms: self.platforms,
410 };
411 let built_context = context_builder.build()?;
412 let (context, devices): (ClContext, Vec<ClDeviceID>) = match built_context {
413 BuiltClContext::Context(ctx) => (ctx, self.devices.unwrap().to_vec()),
414 BuiltClContext::ContextWithDevices(ctx, owned_devices) => (ctx, owned_devices),
415 };
416 let program: ClProgram = match (&self, devices.len()) {
417 (
418 Self {
419 program_src: Some(src),
420 ..
421 },
422 _,
423 ) => {
424 let mut prog: ClProgram = ClProgram::create_with_source(&context, src)?;
425 prog.build(&devices[..])?;
426 Ok(prog)
427 }
428 (
429 Self {
430 program_binaries: Some(bins),
431 ..
432 },
433 1,
434 ) => {
435 let mut prog: ClProgram =
436 ClProgram::create_with_binary(&context, &devices[0], *bins)?;
437 prog.build(&devices[..])?;
438 Ok(prog)
439 }
440 (
441 Self {
442 program_binaries: Some(_),
443 ..
444 },
445 n_devices,
446 ) => {
447 let e = SessionBuilderError::BinaryProgramRequiresExactlyOneDevice(n_devices);
448 Err(Error::SessionBuilderError(e))
449 }
450 _ => unreachable!(),
451 }?;
452
453 let props = CommandQueueProperties::default();
454 let maybe_queues: Result<Vec<ClCommandQueue>, Error> = devices
455 .iter()
456 .map(|dev| ClCommandQueue::create(&context, dev, Some(props)))
457 .collect();
458 let queues = maybe_queues?;
459
460 let sess = Session {
461 devices: ManuallyDrop::new(devices),
462 context: ManuallyDrop::new(context),
463 program: ManuallyDrop::new(program),
464 queues: ManuallyDrop::new(queues),
465 };
466 Ok(sess)
467 }
468}
469
470#[cfg(test)]
471mod tests {
472 use crate::{BufferReadEvent, KernelOperation, Session};
473
474 const SRC: &'static str = "__kernel void test(__global int *data) {
475 data[get_global_id(0)] += 1;
476 }";
477 fn get_session(src: &str) -> Session {
479 Session::create(src).unwrap_or_else(|e| panic!("Failed to get_session {:?}", e))
480 }
481
482 #[test]
483 fn session_execute_sync_kernel_operation_works() {
484 let mut session = get_session(SRC);
485 let data: Vec<i32> = vec![1, 2, 3, 4, 5];
486 let dims = data.len();
487 let mut buff = unsafe { session.create_mem(&data[..]) }.unwrap();
488 let kernel_op = KernelOperation::new("test")
489 .with_dims(dims)
490 .add_arg(&mut buff);
491 session
492 .execute_sync_kernel_operation(0, kernel_op)
493 .unwrap_or_else(|e| {
494 panic!("Failed to execute sync kernel operation: {:?}", e);
495 });
496 let data3 = vec![0i32; 5];
497 unsafe {
498 let mut read_event: BufferReadEvent<i32> = session
499 .read_buffer(0, &mut buff, data3, None)
500 .unwrap_or_else(|e| {
501 panic!("Failed to read buffer: {:?}", e);
502 });
503 let data4 = read_event
504 .wait()
505 .unwrap_or_else(|e| panic!("Failed to wait for read event: {:?}", e))
506 .unwrap();
507 assert_eq!(data4, vec![2, 3, 4, 5, 6]);
508 }
509 }
510}