Skip to main content

ha_ndarray/
buffer.rs

1use std::fmt;
2
3#[cfg(feature = "stream")]
4use destream::{de, en};
5use get_size::GetSize;
6
7#[cfg(feature = "opencl")]
8use crate::opencl;
9use crate::{host, Error, Number};
10
11/// A data buffer
12pub trait BufferInstance<T: Number>: Send + Sync {
13    /// Borrow this buffer as a [`BufferConverter`].
14    fn read(&self) -> BufferConverter<'_, T>;
15
16    /// Read a single value in this buffer.
17    fn read_value(&self, offset: usize) -> Result<T, Error>;
18
19    /// Return the length of this buffer.
20    fn len(&self) -> usize;
21
22    fn is_empty(&self) -> bool {
23        self.len() == 0
24    }
25}
26
27/// A mutable data buffer
28pub trait BufferMut<T: Number>: BufferInstance<T> + fmt::Debug {
29    #[cfg(feature = "opencl")]
30    /// Borrow this buffer as an [`ocl::Buffer`], or return an error if this not an OpenCL buffer.
31    fn cl(&mut self) -> Result<&mut ocl::Buffer<T>, Error> {
32        Err(Error::unsupported(format!(
33            "not an OpenCL buffer: {self:?}"
34        )))
35    }
36
37    /// Overwrite this buffer.
38    fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error>;
39
40    /// Overwrite this buffer with a single value.
41    fn write_value(&mut self, value: T) -> Result<(), Error>;
42
43    /// Overwrite a single value in this buffer.
44    fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error>;
45}
46
47/// A general-purpose buffer which can represent a buffer on any supported platform.
48#[derive(Clone)]
49pub enum Buffer<T: Number> {
50    #[cfg(feature = "opencl")]
51    CL(ocl::Buffer<T>),
52    Host(host::Buffer<T>),
53}
54
55impl<T: Number> Buffer<T> {
56    /// Construct a new [`Buffer`] from a slice of data.
57    pub fn from_slice(slice: &[T]) -> Result<Self, Error> {
58        BufferConverter::from(slice).into_buffer()
59    }
60}
61
62impl<T: Number> GetSize for Buffer<T> {
63    fn get_size(&self) -> usize {
64        self.len() * std::mem::size_of::<T>()
65    }
66}
67
68impl<T: Number> BufferInstance<T> for Buffer<T> {
69    fn read(&self) -> BufferConverter<'_, T> {
70        BufferConverter::from(self)
71    }
72
73    fn read_value(&self, offset: usize) -> Result<T, Error> {
74        match self {
75            #[cfg(feature = "opencl")]
76            Self::CL(buf) => buf.read_value(offset),
77            Self::Host(buf) => buf.read_value(offset),
78        }
79    }
80
81    fn len(&self) -> usize {
82        match self {
83            #[cfg(feature = "opencl")]
84            Self::CL(buf) => buf.len(),
85            Self::Host(buf) => buf.len(),
86        }
87    }
88}
89
90impl<T: Number> BufferMut<T> for Buffer<T> {
91    #[cfg(feature = "opencl")]
92    fn cl(&mut self) -> Result<&mut ocl::Buffer<T>, Error> {
93        match self {
94            #[cfg(feature = "opencl")]
95            Self::CL(buf) => BufferMut::<T>::cl(buf),
96            Self::Host(buf) => buf.cl(),
97        }
98    }
99
100    fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
101        match self {
102            #[cfg(feature = "opencl")]
103            Self::CL(buf) => buf.write(data),
104            Self::Host(buf) => buf.write(data),
105        }
106    }
107
108    fn write_value(&mut self, value: T) -> Result<(), Error> {
109        match self {
110            #[cfg(feature = "opencl")]
111            Self::CL(buf) => buf.write_value(value),
112            Self::Host(buf) => buf.write_value(value),
113        }
114    }
115
116    fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
117        match self {
118            #[cfg(feature = "opencl")]
119            Self::CL(buf) => buf.write_value_at(offset, value),
120            Self::Host(buf) => buf.write_value_at(offset, value),
121        }
122    }
123}
124
125impl<T: Number> BufferInstance<T> for &Buffer<T> {
126    fn read(&self) -> BufferConverter<'_, T> {
127        BufferConverter::from(*self)
128    }
129
130    fn read_value(&self, offset: usize) -> Result<T, Error> {
131        BufferInstance::read_value(*self, offset)
132    }
133
134    fn len(&self) -> usize {
135        BufferInstance::len(*self)
136    }
137}
138
139impl<T: Number> BufferInstance<T> for &mut Buffer<T> {
140    fn read(&self) -> BufferConverter<'_, T> {
141        BufferConverter::from(&**self)
142    }
143
144    fn read_value(&self, offset: usize) -> Result<T, Error> {
145        BufferInstance::read_value(&**self, offset)
146    }
147
148    fn len(&self) -> usize {
149        BufferInstance::len(*self)
150    }
151}
152
153impl<T: Number> BufferMut<T> for &mut Buffer<T> {
154    #[cfg(feature = "opencl")]
155    fn cl(&mut self) -> Result<&mut ocl::Buffer<T>, Error> {
156        Buffer::<T>::cl(&mut **self)
157    }
158
159    fn write<'b>(&mut self, data: BufferConverter<'b, T>) -> Result<(), Error> {
160        Buffer::<T>::write(*self, data)
161    }
162
163    fn write_value(&mut self, value: T) -> Result<(), Error> {
164        Buffer::<T>::write_value(*self, value)
165    }
166
167    fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
168        Buffer::<T>::write_value_at(*self, offset, value)
169    }
170}
171
172#[cfg(feature = "freqfs")]
173impl<FE: Send + Sync, T: Number> BufferInstance<T> for freqfs::FileReadGuardOwned<FE, Buffer<T>> {
174    fn read(&self) -> BufferConverter<'_, T> {
175        BufferInstance::read(&**self)
176    }
177
178    fn len(&self) -> usize {
179        BufferInstance::len(&**self)
180    }
181
182    fn read_value(&self, offset: usize) -> Result<T, Error> {
183        BufferInstance::read_value(&**self, offset)
184    }
185}
186
187#[cfg(feature = "freqfs")]
188impl<FE: Send + Sync, T: Number> BufferInstance<T> for freqfs::FileWriteGuardOwned<FE, Buffer<T>> {
189    fn read(&self) -> BufferConverter<'_, T> {
190        BufferInstance::read(&**self)
191    }
192
193    fn len(&self) -> usize {
194        BufferInstance::len(&**self)
195    }
196
197    fn read_value(&self, offset: usize) -> Result<T, Error> {
198        BufferInstance::read_value(&**self, offset)
199    }
200}
201
202#[cfg(feature = "freqfs")]
203impl<FE: Send + Sync, T: Number> BufferMut<T> for freqfs::FileWriteGuardOwned<FE, Buffer<T>> {
204    #[cfg(feature = "opencl")]
205    fn cl(&mut self) -> Result<&mut ocl::Buffer<T>, Error> {
206        BufferMut::cl(&mut **self)
207    }
208
209    fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
210        BufferMut::write(&mut **self, data)
211    }
212
213    fn write_value(&mut self, value: T) -> Result<(), Error> {
214        BufferMut::write_value(&mut **self, value)
215    }
216
217    fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
218        BufferMut::write_value_at(&mut **self, offset, value)
219    }
220}
221
222#[cfg(feature = "opencl")]
223impl<T: Number> From<ocl::Buffer<T>> for Buffer<T> {
224    fn from(buf: ocl::Buffer<T>) -> Self {
225        Self::CL(buf)
226    }
227}
228
229impl<T: Number> From<host::StackVec<T>> for Buffer<T> {
230    fn from(buf: host::StackVec<T>) -> Self {
231        Self::Host(buf.into())
232    }
233}
234
235impl<T: Number> From<Vec<T>> for Buffer<T> {
236    fn from(buf: Vec<T>) -> Self {
237        Self::Host(buf.into())
238    }
239}
240
241impl<T: Number> From<host::Buffer<T>> for Buffer<T> {
242    fn from(buf: host::Buffer<T>) -> Self {
243        Self::Host(buf)
244    }
245}
246
247#[derive(Clone)]
248/// A sequence of elements in a single contiguous block of memory
249pub enum BufferConverter<'a, T: Number> {
250    #[cfg(feature = "opencl")]
251    CL(opencl::CLConverter<'a, T>),
252    Host(host::SliceConverter<'a, T>),
253}
254
255impl<'a, T: Number> BufferConverter<'a, T> {
256    /// Return an owned [`Buffer`], allocating memory only if this [`BufferConverter`]'s data is borrowed.
257    pub fn into_buffer(self) -> Result<Buffer<T>, Error> {
258        match self {
259            #[cfg(feature = "opencl")]
260            Self::CL(buffer) => buffer.into_buffer().map(Buffer::CL),
261            Self::Host(buffer) => Ok(Buffer::Host(buffer.into_buffer())),
262        }
263    }
264
265    /// Return the number of elements in this [`Buffer`].
266    pub fn len(&self) -> usize {
267        match self {
268            #[cfg(feature = "opencl")]
269            Self::CL(buffer) => buffer.len(),
270            Self::Host(buffer) => buffer.len(),
271        }
272    }
273
274    pub fn is_empty(&self) -> bool {
275        self.len() == 0
276    }
277
278    #[cfg(feature = "opencl")]
279    /// Ensure that this [`Buffer`] is in OpenCL memory by making a copy if necessary.
280    pub fn to_cl(self) -> Result<opencl::CLConverter<'a, T>, ocl::Error> {
281        match self {
282            Self::CL(buffer) => Ok(buffer),
283            Self::Host(buffer) => {
284                opencl::OpenCL::copy_into_buffer::<T>(&buffer).map(opencl::CLConverter::Owned)
285            }
286        }
287    }
288
289    /// Ensure that this buffer is in host memory by making a copy if necessary.
290    #[inline]
291    pub fn to_slice(self) -> Result<host::SliceConverter<'a, T>, Error> {
292        match self {
293            #[cfg(feature = "opencl")]
294            Self::CL(buffer) => {
295                let mut copy = vec![T::default(); buffer.len()];
296                buffer.read(&mut copy[..]).enq()?;
297
298                let copy = copy.into_iter().collect::<Vec<T>>();
299                Ok(host::SliceConverter::from(copy))
300            }
301            Self::Host(buffer) => Ok(buffer),
302        }
303    }
304}
305
306impl<T: Number> From<Buffer<T>> for BufferConverter<'static, T> {
307    fn from(buf: Buffer<T>) -> Self {
308        match buf {
309            #[cfg(feature = "opencl")]
310            Buffer::CL(buf) => Self::CL(buf.into()),
311            Buffer::Host(buf) => Self::Host(buf.into()),
312        }
313    }
314}
315
316impl<'a, T: Number> From<&'a Buffer<T>> for BufferConverter<'a, T> {
317    fn from(buf: &'a Buffer<T>) -> Self {
318        match buf {
319            #[cfg(feature = "opencl")]
320            Buffer::CL(buf) => Self::CL(buf.into()),
321            Buffer::Host(buf) => Self::Host(buf.into()),
322        }
323    }
324}
325
326impl<T: Number> From<Vec<T>> for BufferConverter<'static, T> {
327    fn from(buf: Vec<T>) -> Self {
328        Self::Host(buf.into())
329    }
330}
331
332impl<T: Number> From<host::StackVec<T>> for BufferConverter<'static, T> {
333    fn from(buf: host::StackVec<T>) -> Self {
334        Self::Host(buf.into())
335    }
336}
337
338impl<T: Number> From<host::Buffer<T>> for BufferConverter<'static, T> {
339    fn from(buf: host::Buffer<T>) -> Self {
340        Self::Host(buf.into())
341    }
342}
343
344impl<'a, T: Number> From<&'a [T]> for BufferConverter<'a, T> {
345    fn from(buf: &'a [T]) -> Self {
346        Self::Host(buf.into())
347    }
348}
349
350#[cfg(feature = "opencl")]
351impl<T: Number> From<ocl::Buffer<T>> for BufferConverter<'static, T> {
352    fn from(buf: ocl::Buffer<T>) -> Self {
353        Self::CL(buf.into())
354    }
355}
356
357#[cfg(feature = "opencl")]
358impl<'a, T: Number> From<&'a ocl::Buffer<T>> for BufferConverter<'a, T> {
359    fn from(buf: &'a ocl::Buffer<T>) -> Self {
360        Self::CL(buf.into())
361    }
362}
363
364#[cfg(feature = "stream")]
365struct BufferVisitor<T> {
366    data: Vec<T>,
367}
368
369#[cfg(feature = "stream")]
370impl<T> BufferVisitor<T> {
371    fn new() -> Self {
372        Self { data: Vec::new() }
373    }
374}
375
376#[cfg(feature = "stream")]
377macro_rules! decode_buffer {
378    ($t:ty, $name:expr, $decode:ident, $visit:ident, $encode:ident) => {
379        impl de::Visitor for BufferVisitor<$t> {
380            type Value = Buffer<$t>;
381
382            fn expecting() -> &'static str {
383                $name
384            }
385
386            async fn $visit<A: de::ArrayAccess<$t>>(
387                self,
388                mut array: A,
389            ) -> Result<Self::Value, A::Error> {
390                use crate::platform::{Convert, PlatformInstance};
391
392                const BUF_SIZE: usize = 4_096;
393                let mut data = self.data;
394
395                let mut buf = [<$t>::ZERO; BUF_SIZE];
396                loop {
397                    let len = array.buffer(&mut buf).await?;
398                    if len == 0 {
399                        break;
400                    } else {
401                        data.extend_from_slice(&buf[..len]);
402                    }
403                }
404
405                crate::Platform::select(data.len())
406                    .convert(data.into())
407                    .map_err(de::Error::custom)
408            }
409        }
410
411        impl de::FromStream for Buffer<$t> {
412            type Context = ();
413
414            async fn from_stream<D: de::Decoder>(
415                _cxt: (),
416                decoder: &mut D,
417            ) -> Result<Self, D::Error> {
418                decoder.$decode(BufferVisitor::<$t>::new()).await
419            }
420        }
421
422        impl<'en> en::ToStream<'en> for Buffer<$t> {
423            fn to_stream<E: en::Encoder<'en>>(&'en self, encoder: E) -> Result<E::Ok, E::Error> {
424                match self {
425                    Self::Host(buffer) => {
426                        let fut = futures::future::ready(buffer.to_vec());
427                        let stream = futures::stream::once(fut);
428                        encoder.$encode(stream)
429                    }
430                    #[cfg(feature = "opencl")]
431                    Self::CL(buffer) => {
432                        let mut data = Vec::with_capacity(buffer.len());
433                        buffer.read(&mut data).enq().map_err(en::Error::custom)?;
434                        encoder.$encode(futures::stream::once(futures::future::ready(data)))
435                    }
436                }
437            }
438        }
439
440        impl<'en> en::IntoStream<'en> for Buffer<$t> {
441            fn into_stream<E: en::Encoder<'en>>(self, encoder: E) -> Result<E::Ok, E::Error> {
442                match self {
443                    Self::Host(buffer) => {
444                        let buffer = buffer.to_vec();
445                        encoder.$encode(futures::stream::once(futures::future::ready(buffer)))
446                    }
447                    #[cfg(feature = "opencl")]
448                    Self::CL(buffer) => {
449                        let mut data = Vec::with_capacity(buffer.len());
450                        buffer.read(&mut data).enq().map_err(en::Error::custom)?;
451                        encoder.$encode(futures::stream::once(futures::future::ready(data)))
452                    }
453                }
454            }
455        }
456    };
457}
458
459#[cfg(feature = "stream")]
460decode_buffer!(
461    u8,
462    "byte array",
463    decode_array_u8,
464    visit_array_u8,
465    encode_array_u8
466);
467
468#[cfg(feature = "stream")]
469decode_buffer!(
470    u16,
471    "16-bit unsigned int array",
472    decode_array_u16,
473    visit_array_u16,
474    encode_array_u16
475);
476
477#[cfg(feature = "stream")]
478decode_buffer!(
479    u32,
480    "32-bit unsigned int array",
481    decode_array_u32,
482    visit_array_u32,
483    encode_array_u32
484);
485
486#[cfg(feature = "stream")]
487decode_buffer!(
488    u64,
489    "64-bit unsigned int array",
490    decode_array_u64,
491    visit_array_u64,
492    encode_array_u64
493);
494
495#[cfg(feature = "stream")]
496decode_buffer!(
497    i16,
498    "16-bit int array",
499    decode_array_i16,
500    visit_array_i16,
501    encode_array_i16
502);
503
504#[cfg(feature = "stream")]
505decode_buffer!(
506    i32,
507    "32-bit int array",
508    decode_array_i32,
509    visit_array_i32,
510    encode_array_i32
511);
512
513#[cfg(feature = "stream")]
514decode_buffer!(
515    i64,
516    "64-bit int array",
517    decode_array_i64,
518    visit_array_i64,
519    encode_array_i64
520);
521
522#[cfg(feature = "stream")]
523decode_buffer!(
524    f32,
525    "32-bit int array",
526    decode_array_f32,
527    visit_array_f32,
528    encode_array_f32
529);
530
531#[cfg(feature = "stream")]
532decode_buffer!(
533    f64,
534    "64-bit int array",
535    decode_array_f64,
536    visit_array_f64,
537    encode_array_f64
538);
539
540impl<T: Number + fmt::Debug> fmt::Debug for Buffer<T> {
541    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
542        match self {
543            Self::Host(buffer) => fmt::Debug::fmt(buffer, f),
544            #[cfg(feature = "opencl")]
545            Self::CL(buffer) => fmt::Debug::fmt(buffer, f),
546        }
547    }
548}