ha_ndarray/host/
buffer.rs

1use std::borrow::Borrow;
2use std::fmt;
3use std::ops::Deref;
4
5use smallvec::SmallVec;
6
7use crate::buffer::{BufferConverter, BufferInstance, BufferMut};
8use crate::{CType, Error};
9
10use super::VEC_MIN_SIZE;
11
12/// A stack-allocated buffer.
13pub type StackVec<T> = SmallVec<[T; VEC_MIN_SIZE]>;
14
15impl<T: CType> BufferInstance<T> for StackVec<T> {
16    fn read(&self) -> BufferConverter<T> {
17        self.as_slice().into()
18    }
19
20    fn read_value(&self, offset: usize) -> Result<T, Error> {
21        BufferInstance::read_value(&self.as_slice(), offset)
22    }
23
24    fn len(&self) -> usize {
25        StackVec::len(self)
26    }
27}
28
29impl<T: CType> BufferMut<T> for StackVec<T> {
30    fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
31        self.as_mut_slice().write(data)
32    }
33
34    fn write_value(&mut self, value: T) -> Result<(), Error> {
35        self.as_mut_slice().write_value(value)
36    }
37
38    fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
39        self.as_mut_slice().write_value_at(offset, value)
40    }
41}
42
43impl<T: CType> BufferInstance<T> for Vec<T> {
44    fn read(&self) -> BufferConverter<T> {
45        self.as_slice().into()
46    }
47
48    fn read_value(&self, offset: usize) -> Result<T, Error> {
49        BufferInstance::read_value(&self.as_slice(), offset)
50    }
51
52    fn len(&self) -> usize {
53        Vec::len(self)
54    }
55}
56
57impl<T: CType> BufferMut<T> for Vec<T> {
58    fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
59        self.as_mut_slice().write(data)
60    }
61
62    fn write_value(&mut self, value: T) -> Result<(), Error> {
63        self.as_mut_slice().write_value(value)
64    }
65
66    fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
67        self.as_mut_slice().write_value_at(offset, value)
68    }
69}
70
71impl<'a, T: CType> BufferInstance<T> for &'a [T] {
72    fn read(&self) -> BufferConverter<T> {
73        (*self).into()
74    }
75
76    fn read_value(&self, offset: usize) -> Result<T, Error> {
77        self.get(offset).copied().ok_or_else(|| {
78            Error::Bounds(format!(
79                "invalid offset {offset} for a buffer of length {}",
80                self.len()
81            ))
82        })
83    }
84
85    fn len(&self) -> usize {
86        <[T]>::len(self)
87    }
88}
89
90impl<'a, T: CType> BufferInstance<T> for &'a mut [T] {
91    fn read(&self) -> BufferConverter<T> {
92        (&**self).into()
93    }
94
95    fn read_value(&self, offset: usize) -> Result<T, Error> {
96        BufferInstance::read_value(&&**self, offset)
97    }
98
99    fn len(&self) -> usize {
100        <[T]>::len(self)
101    }
102}
103
104impl<'a, T: CType> BufferMut<T> for &'a mut [T] {
105    fn write<'b>(&mut self, data: BufferConverter<'b, T>) -> Result<(), Error> {
106        if data.len() == self.len() {
107            let data = data.to_slice()?;
108            self.copy_from_slice(&*data);
109            Ok(())
110        } else {
111            Err(Error::Bounds(format!(
112                "cannot overwrite a buffer of size {} with one of size {}",
113                self.len(),
114                data.len()
115            )))
116        }
117    }
118
119    fn write_value(&mut self, value: T) -> Result<(), Error> {
120        self.fill(value);
121        Ok(())
122    }
123
124    fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
125        if offset < self.len() {
126            self[offset] = value;
127            Ok(())
128        } else {
129            Err(Error::Bounds(format!(
130                "invalid offset {offset} for a buffer of length {}",
131                self.len()
132            )))
133        }
134    }
135}
136
137#[derive(Clone, Eq, PartialEq, Debug)]
138pub enum Buffer<T> {
139    Heap(Vec<T>),
140    Stack(StackVec<T>),
141}
142
143impl<T: Clone> Buffer<T> {
144    pub fn into_vec(self) -> Vec<T> {
145        match self {
146            Self::Heap(data) => data,
147            Self::Stack(data) => data.into_vec(),
148        }
149    }
150
151    pub fn to_vec(&self) -> Vec<T> {
152        match self {
153            Self::Heap(data) => data.to_vec(),
154            Self::Stack(data) => data.to_vec(),
155        }
156    }
157}
158
159impl<T> Borrow<[T]> for Buffer<T> {
160    fn borrow(&self) -> &[T] {
161        match self {
162            Self::Heap(buf) => buf.borrow(),
163            Self::Stack(buf) => buf.borrow(),
164        }
165    }
166}
167
168impl<T> AsMut<[T]> for Buffer<T> {
169    fn as_mut(&mut self) -> &mut [T] {
170        match self {
171            Self::Heap(buf) => buf.as_mut_slice(),
172            Self::Stack(buf) => buf.as_mut_slice(),
173        }
174    }
175}
176
177impl<T: CType> BufferInstance<T> for Buffer<T> {
178    fn read(&self) -> BufferConverter<T> {
179        BufferConverter::Host(self.into())
180    }
181
182    fn read_value(&self, offset: usize) -> Result<T, Error> {
183        match self {
184            Self::Heap(buf) => buf.read_value(offset),
185            Self::Stack(buf) => buf.read_value(offset),
186        }
187    }
188
189    fn len(&self) -> usize {
190        match self {
191            Self::Heap(buf) => buf.len(),
192            Self::Stack(buf) => buf.len(),
193        }
194    }
195}
196
197impl<T: CType> BufferMut<T> for Buffer<T> {
198    fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
199        match self {
200            Self::Heap(buf) => buf.write(data),
201            Self::Stack(buf) => buf.write(data),
202        }
203    }
204
205    fn write_value(&mut self, value: T) -> Result<(), Error> {
206        match self {
207            Self::Heap(buf) => buf.write_value(value),
208            Self::Stack(buf) => buf.write_value(value),
209        }
210    }
211
212    fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
213        match self {
214            Self::Heap(buf) => buf.write_value_at(offset, value),
215            Self::Stack(buf) => buf.write_value_at(offset, value),
216        }
217    }
218}
219
220impl<T> From<StackVec<T>> for Buffer<T> {
221    fn from(buf: StackVec<T>) -> Self {
222        Self::Stack(buf)
223    }
224}
225
226impl<T> From<Vec<T>> for Buffer<T> {
227    fn from(buf: Vec<T>) -> Self {
228        Self::Heap(buf)
229    }
230}
231
232#[derive(Clone)]
233/// A buffer in host memory, either borrowed or owned
234pub enum SliceConverter<'a, T> {
235    Heap(Vec<T>),
236    Stack(StackVec<T>),
237    Slice(&'a [T]),
238}
239
240impl<'a, T> SliceConverter<'a, T> {
241    /// Return the number of elements in this buffer.
242    pub fn len(&self) -> usize {
243        match self {
244            Self::Heap(vec) => vec.len(),
245            Self::Stack(vec) => vec.len(),
246            Self::Slice(slice) => slice.len(),
247        }
248    }
249}
250
251impl<'a, T: Copy> SliceConverter<'a, T> {
252    /// Return this buffer as an owned [`Vec`].
253    /// This will allocate a new [`Vec`] if this buffer is a [`StackVec`] or borrowed slice.
254    pub fn into_vec(self) -> Vec<T> {
255        match self {
256            Self::Heap(vec) => vec,
257            Self::Stack(vec) => vec.into_vec(),
258            Self::Slice(slice) => slice.to_vec(),
259        }
260    }
261
262    /// Return this buffer as an owned [`StackVec`].
263    pub fn into_stackvec(self) -> StackVec<T> {
264        match self {
265            Self::Heap(vec) => vec.into(),
266            Self::Stack(vec) => vec,
267            Self::Slice(slice) => StackVec::from_slice(slice),
268        }
269    }
270
271    /// Return this buffer as an owned host [`Buffer`].
272    pub fn into_buffer(self) -> Buffer<T> {
273        match self {
274            Self::Heap(vec) => Buffer::Heap(vec),
275            Self::Stack(vec) => Buffer::Stack(vec),
276            Self::Slice(slice) => {
277                if slice.len() < VEC_MIN_SIZE {
278                    Buffer::Stack(StackVec::from_slice(slice))
279                } else {
280                    Buffer::Heap(slice.to_vec())
281                }
282            }
283        }
284    }
285}
286
287impl<T> From<Buffer<T>> for SliceConverter<'static, T> {
288    fn from(buf: Buffer<T>) -> Self {
289        match buf {
290            Buffer::Heap(buf) => SliceConverter::Heap(buf),
291            Buffer::Stack(buf) => SliceConverter::Stack(buf),
292        }
293    }
294}
295
296impl<'a, T> From<&'a Buffer<T>> for SliceConverter<'a, T> {
297    fn from(buf: &'a Buffer<T>) -> Self {
298        match buf {
299            Buffer::Heap(slice) => slice.as_slice().into(),
300            Buffer::Stack(slice) => slice.as_slice().into(),
301        }
302    }
303}
304
305impl<T> From<StackVec<T>> for SliceConverter<'static, T> {
306    fn from(vec: StackVec<T>) -> Self {
307        Self::Stack(vec)
308    }
309}
310
311impl<T> From<Vec<T>> for SliceConverter<'static, T> {
312    fn from(vec: Vec<T>) -> Self {
313        Self::Heap(vec)
314    }
315}
316
317impl<'a, T> From<&'a [T]> for SliceConverter<'a, T> {
318    fn from(slice: &'a [T]) -> Self {
319        Self::Slice(slice)
320    }
321}
322
323impl<'a, T> Deref for SliceConverter<'a, T> {
324    type Target = [T];
325
326    fn deref(&self) -> &Self::Target {
327        match self {
328            Self::Heap(data) => data.as_slice(),
329            Self::Stack(data) => data.as_slice(),
330            Self::Slice(slice) => slice,
331        }
332    }
333}
334
335impl<'a, T: fmt::Debug> fmt::Debug for SliceConverter<'a, T> {
336    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
337        fmt::Debug::fmt(self.deref(), f)
338    }
339}