open_cl_low_level/
mem.rs

1use std::fmt;
2use libc::c_void;
3
4use crate::ffi::{
5    clCreateBuffer, clGetMemObjectInfo, cl_context, cl_int, cl_mem, cl_mem_flags, cl_mem_info,
6};
7
8use crate::cl_helpers::cl_get_info5;
9use crate::{
10    build_output, ClContext, ClNumber, ClPointer, ContextPtr, HostAccessMemFlags,
11    KernelAccessMemFlags, MemFlags, MemInfo, MemLocationMemFlags, Output, NumberType,
12    NumberTyped, ObjectWrapper,
13};
14
15
16/// Low-level helper for creating a cl_mem buffer from a context, mem flags, and a buffer creator.
17///
18/// # Safety
19/// Use of a invalid cl_context in this function call is undefined behavior.
20pub unsafe fn cl_create_buffer_with_creator<T: ClNumber, B: BufferCreator<T>>(
21    context: cl_context,
22    mem_flags: cl_mem_flags,
23    buffer_creator: B,
24) -> Output<cl_mem> {
25    cl_create_buffer(
26        context,
27        mem_flags,
28        buffer_creator.buffer_byte_size(),
29        buffer_creator.buffer_ptr()
30    )
31}
32
33/// Low level helper functin for creating cl_mem buffer.
34///
35/// # Safety
36/// Calling this function with an invalid context, or an incorrect size in bytes,
37/// or an invalid host pointer is undefined behavior.
38pub unsafe fn cl_create_buffer(
39    context: cl_context,
40    mem_flags: cl_mem_flags,
41    size_in_bytes: usize,
42    ptr: *mut c_void,
43) -> Output<cl_mem> {
44    let mut err_code: cl_int = 0;
45    let cl_mem_object: cl_mem =
46        clCreateBuffer(context, mem_flags, size_in_bytes, ptr, &mut err_code);
47    build_output(cl_mem_object, err_code)
48}
49
50pub fn cl_get_mem_object_info<T>(device_mem: cl_mem, flag: cl_mem_info) -> Output<ClPointer<T>>
51where
52    T: Copy,
53{
54    unsafe { cl_get_info5(device_mem, flag, clGetMemObjectInfo) }
55}
56
57pub trait BufferCreator<T: ClNumber>: Sized {
58    /// The SizeAndPtr of a buffer creation arg.
59    ///
60    /// Currently the only 2 types that implement BufferCreator are
61    /// `usize` representiing length/size and &[T] for ClNumber T representing data.
62    fn buffer_byte_size(&self) -> usize;
63    fn buffer_ptr(&self) -> *mut c_void;
64    fn mem_config(&self) -> MemConfig;
65}
66
67impl<T: ClNumber> BufferCreator<T> for &[T] {
68    fn buffer_byte_size(&self) -> usize {
69        std::mem::size_of::<T>() * self.len()
70    }
71
72    fn buffer_ptr(&self) -> *mut c_void {
73        self.as_ptr() as *const _ as *mut c_void
74
75    }
76
77    fn mem_config(&self) -> MemConfig {
78        MemConfig::for_data()
79    }
80}
81
82impl<T: ClNumber> BufferCreator<T> for &mut [T] {
83    fn buffer_byte_size(&self) -> usize {
84        std::mem::size_of::<T>() * self.len()
85    }
86
87    fn buffer_ptr(&self) -> *mut c_void {
88        self.as_ptr() as *const _ as *mut c_void
89
90    }
91
92    fn mem_config(&self) -> MemConfig {
93        MemConfig::for_data()
94    }
95}
96
97
98impl<T: ClNumber> BufferCreator<T> for usize {
99    fn buffer_byte_size(&self) -> usize {
100        std::mem::size_of::<T>() * *self
101    }
102
103    fn buffer_ptr(&self) -> *mut c_void {
104        std::ptr::null_mut()
105    }
106
107    fn mem_config(&self) -> MemConfig {
108        MemConfig::for_size()
109    }
110}
111
112/// The MemPtr trait gives access to the cl_mem of a wrapping object and provides
113/// functions for cl_mem info.
114///
115/// # Safety
116/// This trait is unsafe because it allows access to an un-reference-counted raw pointer.
117pub unsafe trait MemPtr: NumberTyped {
118    /// Returns a copy to the cl_mem of the implementor.
119    ///
120    /// # Safety
121    /// This function is unsafe because it returns an uncounted cl_mem
122    /// object and gives access to a raw pointer.
123    unsafe fn mem_ptr(&self) -> cl_mem;
124
125    /// Returns a reference to the cl_mem of the implementor.
126    ///
127    /// # Safety
128    /// This function is unsafe because it results in an uncounted copy of
129    /// a cl_mem if the user dereferences the reference.
130    unsafe fn mem_ptr_ref(&self) -> &cl_mem;
131
132    /// Returns the ClPointer of the info type of a given MemInfo flag.
133    ///
134    /// # Safety
135    /// Calling this function a mismatch between the MemInfo's expected type and T is
136    /// undefined behavior.
137    unsafe fn get_info<I: Copy>(&self, flag: MemInfo) -> Output<ClPointer<I>> {
138        cl_get_mem_object_info::<I>(self.mem_ptr(), flag.into())
139    }
140
141    /// Returns the len of the ClMem.
142    ///
143    /// # Safety
144    /// Calling this function with an invalid ClMem is invalid behavior.
145    unsafe fn len(&self) -> Output<usize> {
146        let mem_size_in_bytes = self.size()?;
147        Ok(mem_size_in_bytes / self.number_type().size_of_t())
148    }
149
150    /// Determines if ClMem is empty or not.
151    ///
152    /// # Safety
153    /// Calling this function with an invalid ClMem is invalid behavior.
154    unsafe fn is_empty(&self) -> Output<bool> {
155        self.len().map(|l| l == 0)
156    }
157
158    // /// This is SUPER unsafe. Leave this out.
159    // /// Someone: "But elbow-jason you can use this to make a slice!"
160    // /// Me: "A slice with what lifetime? Is it safe to read?"
161    // /// Me: "If you want the underlying data read the buffer like a human being."
162    // fn host_ptr(&self) -> Output<Option<Vec<T>>>
163    // where
164    //     T: Copy,
165    // {
166    //     unsafe {
167    //         self.get_info::<T>(MemInfo::HostPtr).map(|ret| {
168    //             // let host_vec =
169    //             if ret.is_null() {
170    //                 return None;
171    //             }
172    //             // if host_vec.as_ptr() as usize == 1 {
173    //             //     return None;
174    //             // }
175    //             Some(ret.into_vec())
176    //         })
177    //     }
178    // }
179
180    // /// Returns the associated_memobject of the ClMem.
181    // ///
182    // /// # Safety
183    // /// associated_memobject is unsafe because this method grants access to a
184    // /// cl_mem object that already exists as an owned cl_mem object. Without
185    // /// synchronized access, the use of these objects can lead to undefined
186    // /// behavior.
187    // unsafe fn associated_memobject(&self) -> Output<ClMem<T>> {
188    //     self.get_info::<cl_mem>(MemInfo::AssociatedMemobject)
189    //         .map(|ret| {
190    //             let mem_obj: cl_mem = ret.into_one();
191    //             retain_mem(mem_obj);
192    //             ClMem::new(mem_obj)
193    //         })
194    //         .map_err(|e| match e {
195    //             Error::ClObjectCannotBeNull => NO_ASSOCIATED_MEM_OBJECT,
196    //             other => other,
197    //         })
198    // }
199
200    /// Returns the ClContext of the ClMem.
201    ///
202    /// # Safety
203    /// Calling this function with an invalid ClMem is invalid behavior.
204    unsafe fn context(&self) -> Output<ClContext> {
205        self.get_info::<cl_context>(MemInfo::Context)
206            .and_then(|cl_ptr| ClContext::retain_new(cl_ptr.into_one()))
207    }
208
209    /// Returns the reference count info for the ClMem.
210    ///
211    /// # Safety
212    /// Calling this function with an invalid ClMem is invalid behavior.
213    unsafe fn reference_count(&self) -> Output<u32> {
214        self.get_info(MemInfo::ReferenceCount)
215            .map(|ret| ret.into_one())
216    }
217
218    /// Returns the size info for the ClMem.
219    ///
220    /// # Safety
221    /// Calling this function with an invalid ClMem is invalid behavior.
222    unsafe fn size(&self) -> Output<usize> {
223        self.get_info(MemInfo::Size).map(|ret| ret.into_one())
224    }
225
226    /// Returns the offset info for the ClMem.
227    ///
228    /// # Safety
229    /// Calling this function with an invalid ClMem is invalid behavior.
230    unsafe fn offset(&self) -> Output<usize> {
231        self.get_info(MemInfo::Offset).map(|ret| ret.into_one())
232    }
233
234    /// Returns the MemFlag info for the ClMem.
235    ///
236    /// # Safety
237    /// Calling this function with an invalid ClMem is invalid behavior.
238    unsafe fn flags(&self) -> Output<MemFlags> {
239        self.get_info(MemInfo::Flags).map(|ret| ret.into_one())
240    }
241
242    // // TODO: figure out what this is...
243    // fn mem_type(&self) -> Output<MemType> {
244    //     unsafe { self.get_info(MemInfo::Type).map(|ret| ret.into_one()) }
245    // }
246}
247
248#[derive(Eq, PartialEq)]
249pub struct ClMem {
250    inner: ObjectWrapper<cl_mem>,
251    t: NumberType,
252}
253
254impl NumberTyped for ClMem {
255    fn number_type(&self) -> NumberType {
256        self.t
257    }
258}
259
260impl ClMem {
261    /// Instantiates a new ClMem of type T.
262    ///
263    /// # Safety
264    /// This function does not retain its cl_mem, but will release its cl_mem
265    /// when it is dropped. Mismanagement of a cl_mem's lifetime.  Therefore,
266    /// this function is unsafe.
267    pub unsafe fn new<T: ClNumber>(object: cl_mem) -> Output<ClMem> {
268        Ok(ClMem {
269            inner: ObjectWrapper::new(object)?,
270            t: T::number_type()
271        })
272    }
273
274    pub fn create<T: ClNumber, B: BufferCreator<T>>(
275        context: &ClContext,
276        buffer_creator: B,
277        host_access: HostAccess,
278        kernel_access: KernelAccess,
279        mem_location: MemLocation,
280    ) -> Output<ClMem> {
281        unsafe {
282            let mem_object = cl_create_buffer_with_creator(
283                context.context_ptr(),
284                cl_mem_flags::from(host_access)
285                    | cl_mem_flags::from(kernel_access)
286                    | cl_mem_flags::from(mem_location),
287                buffer_creator,
288            )?;
289            ClMem::new::<T>(mem_object)
290        }
291    }
292
293    /// Created a device memory buffer given the context, the buffer creator and some config.
294    /// There are some buffer creators that are not valid for some MemConfigs. However, a
295    /// mismatch of type and configuration between a buffer creator and the MemConfig will,
296    /// at worst, result in this function call returning an error.
297    ///
298    /// # Safety
299    /// Using an invalid context in this function call is undefined behavior.
300    pub unsafe fn create_with_config<T: ClNumber, B: BufferCreator<T>>(
301        context: &ClContext,
302        buffer_creator: B,
303        mem_config: MemConfig,
304    ) -> Output<ClMem> {
305        let mem_object = cl_create_buffer_with_creator(
306            context.context_ptr(),
307            mem_config.into(),
308            buffer_creator,
309        )?;
310        ClMem::new::<T>(mem_object)
311    }
312}
313
314unsafe impl MemPtr for ClMem {
315    unsafe fn mem_ptr(&self) -> cl_mem {
316        self.inner.cl_object()
317    }
318
319    unsafe fn mem_ptr_ref(&self) -> &cl_mem {
320        self.inner.cl_object_ref()
321    }
322}
323
324unsafe impl Send for ClMem {}
325
326impl fmt::Debug for ClMem {
327    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
328        write!(f, "{:?}", unsafe { self.mem_ptr() })
329    }
330}
331
332
333#[cfg(test)]
334mod tests {
335    use crate::*;
336
337    #[test]
338    fn mem_can_be_created_with_len() {
339        let (context, _devices) = ll_testing::get_context();
340        let mem_config = MemConfig::default();
341        let _mem: ClMem =
342            unsafe { ClMem::create_with_config::<u32, usize>(&context, 10, mem_config).unwrap() };
343    }
344
345    #[test]
346    fn mem_can_be_created_with_slice() {
347        let (context, _devices) = ll_testing::get_context();
348        let data: Vec<u32> = vec![0, 1, 2, 3, 4];
349        let mem_config = MemConfig::for_data();
350        let _mem: ClMem =
351            unsafe { ClMem::create_with_config(&context, &data[..], mem_config).unwrap() };
352    }
353
354    mod mem_ptr_trait {
355        use crate::*;
356
357        #[test]
358        fn len_method_works() {
359            let (_devices, _context, ll_mem) = ll_testing::get_mem::<u32>(10);
360            let len = unsafe { ll_mem.len().unwrap() };
361            assert_eq!(len, 10);
362        }
363
364        #[test]
365        fn reference_count_method_works() {
366            let (_devices, _context, ll_mem) = ll_testing::get_mem::<u32>(10);
367            let ref_count = unsafe { ll_mem.reference_count().unwrap() };
368            assert_eq!(ref_count, 1);
369        }
370
371        #[test]
372        fn size_method_returns_size_in_bytes() {
373            let (_devices, _context, ll_mem) = ll_testing::get_mem::<u32>(10);
374            let bytes_size = unsafe { ll_mem.size().unwrap() };
375            assert_eq!(bytes_size, 10 * std::mem::size_of::<u32>());
376        }
377
378        #[test]
379        fn offset_method_works() {
380            let (_devices, _context, ll_mem) = ll_testing::get_mem::<u32>(10);
381            let offset = unsafe { ll_mem.offset().unwrap() };
382            assert_eq!(offset, 0);
383        }
384
385        #[test]
386        fn flags_method_works() {
387            let (_devices, _context, ll_mem) = ll_testing::get_mem::<u32>(10);
388            let flags = unsafe { ll_mem.flags().unwrap() };
389            assert_eq!(flags, MemFlags::READ_WRITE_ALLOC_HOST_PTR);
390        }
391    }
392}
393
394pub enum KernelAccess {
395    ReadOnly,
396    WriteOnly,
397    ReadWrite,
398}
399
400impl From<KernelAccess> for KernelAccessMemFlags {
401    fn from(kernel_access: KernelAccess) -> KernelAccessMemFlags {
402        match kernel_access {
403            KernelAccess::ReadOnly => KernelAccessMemFlags::READ_ONLY,
404            KernelAccess::WriteOnly => KernelAccessMemFlags::WRITE_ONLY,
405            KernelAccess::ReadWrite => KernelAccessMemFlags::READ_WRITE,
406        }
407    }
408}
409
410impl From<KernelAccess> for MemFlags {
411    fn from(kernel_access: KernelAccess) -> MemFlags {
412        MemFlags::from(KernelAccessMemFlags::from(kernel_access))
413    }
414}
415
416impl From<KernelAccess> for cl_mem_flags {
417    fn from(kernel_access: KernelAccess) -> cl_mem_flags {
418        cl_mem_flags::from(MemFlags::from(kernel_access))
419    }
420}
421
422pub enum HostAccess {
423    ReadOnly,
424    WriteOnly,
425    NoAccess,
426    ReadWrite,
427}
428
429impl From<HostAccess> for HostAccessMemFlags {
430    fn from(host_access: HostAccess) -> HostAccessMemFlags {
431        match host_access {
432            HostAccess::ReadOnly => HostAccessMemFlags::READ_ONLY,
433            HostAccess::WriteOnly => HostAccessMemFlags::WRITE_ONLY,
434            HostAccess::NoAccess => HostAccessMemFlags::NO_ACCESS,
435            HostAccess::ReadWrite => HostAccessMemFlags::READ_WRITE,
436        }
437    }
438}
439
440impl From<HostAccess> for MemFlags {
441    fn from(host_access: HostAccess) -> MemFlags {
442        MemFlags::from(HostAccessMemFlags::from(host_access))
443    }
444}
445
446impl From<HostAccess> for cl_mem_flags {
447    fn from(host_access: HostAccess) -> cl_mem_flags {
448        cl_mem_flags::from(MemFlags::from(host_access))
449    }
450}
451
452/// The enumeration of how memory allocation (or not) can be directed.
453///
454/// This forum post has some good explanations:
455///   https://software.intel.com/en-us/forums/opencl/topic/708049
456pub enum MemLocation {
457    KeepInPlace,
458    AllocOnDevice,
459    CopyToDevice,
460    ForceCopyToDevice,
461}
462
463impl From<MemLocation> for MemLocationMemFlags {
464    fn from(mem_location: MemLocation) -> MemLocationMemFlags {
465        match mem_location {
466            MemLocation::KeepInPlace => MemLocationMemFlags::KEEP_IN_PLACE,
467            MemLocation::AllocOnDevice => MemLocationMemFlags::ALLOC_ON_DEVICE,
468            MemLocation::CopyToDevice => MemLocationMemFlags::COPY_TO_DEVICE,
469            MemLocation::ForceCopyToDevice => MemLocationMemFlags::FORCE_COPY_TO_DEVICE,
470        }
471    }
472}
473
474impl From<MemLocation> for MemFlags {
475    fn from(mem_location: MemLocation) -> MemFlags {
476        MemFlags::from(MemLocationMemFlags::from(mem_location))
477    }
478}
479
480impl From<MemLocation> for cl_mem_flags {
481    fn from(mem_location: MemLocation) -> cl_mem_flags {
482        cl_mem_flags::from(MemFlags::from(mem_location))
483    }
484}
485
486pub struct MemConfig {
487    pub host_access: HostAccess,
488    pub kernel_access: KernelAccess,
489    pub mem_location: MemLocation,
490}
491
492impl MemConfig {
493    pub fn build(
494        host_access: HostAccess,
495        kernel_access: KernelAccess,
496        mem_location: MemLocation,
497    ) -> MemConfig {
498        MemConfig {
499            host_access,
500            kernel_access,
501            mem_location,
502        }
503    }
504}
505
506impl From<MemConfig> for MemFlags {
507    fn from(mem_config: MemConfig) -> MemFlags {
508        unsafe { MemFlags::from_bits_unchecked(cl_mem_flags::from(mem_config)) }
509    }
510}
511
512impl From<MemConfig> for cl_mem_flags {
513    fn from(mem_config: MemConfig) -> cl_mem_flags {
514        cl_mem_flags::from(mem_config.host_access)
515            | cl_mem_flags::from(mem_config.kernel_access)
516            | cl_mem_flags::from(mem_config.mem_location)
517    }
518}
519
520impl Default for MemConfig {
521    fn default() -> MemConfig {
522        MemConfig {
523            host_access: HostAccess::ReadWrite,
524            kernel_access: KernelAccess::ReadWrite,
525            mem_location: MemLocation::AllocOnDevice,
526        }
527    }
528}
529
530impl MemConfig {
531    pub fn for_data() -> MemConfig {
532        MemConfig {
533            mem_location: MemLocation::CopyToDevice,
534            ..MemConfig::default()
535        }
536    }
537
538    pub fn for_size() -> MemConfig {
539        MemConfig {
540            mem_location: MemLocation::AllocOnDevice,
541            ..MemConfig::default()
542        }
543    }
544}
545
546#[cfg(test)]
547mod mem_flags_tests {
548    use super::*;
549    use crate::KernelAccessMemFlags;
550
551    #[test]
552    fn kernel_access_read_only_conversion_into_kernel_access_mem_flag() {
553        let kernel_access = KernelAccess::ReadOnly;
554        assert_eq!(
555            KernelAccessMemFlags::from(kernel_access),
556            KernelAccessMemFlags::READ_ONLY
557        );
558    }
559
560    #[test]
561    fn kernel_access_write_only_conversion_into_kernel_access_mem_flag() {
562        let kernel_access = KernelAccess::WriteOnly;
563        assert_eq!(
564            KernelAccessMemFlags::from(kernel_access),
565            KernelAccessMemFlags::WRITE_ONLY
566        );
567    }
568
569    #[test]
570    fn kernel_access_convert_read_write_into_kernel_access_mem_flag() {
571        let kernel_access = KernelAccess::ReadWrite;
572        assert_eq!(
573            KernelAccessMemFlags::from(kernel_access),
574            KernelAccessMemFlags::READ_WRITE
575        );
576    }
577
578    #[test]
579    fn host_access_read_only_conversion_into_host_access_mem_flag() {
580        let host_access = HostAccess::ReadOnly;
581        assert_eq!(
582            HostAccessMemFlags::from(host_access),
583            HostAccessMemFlags::READ_ONLY
584        );
585    }
586
587    #[test]
588    fn host_access_write_only_conversion_into_host_access_mem_flag() {
589        let host_access = HostAccess::WriteOnly;
590        assert_eq!(
591            HostAccessMemFlags::from(host_access),
592            HostAccessMemFlags::WRITE_ONLY
593        );
594    }
595
596    #[test]
597    fn host_access_read_write_conversion_into_host_access_mem_flag() {
598        let host_access = HostAccess::ReadWrite;
599        assert_eq!(
600            HostAccessMemFlags::from(host_access),
601            HostAccessMemFlags::READ_WRITE
602        );
603    }
604
605    #[test]
606    fn host_access_no_access_conversion_into_host_access_mem_flag() {
607        let host_access = HostAccess::NoAccess;
608        assert_eq!(
609            HostAccessMemFlags::from(host_access),
610            HostAccessMemFlags::NO_ACCESS
611        );
612    }
613
614    #[test]
615    fn mem_location_keep_in_place_conversion_into_mem_location_mem_flag() {
616        let mem_location = MemLocation::KeepInPlace;
617        assert_eq!(
618            MemLocationMemFlags::from(mem_location),
619            MemLocationMemFlags::KEEP_IN_PLACE
620        );
621    }
622
623    #[test]
624    fn mem_location_alloc_on_device_conversion_into_mem_location_mem_flag() {
625        let mem_location = MemLocation::AllocOnDevice;
626        assert_eq!(
627            MemLocationMemFlags::from(mem_location),
628            MemLocationMemFlags::ALLOC_ON_DEVICE
629        );
630    }
631
632    #[test]
633    fn mem_location_copy_to_device_conversion_into_mem_location_mem_flag() {
634        let mem_location = MemLocation::CopyToDevice;
635        assert_eq!(
636            MemLocationMemFlags::from(mem_location),
637            MemLocationMemFlags::COPY_TO_DEVICE
638        );
639    }
640
641    #[test]
642    fn mem_location_force_copy_to_device_conversion_into_mem_location_mem_flag() {
643        let mem_location = MemLocation::ForceCopyToDevice;
644        assert_eq!(
645            MemLocationMemFlags::from(mem_location),
646            MemLocationMemFlags::FORCE_COPY_TO_DEVICE
647        );
648    }
649
650    #[test]
651    fn mem_config_conversion_into_cl_mem_flags() {
652        let mem_location = MemLocation::AllocOnDevice;
653        let host_access = HostAccess::ReadWrite;
654        let kernel_access = KernelAccess::ReadWrite;
655        let mem_config = MemConfig {
656            mem_location,
657            host_access,
658            kernel_access,
659        };
660        let expected = MemFlags::ALLOC_HOST_PTR.bits()
661            | MemFlags::HOST_READ_WRITE.bits()
662            | MemFlags::KERNEL_READ_WRITE.bits();
663
664        assert_eq!(cl_mem_flags::from(mem_config), expected);
665    }
666}