ha_ndarray/
access.rs

1use std::borrow::{Borrow, BorrowMut};
2use std::fmt;
3use std::marker::PhantomData;
4use std::sync::Arc;
5
6use crate::buffer::{BufferConverter, BufferInstance, BufferMut};
7use crate::ops::{ReadOp, Write};
8use crate::platform::PlatformInstance;
9use crate::{Buffer, CType, Error, Platform};
10
11/// A type which allows accessing array data
12pub trait Access<T: CType>: Send + Sync {
13    /// Read the data of this accessor as a [`BufferConverter`].
14    fn read(&self) -> Result<BufferConverter<T>, Error>;
15
16    /// Access a single value.
17    fn read_value(&self, offset: usize) -> Result<T, Error>;
18
19    /// Return the data size.
20    fn size(&self) -> usize;
21}
22
23/// A type which allows accessing array data mutably
24pub trait AccessMut<T: CType>: Access<T> + fmt::Debug {
25    #[cfg(feature = "opencl")]
26    /// Borrow the array data as an [`ocl::Buffer`], or return an error if this not an OpenCL buffer.
27    fn cl_buffer(&mut self) -> Result<&mut ocl::Buffer<T>, Error> {
28        Err(Error::Unsupported(format!(
29            "not an OpenCL buffer: {self:?}"
30        )))
31    }
32
33    /// Overwrite these data with the given `data`.
34    fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error>;
35
36    /// Overwrite these data with a single value.
37    fn write_value(&mut self, value: T) -> Result<(), Error>;
38
39    /// Overwrite a single value.
40    fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error>;
41}
42
43/// A struct which provides n-dimensional access to an underlying [`BufferInstance`]
44pub struct AccessBuf<B> {
45    buffer: B,
46}
47
48impl<B: Clone> Clone for AccessBuf<B> {
49    fn clone(&self) -> Self {
50        Self {
51            buffer: self.buffer.clone(),
52        }
53    }
54}
55
56impl<B> AccessBuf<B> {
57    /// Construct an [`AccessBuf`] from a mutable reference to this buffer.
58    pub fn as_mut<RB: ?Sized>(&mut self) -> AccessBuf<&mut RB>
59    where
60        B: BorrowMut<RB>,
61    {
62        AccessBuf {
63            buffer: self.buffer.borrow_mut(),
64        }
65    }
66
67    /// Construct an [`AccessBuf`] from a reference to this buffer.
68    pub fn as_ref<RB: ?Sized>(&self) -> AccessBuf<&RB>
69    where
70        B: Borrow<RB>,
71    {
72        AccessBuf {
73            buffer: self.buffer.borrow(),
74        }
75    }
76
77    /// Borrow the underlying [`BufferInstance`] of this [`AccessBuf`].
78    pub fn inner(&self) -> &B {
79        &self.buffer
80    }
81
82    /// Borrow the underlying [`BufferInstance`] of this [`AccessBuf`] mutably.
83    pub fn inner_mut(&mut self) -> &mut B {
84        &mut self.buffer
85    }
86
87    /// Destructure this [`AccessBuf`] into its underlying [`BufferInstance`].
88    pub fn into_inner(self) -> B {
89        self.buffer
90    }
91}
92
93impl<B> From<B> for AccessBuf<B> {
94    fn from(buffer: B) -> Self {
95        Self { buffer }
96    }
97}
98
99impl<T, B> Access<T> for AccessBuf<B>
100where
101    T: CType,
102    B: BufferInstance<T>,
103{
104    fn read(&self) -> Result<BufferConverter<T>, Error> {
105        Ok(self.buffer.read())
106    }
107
108    fn read_value(&self, offset: usize) -> Result<T, Error> {
109        self.buffer.read_value(offset)
110    }
111
112    fn size(&self) -> usize {
113        self.buffer.len()
114    }
115}
116
117impl<T, B> AccessMut<T> for AccessBuf<B>
118where
119    T: CType,
120    B: BufferMut<T>,
121{
122    #[cfg(feature = "opencl")]
123    fn cl_buffer(&mut self) -> Result<&mut ocl::Buffer<T>, Error> {
124        self.buffer.cl()
125    }
126
127    fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
128        self.buffer.write(data)
129    }
130
131    fn write_value(&mut self, value: T) -> Result<(), Error> {
132        self.buffer.write_value(value)
133    }
134
135    fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
136        self.buffer.write_value_at(offset, value)
137    }
138}
139
140impl<B: fmt::Debug> fmt::Debug for AccessBuf<B> {
141    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
142        write!(f, "access {:?}", self.buffer)
143    }
144}
145
146/// A struct which provides n-dimensional access to the result of an array operation.
147pub struct AccessOp<O, P> {
148    op: O,
149    platform: PhantomData<P>,
150}
151
152impl<O, P> AccessOp<O, P> {
153    /// Convert the given [`AccessOp`] to a more general type of [`PlatformIntance`].
154    pub fn wrap<FO, FP>(access: AccessOp<FO, FP>) -> Self
155    where
156        FO: Into<O>,
157        FP: Into<P>,
158    {
159        Self {
160            op: access.op.into(),
161            platform: PhantomData,
162        }
163    }
164}
165
166impl<O, P> From<O> for AccessOp<O, P> {
167    fn from(op: O) -> Self {
168        Self {
169            op,
170            platform: PhantomData,
171        }
172    }
173}
174
175impl<O, P, T> Access<T> for AccessOp<O, P>
176where
177    T: CType,
178    O: ReadOp<P, T>,
179    P: PlatformInstance,
180{
181    fn read(&self) -> Result<BufferConverter<'static, T>, Error> {
182        self.op.enqueue().map(|buffer| buffer.into())
183    }
184
185    fn read_value(&self, offset: usize) -> Result<T, Error> {
186        self.op.read_value(offset)
187    }
188
189    fn size(&self) -> usize {
190        self.op.size()
191    }
192}
193
194impl<'a, O, P, T> Access<T> for &'a AccessOp<O, P>
195where
196    T: CType,
197    O: ReadOp<P, T>,
198    P: PlatformInstance,
199    BufferConverter<'static, T>: From<O::Buffer>,
200{
201    fn read(&self) -> Result<BufferConverter<'static, T>, Error> {
202        self.op.enqueue().map(BufferConverter::from)
203    }
204
205    fn read_value(&self, offset: usize) -> Result<T, Error> {
206        self.op.read_value(offset)
207    }
208
209    fn size(&self) -> usize {
210        self.op.size()
211    }
212}
213
214impl<O, P, T> AccessMut<T> for AccessOp<O, P>
215where
216    T: CType,
217    O: ReadOp<P, T> + Write<P, T>,
218    P: PlatformInstance,
219    BufferConverter<'static, T>: From<O::Buffer>,
220{
221    fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
222        self.op.write(data)
223    }
224
225    fn write_value(&mut self, value: T) -> Result<(), Error> {
226        self.op.write_value(value)
227    }
228
229    fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
230        self.op.write_value_at(offset, value)
231    }
232}
233
234impl<O, P: fmt::Debug> fmt::Debug for AccessOp<O, P> {
235    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
236        write!(
237            f,
238            "access op {:?} on {:?}",
239            std::any::type_name::<O>(),
240            self.platform
241        )
242    }
243}
244
245/// A general-purpose implementor of [`Access`] used to elide recursive types.
246/// Uses an [`Arc`] so that cloning does not allocate.
247#[derive(Clone)]
248pub enum Accessor<T: CType> {
249    Buffer(Arc<dyn BufferInstance<T>>),
250    Op(Arc<dyn ReadOp<Platform, T, Buffer = Buffer<T>>>),
251}
252
253impl<T: CType> Access<T> for Accessor<T> {
254    fn read(&self) -> Result<BufferConverter<T>, Error> {
255        match self {
256            Self::Buffer(buf) => Ok(buf.read()),
257            Self::Op(op) => op.enqueue().map(BufferConverter::from),
258        }
259    }
260
261    fn read_value(&self, offset: usize) -> Result<T, Error> {
262        match self {
263            Self::Buffer(buf) => buf.read_value(offset),
264            Self::Op(op) => op.read_value(offset),
265        }
266    }
267
268    fn size(&self) -> usize {
269        match self {
270            Self::Buffer(buf) => buf.len(),
271            Self::Op(op) => op.size(),
272        }
273    }
274}
275
276impl<T, B> From<AccessBuf<B>> for Accessor<T>
277where
278    T: CType,
279    B: BufferInstance<T> + 'static,
280{
281    fn from(access: AccessBuf<B>) -> Self {
282        Self::Buffer(Arc::new(access.buffer))
283    }
284}
285
286impl<T, O, P> From<AccessOp<O, P>> for Accessor<T>
287where
288    T: CType,
289    O: ReadOp<Platform, T, Buffer = Buffer<T>> + 'static,
290    P: PlatformInstance + Into<Platform>,
291{
292    fn from(access: AccessOp<O, P>) -> Self {
293        let access: AccessOp<O, Platform> = AccessOp::wrap(access);
294        let op: Arc<dyn ReadOp<Platform, T, Buffer = Buffer<T>>> = Arc::new(access.op);
295        Self::Op(op)
296    }
297}