Skip to main content

ha_ndarray/
access.rs

1use std::borrow::{Borrow, BorrowMut};
2use std::fmt;
3use std::marker::PhantomData;
4use std::sync::Arc;
5
6use crate::buffer::{BufferConverter, BufferInstance, BufferMut};
7use crate::ops::{ReadOp, Write};
8use crate::platform::PlatformInstance;
9use crate::{Buffer, Error, Number, Platform};
10
11/// A type which allows accessing array data
12pub trait Access<T: Number>: Send + Sync {
13    /// Read the data of this accessor as a [`BufferConverter`].
14    fn read(&self) -> Result<BufferConverter<'_, T>, Error>;
15
16    /// Access a single value.
17    fn read_value(&self, offset: usize) -> Result<T, Error>;
18
19    /// Return the data size.
20    fn size(&self) -> usize;
21}
22
23/// A type which allows accessing array data mutably
24pub trait AccessMut<T: Number>: Access<T> + fmt::Debug {
25    #[cfg(feature = "opencl")]
26    /// Borrow the array data as an [`ocl::Buffer`], or return an error if this not an OpenCL buffer.
27    fn cl_buffer(&mut self) -> Result<&mut ocl::Buffer<T>, Error> {
28        Err(Error::unsupported(format!(
29            "not an OpenCL buffer: {self:?}"
30        )))
31    }
32
33    /// Overwrite these data with the given `data`.
34    fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error>;
35
36    /// Overwrite these data with a single value.
37    fn write_value(&mut self, value: T) -> Result<(), Error>;
38
39    /// Overwrite a single value.
40    fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error>;
41}
42
43/// Borrow an accessor immutably
44pub trait AccessBorrow<'a, T, B>: Access<T>
45where
46    T: Number,
47    B: Access<T> + 'a,
48{
49    fn borrow(&'a self) -> B;
50}
51
52/// Borrow an accessor mutably
53pub trait AccessBorrowMut<'a, T, B>: Access<T>
54where
55    T: Number,
56    B: AccessMut<T> + 'a,
57{
58    fn borrow_mut(&'a mut self) -> B;
59}
60
61/// A struct which provides n-dimensional access to an underlying [`BufferInstance`]
62pub struct AccessBuf<B> {
63    buffer: B,
64}
65
66impl<B: Clone> Clone for AccessBuf<B> {
67    fn clone(&self) -> Self {
68        Self {
69            buffer: self.buffer.clone(),
70        }
71    }
72}
73
74impl<B> AccessBuf<B> {
75    /// Borrow the underlying [`BufferInstance`] of this [`AccessBuf`].
76    pub fn inner(&self) -> &B {
77        &self.buffer
78    }
79
80    /// Borrow the underlying [`BufferInstance`] of this [`AccessBuf`] mutably.
81    pub fn inner_mut(&mut self) -> &mut B {
82        &mut self.buffer
83    }
84
85    /// Destructure this [`AccessBuf`] into its underlying [`BufferInstance`].
86    pub fn into_inner(self) -> B {
87        self.buffer
88    }
89}
90
91impl<'a, T, B, RB> AccessBorrow<'a, T, AccessBuf<&'a RB>> for AccessBuf<B>
92where
93    T: Number,
94    B: BufferInstance<T> + Borrow<RB>,
95    &'a RB: BufferInstance<T>,
96{
97    fn borrow(&'a self) -> AccessBuf<&'a RB> {
98        AccessBuf {
99            buffer: self.buffer.borrow(),
100        }
101    }
102}
103
104impl<'a, T, B, RB> AccessBorrowMut<'a, T, AccessBuf<&'a mut RB>> for AccessBuf<B>
105where
106    T: Number,
107    B: BufferInstance<T> + BorrowMut<RB>,
108    &'a mut RB: BufferMut<T>,
109{
110    fn borrow_mut(&'a mut self) -> AccessBuf<&'a mut RB> {
111        AccessBuf {
112            buffer: self.buffer.borrow_mut(),
113        }
114    }
115}
116
117impl<B> From<B> for AccessBuf<B> {
118    fn from(buffer: B) -> Self {
119        Self { buffer }
120    }
121}
122
123impl<T, B> Access<T> for AccessBuf<B>
124where
125    T: Number,
126    B: BufferInstance<T>,
127{
128    fn read(&self) -> Result<BufferConverter<'_, T>, Error> {
129        Ok(self.buffer.read())
130    }
131
132    fn read_value(&self, offset: usize) -> Result<T, Error> {
133        self.buffer.read_value(offset)
134    }
135
136    fn size(&self) -> usize {
137        self.buffer.len()
138    }
139}
140
141impl<T, B> AccessMut<T> for AccessBuf<B>
142where
143    T: Number,
144    B: BufferMut<T>,
145{
146    #[cfg(feature = "opencl")]
147    fn cl_buffer(&mut self) -> Result<&mut ocl::Buffer<T>, Error> {
148        self.buffer.cl()
149    }
150
151    fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
152        self.buffer.write(data)
153    }
154
155    fn write_value(&mut self, value: T) -> Result<(), Error> {
156        self.buffer.write_value(value)
157    }
158
159    fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
160        self.buffer.write_value_at(offset, value)
161    }
162}
163
164impl<B: fmt::Debug> fmt::Debug for AccessBuf<B> {
165    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
166        write!(f, "access {:?}", self.buffer)
167    }
168}
169
170/// A struct which provides n-dimensional access to the result of an array operation.
171pub struct AccessOp<O, P> {
172    op: O,
173    platform: PhantomData<P>,
174}
175
176impl<'a, T, O, P> AccessBorrow<'a, T, &'a Self> for AccessOp<O, P>
177where
178    T: Number,
179    O: ReadOp<P, T>,
180    P: PlatformInstance,
181    Self: Access<T>,
182    &'a Self: Access<T>,
183{
184    fn borrow(&'a self) -> &'a Self {
185        self
186    }
187}
188
189impl<'a, T, O, P> AccessBorrowMut<'a, T, &'a mut Self> for AccessOp<O, P>
190where
191    T: Number,
192    O: ReadOp<P, T> + Write<P, T>,
193    P: PlatformInstance,
194    Self: AccessMut<T>,
195    &'a mut Self: AccessMut<T>,
196{
197    fn borrow_mut(&'a mut self) -> &'a mut Self {
198        self
199    }
200}
201
202impl<O, P> AccessOp<O, P> {
203    /// Convert the given [`AccessOp`] to a more general type of [`PlatformIntance`].
204    pub fn wrap<FO, FP>(access: AccessOp<FO, FP>) -> Self
205    where
206        FO: Into<O>,
207        FP: Into<P>,
208    {
209        Self {
210            op: access.op.into(),
211            platform: PhantomData,
212        }
213    }
214}
215
216impl<O, P> From<O> for AccessOp<O, P> {
217    fn from(op: O) -> Self {
218        Self {
219            op,
220            platform: PhantomData,
221        }
222    }
223}
224
225impl<O, P, T> Access<T> for AccessOp<O, P>
226where
227    T: Number,
228    O: ReadOp<P, T>,
229    P: PlatformInstance,
230{
231    fn read(&self) -> Result<BufferConverter<'static, T>, Error> {
232        self.op.enqueue().map(|buffer| buffer.into())
233    }
234
235    fn read_value(&self, offset: usize) -> Result<T, Error> {
236        self.op.read_value(offset)
237    }
238
239    fn size(&self) -> usize {
240        self.op.size()
241    }
242}
243
244impl<O, P, T> Access<T> for &AccessOp<O, P>
245where
246    T: Number,
247    O: ReadOp<P, T>,
248    P: PlatformInstance,
249    BufferConverter<'static, T>: From<O::Buffer>,
250{
251    fn read(&self) -> Result<BufferConverter<'static, T>, Error> {
252        self.op.enqueue().map(BufferConverter::from)
253    }
254
255    fn read_value(&self, offset: usize) -> Result<T, Error> {
256        self.op.read_value(offset)
257    }
258
259    fn size(&self) -> usize {
260        self.op.size()
261    }
262}
263
264impl<O, P, T> AccessMut<T> for AccessOp<O, P>
265where
266    T: Number,
267    O: ReadOp<P, T> + Write<P, T>,
268    P: PlatformInstance,
269    BufferConverter<'static, T>: From<O::Buffer>,
270{
271    fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
272        self.op.write(data)
273    }
274
275    fn write_value(&mut self, value: T) -> Result<(), Error> {
276        self.op.write_value(value)
277    }
278
279    fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
280        self.op.write_value_at(offset, value)
281    }
282}
283
284impl<O, P: fmt::Debug> fmt::Debug for AccessOp<O, P> {
285    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
286        write!(
287            f,
288            "access op {:?} on {:?}",
289            std::any::type_name::<O>(),
290            self.platform
291        )
292    }
293}
294
295/// A general-purpose implementor of [`Access`] used to elide recursive types.
296/// Uses an [`Arc`] so that cloning does not allocate.
297#[derive(Clone)]
298pub enum Accessor<'a, T: Number> {
299    Buffer(Arc<dyn BufferInstance<T> + 'a>),
300    Op(Arc<dyn ReadOp<Platform, T, Buffer = Buffer<T>> + 'a>),
301}
302
303impl<'a, T: Number> Access<T> for Accessor<'a, T> {
304    fn read(&self) -> Result<BufferConverter<'_, T>, Error> {
305        match self {
306            Self::Buffer(buf) => Ok(buf.read()),
307            Self::Op(op) => op.enqueue().map(BufferConverter::from),
308        }
309    }
310
311    fn read_value(&self, offset: usize) -> Result<T, Error> {
312        match self {
313            Self::Buffer(buf) => buf.read_value(offset),
314            Self::Op(op) => op.read_value(offset),
315        }
316    }
317
318    fn size(&self) -> usize {
319        match self {
320            Self::Buffer(buf) => buf.len(),
321            Self::Op(op) => op.size(),
322        }
323    }
324}
325
326impl<'a, T, B> From<AccessBuf<B>> for Accessor<'a, T>
327where
328    T: Number,
329    B: BufferInstance<T> + 'a,
330{
331    fn from(access: AccessBuf<B>) -> Self {
332        Self::Buffer(Arc::new(access.buffer))
333    }
334}
335
336impl<'a, T, O, P> From<AccessOp<O, P>> for Accessor<'a, T>
337where
338    T: Number,
339    O: ReadOp<Platform, T, Buffer = Buffer<T>> + 'a,
340    P: PlatformInstance + Into<Platform>,
341{
342    fn from(access: AccessOp<O, P>) -> Self {
343        let access: AccessOp<O, Platform> = AccessOp::wrap(access);
344        let op: Arc<dyn ReadOp<Platform, T, Buffer = Buffer<T>>> = Arc::new(access.op);
345        Self::Op(op)
346    }
347}