1use std::fmt;
2
3#[cfg(feature = "stream")]
4use destream::{de, en};
5use get_size::GetSize;
6
7#[cfg(feature = "opencl")]
8use crate::opencl;
9use crate::{host, Error, Number};
10
11pub trait BufferInstance<T: Number>: Send + Sync {
13 fn read(&self) -> BufferConverter<'_, T>;
15
16 fn read_value(&self, offset: usize) -> Result<T, Error>;
18
19 fn len(&self) -> usize;
21
22 fn is_empty(&self) -> bool {
23 self.len() == 0
24 }
25}
26
27pub trait BufferMut<T: Number>: BufferInstance<T> + fmt::Debug {
29 #[cfg(feature = "opencl")]
30 fn cl(&mut self) -> Result<&mut ocl::Buffer<T>, Error> {
32 Err(Error::unsupported(format!(
33 "not an OpenCL buffer: {self:?}"
34 )))
35 }
36
37 fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error>;
39
40 fn write_value(&mut self, value: T) -> Result<(), Error>;
42
43 fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error>;
45}
46
47#[derive(Clone)]
49pub enum Buffer<T: Number> {
50 #[cfg(feature = "opencl")]
51 CL(ocl::Buffer<T>),
52 Host(host::Buffer<T>),
53}
54
55impl<T: Number> Buffer<T> {
56 pub fn from_slice(slice: &[T]) -> Result<Self, Error> {
58 BufferConverter::from(slice).into_buffer()
59 }
60}
61
62impl<T: Number> GetSize for Buffer<T> {
63 fn get_size(&self) -> usize {
64 self.len() * std::mem::size_of::<T>()
65 }
66}
67
68impl<T: Number> BufferInstance<T> for Buffer<T> {
69 fn read(&self) -> BufferConverter<'_, T> {
70 BufferConverter::from(self)
71 }
72
73 fn read_value(&self, offset: usize) -> Result<T, Error> {
74 match self {
75 #[cfg(feature = "opencl")]
76 Self::CL(buf) => buf.read_value(offset),
77 Self::Host(buf) => buf.read_value(offset),
78 }
79 }
80
81 fn len(&self) -> usize {
82 match self {
83 #[cfg(feature = "opencl")]
84 Self::CL(buf) => buf.len(),
85 Self::Host(buf) => buf.len(),
86 }
87 }
88}
89
90impl<T: Number> BufferMut<T> for Buffer<T> {
91 #[cfg(feature = "opencl")]
92 fn cl(&mut self) -> Result<&mut ocl::Buffer<T>, Error> {
93 match self {
94 #[cfg(feature = "opencl")]
95 Self::CL(buf) => BufferMut::<T>::cl(buf),
96 Self::Host(buf) => buf.cl(),
97 }
98 }
99
100 fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
101 match self {
102 #[cfg(feature = "opencl")]
103 Self::CL(buf) => buf.write(data),
104 Self::Host(buf) => buf.write(data),
105 }
106 }
107
108 fn write_value(&mut self, value: T) -> Result<(), Error> {
109 match self {
110 #[cfg(feature = "opencl")]
111 Self::CL(buf) => buf.write_value(value),
112 Self::Host(buf) => buf.write_value(value),
113 }
114 }
115
116 fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
117 match self {
118 #[cfg(feature = "opencl")]
119 Self::CL(buf) => buf.write_value_at(offset, value),
120 Self::Host(buf) => buf.write_value_at(offset, value),
121 }
122 }
123}
124
125impl<T: Number> BufferInstance<T> for &Buffer<T> {
126 fn read(&self) -> BufferConverter<'_, T> {
127 BufferConverter::from(*self)
128 }
129
130 fn read_value(&self, offset: usize) -> Result<T, Error> {
131 BufferInstance::read_value(*self, offset)
132 }
133
134 fn len(&self) -> usize {
135 BufferInstance::len(*self)
136 }
137}
138
139impl<T: Number> BufferInstance<T> for &mut Buffer<T> {
140 fn read(&self) -> BufferConverter<'_, T> {
141 BufferConverter::from(&**self)
142 }
143
144 fn read_value(&self, offset: usize) -> Result<T, Error> {
145 BufferInstance::read_value(&**self, offset)
146 }
147
148 fn len(&self) -> usize {
149 BufferInstance::len(*self)
150 }
151}
152
153impl<T: Number> BufferMut<T> for &mut Buffer<T> {
154 #[cfg(feature = "opencl")]
155 fn cl(&mut self) -> Result<&mut ocl::Buffer<T>, Error> {
156 Buffer::<T>::cl(&mut **self)
157 }
158
159 fn write<'b>(&mut self, data: BufferConverter<'b, T>) -> Result<(), Error> {
160 Buffer::<T>::write(*self, data)
161 }
162
163 fn write_value(&mut self, value: T) -> Result<(), Error> {
164 Buffer::<T>::write_value(*self, value)
165 }
166
167 fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
168 Buffer::<T>::write_value_at(*self, offset, value)
169 }
170}
171
172#[cfg(feature = "freqfs")]
173impl<FE: Send + Sync, T: Number> BufferInstance<T> for freqfs::FileReadGuardOwned<FE, Buffer<T>> {
174 fn read(&self) -> BufferConverter<'_, T> {
175 BufferInstance::read(&**self)
176 }
177
178 fn len(&self) -> usize {
179 BufferInstance::len(&**self)
180 }
181
182 fn read_value(&self, offset: usize) -> Result<T, Error> {
183 BufferInstance::read_value(&**self, offset)
184 }
185}
186
187#[cfg(feature = "freqfs")]
188impl<FE: Send + Sync, T: Number> BufferInstance<T> for freqfs::FileWriteGuardOwned<FE, Buffer<T>> {
189 fn read(&self) -> BufferConverter<'_, T> {
190 BufferInstance::read(&**self)
191 }
192
193 fn len(&self) -> usize {
194 BufferInstance::len(&**self)
195 }
196
197 fn read_value(&self, offset: usize) -> Result<T, Error> {
198 BufferInstance::read_value(&**self, offset)
199 }
200}
201
202#[cfg(feature = "freqfs")]
203impl<FE: Send + Sync, T: Number> BufferMut<T> for freqfs::FileWriteGuardOwned<FE, Buffer<T>> {
204 #[cfg(feature = "opencl")]
205 fn cl(&mut self) -> Result<&mut ocl::Buffer<T>, Error> {
206 BufferMut::cl(&mut **self)
207 }
208
209 fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
210 BufferMut::write(&mut **self, data)
211 }
212
213 fn write_value(&mut self, value: T) -> Result<(), Error> {
214 BufferMut::write_value(&mut **self, value)
215 }
216
217 fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
218 BufferMut::write_value_at(&mut **self, offset, value)
219 }
220}
221
222#[cfg(feature = "opencl")]
223impl<T: Number> From<ocl::Buffer<T>> for Buffer<T> {
224 fn from(buf: ocl::Buffer<T>) -> Self {
225 Self::CL(buf)
226 }
227}
228
229impl<T: Number> From<host::StackVec<T>> for Buffer<T> {
230 fn from(buf: host::StackVec<T>) -> Self {
231 Self::Host(buf.into())
232 }
233}
234
235impl<T: Number> From<Vec<T>> for Buffer<T> {
236 fn from(buf: Vec<T>) -> Self {
237 Self::Host(buf.into())
238 }
239}
240
241impl<T: Number> From<host::Buffer<T>> for Buffer<T> {
242 fn from(buf: host::Buffer<T>) -> Self {
243 Self::Host(buf)
244 }
245}
246
247#[derive(Clone)]
248pub enum BufferConverter<'a, T: Number> {
250 #[cfg(feature = "opencl")]
251 CL(opencl::CLConverter<'a, T>),
252 Host(host::SliceConverter<'a, T>),
253}
254
255impl<'a, T: Number> BufferConverter<'a, T> {
256 pub fn into_buffer(self) -> Result<Buffer<T>, Error> {
258 match self {
259 #[cfg(feature = "opencl")]
260 Self::CL(buffer) => buffer.into_buffer().map(Buffer::CL),
261 Self::Host(buffer) => Ok(Buffer::Host(buffer.into_buffer())),
262 }
263 }
264
265 pub fn len(&self) -> usize {
267 match self {
268 #[cfg(feature = "opencl")]
269 Self::CL(buffer) => buffer.len(),
270 Self::Host(buffer) => buffer.len(),
271 }
272 }
273
274 pub fn is_empty(&self) -> bool {
275 self.len() == 0
276 }
277
278 #[cfg(feature = "opencl")]
279 pub fn to_cl(self) -> Result<opencl::CLConverter<'a, T>, ocl::Error> {
281 match self {
282 Self::CL(buffer) => Ok(buffer),
283 Self::Host(buffer) => {
284 opencl::OpenCL::copy_into_buffer::<T>(&buffer).map(opencl::CLConverter::Owned)
285 }
286 }
287 }
288
289 #[inline]
291 pub fn to_slice(self) -> Result<host::SliceConverter<'a, T>, Error> {
292 match self {
293 #[cfg(feature = "opencl")]
294 Self::CL(buffer) => {
295 let mut copy = vec![T::default(); buffer.len()];
296 buffer.read(&mut copy[..]).enq()?;
297
298 let copy = copy.into_iter().collect::<Vec<T>>();
299 Ok(host::SliceConverter::from(copy))
300 }
301 Self::Host(buffer) => Ok(buffer),
302 }
303 }
304}
305
306impl<T: Number> From<Buffer<T>> for BufferConverter<'static, T> {
307 fn from(buf: Buffer<T>) -> Self {
308 match buf {
309 #[cfg(feature = "opencl")]
310 Buffer::CL(buf) => Self::CL(buf.into()),
311 Buffer::Host(buf) => Self::Host(buf.into()),
312 }
313 }
314}
315
316impl<'a, T: Number> From<&'a Buffer<T>> for BufferConverter<'a, T> {
317 fn from(buf: &'a Buffer<T>) -> Self {
318 match buf {
319 #[cfg(feature = "opencl")]
320 Buffer::CL(buf) => Self::CL(buf.into()),
321 Buffer::Host(buf) => Self::Host(buf.into()),
322 }
323 }
324}
325
326impl<T: Number> From<Vec<T>> for BufferConverter<'static, T> {
327 fn from(buf: Vec<T>) -> Self {
328 Self::Host(buf.into())
329 }
330}
331
332impl<T: Number> From<host::StackVec<T>> for BufferConverter<'static, T> {
333 fn from(buf: host::StackVec<T>) -> Self {
334 Self::Host(buf.into())
335 }
336}
337
338impl<T: Number> From<host::Buffer<T>> for BufferConverter<'static, T> {
339 fn from(buf: host::Buffer<T>) -> Self {
340 Self::Host(buf.into())
341 }
342}
343
344impl<'a, T: Number> From<&'a [T]> for BufferConverter<'a, T> {
345 fn from(buf: &'a [T]) -> Self {
346 Self::Host(buf.into())
347 }
348}
349
350#[cfg(feature = "opencl")]
351impl<T: Number> From<ocl::Buffer<T>> for BufferConverter<'static, T> {
352 fn from(buf: ocl::Buffer<T>) -> Self {
353 Self::CL(buf.into())
354 }
355}
356
357#[cfg(feature = "opencl")]
358impl<'a, T: Number> From<&'a ocl::Buffer<T>> for BufferConverter<'a, T> {
359 fn from(buf: &'a ocl::Buffer<T>) -> Self {
360 Self::CL(buf.into())
361 }
362}
363
364#[cfg(feature = "stream")]
365struct BufferVisitor<T> {
366 data: Vec<T>,
367}
368
369#[cfg(feature = "stream")]
370impl<T> BufferVisitor<T> {
371 fn new() -> Self {
372 Self { data: Vec::new() }
373 }
374}
375
376#[cfg(feature = "stream")]
377macro_rules! decode_buffer {
378 ($t:ty, $name:expr, $decode:ident, $visit:ident, $encode:ident) => {
379 impl de::Visitor for BufferVisitor<$t> {
380 type Value = Buffer<$t>;
381
382 fn expecting() -> &'static str {
383 $name
384 }
385
386 async fn $visit<A: de::ArrayAccess<$t>>(
387 self,
388 mut array: A,
389 ) -> Result<Self::Value, A::Error> {
390 use crate::platform::{Convert, PlatformInstance};
391
392 const BUF_SIZE: usize = 4_096;
393 let mut data = self.data;
394
395 let mut buf = [<$t>::ZERO; BUF_SIZE];
396 loop {
397 let len = array.buffer(&mut buf).await?;
398 if len == 0 {
399 break;
400 } else {
401 data.extend_from_slice(&buf[..len]);
402 }
403 }
404
405 crate::Platform::select(data.len())
406 .convert(data.into())
407 .map_err(de::Error::custom)
408 }
409 }
410
411 impl de::FromStream for Buffer<$t> {
412 type Context = ();
413
414 async fn from_stream<D: de::Decoder>(
415 _cxt: (),
416 decoder: &mut D,
417 ) -> Result<Self, D::Error> {
418 decoder.$decode(BufferVisitor::<$t>::new()).await
419 }
420 }
421
422 impl<'en> en::ToStream<'en> for Buffer<$t> {
423 fn to_stream<E: en::Encoder<'en>>(&'en self, encoder: E) -> Result<E::Ok, E::Error> {
424 match self {
425 Self::Host(buffer) => {
426 let fut = futures::future::ready(buffer.to_vec());
427 let stream = futures::stream::once(fut);
428 encoder.$encode(stream)
429 }
430 #[cfg(feature = "opencl")]
431 Self::CL(buffer) => {
432 let mut data = Vec::with_capacity(buffer.len());
433 buffer.read(&mut data).enq().map_err(en::Error::custom)?;
434 encoder.$encode(futures::stream::once(futures::future::ready(data)))
435 }
436 }
437 }
438 }
439
440 impl<'en> en::IntoStream<'en> for Buffer<$t> {
441 fn into_stream<E: en::Encoder<'en>>(self, encoder: E) -> Result<E::Ok, E::Error> {
442 match self {
443 Self::Host(buffer) => {
444 let buffer = buffer.to_vec();
445 encoder.$encode(futures::stream::once(futures::future::ready(buffer)))
446 }
447 #[cfg(feature = "opencl")]
448 Self::CL(buffer) => {
449 let mut data = Vec::with_capacity(buffer.len());
450 buffer.read(&mut data).enq().map_err(en::Error::custom)?;
451 encoder.$encode(futures::stream::once(futures::future::ready(data)))
452 }
453 }
454 }
455 }
456 };
457}
458
459#[cfg(feature = "stream")]
460decode_buffer!(
461 u8,
462 "byte array",
463 decode_array_u8,
464 visit_array_u8,
465 encode_array_u8
466);
467
468#[cfg(feature = "stream")]
469decode_buffer!(
470 u16,
471 "16-bit unsigned int array",
472 decode_array_u16,
473 visit_array_u16,
474 encode_array_u16
475);
476
477#[cfg(feature = "stream")]
478decode_buffer!(
479 u32,
480 "32-bit unsigned int array",
481 decode_array_u32,
482 visit_array_u32,
483 encode_array_u32
484);
485
486#[cfg(feature = "stream")]
487decode_buffer!(
488 u64,
489 "64-bit unsigned int array",
490 decode_array_u64,
491 visit_array_u64,
492 encode_array_u64
493);
494
495#[cfg(feature = "stream")]
496decode_buffer!(
497 i16,
498 "16-bit int array",
499 decode_array_i16,
500 visit_array_i16,
501 encode_array_i16
502);
503
504#[cfg(feature = "stream")]
505decode_buffer!(
506 i32,
507 "32-bit int array",
508 decode_array_i32,
509 visit_array_i32,
510 encode_array_i32
511);
512
513#[cfg(feature = "stream")]
514decode_buffer!(
515 i64,
516 "64-bit int array",
517 decode_array_i64,
518 visit_array_i64,
519 encode_array_i64
520);
521
522#[cfg(feature = "stream")]
523decode_buffer!(
524 f32,
525 "32-bit int array",
526 decode_array_f32,
527 visit_array_f32,
528 encode_array_f32
529);
530
531#[cfg(feature = "stream")]
532decode_buffer!(
533 f64,
534 "64-bit int array",
535 decode_array_f64,
536 visit_array_f64,
537 encode_array_f64
538);
539
540impl<T: Number + fmt::Debug> fmt::Debug for Buffer<T> {
541 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
542 match self {
543 Self::Host(buffer) => fmt::Debug::fmt(buffer, f),
544 #[cfg(feature = "opencl")]
545 Self::CL(buffer) => fmt::Debug::fmt(buffer, f),
546 }
547 }
548}