ha_ndarray/
buffer.rs

1use std::fmt;
2
3#[cfg(feature = "stream")]
4use destream::{de, en};
5use get_size::GetSize;
6
7#[cfg(feature = "opencl")]
8use crate::opencl;
9use crate::{host, CType, Error};
10
11/// A data buffer
12pub trait BufferInstance<T: CType>: Send + Sync {
13    /// Borrow this buffer as a [`BufferConverter`].
14    fn read(&self) -> BufferConverter<T>;
15
16    /// Read a single value in this buffer.
17    fn read_value(&self, offset: usize) -> Result<T, Error>;
18
19    /// Return the length of this buffer.
20    fn len(&self) -> usize;
21}
22
23/// A mutable data buffer
24pub trait BufferMut<T: CType>: BufferInstance<T> + fmt::Debug {
25    #[cfg(feature = "opencl")]
26    /// Borrow this buffer as an [`ocl::Buffer`], or return an error if this not an OpenCL buffer.
27    fn cl(&mut self) -> Result<&mut ocl::Buffer<T>, Error> {
28        Err(Error::Unsupported(format!(
29            "not an OpenCL buffer: {self:?}"
30        )))
31    }
32
33    /// Overwrite this buffer.
34    fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error>;
35
36    /// Overwrite this buffer with a single value.
37    fn write_value(&mut self, value: T) -> Result<(), Error>;
38
39    /// Overwrite a single value in this buffer.
40    fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error>;
41}
42
43/// A general-purpose buffer which can represent a buffer on any supported platform.
44#[derive(Clone)]
45pub enum Buffer<T: CType> {
46    #[cfg(feature = "opencl")]
47    CL(ocl::Buffer<T>),
48    Host(host::Buffer<T>),
49}
50
51impl<T: CType> Buffer<T> {
52    /// Construct a new [`Buffer`] from a slice of data.
53    pub fn from_slice(slice: &[T]) -> Result<Self, Error> {
54        BufferConverter::from(slice).into_buffer()
55    }
56}
57
58impl<T: CType> GetSize for Buffer<T> {
59    fn get_size(&self) -> usize {
60        self.len() * std::mem::size_of::<T>()
61    }
62}
63
64impl<T: CType> BufferInstance<T> for Buffer<T> {
65    fn read(&self) -> BufferConverter<T> {
66        BufferConverter::from(self)
67    }
68
69    fn read_value(&self, offset: usize) -> Result<T, Error> {
70        match self {
71            #[cfg(feature = "opencl")]
72            Self::CL(buf) => buf.read_value(offset),
73            Self::Host(buf) => buf.read_value(offset),
74        }
75    }
76
77    fn len(&self) -> usize {
78        match self {
79            #[cfg(feature = "opencl")]
80            Self::CL(buf) => buf.len(),
81            Self::Host(buf) => buf.len(),
82        }
83    }
84}
85
86impl<T: CType> BufferMut<T> for Buffer<T> {
87    #[cfg(feature = "opencl")]
88    fn cl(&mut self) -> Result<&mut ocl::Buffer<T>, Error> {
89        match self {
90            #[cfg(feature = "opencl")]
91            Self::CL(buf) => buf.cl(),
92            Self::Host(buf) => buf.cl(),
93        }
94    }
95
96    fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
97        match self {
98            #[cfg(feature = "opencl")]
99            Self::CL(buf) => buf.write(data),
100            Self::Host(buf) => buf.write(data),
101        }
102    }
103
104    fn write_value(&mut self, value: T) -> Result<(), Error> {
105        match self {
106            #[cfg(feature = "opencl")]
107            Self::CL(buf) => buf.write_value(value),
108            Self::Host(buf) => buf.write_value(value),
109        }
110    }
111
112    fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
113        match self {
114            #[cfg(feature = "opencl")]
115            Self::CL(buf) => buf.write_value_at(offset, value),
116            Self::Host(buf) => buf.write_value_at(offset, value),
117        }
118    }
119}
120
121impl<'a, T: CType> BufferInstance<T> for &'a Buffer<T> {
122    fn read(&self) -> BufferConverter<T> {
123        BufferConverter::from(*self)
124    }
125
126    fn read_value(&self, offset: usize) -> Result<T, Error> {
127        BufferInstance::read_value(*self, offset)
128    }
129
130    fn len(&self) -> usize {
131        BufferInstance::len(*self)
132    }
133}
134
135impl<'a, T: CType> BufferInstance<T> for &'a mut Buffer<T> {
136    fn read(&self) -> BufferConverter<T> {
137        BufferConverter::from(&**self)
138    }
139
140    fn read_value(&self, offset: usize) -> Result<T, Error> {
141        BufferInstance::read_value(&**self, offset)
142    }
143
144    fn len(&self) -> usize {
145        BufferInstance::len(*self)
146    }
147}
148
149impl<'a, T: CType> BufferMut<T> for &'a mut Buffer<T> {
150    #[cfg(feature = "opencl")]
151    fn cl(&mut self) -> Result<&mut ocl::Buffer<T>, Error> {
152        Buffer::<T>::cl(&mut **self)
153    }
154
155    fn write<'b>(&mut self, data: BufferConverter<'b, T>) -> Result<(), Error> {
156        Buffer::<T>::write(*self, data)
157    }
158
159    fn write_value(&mut self, value: T) -> Result<(), Error> {
160        Buffer::<T>::write_value(*self, value)
161    }
162
163    fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
164        Buffer::<T>::write_value_at(*self, offset, value)
165    }
166}
167
168#[cfg(feature = "freqfs")]
169impl<FE: Send + Sync, T: CType> BufferInstance<T> for freqfs::FileReadGuardOwned<FE, Buffer<T>> {
170    fn read(&self) -> BufferConverter<T> {
171        BufferInstance::read(&**self)
172    }
173
174    fn len(&self) -> usize {
175        BufferInstance::len(&**self)
176    }
177
178    fn read_value(&self, offset: usize) -> Result<T, Error> {
179        BufferInstance::read_value(&**self, offset)
180    }
181}
182
183#[cfg(feature = "freqfs")]
184impl<FE: Send + Sync, T: CType> BufferInstance<T> for freqfs::FileWriteGuardOwned<FE, Buffer<T>> {
185    fn read(&self) -> BufferConverter<T> {
186        BufferInstance::read(&**self)
187    }
188
189    fn len(&self) -> usize {
190        BufferInstance::len(&**self)
191    }
192
193    fn read_value(&self, offset: usize) -> Result<T, Error> {
194        BufferInstance::read_value(&**self, offset)
195    }
196}
197
198#[cfg(feature = "freqfs")]
199impl<FE: Send + Sync, T: CType> BufferMut<T> for freqfs::FileWriteGuardOwned<FE, Buffer<T>> {
200    #[cfg(feature = "opencl")]
201    fn cl(&mut self) -> Result<&mut ocl::Buffer<T>, Error> {
202        BufferMut::cl(&mut **self)
203    }
204
205    fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
206        BufferMut::write(&mut **self, data)
207    }
208
209    fn write_value(&mut self, value: T) -> Result<(), Error> {
210        BufferMut::write_value(&mut **self, value)
211    }
212
213    fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
214        BufferMut::write_value_at(&mut **self, offset, value)
215    }
216}
217
218#[cfg(feature = "opencl")]
219impl<T: CType> From<ocl::Buffer<T>> for Buffer<T> {
220    fn from(buf: ocl::Buffer<T>) -> Self {
221        Self::CL(buf)
222    }
223}
224
225impl<T: CType> From<host::StackVec<T>> for Buffer<T> {
226    fn from(buf: host::StackVec<T>) -> Self {
227        Self::Host(buf.into())
228    }
229}
230
231impl<T: CType> From<Vec<T>> for Buffer<T> {
232    fn from(buf: Vec<T>) -> Self {
233        Self::Host(buf.into())
234    }
235}
236
237impl<T: CType> From<host::Buffer<T>> for Buffer<T> {
238    fn from(buf: host::Buffer<T>) -> Self {
239        Self::Host(buf)
240    }
241}
242
243#[derive(Clone)]
244/// A sequence of elements in a single contiguous block of memory
245pub enum BufferConverter<'a, T: CType> {
246    #[cfg(feature = "opencl")]
247    CL(opencl::CLConverter<'a, T>),
248    Host(host::SliceConverter<'a, T>),
249}
250
251impl<'a, T: CType> BufferConverter<'a, T> {
252    /// Return an owned [`Buffer`], allocating memory only if this [`BufferConverter`]'s data is borrowed.
253    pub fn into_buffer(self) -> Result<Buffer<T>, Error> {
254        match self {
255            #[cfg(feature = "opencl")]
256            Self::CL(buffer) => buffer.into_buffer().map(Buffer::CL),
257            Self::Host(buffer) => Ok(Buffer::Host(buffer.into_buffer())),
258        }
259    }
260
261    /// Return the number of elements in this [`Buffer`].
262    pub fn len(&self) -> usize {
263        match self {
264            #[cfg(feature = "opencl")]
265            Self::CL(buffer) => buffer.len(),
266            Self::Host(buffer) => buffer.len(),
267        }
268    }
269
270    #[cfg(feature = "opencl")]
271    /// Ensure that this [`Buffer`] is in OpenCL memory by making a copy if necessary.
272    pub fn to_cl(self) -> Result<opencl::CLConverter<'a, T>, ocl::Error> {
273        match self {
274            Self::CL(buffer) => Ok(buffer),
275            Self::Host(buffer) => {
276                opencl::OpenCL::copy_into_buffer(buffer.as_ref()).map(opencl::CLConverter::Owned)
277            }
278        }
279    }
280
281    /// Ensure that this buffer is in host memory by making a copy if necessary.
282    #[inline]
283    pub fn to_slice(self) -> Result<host::SliceConverter<'a, T>, Error> {
284        match self {
285            #[cfg(feature = "opencl")]
286            Self::CL(buffer) => {
287                let mut copy = vec![T::default(); buffer.len()];
288                buffer.read(&mut copy[..]).enq()?;
289                Ok(host::SliceConverter::from(copy))
290            }
291            Self::Host(buffer) => Ok(buffer),
292        }
293    }
294}
295
296impl<T: CType> From<Buffer<T>> for BufferConverter<'static, T> {
297    fn from(buf: Buffer<T>) -> Self {
298        match buf {
299            #[cfg(feature = "opencl")]
300            Buffer::CL(buf) => Self::CL(buf.into()),
301            Buffer::Host(buf) => Self::Host(buf.into()),
302        }
303    }
304}
305
306impl<'a, T: CType> From<&'a Buffer<T>> for BufferConverter<'a, T> {
307    fn from(buf: &'a Buffer<T>) -> Self {
308        match buf {
309            #[cfg(feature = "opencl")]
310            Buffer::CL(buf) => Self::CL(buf.into()),
311            Buffer::Host(buf) => Self::Host(buf.into()),
312        }
313    }
314}
315
316impl<T: CType> From<Vec<T>> for BufferConverter<'static, T> {
317    fn from(buf: Vec<T>) -> Self {
318        Self::Host(buf.into())
319    }
320}
321
322impl<T: CType> From<host::StackVec<T>> for BufferConverter<'static, T> {
323    fn from(buf: host::StackVec<T>) -> Self {
324        Self::Host(buf.into())
325    }
326}
327
328impl<T: CType> From<host::Buffer<T>> for BufferConverter<'static, T> {
329    fn from(buf: host::Buffer<T>) -> Self {
330        Self::Host(buf.into())
331    }
332}
333
334impl<'a, T: CType> From<&'a [T]> for BufferConverter<'a, T> {
335    fn from(buf: &'a [T]) -> Self {
336        Self::Host(buf.into())
337    }
338}
339
340#[cfg(feature = "opencl")]
341impl<T: CType> From<ocl::Buffer<T>> for BufferConverter<'static, T> {
342    fn from(buf: ocl::Buffer<T>) -> Self {
343        Self::CL(buf.into())
344    }
345}
346
347#[cfg(feature = "opencl")]
348impl<'a, T: CType> From<&'a ocl::Buffer<T>> for BufferConverter<'a, T> {
349    fn from(buf: &'a ocl::Buffer<T>) -> Self {
350        Self::CL(buf.into())
351    }
352}
353
354#[cfg(feature = "stream")]
355struct BufferVisitor<T> {
356    data: Vec<T>,
357}
358
359#[cfg(feature = "stream")]
360impl<T> BufferVisitor<T> {
361    fn new() -> Self {
362        Self { data: Vec::new() }
363    }
364}
365
366#[cfg(feature = "stream")]
367macro_rules! decode_buffer {
368    ($t:ty, $name:expr, $decode:ident, $visit:ident, $encode:ident) => {
369        #[async_trait::async_trait]
370        impl de::Visitor for BufferVisitor<$t> {
371            type Value = Buffer<$t>;
372
373            fn expecting() -> &'static str {
374                $name
375            }
376
377            async fn $visit<A: de::ArrayAccess<$t>>(
378                self,
379                mut array: A,
380            ) -> Result<Self::Value, A::Error> {
381                use crate::platform::{Convert, PlatformInstance};
382
383                const BUF_SIZE: usize = 4_096;
384                let mut data = self.data;
385
386                let mut buf = [<$t>::ZERO; BUF_SIZE];
387                loop {
388                    let len = array.buffer(&mut buf).await?;
389                    if len == 0 {
390                        break;
391                    } else {
392                        data.extend_from_slice(&buf[..len]);
393                    }
394                }
395
396                crate::Platform::select(data.len())
397                    .convert(data.into())
398                    .map_err(de::Error::custom)
399            }
400        }
401
402        #[async_trait::async_trait]
403        impl de::FromStream for Buffer<$t> {
404            type Context = ();
405
406            async fn from_stream<D: de::Decoder>(
407                _cxt: (),
408                decoder: &mut D,
409            ) -> Result<Self, D::Error> {
410                decoder.$decode(BufferVisitor::<$t>::new()).await
411            }
412        }
413
414        impl<'en> en::ToStream<'en> for Buffer<$t> {
415            fn to_stream<E: en::Encoder<'en>>(&'en self, encoder: E) -> Result<E::Ok, E::Error> {
416                match self {
417                    Self::Host(buffer) => {
418                        let fut = futures::future::ready(buffer.to_vec());
419                        let stream = futures::stream::once(fut);
420                        encoder.$encode(stream)
421                    }
422                    #[cfg(feature = "opencl")]
423                    Self::CL(buffer) => {
424                        let mut data = Vec::with_capacity(buffer.len());
425                        buffer.read(&mut data).enq().map_err(en::Error::custom)?;
426                        encoder.$encode(futures::stream::once(futures::future::ready(data)))
427                    }
428                }
429            }
430        }
431
432        impl<'en> en::IntoStream<'en> for Buffer<$t> {
433            fn into_stream<E: en::Encoder<'en>>(self, encoder: E) -> Result<E::Ok, E::Error> {
434                match self {
435                    Self::Host(buffer) => {
436                        let buffer = buffer.to_vec();
437                        encoder.$encode(futures::stream::once(futures::future::ready(buffer)))
438                    }
439                    #[cfg(feature = "opencl")]
440                    Self::CL(buffer) => {
441                        let mut data = Vec::with_capacity(buffer.len());
442                        buffer.read(&mut data).enq().map_err(en::Error::custom)?;
443                        encoder.$encode(futures::stream::once(futures::future::ready(data)))
444                    }
445                }
446            }
447        }
448    };
449}
450
451#[cfg(feature = "stream")]
452decode_buffer!(
453    u8,
454    "byte array",
455    decode_array_u8,
456    visit_array_u8,
457    encode_array_u8
458);
459
460#[cfg(feature = "stream")]
461decode_buffer!(
462    u16,
463    "16-bit unsigned int array",
464    decode_array_u16,
465    visit_array_u16,
466    encode_array_u16
467);
468
469#[cfg(feature = "stream")]
470decode_buffer!(
471    u32,
472    "32-bit unsigned int array",
473    decode_array_u32,
474    visit_array_u32,
475    encode_array_u32
476);
477
478#[cfg(feature = "stream")]
479decode_buffer!(
480    u64,
481    "64-bit unsigned int array",
482    decode_array_u64,
483    visit_array_u64,
484    encode_array_u64
485);
486
487#[cfg(feature = "stream")]
488decode_buffer!(
489    i16,
490    "16-bit int array",
491    decode_array_i16,
492    visit_array_i16,
493    encode_array_i16
494);
495
496#[cfg(feature = "stream")]
497decode_buffer!(
498    i32,
499    "32-bit int array",
500    decode_array_i32,
501    visit_array_i32,
502    encode_array_i32
503);
504
505#[cfg(feature = "stream")]
506decode_buffer!(
507    i64,
508    "64-bit int array",
509    decode_array_i64,
510    visit_array_i64,
511    encode_array_i64
512);
513
514#[cfg(feature = "stream")]
515decode_buffer!(
516    f32,
517    "32-bit int array",
518    decode_array_f32,
519    visit_array_f32,
520    encode_array_f32
521);
522
523#[cfg(feature = "stream")]
524decode_buffer!(
525    f64,
526    "64-bit int array",
527    decode_array_f64,
528    visit_array_f64,
529    encode_array_f64
530);
531
532impl<T: CType + fmt::Debug> fmt::Debug for Buffer<T> {
533    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
534        match self {
535            Self::Host(buffer) => fmt::Debug::fmt(buffer, f),
536            #[cfg(feature = "opencl")]
537            Self::CL(buffer) => fmt::Debug::fmt(buffer, f),
538        }
539    }
540}