open_cl_low_level/
session.rs

1use std::marker::PhantomData;
2use std::mem::ManuallyDrop;
3use std::convert::TryInto;
4
5use crate::vec_or_slice::VecOrSlice;
6use crate::*;
7
8/// An error related to Session Building.
9#[derive(Debug, Fail, PartialEq, Eq, Clone)]
10pub enum SessionError {
11    #[fail(display = "The given queue index {} was out of range", _0)]
12    QueueIndexOutOfRange(usize),
13}
14
15/// Session is the structure that is responsible for Dropping
16/// Low-Level OpenCL pointer wrappers in the correct order. Dropping OpenCL
17/// pointers in the wrong order can lead to undefined behavior.
18#[derive(Debug)]
19pub struct Session {
20    devices: ManuallyDrop<Vec<ClDeviceID>>,
21    context: ManuallyDrop<ClContext>,
22    program: ManuallyDrop<ClProgram>,
23    queues: ManuallyDrop<Vec<ClCommandQueue>>,
24}
25
26impl Session {
27    pub fn create_with_devices<'a, D>(devices: D, src: &str) -> Output<Session>
28    where
29        D: Into<VecOrSlice<'a, ClDeviceID>>,
30    {
31        unsafe {
32            let devices = devices.into();
33            let context = ClContext::create(devices.as_slice())?;
34            let mut program = ClProgram::create_with_source(&context, src)?;
35            program.build(devices.as_slice())?;
36            let props = CommandQueueProperties::default();
37            let maybe_queues: Result<Vec<ClCommandQueue>, Error> = devices
38                .iter()
39                .map(|dev| ClCommandQueue::create(&context, dev, Some(props)))
40                .collect();
41
42            let queues = maybe_queues?;
43
44            let sess = Session {
45                devices: ManuallyDrop::new(devices.to_vec()),
46                context: ManuallyDrop::new(context),
47                program: ManuallyDrop::new(program),
48                queues: ManuallyDrop::new(queues),
49            };
50            Ok(sess)
51        }
52    }
53
54    /// Given a string slice of OpenCL source code this function creates a session for
55    /// all available platforms and devices. A Session consists of:
56    ///
57    /// one or more devices
58    /// one context (for sharing mem objects between devices)
59    /// one program (build on each of the devices)
60    /// one or more queues (each queue belongs to exactly one of the devices)
61    pub fn create(src: &str) -> Output<Session> {
62        let platforms = list_platforms()?;
63        let mut devices = Vec::new();
64        for platform in platforms.iter() {
65            let platform_devices = list_devices_by_type(platform, DeviceType::ALL)?;
66            devices.extend(platform_devices);
67        }
68        Session::create_with_devices(devices, src)
69    }
70
71    /// Consumes the session returning the parts as individual parts.
72    ///
73    /// # Safety
74    /// Moving the components of a Session out of the Session can easily lead to
75    /// undefined behavior. The Session has a carefully implemented drop that ensures
76    /// the an Object is dropped before it's dependencies. If any of the dependencies of an object are ever dropped
77    /// in the incorrect order or any dependency of an object is dropped and the object is then used the result is undefined behavior.
78    pub unsafe fn decompose(
79        mut self,
80    ) -> (Vec<ClDeviceID>, ClContext, ClProgram, Vec<ClCommandQueue>) {
81        let devices: Vec<ClDeviceID> = utils::take_manually_drop(&mut self.devices);
82        let context: ClContext = utils::take_manually_drop(&mut self.context);
83        let program: ClProgram = utils::take_manually_drop(&mut self.program);
84        let queues: Vec<ClCommandQueue> = utils::take_manually_drop(&mut self.queues);
85        std::mem::forget(self);
86        (devices, context, program, queues)
87    }
88
89    /// A slice of the ClDeviceIDs of this Session.
90    pub fn devices(&self) -> &[ClDeviceID] {
91        &(*self.devices)[..]
92    }
93
94    /// A reference to the ClContext of this Session.
95    pub fn context(&self) -> &ClContext {
96        &(*self.context)
97    }
98
99    /// A reference to the ClProgram of this Session.
100    pub fn program(&self) -> &ClProgram {
101        &(*self.program)
102    }
103
104    /// A slice of the ClCommandQueues of this Session.
105    pub fn queues(&self) -> &[ClCommandQueue] {
106        &(*self.queues)[..]
107    }
108
109    /// Creates a ClKernel from the session's program.
110    ///
111    /// # Safety
112    /// Note: This function may, in fact, be safe. However, creating a kernel with a
113    /// program object that is in an invalid state can lead to undefined behavior.
114    /// Using the ClKernel after the session has been released can lead to undefined behavior.
115    /// Using the ClKernel outside it's own context/program can lead to undefined behavior.
116    pub unsafe fn create_kernel(&self, kernel_name: &str) -> Output<ClKernel> {
117        ClKernel::create(self.program(), kernel_name)
118    }
119
120    /// Creates a ClMem object in the given context, with the given buffer creator
121    /// (either a length or some data). This function uses the BufferCreator's implementation
122    /// to retrieve the appropriate MemConfig.
123    ///
124    /// # Safety
125    /// This function can cause undefined behavior if the OpenCL context object that
126    /// is passed is not in a valid state (null, released, etc.)
127    pub unsafe fn create_mem<T: ClNumber, B: BufferCreator<T>>(
128        &self,
129        buffer_creator: B,
130    ) -> Output<ClMem> {
131        let cfg = buffer_creator.mem_config();
132        ClMem::create_with_config(self.context(), buffer_creator, cfg)
133    }
134
135    /// Creates a ClMem object in the given context, with the given buffer creator
136    /// (either a length or some data) and a given MemConfig.
137    ///
138    /// # Safety
139    /// This function can cause undefined behavior if the OpenCL context object that
140    /// is passed is not in a valid state (null, released, etc.)
141    pub unsafe fn create_mem_with_config<T: ClNumber, B: BufferCreator<T>>(
142        &self,
143        buffer_creator: B,
144        mem_config: MemConfig,
145    ) -> Output<ClMem> {
146        ClMem::create_with_config(self.context(), buffer_creator, mem_config)
147    }
148
149    #[inline]
150    fn get_queue_by_index(&mut self, index: usize) -> Output<&mut ClCommandQueue> {
151        self.queues
152            .get_mut(index)
153            .ok_or_else(|| SessionError::QueueIndexOutOfRange(index).into())
154    }
155
156    /// This function copies data from the host buffer into the device mem buffer. The host
157    /// buffer must be a mutable slice or a vector to ensure the safety of the read_Buffer
158    /// operation.
159    ///
160    /// # Safety
161    /// This function call is safe only if the ClMem object's dependencies are still valid, if the
162    /// ClMem is valid, if the ClCommandQueue's dependencies are valid, if the ClCommandQueue's object
163    /// itself still valid, if the device's size and type exactly match the host buffer's size and type,
164    /// if the waitlist's events are in a valid state and the list goes on...
165    pub unsafe fn write_buffer<'a, T: ClNumber, H: Into<VecOrSlice<'a, T>>>(
166        &mut self,
167        queue_index: usize,
168        mem: &mut ClMem,
169        host_buffer: H,
170        opts: Option<CommandQueueOptions>,
171    ) -> Output<ClEvent> {
172        mem.number_type().match_or_panic(T::number_type());
173        let queue: &mut ClCommandQueue = self.get_queue_by_index(queue_index)?;
174        queue.write_buffer(mem, host_buffer, opts)
175    }
176
177    /// This function copies data from a device mem buffer into a host buffer. The host
178    /// buffer must be a mutable slice or a vector. For the moment the device mem must also
179    /// be passed as mutable; I don't trust OpenCL.
180    ///
181    /// # Safety
182    /// This function call is safe only if the ClMem object's dependencies are still valid, if the
183    /// ClMem is valid, if the ClCommandQueue's dependencies are valid, if the ClCommandQueue's object
184    /// itself still valid, if the device's size and type exactly match the host buffer's size and type,
185    /// if the waitlist's events are in a valid state and the list goes on...
186    pub unsafe fn read_buffer<'a, T: ClNumber, H: Into<MutVecOrSlice<'a, T>>>(
187        &mut self,
188        queue_index: usize,
189        mem: &mut ClMem,
190        host_buffer: H,
191        opts: Option<CommandQueueOptions>,
192    ) -> Output<BufferReadEvent<T>> {
193        let queue: &mut ClCommandQueue = self.get_queue_by_index(queue_index)?;
194        queue.read_buffer(mem, host_buffer, opts)
195    }
196
197    /// This function enqueues a CLKernel into a command queue
198    ///
199    /// # Safety
200    /// If the ClKernel is not in a usable state or any of the Kernel's dependent object
201    /// has been release, or the kernel belongs to a different session, or the ClKernel's
202    /// pointer is a null pointer, then calling this function will cause undefined behavior.
203    pub unsafe fn enqueue_kernel(
204        &mut self,
205        queue_index: usize,
206        kernel: &mut ClKernel,
207        work: &Work,
208        opts: Option<CommandQueueOptions>,
209    ) -> Output<ClEvent> {
210        let queue: &mut ClCommandQueue = self.get_queue_by_index(queue_index)?;
211        let cq_opts: CommandQueueOptions = opts.into();
212        let event = cl_enqueue_nd_range_kernel(
213            queue.command_queue_ptr(),
214            kernel.kernel_ptr(),
215            work,
216            &cq_opts.waitlist[..],
217        )?;
218        ClEvent::new(event)
219    }
220
221    pub fn execute_sync_kernel_operation(
222        &mut self,
223        queue_index: usize,
224        mut kernel_op: KernelOperation,
225    ) -> Output<()> {
226        unsafe {
227            let mut kernel = self.create_kernel(kernel_op.name())?;
228            let queue: &mut ClCommandQueue = self.get_queue_by_index(queue_index)?;
229            for (arg_index, (arg_size, arg_ptr)) in kernel_op.mut_args().iter_mut().enumerate() {
230                kernel.set_arg_raw(
231                    arg_index.try_into().unwrap(),
232                    *arg_size,
233                    *arg_ptr
234                )?;
235            }
236            let work = kernel_op.work()?;
237            let event = queue.enqueue_kernel(&mut kernel, &work, kernel_op.command_queue_opts())?;
238            event.wait()?;
239            Ok(())
240        }
241    }
242}
243
244/// Session can be safely sent between threads.
245///
246/// # Safety
247/// All the contained OpenCL objects Session are Send so Session is Send. However,
248/// The low level Session has ZERO Synchronization for mutable objects Program and
249/// CommandQueue. Therefore the low level Session is not Sync. If a Sync Session is
250/// required, the Session of opencl_core is Sync by synchronizing mutations of it's
251/// objects via RwLocks.
252unsafe impl Send for Session {}
253// unsafe impl Sync for Session {}
254
255// impl fmt::Debug for Session {
256//     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
257//         write!(f, "Session{{{:?}}}", self.address())
258//     }
259// }
260
261// preserve the ordering of these fields
262// The drop order must be:
263// 1) program
264// 2) command_queue
265// 3) context
266// 4) device
267// Else... SEGFAULT :(
268impl Drop for Session {
269    fn drop(&mut self) {
270        unsafe {
271            ManuallyDrop::drop(&mut self.queues);
272            ManuallyDrop::drop(&mut self.program);
273            ManuallyDrop::drop(&mut self.context);
274            ManuallyDrop::drop(&mut self.devices);
275        }
276    }
277}
278
279#[derive(Debug, Clone, Copy, PartialEq)]
280pub struct SessionQueue<'a> {
281    phantom: PhantomData<&'a ClCommandQueue>,
282    index: usize,
283}
284
285impl<'a> SessionQueue<'a> {
286    pub fn new(index: usize) -> SessionQueue<'a> {
287        SessionQueue {
288            index,
289            phantom: PhantomData,
290        }
291    }
292}
293
294/// An error related to Session Building.
295#[derive(Debug, Fail, PartialEq, Eq, Clone)]
296pub enum SessionBuilderError {
297    #[fail(display = "Given ClMem has no associated cl_mem object")]
298    NoAssociatedMemObject,
299
300    #[fail(
301        display = "For session building platforms AND devices cannot be specifed together; they are mutually exclusive."
302    )]
303    CannotSpecifyPlatformsAndDevices,
304
305    #[fail(
306        display = "For session building program src AND binaries cannot be specifed together; they are mutually exclusive."
307    )]
308    CannotSpecifyProgramSrcAndProgramBinaries,
309
310    #[fail(
311        display = "For session building either program src or program binaries must be specified."
312    )]
313    MustSpecifyProgramSrcOrProgramBinaries,
314
315    #[fail(
316        display = "Building a session with program binaries requires exactly 1 device: Got {:?} devices",
317        _0
318    )]
319    BinaryProgramRequiresExactlyOneDevice(usize),
320}
321
322const CANNOT_SPECIFY_SRC_AND_BINARIES: Error =
323    Error::SessionBuilderError(SessionBuilderError::CannotSpecifyProgramSrcAndProgramBinaries);
324const MUST_SPECIFY_SRC_OR_BINARIES: Error =
325    Error::SessionBuilderError(SessionBuilderError::MustSpecifyProgramSrcOrProgramBinaries);
326
327#[derive(Default)]
328pub struct SessionBuilder<'a> {
329    pub program_src: Option<&'a str>,
330    pub program_binaries: Option<&'a [u8]>,
331    pub device_type: Option<DeviceType>,
332    pub platforms: Option<&'a [ClPlatformID]>,
333    pub devices: Option<&'a [ClDeviceID]>,
334    pub command_queue_properties: Option<CommandQueueProperties>,
335}
336
337impl<'a> SessionBuilder<'a> {
338    pub fn new() -> SessionBuilder<'a> {
339        SessionBuilder {
340            program_src: None,
341            program_binaries: None,
342            device_type: None,
343            platforms: None,
344            devices: None,
345            command_queue_properties: None,
346        }
347    }
348
349    pub fn with_program_src(mut self, src: &'a str) -> SessionBuilder<'a> {
350        self.program_src = Some(src);
351        self
352    }
353
354    pub fn with_program_binaries(mut self, bins: &'a [u8]) -> SessionBuilder<'a> {
355        self.program_binaries = Some(bins);
356        self
357    }
358
359    pub fn with_platforms(mut self, platforms: &'a [ClPlatformID]) -> SessionBuilder<'a> {
360        self.platforms = Some(platforms);
361        self
362    }
363
364    pub fn with_devices(mut self, devices: &'a [ClDeviceID]) -> SessionBuilder<'a> {
365        self.devices = Some(devices);
366        self
367    }
368
369    pub fn with_device_type(mut self, device_type: DeviceType) -> SessionBuilder<'a> {
370        self.device_type = Some(device_type);
371        self
372    }
373
374    pub fn with_command_queue_properties(
375        mut self,
376        props: CommandQueueProperties,
377    ) -> SessionBuilder<'a> {
378        self.command_queue_properties = Some(props);
379        self
380    }
381    fn check_for_error_state(&self) -> Output<()> {
382        match self {
383            Self {
384                program_src: Some(_),
385                program_binaries: Some(_),
386                ..
387            } => return Err(CANNOT_SPECIFY_SRC_AND_BINARIES),
388            Self {
389                program_src: None,
390                program_binaries: None,
391                ..
392            } => return Err(MUST_SPECIFY_SRC_OR_BINARIES),
393            _ => Ok(()),
394        }
395    }
396
397    /// Builds a SessionBuilder into a Session
398    ///
399    /// # Safety
400    /// This function may, in fact, be safe, mismanagement of objects and lifetimes
401    /// are not possible as long as the underlying function calls are implemented
402    /// as intended. However, this claim needs to be reviewed. For now it remains
403    /// marked as unsafe.
404    pub unsafe fn build(self) -> Output<Session> {
405        self.check_for_error_state()?;
406        let context_builder = ClContextBuilder {
407            devices: self.devices,
408            device_type: self.device_type,
409            platforms: self.platforms,
410        };
411        let built_context = context_builder.build()?;
412        let (context, devices): (ClContext, Vec<ClDeviceID>) = match built_context {
413            BuiltClContext::Context(ctx) => (ctx, self.devices.unwrap().to_vec()),
414            BuiltClContext::ContextWithDevices(ctx, owned_devices) => (ctx, owned_devices),
415        };
416        let program: ClProgram = match (&self, devices.len()) {
417            (
418                Self {
419                    program_src: Some(src),
420                    ..
421                },
422                _,
423            ) => {
424                let mut prog: ClProgram = ClProgram::create_with_source(&context, src)?;
425                prog.build(&devices[..])?;
426                Ok(prog)
427            }
428            (
429                Self {
430                    program_binaries: Some(bins),
431                    ..
432                },
433                1,
434            ) => {
435                let mut prog: ClProgram =
436                    ClProgram::create_with_binary(&context, &devices[0], *bins)?;
437                prog.build(&devices[..])?;
438                Ok(prog)
439            }
440            (
441                Self {
442                    program_binaries: Some(_),
443                    ..
444                },
445                n_devices,
446            ) => {
447                let e = SessionBuilderError::BinaryProgramRequiresExactlyOneDevice(n_devices);
448                Err(Error::SessionBuilderError(e))
449            }
450            _ => unreachable!(),
451        }?;
452
453        let props = CommandQueueProperties::default();
454        let maybe_queues: Result<Vec<ClCommandQueue>, Error> = devices
455            .iter()
456            .map(|dev| ClCommandQueue::create(&context, dev, Some(props)))
457            .collect();
458        let queues = maybe_queues?;
459
460        let sess = Session {
461            devices: ManuallyDrop::new(devices),
462            context: ManuallyDrop::new(context),
463            program: ManuallyDrop::new(program),
464            queues: ManuallyDrop::new(queues),
465        };
466        Ok(sess)
467    }
468}
469
470#[cfg(test)]
471mod tests {
472    use crate::{BufferReadEvent, KernelOperation, Session};
473
474    const SRC: &'static str = "__kernel void test(__global int *data) {
475        data[get_global_id(0)] += 1;
476    }";
477    // use crate::ll_testing;
478    fn get_session(src: &str) -> Session {
479        Session::create(src).unwrap_or_else(|e| panic!("Failed to get_session {:?}", e))
480    }
481
482    #[test]
483    fn session_execute_sync_kernel_operation_works() {
484        let mut session = get_session(SRC);
485        let data: Vec<i32> = vec![1, 2, 3, 4, 5];
486        let dims = data.len();
487        let mut buff = unsafe { session.create_mem(&data[..]) }.unwrap();
488        let kernel_op = KernelOperation::new("test")
489            .with_dims(dims)
490            .add_arg(&mut buff);
491        session
492            .execute_sync_kernel_operation(0, kernel_op)
493            .unwrap_or_else(|e| {
494                panic!("Failed to execute sync kernel operation: {:?}", e);
495            });
496        let data3 = vec![0i32; 5];
497        unsafe {
498            let mut read_event: BufferReadEvent<i32> = session
499                .read_buffer(0, &mut buff, data3, None)
500                .unwrap_or_else(|e| {
501                    panic!("Failed to read buffer: {:?}", e);
502                });
503            let data4 = read_event
504                .wait()
505                .unwrap_or_else(|e| panic!("Failed to wait for read event: {:?}", e))
506                .unwrap();
507            assert_eq!(data4, vec![2, 3, 4, 5, 6]);
508        }
509    }
510}