1use std::f32::consts::PI;
2use std::iter;
3use std::marker::PhantomData;
4
5use rand::Rng;
6use rayon::join;
7use rayon::prelude::*;
8
9use crate::access::Access;
10use crate::ops::{Enqueue, Op, ReadValue, SliceSpec, ViewSpec};
11use crate::{
12 stackvec, strides_for, AccessMut, Axes, BufferConverter, CType, Error, Float, Range, Shape,
13 Strides,
14};
15
16use super::buffer::Buffer;
17use super::platform::{Heap, Host, Stack};
18use super::{SliceConverter, StackVec, VEC_MIN_SIZE};
19
20macro_rules! host_enqueue {
21 ($this:expr, $cond:expr, $t:ty) => {
22 if $cond {
23 Enqueue::<Stack, $t>::enqueue($this).map(Buffer::Stack)
24 } else {
25 Enqueue::<Heap, $t>::enqueue($this).map(Buffer::Heap)
26 }
27 };
28}
29
30pub struct Cast<A, IT, OT> {
31 access: A,
32 dtype: PhantomData<(IT, OT)>,
33}
34
35impl<A, IT, OT> Cast<A, IT, OT> {
36 pub fn new(access: A) -> Self {
37 Self {
38 access,
39 dtype: PhantomData,
40 }
41 }
42}
43
44impl<A: Access<IT>, IT: CType, OT: CType> Op for Cast<A, IT, OT> {
45 fn size(&self) -> usize {
46 self.access.size()
47 }
48}
49
50impl<A: Access<IT>, IT: CType, OT: CType> Enqueue<Heap, OT> for Cast<A, IT, OT> {
51 type Buffer = Vec<OT>;
52
53 fn enqueue(&self) -> Result<Self::Buffer, Error> {
54 self.access
55 .read()
56 .and_then(|buf| buf.to_slice())
57 .map(|slice| {
58 slice
59 .into_par_iter()
60 .map(|n| n.to_f64())
61 .map(OT::from_f64)
62 .collect()
63 })
64 }
65}
66
67impl<A: Access<IT>, IT: CType, OT: CType> Enqueue<Stack, OT> for Cast<A, IT, OT> {
68 type Buffer = StackVec<OT>;
69
70 fn enqueue(&self) -> Result<Self::Buffer, Error> {
71 self.access
72 .read()
73 .and_then(|buf| buf.to_slice())
74 .map(|slice| {
75 slice
76 .into_iter()
77 .map(|n| n.to_f64())
78 .map(OT::from_f64)
79 .collect()
80 })
81 }
82}
83
84impl<A: Access<IT>, IT: CType, OT: CType> Enqueue<Host, OT> for Cast<A, IT, OT> {
85 type Buffer = Buffer<OT>;
86
87 fn enqueue(&self) -> Result<Self::Buffer, Error> {
88 host_enqueue!(self, self.size() < VEC_MIN_SIZE, OT)
89 }
90}
91
92impl<A: Access<IT>, IT: CType, OT: CType> ReadValue<Host, OT> for Cast<A, IT, OT> {
93 fn read_value(&self, offset: usize) -> Result<OT, Error> {
94 self.access
95 .read_value(offset)
96 .map(|n| n.to_f64())
97 .map(OT::from_f64)
98 }
99}
100
101pub struct Dual<L, R, IT, OT> {
102 left: L,
103 right: R,
104 zip: fn(IT, IT) -> OT,
105}
106
107impl<L, R, IT, OT> Op for Dual<L, R, IT, OT>
108where
109 L: Access<IT>,
110 R: Access<IT>,
111 IT: CType,
112 OT: CType,
113{
114 fn size(&self) -> usize {
115 self.left.size()
116 }
117}
118
119impl<L, R, T: CType> Dual<L, R, T, T> {
121 pub fn add(left: L, right: R) -> Self {
122 Self {
123 left,
124 right,
125 zip: T::add,
126 }
127 }
128
129 pub fn div(left: L, right: R) -> Self {
130 Self {
131 left,
132 right,
133 zip: T::div,
134 }
135 }
136
137 pub fn log(left: L, right: R) -> Self {
138 Self {
139 left,
140 right,
141 zip: |a, b| T::from_float(a.to_float().log(b.to_float())),
142 }
143 }
144
145 pub fn mul(left: L, right: R) -> Self {
146 Self {
147 left,
148 right,
149 zip: T::mul,
150 }
151 }
152
153 pub fn pow(left: L, right: R) -> Self {
154 Self {
155 left,
156 right,
157 zip: T::pow,
158 }
159 }
160
161 pub fn rem(left: L, right: R) -> Self {
162 Self {
163 left,
164 right,
165 zip: T::rem,
166 }
167 }
168
169 pub fn sub(left: L, right: R) -> Self {
170 Self {
171 left,
172 right,
173 zip: T::sub,
174 }
175 }
176}
177
178impl<L, R, T: CType> Dual<L, R, T, u8> {
180 pub fn and(left: L, right: R) -> Self {
181 Self {
182 left,
183 right,
184 zip: |l, r| if l != T::ZERO && r != T::ZERO { 1 } else { 0 },
185 }
186 }
187
188 pub fn or(left: L, right: R) -> Self {
189 Self {
190 left,
191 right,
192 zip: |l, r| if l != T::ZERO || r != T::ZERO { 1 } else { 0 },
193 }
194 }
195
196 pub fn xor(left: L, right: R) -> Self {
197 Self {
198 left,
199 right,
200 zip: |l, r| {
201 if (l != T::ZERO) ^ (r != T::ZERO) {
202 1
203 } else {
204 0
205 }
206 },
207 }
208 }
209}
210
211impl<L, R, T: CType> Dual<L, R, T, u8> {
213 pub fn eq(left: L, right: R) -> Self {
214 Self {
215 left,
216 right,
217 zip: |l, r| if l == r { 1 } else { 0 },
218 }
219 }
220
221 pub fn ge(left: L, right: R) -> Self {
222 Self {
223 left,
224 right,
225 zip: |l, r| if l >= r { 1 } else { 0 },
226 }
227 }
228
229 pub fn gt(left: L, right: R) -> Self {
230 Self {
231 left,
232 right,
233 zip: |l, r| if l > r { 1 } else { 0 },
234 }
235 }
236
237 pub fn le(left: L, right: R) -> Self {
238 Self {
239 left,
240 right,
241 zip: |l, r| if l <= r { 1 } else { 0 },
242 }
243 }
244
245 pub fn lt(left: L, right: R) -> Self {
246 Self {
247 left,
248 right,
249 zip: |l, r| if l < r { 1 } else { 0 },
250 }
251 }
252
253 pub fn ne(left: L, right: R) -> Self {
254 Self {
255 left,
256 right,
257 zip: |l, r| if l != r { 1 } else { 0 },
258 }
259 }
260}
261
262impl<L, R, IT, OT> Enqueue<Stack, OT> for Dual<L, R, IT, OT>
263where
264 L: Access<IT>,
265 R: Access<IT>,
266 IT: CType,
267 OT: CType,
268{
269 type Buffer = StackVec<OT>;
270
271 fn enqueue(&self) -> Result<Self::Buffer, Error> {
272 let left = self.left.read()?.to_slice()?;
273 let right = self.right.read()?.to_slice()?;
274 exec_dual(self.zip, left, right)
275 }
276}
277
278impl<L, R, IT, OT> Enqueue<Heap, OT> for Dual<L, R, IT, OT>
279where
280 L: Access<IT>,
281 R: Access<IT>,
282 IT: CType,
283 OT: CType,
284{
285 type Buffer = Vec<OT>;
286
287 fn enqueue(&self) -> Result<Self::Buffer, Error> {
288 let (left, right) = try_join_read(&self.left, &self.right)?;
289 exec_dual_parallel(self.zip, left, right)
290 }
291}
292
293impl<L, R, IT, OT> Enqueue<Host, OT> for Dual<L, R, IT, OT>
294where
295 L: Access<IT>,
296 R: Access<IT>,
297 IT: CType,
298 OT: CType,
299{
300 type Buffer = Buffer<OT>;
301
302 fn enqueue(&self) -> Result<Self::Buffer, Error> {
303 host_enqueue!(self, self.size() < VEC_MIN_SIZE, OT)
304 }
305}
306
307impl<L, R, IT, OT> ReadValue<Host, OT> for Dual<L, R, IT, OT>
308where
309 L: Access<IT>,
310 R: Access<IT>,
311 IT: CType,
312 OT: CType,
313{
314 fn read_value(&self, offset: usize) -> Result<OT, Error> {
315 try_join_value(&self.left, &self.right, offset).map(|(l, r)| (self.zip)(l, r))
316 }
317}
318
319pub struct Cond<A, L, R, T> {
320 cond: A,
321 then: L,
322 or_else: R,
323 dtype: PhantomData<T>,
324}
325
326impl<A, L, R, T> Cond<A, L, R, T> {
327 pub fn new(cond: A, then: L, or_else: R) -> Self {
328 Self {
329 cond,
330 then,
331 or_else,
332 dtype: PhantomData,
333 }
334 }
335}
336
337impl<A, L, R, T> Op for Cond<A, L, R, T>
338where
339 A: Access<u8>,
340 L: Access<T>,
341 R: Access<T>,
342 T: CType,
343{
344 fn size(&self) -> usize {
345 debug_assert_eq!(self.cond.size(), self.then.size());
346 debug_assert_eq!(self.cond.size(), self.or_else.size());
347 self.cond.size()
348 }
349}
350
351impl<A, L, R, T> Enqueue<Stack, T> for Cond<A, L, R, T>
352where
353 A: Access<u8>,
354 L: Access<T>,
355 R: Access<T>,
356 T: CType,
357{
358 type Buffer = StackVec<T>;
359
360 fn enqueue(&self) -> Result<Self::Buffer, Error> {
361 let cond = self.cond.read()?.to_slice()?;
362 let then = self.then.read()?.to_slice()?;
363 let or_else = self.or_else.read()?.to_slice()?;
364
365 let output = cond
366 .into_iter()
367 .copied()
368 .zip(then.into_iter().zip(or_else.into_iter()))
369 .map(
370 |(cond, (then, or_else))| {
371 if cond != 0 {
372 then
373 } else {
374 or_else
375 }
376 },
377 )
378 .copied()
379 .collect();
380
381 Ok(output)
382 }
383}
384
385impl<A, L, R, T> Enqueue<Heap, T> for Cond<A, L, R, T>
386where
387 A: Access<u8>,
388 L: Access<T>,
389 R: Access<T>,
390 T: CType,
391{
392 type Buffer = Vec<T>;
393
394 fn enqueue(&self) -> Result<Self::Buffer, Error> {
395 let (cond, (then, or_else)) = join(
396 || self.cond.read().and_then(|buf| buf.to_slice()),
397 || {
398 join(
399 || self.then.read().and_then(|buf| buf.to_slice()),
400 || self.or_else.read().and_then(|buf| buf.to_slice()),
401 )
402 },
403 );
404
405 let (cond, (then, or_else)) = (cond?, (then?, or_else?));
406
407 let output = cond
408 .into_par_iter()
409 .copied()
410 .zip(then.into_par_iter().zip(or_else.into_par_iter()))
411 .map(
412 |(cond, (then, or_else))| {
413 if cond != 0 {
414 then
415 } else {
416 or_else
417 }
418 },
419 )
420 .copied()
421 .collect();
422
423 Ok(output)
424 }
425}
426
427impl<A, L, R, T> Enqueue<Host, T> for Cond<A, L, R, T>
428where
429 A: Access<u8>,
430 L: Access<T>,
431 R: Access<T>,
432 T: CType,
433{
434 type Buffer = Buffer<T>;
435
436 fn enqueue(&self) -> Result<Self::Buffer, Error> {
437 host_enqueue!(self, self.size() < VEC_MIN_SIZE, T)
438 }
439}
440
441impl<A, L, R, T> ReadValue<Host, T> for Cond<A, L, R, T>
442where
443 A: Access<u8>,
444 L: Access<T>,
445 R: Access<T>,
446 T: CType,
447{
448 fn read_value(&self, offset: usize) -> Result<T, Error> {
449 let (cond, (then, or_else)) = join(
450 || self.cond.read_value(offset),
451 || {
452 join(
453 || self.then.read_value(offset),
454 || self.or_else.read_value(offset),
455 )
456 },
457 );
458
459 let (cond, (then, or_else)) = (cond?, (then?, or_else?));
460
461 if cond != 0 {
462 Ok(then)
463 } else {
464 Ok(or_else)
465 }
466 }
467}
468
469pub struct Linear<T> {
470 start: T,
471 step: f64,
472 size: usize,
473}
474
475impl<T> Linear<T> {
476 pub fn new(start: T, step: f64, size: usize) -> Self {
477 Self { start, step, size }
478 }
479
480 #[inline]
481 fn value_at(&self, offset: usize) -> T
482 where
483 T: CType,
484 {
485 T::add(self.start, T::from_f64((offset as f64) * self.step))
486 }
487}
488
489impl<T: Send + Sync> Op for Linear<T> {
490 fn size(&self) -> usize {
491 self.size
492 }
493}
494
495impl<T: CType> Enqueue<Stack, T> for Linear<T> {
496 type Buffer = StackVec<T>;
497
498 fn enqueue(&self) -> Result<Self::Buffer, Error> {
499 let start = self.start.to_f64();
500
501 let buffer = (0..self.size)
502 .into_iter()
503 .map(|i| i as f64)
504 .map(|i| i * self.step)
505 .map(|o| start + o)
506 .map(T::from_f64)
507 .collect();
508
509 Ok(buffer)
510 }
511}
512
513impl<T: CType> Enqueue<Heap, T> for Linear<T> {
514 type Buffer = Vec<T>;
515
516 fn enqueue(&self) -> Result<Self::Buffer, Error> {
517 let buffer = (0..self.size)
518 .into_par_iter()
519 .map(|offset| self.value_at(offset))
520 .collect();
521
522 Ok(buffer)
523 }
524}
525
526impl<T: CType> Enqueue<Host, T> for Linear<T> {
527 type Buffer = Buffer<T>;
528
529 fn enqueue(&self) -> Result<Self::Buffer, Error> {
530 host_enqueue!(self, self.size < VEC_MIN_SIZE, T)
531 }
532}
533
534impl<T: CType> ReadValue<Host, T> for Linear<T> {
535 fn read_value(&self, offset: usize) -> Result<T, Error> {
536 Ok(self.value_at(offset))
537 }
538}
539
540pub struct MatDiag<A, T> {
541 access: A,
542 dim: usize,
543 batch_size: usize,
544 dtype: PhantomData<T>,
545}
546
547impl<A, T> MatDiag<A, T> {
548 pub fn new(access: A, batch_size: usize, dim: usize) -> Self {
549 Self {
550 access,
551 dim,
552 batch_size,
553 dtype: PhantomData,
554 }
555 }
556}
557
558impl<A: Access<T>, T: CType> Op for MatDiag<A, T> {
559 fn size(&self) -> usize {
560 debug_assert_eq!(self.access.size(), self.batch_size * self.dim * self.dim);
561 self.batch_size * self.dim
562 }
563}
564
565impl<A: Access<T>, T: CType> Enqueue<Heap, T> for MatDiag<A, T> {
566 type Buffer = Vec<T>;
567
568 fn enqueue(&self) -> Result<Self::Buffer, Error> {
569 let input = self.access.read()?.to_slice()?;
570
571 let diagonals = input
572 .par_chunks_exact(self.dim * self.dim)
573 .map(|matrix| {
574 matrix
575 .par_chunks_exact(self.dim)
576 .enumerate()
577 .map(|(i, row)| row[i])
578 })
579 .flatten()
580 .collect();
581
582 Ok(diagonals)
583 }
584}
585
586impl<A: Access<T>, T: CType> Enqueue<Stack, T> for MatDiag<A, T> {
587 type Buffer = StackVec<T>;
588
589 fn enqueue(&self) -> Result<Self::Buffer, Error> {
590 let input = self.access.read()?.to_slice()?;
591
592 let diagonals = input
593 .chunks_exact(self.dim * self.dim)
594 .map(|matrix| {
595 matrix
596 .chunks_exact(self.dim)
597 .enumerate()
598 .map(|(i, row)| row[i])
599 })
600 .flatten()
601 .collect();
602
603 Ok(diagonals)
604 }
605}
606
607impl<A: Access<T>, T: CType> Enqueue<Host, T> for MatDiag<A, T> {
608 type Buffer = Buffer<T>;
609
610 fn enqueue(&self) -> Result<Self::Buffer, Error> {
611 host_enqueue!(self, self.size() < VEC_MIN_SIZE, T)
612 }
613}
614
615impl<A: Access<T>, T: CType> ReadValue<Host, T> for MatDiag<A, T> {
616 fn read_value(&self, offset: usize) -> Result<T, Error> {
617 let batch = offset / self.batch_size;
618 let i = offset % self.batch_size;
619 let source_offset = (batch * self.dim * self.dim) + (i * self.dim) + i;
620 self.access.read_value(source_offset)
621 }
622}
623
624pub struct MatMul<L, R, T> {
625 left: L,
626 right: R,
627 batch_size: usize,
628 dims: [usize; 3],
629 dtype: PhantomData<T>,
630}
631
632impl<L, R, T> MatMul<L, R, T> {
633 pub fn new(left: L, right: R, dims: [usize; 4]) -> Self {
634 let [batch_size, a, b, c] = dims;
635
636 Self {
637 left,
638 right,
639 batch_size,
640 dims: [a, b, c],
641 dtype: PhantomData,
642 }
643 }
644}
645
646impl<L, R, T> Op for MatMul<L, R, T>
647where
648 L: Send + Sync,
649 R: Send + Sync,
650 T: Send + Sync,
651{
652 fn size(&self) -> usize {
653 self.batch_size * self.dims[0] * self.dims[2]
654 }
655}
656
657impl<L, R, T> Enqueue<Stack, T> for MatMul<L, R, T>
658where
659 L: Access<T>,
660 R: Access<T>,
661 T: CType,
662{
663 type Buffer = StackVec<T>;
664
665 fn enqueue(&self) -> Result<Self::Buffer, Error> {
666 let left = self.left.read()?.to_slice()?;
667 let right = self.right.read()?.to_slice()?;
668
669 let [a, b, c] = self.dims;
670
671 let mut product = StackVec::with_capacity(self.batch_size * a * c);
672
673 for _batch in 0..self.batch_size {
674 for x in 0..a {
675 for z in 0..c {
676 let mut sum = T::ZERO;
677
678 for y in 0..b {
679 let l_offset = (x * b) + y;
680 let r_offset = (y * c) + z;
681 sum = T::add(sum, T::mul(left[l_offset], right[r_offset]));
682 }
683
684 product.push(sum)
685 }
686 }
687 }
688
689 debug_assert_eq!(product.len(), self.size());
690
691 Ok(product)
692 }
693}
694
695impl<L, R, T> Enqueue<Heap, T> for MatMul<L, R, T>
696where
697 L: Access<T>,
698 R: Access<T>,
699 T: CType,
700{
701 type Buffer = Vec<T>;
702
703 fn enqueue(&self) -> Result<Self::Buffer, Error> {
704 let [a, b, c] = self.dims;
705
706 let (left, right) = try_join_read(&self.left, &self.right)?;
707
708 let right_size = b * c;
710 let right_matrices = right.par_chunks_exact(right_size).map(|right| {
711 let mut right_t = vec![T::ZERO; right_size];
712 transpose::transpose(right, &mut right_t[..], c, b);
713 right_t
714 });
715
716 let left_size = a * b;
717 let left_matrices = left.par_chunks_exact(left_size);
718
719 let output_size = a * c;
720 let mut output = Vec::<T>::with_capacity(self.batch_size * output_size);
721 let output_matrices = left_matrices
722 .zip(right_matrices)
723 .map(|(lm, rm)| {
724 let mut out = Vec::<T>::with_capacity(output_size);
725
726 let product = lm
727 .par_chunks_exact(b)
728 .map(|row| {
729 rm.par_chunks_exact(b).map(move |col| {
730 let col = col.par_chunks(8).map(|cc| cc.into_iter().copied());
732
733 row.par_chunks(8)
734 .zip(col)
735 .map(|(rc, cc)| {
736 rc.into_iter()
737 .copied()
738 .zip(cc)
739 .map(|(r, c)| T::mul(r, c))
740 .reduce(T::add)
741 .expect("sum")
742 })
743 .reduce(|| T::ZERO, T::add)
744 })
745 })
746 .flatten();
747
748 out.par_extend(product);
749 out
750 })
751 .flatten();
752
753 output.par_extend(output_matrices);
754
755 debug_assert_eq!(output.len(), self.batch_size * output_size);
756
757 Ok(output)
758 }
759}
760
761impl<L, R, T> Enqueue<Host, T> for MatMul<L, R, T>
762where
763 L: Access<T>,
764 R: Access<T>,
765 T: CType,
766{
767 type Buffer = Buffer<T>;
768
769 fn enqueue(&self) -> Result<Self::Buffer, Error> {
770 host_enqueue!(
771 self,
772 self.left.size() < VEC_MIN_SIZE && self.right.size() < VEC_MIN_SIZE,
773 T
774 )
775 }
776}
777
778impl<L, R, T> ReadValue<Host, T> for MatMul<L, R, T>
779where
780 L: Access<T>,
781 R: Access<T>,
782 T: CType,
783{
784 fn read_value(&self, _offset: usize) -> Result<T, Error> {
785 Err(Error::Bounds(
786 "reading an individual value from a matrix multiplication is not implemented"
787 .to_string(),
788 ))
789 }
790}
791
792pub struct Scalar<A, IT, OT> {
793 access: A,
794 scalar: IT,
795 op: fn(IT, IT) -> OT,
796}
797
798impl<A, IT, OT> Scalar<A, IT, OT> {
799 fn new(access: A, scalar: IT, op: fn(IT, IT) -> OT) -> Self {
800 Self { access, scalar, op }
801 }
802}
803
804impl<A, T: CType> Scalar<A, T, T> {
805 pub fn add(access: A, scalar: T) -> Self {
806 Self::new(access, scalar, T::add)
807 }
808
809 pub fn div(access: A, scalar: T) -> Self {
810 Self::new(access, scalar, T::div)
811 }
812
813 pub fn log(access: A, scalar: T) -> Self {
814 Self::new(access, scalar, |a, b| {
815 T::from_float(a.to_float().log(b.to_float()))
816 })
817 }
818
819 pub fn mul(access: A, scalar: T) -> Self {
820 Self::new(access, scalar, T::mul)
821 }
822
823 pub fn pow(access: A, scalar: T) -> Self {
824 Self::new(access, scalar, T::pow)
825 }
826
827 pub fn rem(access: A, scalar: T) -> Self {
828 Self::new(access, scalar, T::rem)
829 }
830
831 pub fn sub(access: A, scalar: T) -> Self {
832 Self::new(access, scalar, T::sub)
833 }
834}
835
836impl<A, T> Scalar<A, T, u8> {
837 pub fn and(access: A, scalar: T) -> Self
838 where
839 T: CType,
840 {
841 Self::new(access, scalar, |l, r| {
842 if (l != T::ZERO) && (r != T::ZERO) {
843 1
844 } else {
845 0
846 }
847 })
848 }
849
850 pub fn or(access: A, scalar: T) -> Self
851 where
852 T: CType,
853 {
854 Self::new(access, scalar, |l, r| {
855 if (l != T::ZERO) || (r != T::ZERO) {
856 1
857 } else {
858 0
859 }
860 })
861 }
862
863 pub fn xor(access: A, scalar: T) -> Self
864 where
865 T: CType,
866 {
867 Self::new(access, scalar, |l, r| {
868 if (l != T::ZERO) ^ (r != T::ZERO) {
869 1
870 } else {
871 0
872 }
873 })
874 }
875
876 pub fn eq(access: A, scalar: T) -> Self
877 where
878 T: PartialEq,
879 {
880 Self::new(access, scalar, |l, r| if l == r { 1 } else { 0 })
881 }
882
883 pub fn ge(access: A, scalar: T) -> Self
884 where
885 T: PartialOrd,
886 {
887 Self::new(access, scalar, |l, r| if l >= r { 1 } else { 0 })
888 }
889
890 pub fn gt(access: A, scalar: T) -> Self
891 where
892 T: PartialOrd,
893 {
894 Self::new(access, scalar, |l, r| if l > r { 1 } else { 0 })
895 }
896
897 pub fn le(access: A, scalar: T) -> Self
898 where
899 T: PartialOrd,
900 {
901 Self::new(access, scalar, |l, r| if l <= r { 1 } else { 0 })
902 }
903
904 pub fn lt(access: A, scalar: T) -> Self
905 where
906 T: PartialOrd,
907 {
908 Self::new(access, scalar, |l, r| if l < r { 1 } else { 0 })
909 }
910
911 pub fn ne(access: A, scalar: T) -> Self
912 where
913 T: PartialEq,
914 {
915 Self::new(access, scalar, |l, r| if l != r { 1 } else { 0 })
916 }
917}
918
919impl<A, IT, OT> Op for Scalar<A, IT, OT>
920where
921 A: Access<IT>,
922 IT: CType,
923 OT: CType,
924{
925 fn size(&self) -> usize {
926 self.access.size()
927 }
928}
929
930impl<A, IT, OT> Enqueue<Heap, OT> for Scalar<A, IT, OT>
931where
932 A: Access<IT>,
933 IT: CType,
934 OT: CType,
935{
936 type Buffer = Vec<OT>;
937
938 fn enqueue(&self) -> Result<Self::Buffer, Error> {
939 self.access
940 .read()
941 .and_then(|buf| buf.to_slice())
942 .map(|slice| {
943 slice
944 .as_ref()
945 .into_par_iter()
946 .copied()
947 .map(|l| (self.op)(l, self.scalar))
948 .collect()
949 })
950 }
951}
952
953impl<A, IT, OT> Enqueue<Stack, OT> for Scalar<A, IT, OT>
954where
955 A: Access<IT>,
956 IT: CType,
957 OT: CType,
958{
959 type Buffer = StackVec<OT>;
960
961 fn enqueue(&self) -> Result<Self::Buffer, Error> {
962 self.access
963 .read()
964 .and_then(|buf| buf.to_slice())
965 .map(|slice| {
966 slice
967 .as_ref()
968 .into_iter()
969 .copied()
970 .map(|l| (self.op)(l, self.scalar))
971 .collect()
972 })
973 }
974}
975
976impl<A, IT, OT> Enqueue<Host, OT> for Scalar<A, IT, OT>
977where
978 A: Access<IT>,
979 IT: CType,
980 OT: CType,
981{
982 type Buffer = Buffer<OT>;
983
984 fn enqueue(&self) -> Result<Self::Buffer, Error> {
985 host_enqueue!(self, self.size() < VEC_MIN_SIZE, OT)
986 }
987}
988
989impl<A, IT, OT> ReadValue<Host, OT> for Scalar<A, IT, OT>
990where
991 A: Access<IT>,
992 IT: CType,
993 OT: CType,
994{
995 fn read_value(&self, offset: usize) -> Result<OT, Error> {
996 self.access
997 .read_value(offset)
998 .map(|n| (self.op)(n, self.scalar))
999 }
1000}
1001
1002pub struct RandomNormal {
1003 size: usize,
1004}
1005
1006impl RandomNormal {
1007 pub fn new(size: usize) -> Self {
1008 Self { size }
1009 }
1010
1011 fn box_muller(u: [f32; 2]) -> [f32; 2] {
1012 let [u1, u2] = u;
1013 let r = (u1.ln() * -2.).sqrt();
1014 let theta = 2. * PI * u2;
1015 [r * theta.cos(), r * theta.sin()]
1016 }
1017}
1018
1019impl Op for RandomNormal {
1020 fn size(&self) -> usize {
1021 self.size
1022 }
1023}
1024
1025impl Enqueue<Heap, f32> for RandomNormal {
1026 type Buffer = Vec<f32>;
1027
1028 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1029 let mut u = vec![
1030 0.0f32;
1031 if self.size % 2 == 0 {
1032 self.size
1033 } else {
1034 self.size + 1
1035 }
1036 ];
1037
1038 rand::thread_rng().fill(&mut u[..]);
1039
1040 let mut output = u
1041 .par_chunks_exact(2)
1042 .map(|u| {
1043 let u: [f32; 2] = u.try_into().expect("u");
1044 Self::box_muller(u)
1045 })
1046 .flatten()
1047 .collect::<Vec<f32>>();
1048
1049 if output.len() > self.size {
1050 output.pop();
1051 }
1052
1053 debug_assert_eq!(output.len(), self.size);
1054
1055 Ok(output)
1056 }
1057}
1058
1059impl Enqueue<Stack, f32> for RandomNormal {
1060 type Buffer = StackVec<f32>;
1061
1062 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1063 let mut rng = rand::thread_rng();
1064
1065 let mut output = iter::repeat_with(|| [rng.gen(), rng.gen()])
1066 .take(self.size.div_ceil(2))
1067 .map(Self::box_muller)
1068 .flatten()
1069 .collect::<StackVec<f32>>();
1070
1071 if output.len() > self.size {
1072 output.pop();
1073 }
1074
1075 debug_assert_eq!(output.len(), self.size);
1076
1077 Ok(output)
1078 }
1079}
1080
1081impl Enqueue<Host, f32> for RandomNormal {
1082 type Buffer = Buffer<f32>;
1083
1084 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1085 host_enqueue!(self, self.size < VEC_MIN_SIZE, f32)
1086 }
1087}
1088
1089impl ReadValue<Host, f32> for RandomNormal {
1090 fn read_value(&self, _offset: usize) -> Result<f32, Error> {
1091 Err(Error::Bounds(
1092 "cannot calculate an individual value of a random normal distribution".to_string(),
1093 ))
1094 }
1095}
1096
1097pub struct RandomUniform {
1098 size: usize,
1099}
1100
1101impl RandomUniform {
1102 pub fn new(size: usize) -> Self {
1103 Self { size }
1104 }
1105}
1106
1107impl Op for RandomUniform {
1108 fn size(&self) -> usize {
1109 self.size
1110 }
1111}
1112
1113impl Enqueue<Heap, f32> for RandomUniform {
1114 type Buffer = Vec<f32>;
1115
1116 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1117 let mut data = vec![0.; self.size];
1118 rand::thread_rng().fill(&mut data[..]);
1119 Ok(data)
1120 }
1121}
1122
1123impl Enqueue<Stack, f32> for RandomUniform {
1124 type Buffer = StackVec<f32>;
1125
1126 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1127 let mut data = stackvec![0.; self.size];
1128 rand::thread_rng().fill(&mut data[..]);
1129 Ok(data)
1130 }
1131}
1132
1133impl Enqueue<Host, f32> for RandomUniform {
1134 type Buffer = Buffer<f32>;
1135
1136 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1137 host_enqueue!(self, self.size < VEC_MIN_SIZE, f32)
1138 }
1139}
1140
1141impl ReadValue<Host, f32> for RandomUniform {
1142 fn read_value(&self, _offset: usize) -> Result<f32, Error> {
1143 Ok(rand::thread_rng().gen())
1144 }
1145}
1146
1147pub struct Reduce<A, T> {
1148 access: A,
1149 stride: usize,
1150 reduce: fn(T, T) -> T,
1151 id: T,
1152}
1153
1154impl<A, T> Reduce<A, T>
1155where
1156 T: CType,
1157{
1158 pub fn max(access: A, stride: usize) -> Self {
1159 Self {
1160 access,
1161 stride,
1162 reduce: CType::max,
1163 id: T::MIN,
1164 }
1165 }
1166
1167 pub fn min(access: A, stride: usize) -> Self {
1168 Self {
1169 access,
1170 stride,
1171 reduce: CType::min,
1172 id: T::MAX,
1173 }
1174 }
1175
1176 pub fn product(access: A, stride: usize) -> Self {
1177 Self {
1178 access,
1179 stride,
1180 reduce: T::mul,
1181 id: T::ONE,
1182 }
1183 }
1184
1185 pub fn sum(access: A, stride: usize) -> Self {
1186 Self {
1187 access,
1188 stride,
1189 reduce: T::add,
1190 id: T::ZERO,
1191 }
1192 }
1193}
1194
1195impl<A: Access<T>, T: CType> Op for Reduce<A, T> {
1196 fn size(&self) -> usize {
1197 debug_assert_eq!(self.access.size() % self.stride, 0);
1198 self.access.size() / self.stride
1199 }
1200}
1201
1202impl<A: Access<T>, T: CType> Enqueue<Heap, T> for Reduce<A, T> {
1203 type Buffer = Vec<T>;
1204
1205 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1206 self.access
1207 .read()
1208 .and_then(|buf| buf.to_slice())
1209 .map(|slice| {
1210 slice
1211 .chunks_exact(self.stride)
1212 .map(|chunk| {
1213 chunk
1214 .par_chunks(8)
1216 .map(|chunk| {
1217 chunk.iter().copied().reduce(self.reduce).expect("reduced")
1218 })
1219 .reduce(|| self.id, self.reduce)
1220 })
1221 .collect()
1222 })
1223 }
1224}
1225
1226impl<A: Access<T>, T: CType> Enqueue<Stack, T> for Reduce<A, T> {
1227 type Buffer = StackVec<T>;
1228
1229 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1230 self.access
1231 .read()
1232 .and_then(|buf| buf.to_slice())
1233 .map(|slice| {
1234 slice
1235 .chunks_exact(self.stride)
1236 .map(|chunk| chunk.iter().copied().reduce(self.reduce).expect("reduced"))
1237 .collect()
1238 })
1239 }
1240}
1241
1242impl<A: Access<T>, T: CType> Enqueue<Host, T> for Reduce<A, T> {
1243 type Buffer = Buffer<T>;
1244
1245 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1246 host_enqueue!(
1247 self,
1248 self.stride < VEC_MIN_SIZE && self.size() < VEC_MIN_SIZE,
1249 T
1250 )
1251 }
1252}
1253
1254impl<A: Access<T>, T: CType> ReadValue<Host, T> for Reduce<A, T> {
1255 fn read_value(&self, offset: usize) -> Result<T, Error> {
1256 let offset = offset * self.stride;
1257
1258 if offset < self.access.size() {
1259 (offset..(offset + self.stride))
1260 .into_par_iter()
1261 .map(|offset| self.access.read_value(offset))
1262 .try_reduce(|| self.id, |r, v| Ok((self.reduce)(r, v)))
1263 } else {
1264 Err(Error::Bounds(format!(
1265 "invalid offset {offset} for a reduce op with size {}",
1266 self.size()
1267 )))
1268 }
1269 }
1270}
1271
1272pub struct Slice<A, T> {
1273 access: A,
1274 spec: SliceSpec,
1275 dtype: PhantomData<T>,
1276}
1277
1278impl<A, T> Slice<A, T> {
1279 pub fn new(access: A, shape: &[usize], range: Range) -> Self {
1280 let spec = SliceSpec::new(shape, range);
1281
1282 Self {
1283 access,
1284 spec,
1285 dtype: PhantomData,
1286 }
1287 }
1288}
1289
1290impl<A: Send + Sync, T: Copy + Send + Sync> Slice<A, T> {
1291 fn read(&self, source: &[T]) -> Result<StackVec<T>, Error> {
1292 let output = (0..self.size())
1293 .into_iter()
1294 .map(|offset_out| self.spec.source_offset(offset_out))
1295 .map(|offset_in| source[offset_in])
1296 .collect();
1297
1298 Ok(output)
1299 }
1300
1301 fn read_parallel(&self, source: &[T]) -> Result<Vec<T>, Error> {
1302 let output = (0..self.size())
1303 .into_par_iter()
1304 .map(|offset_out| self.spec.source_offset(offset_out))
1305 .map(|offset_in| source[offset_in])
1306 .collect();
1307
1308 Ok(output)
1309 }
1310}
1311
1312impl<A, T> Slice<A, T>
1313where
1314 T: CType,
1315 A: AccessMut<T>,
1316{
1317 fn overwrite<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
1318 if data.len() == self.size() {
1319 let data = data.to_slice()?;
1320
1321 for (offset, value) in data.into_iter().copied().enumerate() {
1322 let source_offset = self.spec.source_offset(offset);
1323 self.access.write_value_at(source_offset, value)?;
1324 }
1325
1326 Ok(())
1327 } else {
1328 Err(Error::Bounds(format!(
1329 "cannot overwrite a slice of size {} with a buffer of size {}",
1330 self.size(),
1331 data.len(),
1332 )))
1333 }
1334 }
1335
1336 fn overwrite_value(&mut self, value: T) -> Result<(), Error> {
1337 for offset in 0..self.access.size() {
1338 let source_offset = self.spec.source_offset(offset);
1339 self.access.write_value_at(source_offset, value)?;
1340 }
1341
1342 Ok(())
1343 }
1344
1345 fn overwrite_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
1346 let source_offset = self.spec.source_offset(offset);
1347 self.access.write_value_at(source_offset, value)
1348 }
1349}
1350
1351impl<A: Send + Sync, T: Send + Sync> Op for Slice<A, T> {
1352 fn size(&self) -> usize {
1353 self.spec.size()
1354 }
1355}
1356
1357impl<A: Access<T>, T: CType> Enqueue<Heap, T> for Slice<A, T> {
1358 type Buffer = Vec<T>;
1359
1360 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1361 self.access
1362 .read()
1363 .and_then(|buf| buf.to_slice())
1364 .and_then(|buf| self.read_parallel(&*buf))
1365 }
1366}
1367
1368impl<A: Access<T>, T: CType> Enqueue<Stack, T> for Slice<A, T> {
1369 type Buffer = StackVec<T>;
1370
1371 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1372 self.access
1373 .read()
1374 .and_then(|buf| buf.to_slice())
1375 .and_then(|buf| self.read(&*buf))
1376 }
1377}
1378
1379impl<A: Access<T>, T: CType> Enqueue<Host, T> for Slice<A, T> {
1380 type Buffer = Buffer<T>;
1381
1382 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1383 host_enqueue!(self, self.size() < VEC_MIN_SIZE, T)
1384 }
1385}
1386
1387impl<A: Access<T>, T: CType> ReadValue<Host, T> for Slice<A, T> {
1388 fn read_value(&self, offset: usize) -> Result<T, Error> {
1389 let offset = self.spec.source_offset(offset);
1390 self.access.read_value(offset)
1391 }
1392}
1393
1394impl<A, T> crate::ops::Write<Heap, T> for Slice<A, T>
1395where
1396 T: CType,
1397 A: AccessMut<T>,
1398{
1399 fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
1400 self.overwrite(data)
1401 }
1402
1403 fn write_value(&mut self, value: T) -> Result<(), Error> {
1404 self.overwrite_value(value)
1405 }
1406
1407 fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
1408 self.overwrite_value_at(offset, value)
1409 }
1410}
1411
1412impl<A, T> crate::ops::Write<Stack, T> for Slice<A, T>
1413where
1414 T: CType,
1415 A: AccessMut<T>,
1416{
1417 fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
1418 self.overwrite(data)
1419 }
1420
1421 fn write_value(&mut self, value: T) -> Result<(), Error> {
1422 self.overwrite_value(value)
1423 }
1424
1425 fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
1426 self.overwrite_value_at(offset, value)
1427 }
1428}
1429
1430impl<A, T> crate::ops::Write<Host, T> for Slice<A, T>
1431where
1432 T: CType,
1433 A: AccessMut<T>,
1434{
1435 fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
1436 self.overwrite(data)
1437 }
1438
1439 fn write_value(&mut self, value: T) -> Result<(), Error> {
1440 self.overwrite_value(value)
1441 }
1442
1443 fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
1444 self.overwrite_value_at(offset, value)
1445 }
1446}
1447
1448pub struct Unary<A, IT, OT> {
1449 access: A,
1450 op: fn(IT) -> OT,
1451}
1452
1453impl<A: Access<T>, T: CType> Unary<A, T, T> {
1454 pub fn abs(access: A) -> Self {
1455 Self {
1456 access,
1457 op: CType::abs,
1458 }
1459 }
1460
1461 pub fn exp(access: A) -> Self {
1462 Self {
1463 access,
1464 op: |n| T::from_float(n.to_float().exp()),
1465 }
1466 }
1467
1468 pub fn ln(access: A) -> Self {
1469 Self {
1470 access,
1471 op: |n| T::from_float(n.to_float().ln()),
1472 }
1473 }
1474
1475 pub fn round(access: A) -> Self {
1476 Self {
1477 access,
1478 op: CType::round,
1479 }
1480 }
1481}
1482
1483impl<A: Access<T>, T: CType> Unary<A, T, T::Float> {
1484 pub fn sin(access: A) -> Self {
1485 Self {
1486 access,
1487 op: |n| n.to_float().sin(),
1488 }
1489 }
1490
1491 pub fn asin(access: A) -> Self {
1492 Self {
1493 access,
1494 op: |n| n.to_float().asin(),
1495 }
1496 }
1497
1498 pub fn sinh(access: A) -> Self {
1499 Self {
1500 access,
1501 op: |n| n.to_float().sinh(),
1502 }
1503 }
1504
1505 pub fn cos(access: A) -> Self {
1506 Self {
1507 access,
1508 op: |n| n.to_float().cos(),
1509 }
1510 }
1511
1512 pub fn acos(access: A) -> Self {
1513 Self {
1514 access,
1515 op: |n| n.to_float().acos(),
1516 }
1517 }
1518
1519 pub fn cosh(access: A) -> Self {
1520 Self {
1521 access,
1522 op: |n| n.to_float().cosh(),
1523 }
1524 }
1525
1526 pub fn tan(access: A) -> Self {
1527 Self {
1528 access,
1529 op: |n| n.to_float().tan(),
1530 }
1531 }
1532
1533 pub fn atan(access: A) -> Self {
1534 Self {
1535 access,
1536 op: |n| n.to_float().atan(),
1537 }
1538 }
1539
1540 pub fn tanh(access: A) -> Self {
1541 Self {
1542 access,
1543 op: |n| n.to_float().tanh(),
1544 }
1545 }
1546}
1547
1548impl<A: Access<T>, T: CType> Unary<A, T, u8> {
1549 pub fn not(access: A) -> Self {
1550 Self {
1551 access,
1552 op: |n| if n == T::ZERO { 1 } else { 0 },
1553 }
1554 }
1555}
1556
1557impl<A: Access<T>, T: Float> Unary<A, T, u8> {
1558 pub fn inf(access: A) -> Self {
1559 Self {
1560 access,
1561 op: |n| if n.is_inf() { 1 } else { 0 },
1562 }
1563 }
1564
1565 pub fn nan(access: A) -> Self {
1566 Self {
1567 access,
1568 op: |n| if n.is_nan() { 1 } else { 0 },
1569 }
1570 }
1571}
1572
1573impl<A, IT, OT> Op for Unary<A, IT, OT>
1574where
1575 A: Access<IT>,
1576 IT: CType,
1577 OT: CType,
1578{
1579 fn size(&self) -> usize {
1580 self.access.size()
1581 }
1582}
1583
1584impl<A, IT, OT> Enqueue<Heap, OT> for Unary<A, IT, OT>
1585where
1586 A: Access<IT>,
1587 IT: CType,
1588 OT: CType,
1589{
1590 type Buffer = Vec<OT>;
1591
1592 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1593 self.access
1594 .read()
1595 .and_then(|buf| buf.to_slice())
1596 .map(|input| input.into_par_iter().copied().map(self.op).collect())
1597 }
1598}
1599
1600impl<A, IT, OT> Enqueue<Stack, OT> for Unary<A, IT, OT>
1601where
1602 A: Access<IT>,
1603 IT: CType,
1604 OT: CType,
1605{
1606 type Buffer = StackVec<OT>;
1607
1608 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1609 self.access
1610 .read()
1611 .and_then(|buf| buf.to_slice())
1612 .map(|input| input.into_iter().copied().map(self.op).collect())
1613 }
1614}
1615
1616impl<A, IT, OT> Enqueue<Host, OT> for Unary<A, IT, OT>
1617where
1618 A: Access<IT>,
1619 IT: CType,
1620 OT: CType,
1621{
1622 type Buffer = Buffer<OT>;
1623
1624 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1625 host_enqueue!(self, self.size() < VEC_MIN_SIZE, OT)
1626 }
1627}
1628
1629impl<A, IT, OT> ReadValue<Host, OT> for Unary<A, IT, OT>
1630where
1631 A: Access<IT>,
1632 IT: CType,
1633 OT: CType,
1634{
1635 fn read_value(&self, offset: usize) -> Result<OT, Error> {
1636 self.access.read_value(offset).map(|n| (self.op)(n))
1637 }
1638}
1639
1640pub struct View<A, T> {
1641 access: A,
1642 spec: ViewSpec,
1643 dtype: PhantomData<T>,
1644}
1645
1646impl<A: Access<T>, T: CType> View<A, T> {
1647 pub fn broadcast(access: A, shape: Shape, broadcast: Shape) -> Self {
1648 let source_strides = strides_for(&shape, shape.len()).collect();
1649
1650 Self {
1651 access,
1652 spec: ViewSpec::new(broadcast, source_strides),
1653 dtype: PhantomData,
1654 }
1655 }
1656
1657 pub fn transpose(access: A, shape: Shape, axes: Axes) -> Self {
1658 let strides = strides_for(&shape, shape.len()).collect::<Strides>();
1659 let shape = axes.iter().copied().map(|x| shape[x]).collect::<Strides>();
1660 let source_strides = axes.into_iter().map(|x| strides[x]).collect::<Strides>();
1661
1662 Self {
1663 access,
1664 spec: ViewSpec::new(shape, source_strides),
1665 dtype: PhantomData,
1666 }
1667 }
1668}
1669
1670impl<A: Access<T>, T: CType> Op for View<A, T> {
1671 fn size(&self) -> usize {
1672 self.spec.size()
1673 }
1674}
1675
1676impl<A: Access<T>, T: CType> Enqueue<Stack, T> for View<A, T> {
1677 type Buffer = StackVec<T>;
1678
1679 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1680 let source = self.access.read().and_then(|source| source.to_slice())?;
1681
1682 let buffer = (0..self.spec.size())
1683 .into_iter()
1684 .map(|offset| self.spec.source_offset(offset))
1685 .map(|source_offset| source[source_offset])
1686 .collect();
1687
1688 Ok(buffer)
1689 }
1690}
1691
1692impl<A: Access<T>, T: CType> Enqueue<Heap, T> for View<A, T> {
1693 type Buffer = Vec<T>;
1694
1695 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1696 let source = self.access.read().and_then(|source| source.to_slice())?;
1697
1698 let buffer = (0..self.spec.size())
1699 .into_par_iter()
1700 .map(|offset| self.spec.source_offset(offset))
1701 .map(|source_offset| source[source_offset])
1702 .collect();
1703
1704 Ok(buffer)
1705 }
1706}
1707
1708impl<A: Access<T>, T: CType> Enqueue<Host, T> for View<A, T> {
1709 type Buffer = Buffer<T>;
1710
1711 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1712 host_enqueue!(self, self.size() < VEC_MIN_SIZE, T)
1713 }
1714}
1715
1716impl<A: Access<T>, T: CType> ReadValue<Host, T> for View<A, T> {
1717 fn read_value(&self, offset: usize) -> Result<T, Error> {
1718 self.access.read_value(self.spec.source_offset(offset))
1719 }
1720}
1721
1722fn exec_dual<IT: CType, OT: CType>(
1723 zip: fn(IT, IT) -> OT,
1724 left: SliceConverter<IT>,
1725 right: SliceConverter<IT>,
1726) -> Result<StackVec<OT>, Error> {
1727 let output = left
1728 .into_iter()
1729 .copied()
1730 .zip(right.into_iter().copied())
1731 .map(|(l, r)| (zip)(l, r))
1732 .collect();
1733
1734 Ok(output)
1735}
1736
1737fn exec_dual_parallel<IT: CType, OT: CType>(
1738 zip: fn(IT, IT) -> OT,
1739 left: SliceConverter<IT>,
1740 right: SliceConverter<IT>,
1741) -> Result<Vec<OT>, Error> {
1742 let output = left
1743 .into_par_iter()
1744 .copied()
1745 .zip(right.into_par_iter().copied())
1746 .map(|(l, r)| (zip)(l, r))
1747 .collect();
1748
1749 Ok(output)
1750}
1751
1752#[inline]
1753fn try_join_read<'a, L, R, T>(
1754 left: &'a L,
1755 right: &'a R,
1756) -> Result<(SliceConverter<'a, T>, SliceConverter<'a, T>), Error>
1757where
1758 L: Access<T>,
1759 R: Access<T>,
1760 T: CType,
1761{
1762 let (l, r) = join(
1763 || left.read().and_then(|buf| buf.to_slice()),
1764 || right.read().and_then(|buf| buf.to_slice()),
1765 );
1766
1767 Ok((l?, r?))
1768}
1769
1770#[inline]
1771fn try_join_value<'a, L, R, T>(left: &'a L, right: &'a R, offset: usize) -> Result<(T, T), Error>
1772where
1773 L: Access<T>,
1774 R: Access<T>,
1775 T: CType,
1776{
1777 let (l, r) = join(|| left.read_value(offset), || right.read_value(offset));
1778
1779 Ok((l?, r?))
1780}