1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
use super::MemObjectType;
use crate::{buffer::flags::MemFlags, context::RawContext, core::*, non_null_const};
use blaze_proc::docfg;
use opencl_sys::*;
use std::{ffi::c_void, mem::MaybeUninit, ptr::NonNull};

/// A raw OpenCL memory object
#[derive(Debug, PartialEq, Eq, Hash)]
#[repr(transparent)]
pub struct RawMemObject(NonNull<c_void>);

impl RawMemObject {
    #[inline(always)]
    pub const unsafe fn from_id_unchecked(id: cl_mem) -> Self {
        Self(NonNull::new_unchecked(id))
    }

    #[inline(always)]
    pub const unsafe fn from_id(id: cl_mem) -> Option<Self> {
        match non_null_const(id) {
            Some(ptr) => Some(Self(ptr)),
            None => None,
        }
    }

    #[inline(always)]
    pub unsafe fn retain(&self) -> Result<()> {
        tri!(clRetainMemObject(self.id()));
        Ok(())
    }

    #[inline(always)]
    pub const fn id(&self) -> cl_mem {
        self.0.as_ptr()
    }

    #[inline(always)]
    pub const fn id_ref(&self) -> &cl_mem {
        unsafe { core::mem::transmute(&self.0) }
    }

    #[inline(always)]
    pub fn id_ref_mut(&mut self) -> &mut cl_mem {
        unsafe { core::mem::transmute(&mut self.0) }
    }

    /// Returns the memory obejct's type
    #[inline(always)]
    pub fn ty(&self) -> Result<MemObjectType> {
        self.get_info(CL_MEM_TYPE)
    }

    /// Return memory object from which memobj is created.
    #[docfg(feature = "cl1_1")]
    #[inline(always)]
    pub fn associated_memobject(&self) -> Result<Option<RawMemObject>> {
        let v = self.get_info::<cl_mem>(opencl_sys::CL_MEM_ASSOCIATED_MEMOBJECT)?;
        unsafe {
            if let Some(id) = Self::from_id(v) {
                id.retain()?;
                return Ok(Some(id));
            }

            return Ok(None);
        }
    }

    /// Return the flags argument value specified when memobj is created.
    #[inline(always)]
    pub fn flags(&self) -> Result<MemFlags> {
        let flags = self.get_info(CL_MEM_FLAGS)?;
        Ok(MemFlags::from_bits(flags))
    }

    /// Return actual size of the data store associated with memobj in bytes.
    #[inline(always)]
    pub fn size(&self) -> Result<usize> {
        self.get_info(CL_MEM_SIZE)
    }

    /// If memobj is created with a host_ptr specified, return the host_ptr argument value specified when memobj is created.
    #[inline(always)]
    pub fn host_ptr(&self) -> Result<Option<NonNull<c_void>>> {
        self.get_info(CL_MEM_HOST_PTR).map(NonNull::new)
    }

    /// Map count. The map count returned should be considered immediately stale. It is unsuitable for general use in applications. This feature is provided for debugging.
    #[inline(always)]
    pub fn map_count(&self) -> Result<u32> {
        self.get_info(CL_MEM_MAP_COUNT)
    }

    /// Return memobj reference count. The reference count returned should be considered immediately stale. It is unsuitable for general use in applications. This feature is provided for identifying memory leaks.
    #[inline(always)]
    pub fn reference_count(&self) -> Result<u32> {
        self.get_info(CL_MEM_REFERENCE_COUNT)
    }

    /// Return context specified when memory object is created.
    #[inline(always)]
    pub fn context(&self) -> Result<RawContext> {
        let ctx = self.get_info::<cl_context>(CL_MEM_CONTEXT)?;
        unsafe {
            tri!(clRetainContext(ctx));
            // SAFETY: Context checked to be valid by `clRetainContext`.
            Ok(RawContext::from_id_unchecked(ctx))
        }
    }

    /// Return offset if memobj is a sub-buffer object created using [create_sub_buffer](crate::buffer::RawBuffer::create_sub_buffer). Returns 0 if memobj is not a subbuffer object.
    #[docfg(feature = "cl1_1")]
    #[inline(always)]
    pub fn offset(&self) -> Result<usize> {
        self.get_info(opencl_sys::CL_MEM_OFFSET)
    }

    /// Return ```true``` if memobj is a buffer object that was created with CL_MEM_USE_HOST_PTR or is a sub-buffer object of a buffer object that was created with CL_MEM_USE_HOST_PTR and the host_ptr specified when the buffer object was created is a SVM pointer; otherwise returns ```false```.
    #[docfg(feature = "cl2")]
    #[inline(always)]
    pub fn uses_svm_pointer(&self) -> Result<bool> {
        let v = self.get_info::<opencl_sys::cl_bool>(opencl_sys::CL_MEM_USES_SVM_POINTER)?;
        Ok(v != 0)
    }

    #[inline]
    pub(super) fn get_info<O: Copy>(&self, ty: cl_mem_info) -> Result<O> {
        let mut result = MaybeUninit::<O>::uninit();

        unsafe {
            tri!(clGetMemObjectInfo(
                self.id(),
                ty,
                core::mem::size_of::<O>(),
                result.as_mut_ptr().cast(),
                core::ptr::null_mut()
            ));
            Ok(result.assume_init())
        }
    }
}

#[docfg(feature = "cl1_1")]
impl RawMemObject {
    /// Adds a callback to be executed when the memory object is destructed by OpenCL.
    #[inline(always)]
    pub fn on_destruct(&self, f: impl 'static + FnOnce() + Send) -> Result<()> {
        let f = Box::new(f) as Box<_>;
        self.on_destruct_boxed(f)
    }

    #[inline(always)]
    pub fn on_destruct_boxed(&self, f: Box<dyn FnOnce() + Send>) -> Result<()> {
        let data = Box::into_raw(Box::new(f));
        unsafe { self.on_destruct_raw(destructor_callback, data.cast()) }
    }

    #[inline(always)]
    pub unsafe fn on_destruct_raw(
        &self,
        f: unsafe extern "C" fn(memobj: cl_mem, user_data: *mut c_void),
        user_data: *mut c_void,
    ) -> Result<()> {
        tri!(opencl_sys::clSetMemObjectDestructorCallback(
            self.id(),
            Some(f),
            user_data
        ));
        Ok(())
    }
}

impl Clone for RawMemObject {
    #[inline(always)]
    fn clone(&self) -> Self {
        unsafe { tri_panic!(clRetainMemObject(self.id())) }

        Self(self.0)
    }
}

impl Drop for RawMemObject {
    #[inline(always)]
    fn drop(&mut self) {
        unsafe { tri_panic!(clReleaseMemObject(self.id())) }
    }
}

unsafe impl Send for RawMemObject {}
unsafe impl Sync for RawMemObject {}

#[cfg(feature = "cl1_1")]
unsafe extern "C" fn destructor_callback(_memobj: cl_mem, user_data: *mut c_void) {
    let f = *Box::from_raw(user_data as *mut Box<dyn FnOnce() + Send>);
    f()
}