1use std::borrow::{Borrow, BorrowMut};
2use std::fmt;
3use std::marker::PhantomData;
4use std::sync::Arc;
5
6use crate::buffer::{BufferConverter, BufferInstance, BufferMut};
7use crate::ops::{ReadOp, Write};
8use crate::platform::PlatformInstance;
9use crate::{Buffer, CType, Error, Platform};
10
11pub trait Access<T: CType>: Send + Sync {
13 fn read(&self) -> Result<BufferConverter<T>, Error>;
15
16 fn read_value(&self, offset: usize) -> Result<T, Error>;
18
19 fn size(&self) -> usize;
21}
22
23pub trait AccessMut<T: CType>: Access<T> + fmt::Debug {
25 #[cfg(feature = "opencl")]
26 fn cl_buffer(&mut self) -> Result<&mut ocl::Buffer<T>, Error> {
28 Err(Error::Unsupported(format!(
29 "not an OpenCL buffer: {self:?}"
30 )))
31 }
32
33 fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error>;
35
36 fn write_value(&mut self, value: T) -> Result<(), Error>;
38
39 fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error>;
41}
42
43pub struct AccessBuf<B> {
45 buffer: B,
46}
47
48impl<B: Clone> Clone for AccessBuf<B> {
49 fn clone(&self) -> Self {
50 Self {
51 buffer: self.buffer.clone(),
52 }
53 }
54}
55
56impl<B> AccessBuf<B> {
57 pub fn as_mut<RB: ?Sized>(&mut self) -> AccessBuf<&mut RB>
59 where
60 B: BorrowMut<RB>,
61 {
62 AccessBuf {
63 buffer: self.buffer.borrow_mut(),
64 }
65 }
66
67 pub fn as_ref<RB: ?Sized>(&self) -> AccessBuf<&RB>
69 where
70 B: Borrow<RB>,
71 {
72 AccessBuf {
73 buffer: self.buffer.borrow(),
74 }
75 }
76
77 pub fn inner(&self) -> &B {
79 &self.buffer
80 }
81
82 pub fn inner_mut(&mut self) -> &mut B {
84 &mut self.buffer
85 }
86
87 pub fn into_inner(self) -> B {
89 self.buffer
90 }
91}
92
93impl<B> From<B> for AccessBuf<B> {
94 fn from(buffer: B) -> Self {
95 Self { buffer }
96 }
97}
98
99impl<T, B> Access<T> for AccessBuf<B>
100where
101 T: CType,
102 B: BufferInstance<T>,
103{
104 fn read(&self) -> Result<BufferConverter<T>, Error> {
105 Ok(self.buffer.read())
106 }
107
108 fn read_value(&self, offset: usize) -> Result<T, Error> {
109 self.buffer.read_value(offset)
110 }
111
112 fn size(&self) -> usize {
113 self.buffer.len()
114 }
115}
116
117impl<T, B> AccessMut<T> for AccessBuf<B>
118where
119 T: CType,
120 B: BufferMut<T>,
121{
122 #[cfg(feature = "opencl")]
123 fn cl_buffer(&mut self) -> Result<&mut ocl::Buffer<T>, Error> {
124 self.buffer.cl()
125 }
126
127 fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
128 self.buffer.write(data)
129 }
130
131 fn write_value(&mut self, value: T) -> Result<(), Error> {
132 self.buffer.write_value(value)
133 }
134
135 fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
136 self.buffer.write_value_at(offset, value)
137 }
138}
139
140impl<B: fmt::Debug> fmt::Debug for AccessBuf<B> {
141 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
142 write!(f, "access {:?}", self.buffer)
143 }
144}
145
146pub struct AccessOp<O, P> {
148 op: O,
149 platform: PhantomData<P>,
150}
151
152impl<O, P> AccessOp<O, P> {
153 pub fn wrap<FO, FP>(access: AccessOp<FO, FP>) -> Self
155 where
156 FO: Into<O>,
157 FP: Into<P>,
158 {
159 Self {
160 op: access.op.into(),
161 platform: PhantomData,
162 }
163 }
164}
165
166impl<O, P> From<O> for AccessOp<O, P> {
167 fn from(op: O) -> Self {
168 Self {
169 op,
170 platform: PhantomData,
171 }
172 }
173}
174
175impl<O, P, T> Access<T> for AccessOp<O, P>
176where
177 T: CType,
178 O: ReadOp<P, T>,
179 P: PlatformInstance,
180{
181 fn read(&self) -> Result<BufferConverter<'static, T>, Error> {
182 self.op.enqueue().map(|buffer| buffer.into())
183 }
184
185 fn read_value(&self, offset: usize) -> Result<T, Error> {
186 self.op.read_value(offset)
187 }
188
189 fn size(&self) -> usize {
190 self.op.size()
191 }
192}
193
194impl<'a, O, P, T> Access<T> for &'a AccessOp<O, P>
195where
196 T: CType,
197 O: ReadOp<P, T>,
198 P: PlatformInstance,
199 BufferConverter<'static, T>: From<O::Buffer>,
200{
201 fn read(&self) -> Result<BufferConverter<'static, T>, Error> {
202 self.op.enqueue().map(BufferConverter::from)
203 }
204
205 fn read_value(&self, offset: usize) -> Result<T, Error> {
206 self.op.read_value(offset)
207 }
208
209 fn size(&self) -> usize {
210 self.op.size()
211 }
212}
213
214impl<O, P, T> AccessMut<T> for AccessOp<O, P>
215where
216 T: CType,
217 O: ReadOp<P, T> + Write<P, T>,
218 P: PlatformInstance,
219 BufferConverter<'static, T>: From<O::Buffer>,
220{
221 fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
222 self.op.write(data)
223 }
224
225 fn write_value(&mut self, value: T) -> Result<(), Error> {
226 self.op.write_value(value)
227 }
228
229 fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
230 self.op.write_value_at(offset, value)
231 }
232}
233
234impl<O, P: fmt::Debug> fmt::Debug for AccessOp<O, P> {
235 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
236 write!(
237 f,
238 "access op {:?} on {:?}",
239 std::any::type_name::<O>(),
240 self.platform
241 )
242 }
243}
244
245#[derive(Clone)]
248pub enum Accessor<T: CType> {
249 Buffer(Arc<dyn BufferInstance<T>>),
250 Op(Arc<dyn ReadOp<Platform, T, Buffer = Buffer<T>>>),
251}
252
253impl<T: CType> Access<T> for Accessor<T> {
254 fn read(&self) -> Result<BufferConverter<T>, Error> {
255 match self {
256 Self::Buffer(buf) => Ok(buf.read()),
257 Self::Op(op) => op.enqueue().map(BufferConverter::from),
258 }
259 }
260
261 fn read_value(&self, offset: usize) -> Result<T, Error> {
262 match self {
263 Self::Buffer(buf) => buf.read_value(offset),
264 Self::Op(op) => op.read_value(offset),
265 }
266 }
267
268 fn size(&self) -> usize {
269 match self {
270 Self::Buffer(buf) => buf.len(),
271 Self::Op(op) => op.size(),
272 }
273 }
274}
275
276impl<T, B> From<AccessBuf<B>> for Accessor<T>
277where
278 T: CType,
279 B: BufferInstance<T> + 'static,
280{
281 fn from(access: AccessBuf<B>) -> Self {
282 Self::Buffer(Arc::new(access.buffer))
283 }
284}
285
286impl<T, O, P> From<AccessOp<O, P>> for Accessor<T>
287where
288 T: CType,
289 O: ReadOp<Platform, T, Buffer = Buffer<T>> + 'static,
290 P: PlatformInstance + Into<Platform>,
291{
292 fn from(access: AccessOp<O, P>) -> Self {
293 let access: AccessOp<O, Platform> = AccessOp::wrap(access);
294 let op: Arc<dyn ReadOp<Platform, T, Buffer = Buffer<T>>> = Arc::new(access.op);
295 Self::Op(op)
296 }
297}