opencl_core/
buffer.rs

1use std::fmt;
2use std::fmt::Debug;
3use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
4
5use crate::ll::{ClContext, ClMem, MemFlags, MemPtr};
6
7use crate::{
8    BufferCreator, ClNumber, Context, HostAccess, KernelAccess, MemConfig, MemLocation, Output,
9    NumberType, NumberTyped
10};
11
12pub struct Buffer {
13    _t: NumberType,
14    _mem: Arc<RwLock<ClMem>>,
15    _context: Context,
16}
17
18impl NumberTyped for Buffer {
19    fn number_type(&self) -> NumberType {
20        self._t
21    }
22}
23
24unsafe impl Send for Buffer {}
25unsafe impl Sync for Buffer {}
26
27impl Clone for Buffer {
28    fn clone(&self) -> Buffer {
29        Buffer {
30            _t: self._t,
31            _mem: self._mem.clone(),
32            _context: self._context.clone(),
33        }
34    }
35}
36
37impl Debug for Buffer {
38    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
39        write!(f, "Buffer{{{:?}}}", self._mem)
40    }
41}
42
43impl PartialEq for Buffer {
44    fn eq(&self, other: &Self) -> bool {
45        unsafe {
46            let left = self._mem.read().unwrap().mem_ptr();
47            let right = other._mem.read().unwrap().mem_ptr();
48            std::ptr::eq(left, right)
49        }
50    }
51}
52
53impl Buffer {
54    pub fn new(ll_mem: ClMem, context: Context) -> Buffer {
55        Buffer {
56            _t: ll_mem.number_type(),
57            _mem: Arc::new(RwLock::new(ll_mem)),
58            _context: context,
59        }
60    }
61
62    pub fn create<T: ClNumber, B: BufferCreator<T>>(
63        context: &Context,
64        creator: B,
65        host_access: HostAccess,
66        kernel_access: KernelAccess,
67        mem_location: MemLocation,
68    ) -> Output<Buffer> {
69        let ll_mem = ClMem::create(
70            context.low_level_context(),
71            creator,
72            host_access,
73            kernel_access,
74            mem_location,
75        )?;
76        Ok(Buffer::new(ll_mem, context.clone()))
77    }
78
79    pub fn create_with_len<T: ClNumber>(context: &Context, len: usize) -> Output<Buffer> {
80        Buffer::create_from::<T, usize>(context, len)
81    }
82
83    pub fn create_from_slice<T: ClNumber>(context: &Context, data: &[T]) -> Output<Buffer> {
84        Buffer::create_from(context, data)
85    }
86
87    pub fn create_from<T: ClNumber, B: BufferCreator<T>>(
88        context: &Context,
89        creator: B,
90    ) -> Output<Buffer> {
91        let mem_config = { creator.mem_config() };
92        Buffer::create_with_config(context, creator, mem_config)
93    }
94
95    pub fn create_with_config<T: ClNumber, B: BufferCreator<T>>(
96        context: &Context,
97        creator: B,
98        mem_config: MemConfig,
99    ) -> Output<Buffer> {
100        Buffer::create(
101            context,
102            creator,
103            mem_config.host_access,
104            mem_config.kernel_access,
105            mem_config.mem_location,
106        )
107    }
108
109    pub fn create_from_low_level_context<T: ClNumber, B: BufferCreator<T>>(
110        ll_context: &ClContext,
111        creator: B,
112        host_access: HostAccess,
113        kernel_access: KernelAccess,
114        mem_location: MemLocation,
115    ) -> Output<Buffer> {
116        let ll_mem = ClMem::create(
117            ll_context,
118            creator,
119            host_access,
120            kernel_access,
121            mem_location,
122        )?;
123        let context = Context::from_low_level_context(ll_context)?;
124        Ok(Buffer::new(ll_mem, context))
125    }
126
127    pub fn read_lock(&self) -> RwLockReadGuard<ClMem> {
128        self._mem.read().unwrap()
129    }
130
131    pub fn write_lock(&self) -> RwLockWriteGuard<ClMem> {
132        self._mem.write().unwrap()
133    }
134
135    pub fn context(&self) -> &Context {
136        &self._context
137    }
138
139    pub fn reference_count(&self) -> Output<u32> {
140        unsafe { self.read_lock().reference_count() }
141    }
142
143    pub fn size(&self) -> Output<usize> {
144        unsafe { self.read_lock().size() }
145    }
146
147    /// A non-panicking version of len.
148    pub fn length(&self) -> Output<usize> {
149        unsafe { self.read_lock().len() }
150        // Ok(size / std::mem::size_of::<T>())
151    }
152
153    /// A method for getting the len of the device memory buffer.
154    /// Panics if the buffer size info returns an error.
155    pub fn len(&self) -> usize {
156        self.length().unwrap()
157    }
158
159    pub fn offset(&self) -> Output<usize> {
160        unsafe { self.read_lock().offset() }
161    }
162
163    pub fn flags(&self) -> Output<MemFlags> {
164        unsafe { self.read_lock().flags() }
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use crate::ll::*;
171    use crate::*;
172
173    #[test]
174    fn buffer_can_be_created_with_a_length() {
175        let context = testing::get_context();
176        let _buffer = Buffer::create_with_len::<u32>(&context, 10).unwrap();
177    }
178
179    #[test]
180    fn buffer_can_be_created_with_a_slice_of_data() {
181        let context = testing::get_context();
182        let data = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
183        let _buffer = Buffer::create(
184            &context,
185            &data[..],
186            HostAccess::NoAccess,
187            KernelAccess::ReadWrite,
188            MemLocation::CopyToDevice,
189        )
190        .unwrap();
191    }
192
193    #[test]
194    fn buffer_reference_count_works() {
195        let buffer = testing::get_buffer::<u32>(10);
196
197        let ref_count = buffer
198            .reference_count()
199            .expect("Failed to call buffer.reference_count()");
200        assert_eq!(ref_count, 1);
201    }
202
203    #[test]
204    fn buffer_size_works() {
205        let buffer = testing::get_buffer::<u32>(10);
206        let size = buffer.size().expect("Failed to call buffer.size()");
207        assert_eq!(size, 40);
208    }
209
210    // #[test]
211    // fn device_mem_method_mem_type_works() {
212    //     let buffer = testing::get_buffer::<u32>(10);
213    //     let _out: MemObjectType = buffer.mem_type()
214    //         .expect("Failed to call device_mem.mem_type()");
215    // }
216
217    #[test]
218    fn buffer_flags_works() {
219        let buffer = testing::get_buffer::<u32>(10);
220        let flags = buffer.flags().expect("Failed to call buffer.flags()");
221        assert_eq!(
222            flags,
223            MemFlags::KERNEL_READ_WRITE
224                | MemFlags::ALLOC_HOST_PTR
225                | MemFlags::READ_WRITE_ALLOC_HOST_PTR
226        );
227    }
228
229    #[test]
230    fn buffer_offset_works() {
231        let buffer = testing::get_buffer::<u32>(10);
232        let offset = buffer.offset().expect("Failed to call buffer.offset()");
233        assert_eq!(offset, 0);
234    }
235}