1use std::borrow::{Borrow, BorrowMut};
2use std::fmt;
3use std::marker::PhantomData;
4use std::sync::Arc;
5
6use crate::buffer::{BufferConverter, BufferInstance, BufferMut};
7use crate::ops::{ReadOp, Write};
8use crate::platform::PlatformInstance;
9use crate::{Buffer, Error, Number, Platform};
10
11pub trait Access<T: Number>: Send + Sync {
13 fn read(&self) -> Result<BufferConverter<'_, T>, Error>;
15
16 fn read_value(&self, offset: usize) -> Result<T, Error>;
18
19 fn size(&self) -> usize;
21}
22
23pub trait AccessMut<T: Number>: Access<T> + fmt::Debug {
25 #[cfg(feature = "opencl")]
26 fn cl_buffer(&mut self) -> Result<&mut ocl::Buffer<T>, Error> {
28 Err(Error::unsupported(format!(
29 "not an OpenCL buffer: {self:?}"
30 )))
31 }
32
33 fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error>;
35
36 fn write_value(&mut self, value: T) -> Result<(), Error>;
38
39 fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error>;
41}
42
43pub trait AccessBorrow<'a, T, B>: Access<T>
45where
46 T: Number,
47 B: Access<T> + 'a,
48{
49 fn borrow(&'a self) -> B;
50}
51
52pub trait AccessBorrowMut<'a, T, B>: Access<T>
54where
55 T: Number,
56 B: AccessMut<T> + 'a,
57{
58 fn borrow_mut(&'a mut self) -> B;
59}
60
61pub struct AccessBuf<B> {
63 buffer: B,
64}
65
66impl<B: Clone> Clone for AccessBuf<B> {
67 fn clone(&self) -> Self {
68 Self {
69 buffer: self.buffer.clone(),
70 }
71 }
72}
73
74impl<B> AccessBuf<B> {
75 pub fn inner(&self) -> &B {
77 &self.buffer
78 }
79
80 pub fn inner_mut(&mut self) -> &mut B {
82 &mut self.buffer
83 }
84
85 pub fn into_inner(self) -> B {
87 self.buffer
88 }
89}
90
91impl<'a, T, B, RB> AccessBorrow<'a, T, AccessBuf<&'a RB>> for AccessBuf<B>
92where
93 T: Number,
94 B: BufferInstance<T> + Borrow<RB>,
95 &'a RB: BufferInstance<T>,
96{
97 fn borrow(&'a self) -> AccessBuf<&'a RB> {
98 AccessBuf {
99 buffer: self.buffer.borrow(),
100 }
101 }
102}
103
104impl<'a, T, B, RB> AccessBorrowMut<'a, T, AccessBuf<&'a mut RB>> for AccessBuf<B>
105where
106 T: Number,
107 B: BufferInstance<T> + BorrowMut<RB>,
108 &'a mut RB: BufferMut<T>,
109{
110 fn borrow_mut(&'a mut self) -> AccessBuf<&'a mut RB> {
111 AccessBuf {
112 buffer: self.buffer.borrow_mut(),
113 }
114 }
115}
116
117impl<B> From<B> for AccessBuf<B> {
118 fn from(buffer: B) -> Self {
119 Self { buffer }
120 }
121}
122
123impl<T, B> Access<T> for AccessBuf<B>
124where
125 T: Number,
126 B: BufferInstance<T>,
127{
128 fn read(&self) -> Result<BufferConverter<'_, T>, Error> {
129 Ok(self.buffer.read())
130 }
131
132 fn read_value(&self, offset: usize) -> Result<T, Error> {
133 self.buffer.read_value(offset)
134 }
135
136 fn size(&self) -> usize {
137 self.buffer.len()
138 }
139}
140
141impl<T, B> AccessMut<T> for AccessBuf<B>
142where
143 T: Number,
144 B: BufferMut<T>,
145{
146 #[cfg(feature = "opencl")]
147 fn cl_buffer(&mut self) -> Result<&mut ocl::Buffer<T>, Error> {
148 self.buffer.cl()
149 }
150
151 fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
152 self.buffer.write(data)
153 }
154
155 fn write_value(&mut self, value: T) -> Result<(), Error> {
156 self.buffer.write_value(value)
157 }
158
159 fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
160 self.buffer.write_value_at(offset, value)
161 }
162}
163
164impl<B: fmt::Debug> fmt::Debug for AccessBuf<B> {
165 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
166 write!(f, "access {:?}", self.buffer)
167 }
168}
169
170pub struct AccessOp<O, P> {
172 op: O,
173 platform: PhantomData<P>,
174}
175
176impl<'a, T, O, P> AccessBorrow<'a, T, &'a Self> for AccessOp<O, P>
177where
178 T: Number,
179 O: ReadOp<P, T>,
180 P: PlatformInstance,
181 Self: Access<T>,
182 &'a Self: Access<T>,
183{
184 fn borrow(&'a self) -> &'a Self {
185 self
186 }
187}
188
189impl<'a, T, O, P> AccessBorrowMut<'a, T, &'a mut Self> for AccessOp<O, P>
190where
191 T: Number,
192 O: ReadOp<P, T> + Write<P, T>,
193 P: PlatformInstance,
194 Self: AccessMut<T>,
195 &'a mut Self: AccessMut<T>,
196{
197 fn borrow_mut(&'a mut self) -> &'a mut Self {
198 self
199 }
200}
201
202impl<O, P> AccessOp<O, P> {
203 pub fn wrap<FO, FP>(access: AccessOp<FO, FP>) -> Self
205 where
206 FO: Into<O>,
207 FP: Into<P>,
208 {
209 Self {
210 op: access.op.into(),
211 platform: PhantomData,
212 }
213 }
214}
215
216impl<O, P> From<O> for AccessOp<O, P> {
217 fn from(op: O) -> Self {
218 Self {
219 op,
220 platform: PhantomData,
221 }
222 }
223}
224
225impl<O, P, T> Access<T> for AccessOp<O, P>
226where
227 T: Number,
228 O: ReadOp<P, T>,
229 P: PlatformInstance,
230{
231 fn read(&self) -> Result<BufferConverter<'static, T>, Error> {
232 self.op.enqueue().map(|buffer| buffer.into())
233 }
234
235 fn read_value(&self, offset: usize) -> Result<T, Error> {
236 self.op.read_value(offset)
237 }
238
239 fn size(&self) -> usize {
240 self.op.size()
241 }
242}
243
244impl<O, P, T> Access<T> for &AccessOp<O, P>
245where
246 T: Number,
247 O: ReadOp<P, T>,
248 P: PlatformInstance,
249 BufferConverter<'static, T>: From<O::Buffer>,
250{
251 fn read(&self) -> Result<BufferConverter<'static, T>, Error> {
252 self.op.enqueue().map(BufferConverter::from)
253 }
254
255 fn read_value(&self, offset: usize) -> Result<T, Error> {
256 self.op.read_value(offset)
257 }
258
259 fn size(&self) -> usize {
260 self.op.size()
261 }
262}
263
264impl<O, P, T> AccessMut<T> for AccessOp<O, P>
265where
266 T: Number,
267 O: ReadOp<P, T> + Write<P, T>,
268 P: PlatformInstance,
269 BufferConverter<'static, T>: From<O::Buffer>,
270{
271 fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
272 self.op.write(data)
273 }
274
275 fn write_value(&mut self, value: T) -> Result<(), Error> {
276 self.op.write_value(value)
277 }
278
279 fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
280 self.op.write_value_at(offset, value)
281 }
282}
283
284impl<O, P: fmt::Debug> fmt::Debug for AccessOp<O, P> {
285 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
286 write!(
287 f,
288 "access op {:?} on {:?}",
289 std::any::type_name::<O>(),
290 self.platform
291 )
292 }
293}
294
295#[derive(Clone)]
298pub enum Accessor<'a, T: Number> {
299 Buffer(Arc<dyn BufferInstance<T> + 'a>),
300 Op(Arc<dyn ReadOp<Platform, T, Buffer = Buffer<T>> + 'a>),
301}
302
303impl<'a, T: Number> Access<T> for Accessor<'a, T> {
304 fn read(&self) -> Result<BufferConverter<'_, T>, Error> {
305 match self {
306 Self::Buffer(buf) => Ok(buf.read()),
307 Self::Op(op) => op.enqueue().map(BufferConverter::from),
308 }
309 }
310
311 fn read_value(&self, offset: usize) -> Result<T, Error> {
312 match self {
313 Self::Buffer(buf) => buf.read_value(offset),
314 Self::Op(op) => op.read_value(offset),
315 }
316 }
317
318 fn size(&self) -> usize {
319 match self {
320 Self::Buffer(buf) => buf.len(),
321 Self::Op(op) => op.size(),
322 }
323 }
324}
325
326impl<'a, T, B> From<AccessBuf<B>> for Accessor<'a, T>
327where
328 T: Number,
329 B: BufferInstance<T> + 'a,
330{
331 fn from(access: AccessBuf<B>) -> Self {
332 Self::Buffer(Arc::new(access.buffer))
333 }
334}
335
336impl<'a, T, O, P> From<AccessOp<O, P>> for Accessor<'a, T>
337where
338 T: Number,
339 O: ReadOp<Platform, T, Buffer = Buffer<T>> + 'a,
340 P: PlatformInstance + Into<Platform>,
341{
342 fn from(access: AccessOp<O, P>) -> Self {
343 let access: AccessOp<O, Platform> = AccessOp::wrap(access);
344 let op: Arc<dyn ReadOp<Platform, T, Buffer = Buffer<T>>> = Arc::new(access.op);
345 Self::Op(op)
346 }
347}