1use std::fmt;
2use std::fmt::Debug;
3use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
4
5use crate::ll::{ClContext, ClMem, MemFlags, MemPtr};
6
7use crate::{
8 BufferCreator, ClNumber, Context, HostAccess, KernelAccess, MemConfig, MemLocation, Output,
9 NumberType, NumberTyped
10};
11
12pub struct Buffer {
13 _t: NumberType,
14 _mem: Arc<RwLock<ClMem>>,
15 _context: Context,
16}
17
18impl NumberTyped for Buffer {
19 fn number_type(&self) -> NumberType {
20 self._t
21 }
22}
23
24unsafe impl Send for Buffer {}
25unsafe impl Sync for Buffer {}
26
27impl Clone for Buffer {
28 fn clone(&self) -> Buffer {
29 Buffer {
30 _t: self._t,
31 _mem: self._mem.clone(),
32 _context: self._context.clone(),
33 }
34 }
35}
36
37impl Debug for Buffer {
38 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
39 write!(f, "Buffer{{{:?}}}", self._mem)
40 }
41}
42
43impl PartialEq for Buffer {
44 fn eq(&self, other: &Self) -> bool {
45 unsafe {
46 let left = self._mem.read().unwrap().mem_ptr();
47 let right = other._mem.read().unwrap().mem_ptr();
48 std::ptr::eq(left, right)
49 }
50 }
51}
52
53impl Buffer {
54 pub fn new(ll_mem: ClMem, context: Context) -> Buffer {
55 Buffer {
56 _t: ll_mem.number_type(),
57 _mem: Arc::new(RwLock::new(ll_mem)),
58 _context: context,
59 }
60 }
61
62 pub fn create<T: ClNumber, B: BufferCreator<T>>(
63 context: &Context,
64 creator: B,
65 host_access: HostAccess,
66 kernel_access: KernelAccess,
67 mem_location: MemLocation,
68 ) -> Output<Buffer> {
69 let ll_mem = ClMem::create(
70 context.low_level_context(),
71 creator,
72 host_access,
73 kernel_access,
74 mem_location,
75 )?;
76 Ok(Buffer::new(ll_mem, context.clone()))
77 }
78
79 pub fn create_with_len<T: ClNumber>(context: &Context, len: usize) -> Output<Buffer> {
80 Buffer::create_from::<T, usize>(context, len)
81 }
82
83 pub fn create_from_slice<T: ClNumber>(context: &Context, data: &[T]) -> Output<Buffer> {
84 Buffer::create_from(context, data)
85 }
86
87 pub fn create_from<T: ClNumber, B: BufferCreator<T>>(
88 context: &Context,
89 creator: B,
90 ) -> Output<Buffer> {
91 let mem_config = { creator.mem_config() };
92 Buffer::create_with_config(context, creator, mem_config)
93 }
94
95 pub fn create_with_config<T: ClNumber, B: BufferCreator<T>>(
96 context: &Context,
97 creator: B,
98 mem_config: MemConfig,
99 ) -> Output<Buffer> {
100 Buffer::create(
101 context,
102 creator,
103 mem_config.host_access,
104 mem_config.kernel_access,
105 mem_config.mem_location,
106 )
107 }
108
109 pub fn create_from_low_level_context<T: ClNumber, B: BufferCreator<T>>(
110 ll_context: &ClContext,
111 creator: B,
112 host_access: HostAccess,
113 kernel_access: KernelAccess,
114 mem_location: MemLocation,
115 ) -> Output<Buffer> {
116 let ll_mem = ClMem::create(
117 ll_context,
118 creator,
119 host_access,
120 kernel_access,
121 mem_location,
122 )?;
123 let context = Context::from_low_level_context(ll_context)?;
124 Ok(Buffer::new(ll_mem, context))
125 }
126
127 pub fn read_lock(&self) -> RwLockReadGuard<ClMem> {
128 self._mem.read().unwrap()
129 }
130
131 pub fn write_lock(&self) -> RwLockWriteGuard<ClMem> {
132 self._mem.write().unwrap()
133 }
134
135 pub fn context(&self) -> &Context {
136 &self._context
137 }
138
139 pub fn reference_count(&self) -> Output<u32> {
140 unsafe { self.read_lock().reference_count() }
141 }
142
143 pub fn size(&self) -> Output<usize> {
144 unsafe { self.read_lock().size() }
145 }
146
147 pub fn length(&self) -> Output<usize> {
149 unsafe { self.read_lock().len() }
150 }
152
153 pub fn len(&self) -> usize {
156 self.length().unwrap()
157 }
158
159 pub fn offset(&self) -> Output<usize> {
160 unsafe { self.read_lock().offset() }
161 }
162
163 pub fn flags(&self) -> Output<MemFlags> {
164 unsafe { self.read_lock().flags() }
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use crate::ll::*;
171 use crate::*;
172
173 #[test]
174 fn buffer_can_be_created_with_a_length() {
175 let context = testing::get_context();
176 let _buffer = Buffer::create_with_len::<u32>(&context, 10).unwrap();
177 }
178
179 #[test]
180 fn buffer_can_be_created_with_a_slice_of_data() {
181 let context = testing::get_context();
182 let data = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
183 let _buffer = Buffer::create(
184 &context,
185 &data[..],
186 HostAccess::NoAccess,
187 KernelAccess::ReadWrite,
188 MemLocation::CopyToDevice,
189 )
190 .unwrap();
191 }
192
193 #[test]
194 fn buffer_reference_count_works() {
195 let buffer = testing::get_buffer::<u32>(10);
196
197 let ref_count = buffer
198 .reference_count()
199 .expect("Failed to call buffer.reference_count()");
200 assert_eq!(ref_count, 1);
201 }
202
203 #[test]
204 fn buffer_size_works() {
205 let buffer = testing::get_buffer::<u32>(10);
206 let size = buffer.size().expect("Failed to call buffer.size()");
207 assert_eq!(size, 40);
208 }
209
210 #[test]
218 fn buffer_flags_works() {
219 let buffer = testing::get_buffer::<u32>(10);
220 let flags = buffer.flags().expect("Failed to call buffer.flags()");
221 assert_eq!(
222 flags,
223 MemFlags::KERNEL_READ_WRITE
224 | MemFlags::ALLOC_HOST_PTR
225 | MemFlags::READ_WRITE_ALLOC_HOST_PTR
226 );
227 }
228
229 #[test]
230 fn buffer_offset_works() {
231 let buffer = testing::get_buffer::<u32>(10);
232 let offset = buffer.offset().expect("Failed to call buffer.offset()");
233 assert_eq!(offset, 0);
234 }
235}