1use std::fmt;
2
3#[cfg(feature = "stream")]
4use destream::{de, en};
5use get_size::GetSize;
6
7#[cfg(feature = "opencl")]
8use crate::opencl;
9use crate::{host, CType, Error};
10
11pub trait BufferInstance<T: CType>: Send + Sync {
13 fn read(&self) -> BufferConverter<T>;
15
16 fn read_value(&self, offset: usize) -> Result<T, Error>;
18
19 fn len(&self) -> usize;
21}
22
23pub trait BufferMut<T: CType>: BufferInstance<T> + fmt::Debug {
25 #[cfg(feature = "opencl")]
26 fn cl(&mut self) -> Result<&mut ocl::Buffer<T>, Error> {
28 Err(Error::Unsupported(format!(
29 "not an OpenCL buffer: {self:?}"
30 )))
31 }
32
33 fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error>;
35
36 fn write_value(&mut self, value: T) -> Result<(), Error>;
38
39 fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error>;
41}
42
43#[derive(Clone)]
45pub enum Buffer<T: CType> {
46 #[cfg(feature = "opencl")]
47 CL(ocl::Buffer<T>),
48 Host(host::Buffer<T>),
49}
50
51impl<T: CType> Buffer<T> {
52 pub fn from_slice(slice: &[T]) -> Result<Self, Error> {
54 BufferConverter::from(slice).into_buffer()
55 }
56}
57
58impl<T: CType> GetSize for Buffer<T> {
59 fn get_size(&self) -> usize {
60 self.len() * std::mem::size_of::<T>()
61 }
62}
63
64impl<T: CType> BufferInstance<T> for Buffer<T> {
65 fn read(&self) -> BufferConverter<T> {
66 BufferConverter::from(self)
67 }
68
69 fn read_value(&self, offset: usize) -> Result<T, Error> {
70 match self {
71 #[cfg(feature = "opencl")]
72 Self::CL(buf) => buf.read_value(offset),
73 Self::Host(buf) => buf.read_value(offset),
74 }
75 }
76
77 fn len(&self) -> usize {
78 match self {
79 #[cfg(feature = "opencl")]
80 Self::CL(buf) => buf.len(),
81 Self::Host(buf) => buf.len(),
82 }
83 }
84}
85
86impl<T: CType> BufferMut<T> for Buffer<T> {
87 #[cfg(feature = "opencl")]
88 fn cl(&mut self) -> Result<&mut ocl::Buffer<T>, Error> {
89 match self {
90 #[cfg(feature = "opencl")]
91 Self::CL(buf) => buf.cl(),
92 Self::Host(buf) => buf.cl(),
93 }
94 }
95
96 fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
97 match self {
98 #[cfg(feature = "opencl")]
99 Self::CL(buf) => buf.write(data),
100 Self::Host(buf) => buf.write(data),
101 }
102 }
103
104 fn write_value(&mut self, value: T) -> Result<(), Error> {
105 match self {
106 #[cfg(feature = "opencl")]
107 Self::CL(buf) => buf.write_value(value),
108 Self::Host(buf) => buf.write_value(value),
109 }
110 }
111
112 fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
113 match self {
114 #[cfg(feature = "opencl")]
115 Self::CL(buf) => buf.write_value_at(offset, value),
116 Self::Host(buf) => buf.write_value_at(offset, value),
117 }
118 }
119}
120
121impl<'a, T: CType> BufferInstance<T> for &'a Buffer<T> {
122 fn read(&self) -> BufferConverter<T> {
123 BufferConverter::from(*self)
124 }
125
126 fn read_value(&self, offset: usize) -> Result<T, Error> {
127 BufferInstance::read_value(*self, offset)
128 }
129
130 fn len(&self) -> usize {
131 BufferInstance::len(*self)
132 }
133}
134
135impl<'a, T: CType> BufferInstance<T> for &'a mut Buffer<T> {
136 fn read(&self) -> BufferConverter<T> {
137 BufferConverter::from(&**self)
138 }
139
140 fn read_value(&self, offset: usize) -> Result<T, Error> {
141 BufferInstance::read_value(&**self, offset)
142 }
143
144 fn len(&self) -> usize {
145 BufferInstance::len(*self)
146 }
147}
148
149impl<'a, T: CType> BufferMut<T> for &'a mut Buffer<T> {
150 #[cfg(feature = "opencl")]
151 fn cl(&mut self) -> Result<&mut ocl::Buffer<T>, Error> {
152 Buffer::<T>::cl(&mut **self)
153 }
154
155 fn write<'b>(&mut self, data: BufferConverter<'b, T>) -> Result<(), Error> {
156 Buffer::<T>::write(*self, data)
157 }
158
159 fn write_value(&mut self, value: T) -> Result<(), Error> {
160 Buffer::<T>::write_value(*self, value)
161 }
162
163 fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
164 Buffer::<T>::write_value_at(*self, offset, value)
165 }
166}
167
168#[cfg(feature = "freqfs")]
169impl<FE: Send + Sync, T: CType> BufferInstance<T> for freqfs::FileReadGuardOwned<FE, Buffer<T>> {
170 fn read(&self) -> BufferConverter<T> {
171 BufferInstance::read(&**self)
172 }
173
174 fn len(&self) -> usize {
175 BufferInstance::len(&**self)
176 }
177
178 fn read_value(&self, offset: usize) -> Result<T, Error> {
179 BufferInstance::read_value(&**self, offset)
180 }
181}
182
183#[cfg(feature = "freqfs")]
184impl<FE: Send + Sync, T: CType> BufferInstance<T> for freqfs::FileWriteGuardOwned<FE, Buffer<T>> {
185 fn read(&self) -> BufferConverter<T> {
186 BufferInstance::read(&**self)
187 }
188
189 fn len(&self) -> usize {
190 BufferInstance::len(&**self)
191 }
192
193 fn read_value(&self, offset: usize) -> Result<T, Error> {
194 BufferInstance::read_value(&**self, offset)
195 }
196}
197
198#[cfg(feature = "freqfs")]
199impl<FE: Send + Sync, T: CType> BufferMut<T> for freqfs::FileWriteGuardOwned<FE, Buffer<T>> {
200 #[cfg(feature = "opencl")]
201 fn cl(&mut self) -> Result<&mut ocl::Buffer<T>, Error> {
202 BufferMut::cl(&mut **self)
203 }
204
205 fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
206 BufferMut::write(&mut **self, data)
207 }
208
209 fn write_value(&mut self, value: T) -> Result<(), Error> {
210 BufferMut::write_value(&mut **self, value)
211 }
212
213 fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
214 BufferMut::write_value_at(&mut **self, offset, value)
215 }
216}
217
218#[cfg(feature = "opencl")]
219impl<T: CType> From<ocl::Buffer<T>> for Buffer<T> {
220 fn from(buf: ocl::Buffer<T>) -> Self {
221 Self::CL(buf)
222 }
223}
224
225impl<T: CType> From<host::StackVec<T>> for Buffer<T> {
226 fn from(buf: host::StackVec<T>) -> Self {
227 Self::Host(buf.into())
228 }
229}
230
231impl<T: CType> From<Vec<T>> for Buffer<T> {
232 fn from(buf: Vec<T>) -> Self {
233 Self::Host(buf.into())
234 }
235}
236
237impl<T: CType> From<host::Buffer<T>> for Buffer<T> {
238 fn from(buf: host::Buffer<T>) -> Self {
239 Self::Host(buf)
240 }
241}
242
243#[derive(Clone)]
244pub enum BufferConverter<'a, T: CType> {
246 #[cfg(feature = "opencl")]
247 CL(opencl::CLConverter<'a, T>),
248 Host(host::SliceConverter<'a, T>),
249}
250
251impl<'a, T: CType> BufferConverter<'a, T> {
252 pub fn into_buffer(self) -> Result<Buffer<T>, Error> {
254 match self {
255 #[cfg(feature = "opencl")]
256 Self::CL(buffer) => buffer.into_buffer().map(Buffer::CL),
257 Self::Host(buffer) => Ok(Buffer::Host(buffer.into_buffer())),
258 }
259 }
260
261 pub fn len(&self) -> usize {
263 match self {
264 #[cfg(feature = "opencl")]
265 Self::CL(buffer) => buffer.len(),
266 Self::Host(buffer) => buffer.len(),
267 }
268 }
269
270 #[cfg(feature = "opencl")]
271 pub fn to_cl(self) -> Result<opencl::CLConverter<'a, T>, ocl::Error> {
273 match self {
274 Self::CL(buffer) => Ok(buffer),
275 Self::Host(buffer) => {
276 opencl::OpenCL::copy_into_buffer(buffer.as_ref()).map(opencl::CLConverter::Owned)
277 }
278 }
279 }
280
281 #[inline]
283 pub fn to_slice(self) -> Result<host::SliceConverter<'a, T>, Error> {
284 match self {
285 #[cfg(feature = "opencl")]
286 Self::CL(buffer) => {
287 let mut copy = vec![T::default(); buffer.len()];
288 buffer.read(&mut copy[..]).enq()?;
289 Ok(host::SliceConverter::from(copy))
290 }
291 Self::Host(buffer) => Ok(buffer),
292 }
293 }
294}
295
296impl<T: CType> From<Buffer<T>> for BufferConverter<'static, T> {
297 fn from(buf: Buffer<T>) -> Self {
298 match buf {
299 #[cfg(feature = "opencl")]
300 Buffer::CL(buf) => Self::CL(buf.into()),
301 Buffer::Host(buf) => Self::Host(buf.into()),
302 }
303 }
304}
305
306impl<'a, T: CType> From<&'a Buffer<T>> for BufferConverter<'a, T> {
307 fn from(buf: &'a Buffer<T>) -> Self {
308 match buf {
309 #[cfg(feature = "opencl")]
310 Buffer::CL(buf) => Self::CL(buf.into()),
311 Buffer::Host(buf) => Self::Host(buf.into()),
312 }
313 }
314}
315
316impl<T: CType> From<Vec<T>> for BufferConverter<'static, T> {
317 fn from(buf: Vec<T>) -> Self {
318 Self::Host(buf.into())
319 }
320}
321
322impl<T: CType> From<host::StackVec<T>> for BufferConverter<'static, T> {
323 fn from(buf: host::StackVec<T>) -> Self {
324 Self::Host(buf.into())
325 }
326}
327
328impl<T: CType> From<host::Buffer<T>> for BufferConverter<'static, T> {
329 fn from(buf: host::Buffer<T>) -> Self {
330 Self::Host(buf.into())
331 }
332}
333
334impl<'a, T: CType> From<&'a [T]> for BufferConverter<'a, T> {
335 fn from(buf: &'a [T]) -> Self {
336 Self::Host(buf.into())
337 }
338}
339
340#[cfg(feature = "opencl")]
341impl<T: CType> From<ocl::Buffer<T>> for BufferConverter<'static, T> {
342 fn from(buf: ocl::Buffer<T>) -> Self {
343 Self::CL(buf.into())
344 }
345}
346
347#[cfg(feature = "opencl")]
348impl<'a, T: CType> From<&'a ocl::Buffer<T>> for BufferConverter<'a, T> {
349 fn from(buf: &'a ocl::Buffer<T>) -> Self {
350 Self::CL(buf.into())
351 }
352}
353
354#[cfg(feature = "stream")]
355struct BufferVisitor<T> {
356 data: Vec<T>,
357}
358
359#[cfg(feature = "stream")]
360impl<T> BufferVisitor<T> {
361 fn new() -> Self {
362 Self { data: Vec::new() }
363 }
364}
365
366#[cfg(feature = "stream")]
367macro_rules! decode_buffer {
368 ($t:ty, $name:expr, $decode:ident, $visit:ident, $encode:ident) => {
369 #[async_trait::async_trait]
370 impl de::Visitor for BufferVisitor<$t> {
371 type Value = Buffer<$t>;
372
373 fn expecting() -> &'static str {
374 $name
375 }
376
377 async fn $visit<A: de::ArrayAccess<$t>>(
378 self,
379 mut array: A,
380 ) -> Result<Self::Value, A::Error> {
381 use crate::platform::{Convert, PlatformInstance};
382
383 const BUF_SIZE: usize = 4_096;
384 let mut data = self.data;
385
386 let mut buf = [<$t>::ZERO; BUF_SIZE];
387 loop {
388 let len = array.buffer(&mut buf).await?;
389 if len == 0 {
390 break;
391 } else {
392 data.extend_from_slice(&buf[..len]);
393 }
394 }
395
396 crate::Platform::select(data.len())
397 .convert(data.into())
398 .map_err(de::Error::custom)
399 }
400 }
401
402 #[async_trait::async_trait]
403 impl de::FromStream for Buffer<$t> {
404 type Context = ();
405
406 async fn from_stream<D: de::Decoder>(
407 _cxt: (),
408 decoder: &mut D,
409 ) -> Result<Self, D::Error> {
410 decoder.$decode(BufferVisitor::<$t>::new()).await
411 }
412 }
413
414 impl<'en> en::ToStream<'en> for Buffer<$t> {
415 fn to_stream<E: en::Encoder<'en>>(&'en self, encoder: E) -> Result<E::Ok, E::Error> {
416 match self {
417 Self::Host(buffer) => {
418 let fut = futures::future::ready(buffer.to_vec());
419 let stream = futures::stream::once(fut);
420 encoder.$encode(stream)
421 }
422 #[cfg(feature = "opencl")]
423 Self::CL(buffer) => {
424 let mut data = Vec::with_capacity(buffer.len());
425 buffer.read(&mut data).enq().map_err(en::Error::custom)?;
426 encoder.$encode(futures::stream::once(futures::future::ready(data)))
427 }
428 }
429 }
430 }
431
432 impl<'en> en::IntoStream<'en> for Buffer<$t> {
433 fn into_stream<E: en::Encoder<'en>>(self, encoder: E) -> Result<E::Ok, E::Error> {
434 match self {
435 Self::Host(buffer) => {
436 let buffer = buffer.to_vec();
437 encoder.$encode(futures::stream::once(futures::future::ready(buffer)))
438 }
439 #[cfg(feature = "opencl")]
440 Self::CL(buffer) => {
441 let mut data = Vec::with_capacity(buffer.len());
442 buffer.read(&mut data).enq().map_err(en::Error::custom)?;
443 encoder.$encode(futures::stream::once(futures::future::ready(data)))
444 }
445 }
446 }
447 }
448 };
449}
450
451#[cfg(feature = "stream")]
452decode_buffer!(
453 u8,
454 "byte array",
455 decode_array_u8,
456 visit_array_u8,
457 encode_array_u8
458);
459
460#[cfg(feature = "stream")]
461decode_buffer!(
462 u16,
463 "16-bit unsigned int array",
464 decode_array_u16,
465 visit_array_u16,
466 encode_array_u16
467);
468
469#[cfg(feature = "stream")]
470decode_buffer!(
471 u32,
472 "32-bit unsigned int array",
473 decode_array_u32,
474 visit_array_u32,
475 encode_array_u32
476);
477
478#[cfg(feature = "stream")]
479decode_buffer!(
480 u64,
481 "64-bit unsigned int array",
482 decode_array_u64,
483 visit_array_u64,
484 encode_array_u64
485);
486
487#[cfg(feature = "stream")]
488decode_buffer!(
489 i16,
490 "16-bit int array",
491 decode_array_i16,
492 visit_array_i16,
493 encode_array_i16
494);
495
496#[cfg(feature = "stream")]
497decode_buffer!(
498 i32,
499 "32-bit int array",
500 decode_array_i32,
501 visit_array_i32,
502 encode_array_i32
503);
504
505#[cfg(feature = "stream")]
506decode_buffer!(
507 i64,
508 "64-bit int array",
509 decode_array_i64,
510 visit_array_i64,
511 encode_array_i64
512);
513
514#[cfg(feature = "stream")]
515decode_buffer!(
516 f32,
517 "32-bit int array",
518 decode_array_f32,
519 visit_array_f32,
520 encode_array_f32
521);
522
523#[cfg(feature = "stream")]
524decode_buffer!(
525 f64,
526 "64-bit int array",
527 decode_array_f64,
528 visit_array_f64,
529 encode_array_f64
530);
531
532impl<T: CType + fmt::Debug> fmt::Debug for Buffer<T> {
533 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
534 match self {
535 Self::Host(buffer) => fmt::Debug::fmt(buffer, f),
536 #[cfg(feature = "opencl")]
537 Self::CL(buffer) => fmt::Debug::fmt(buffer, f),
538 }
539 }
540}