1use std::f32::consts::PI;
2use std::iter;
3use std::marker::PhantomData;
4
5use frand::Rand;
6use number_general as ng;
7use rayon::join;
8use rayon::prelude::*;
9
10use crate::access::Access;
11use crate::ops::{Concat, Enqueue, FlipSpec, Op, ReadValue, SliceSpec, ViewSpec};
12#[cfg(feature = "complex")]
13use crate::Complex;
14use crate::{
15 strides_for, AccessMut, Axes, BufferConverter, Error, Float, Number, Platform, Range, Real,
16 Shape, Strides,
17};
18
19use super::buffer::Buffer;
20use super::platform::{Heap, Host, Stack};
21use super::{SliceConverter, StackVec, VEC_MIN_SIZE};
22
23#[cfg(feature = "complex")]
24pub mod complex;
25
26macro_rules! host_enqueue {
27 ($this:expr, $cond:expr, $t:ty) => {
28 if $cond {
29 Enqueue::<Stack, $t>::enqueue($this).map(Buffer::Stack)
30 } else {
31 Enqueue::<Heap, $t>::enqueue($this).map(Buffer::Heap)
32 }
33 };
34}
35
36pub struct Cast<A, IT, OT> {
37 access: A,
38 dtype: PhantomData<(IT, OT)>,
39}
40
41impl<A, IT, OT> Cast<A, IT, OT> {
42 pub fn new(access: A) -> Self {
43 Self {
44 access,
45 dtype: PhantomData,
46 }
47 }
48}
49
50impl<A: Access<IT>, IT: Number, OT: Number> Op for Cast<A, IT, OT> {
51 fn size(&self) -> usize {
52 self.access.size()
53 }
54}
55
56impl<A: Access<IT>, IT: Number, OT: Number> Enqueue<Heap, OT> for Cast<A, IT, OT> {
57 type Buffer = Vec<OT>;
58
59 fn enqueue(&self) -> Result<Self::Buffer, Error> {
60 self.access
61 .read()
62 .and_then(|buf| buf.to_slice())
63 .map(|slice| {
64 slice
65 .into_par_iter()
66 .copied()
67 .map(|n| n.into())
68 .map(OT::cast_from)
69 .collect()
70 })
71 }
72}
73
74impl<A: Access<IT>, IT: Number, OT: Number> Enqueue<Stack, OT> for Cast<A, IT, OT> {
75 type Buffer = StackVec<OT>;
76
77 fn enqueue(&self) -> Result<Self::Buffer, Error> {
78 self.access
79 .read()
80 .and_then(|buf| buf.to_slice())
81 .map(|slice| {
82 slice
83 .iter()
84 .copied()
85 .map(|n| n.into())
86 .map(OT::cast_from)
87 .collect()
88 })
89 }
90}
91
92impl<A: Access<IT>, IT: Number, OT: Number> Enqueue<Host, OT> for Cast<A, IT, OT> {
93 type Buffer = Buffer<OT>;
94
95 fn enqueue(&self) -> Result<Self::Buffer, Error> {
96 host_enqueue!(self, self.size() < VEC_MIN_SIZE, OT)
97 }
98}
99
100impl<A: Access<IT>, IT: Number, OT: Number> ReadValue<Host, OT> for Cast<A, IT, OT> {
101 fn read_value(&self, offset: usize) -> Result<OT, Error> {
102 self.access
103 .read_value(offset)
104 .map(|n| n.into())
105 .map(OT::cast_from)
106 }
107}
108
109pub struct Cond<A, L, R, T> {
110 cond: A,
111 then: L,
112 or_else: R,
113 dtype: PhantomData<T>,
114}
115
116impl<A, L, R, T> Cond<A, L, R, T> {
117 pub fn new(cond: A, then: L, or_else: R) -> Self {
118 Self {
119 cond,
120 then,
121 or_else,
122 dtype: PhantomData,
123 }
124 }
125}
126
127impl<A, L, R, T> Op for Cond<A, L, R, T>
128where
129 A: Access<u8>,
130 L: Access<T>,
131 R: Access<T>,
132 T: Number,
133{
134 fn size(&self) -> usize {
135 debug_assert_eq!(self.cond.size(), self.then.size());
136 debug_assert_eq!(self.cond.size(), self.or_else.size());
137 self.cond.size()
138 }
139}
140
141impl<A, L, R, T> Enqueue<Stack, T> for Cond<A, L, R, T>
142where
143 A: Access<u8>,
144 L: Access<T>,
145 R: Access<T>,
146 T: Number,
147{
148 type Buffer = StackVec<T>;
149
150 fn enqueue(&self) -> Result<Self::Buffer, Error> {
151 let cond = self.cond.read()?.to_slice()?;
152 let then = self.then.read()?.to_slice()?;
153 let or_else = self.or_else.read()?.to_slice()?;
154
155 let output = cond
156 .iter()
157 .copied()
158 .zip(then.iter().copied().zip(or_else.iter().copied()))
159 .map(|(cond, (then, or_else))| if cond != 0 { then } else { or_else })
160 .collect();
161
162 Ok(output)
163 }
164}
165
166impl<A, L, R, T> Enqueue<Heap, T> for Cond<A, L, R, T>
167where
168 A: Access<u8>,
169 L: Access<T>,
170 R: Access<T>,
171 T: Number,
172{
173 type Buffer = Vec<T>;
174
175 fn enqueue(&self) -> Result<Self::Buffer, Error> {
176 let (cond, (then, or_else)) = join(
177 || self.cond.read().and_then(|buf| buf.to_slice()),
178 || {
179 join(
180 || self.then.read().and_then(|buf| buf.to_slice()),
181 || self.or_else.read().and_then(|buf| buf.to_slice()),
182 )
183 },
184 );
185
186 let (cond, (then, or_else)) = (cond?, (then?, or_else?));
187
188 let output = cond
189 .into_par_iter()
190 .copied()
191 .zip(then.into_par_iter().zip(or_else.into_par_iter()))
192 .map(
193 |(cond, (then, or_else))| {
194 if cond != 0 {
195 then
196 } else {
197 or_else
198 }
199 },
200 )
201 .copied()
202 .collect();
203
204 Ok(output)
205 }
206}
207
208impl<A, L, R, T> Enqueue<Host, T> for Cond<A, L, R, T>
209where
210 A: Access<u8>,
211 L: Access<T>,
212 R: Access<T>,
213 T: Number,
214{
215 type Buffer = Buffer<T>;
216
217 fn enqueue(&self) -> Result<Self::Buffer, Error> {
218 host_enqueue!(self, self.size() < VEC_MIN_SIZE, T)
219 }
220}
221
222impl<A, L, R, T> ReadValue<Host, T> for Cond<A, L, R, T>
223where
224 A: Access<u8>,
225 L: Access<T>,
226 R: Access<T>,
227 T: Number,
228{
229 fn read_value(&self, offset: usize) -> Result<T, Error> {
230 let (cond, (then, or_else)) = join(
231 || self.cond.read_value(offset),
232 || {
233 join(
234 || self.then.read_value(offset),
235 || self.or_else.read_value(offset),
236 )
237 },
238 );
239
240 let (cond, (then, or_else)) = (cond?, (then?, or_else?));
241
242 if cond != 0 {
243 Ok(then)
244 } else {
245 Ok(or_else)
246 }
247 }
248}
249
250impl<A, T> Enqueue<Host, T> for Concat<A, T>
251where
252 A: Access<T>,
253 T: Number,
254{
255 type Buffer = Buffer<T>;
256
257 fn enqueue(&self) -> Result<Self::Buffer, Error> {
258 let mut buffer = Vec::with_capacity(self.size());
259
260 for access in self.data() {
261 let data = access.read()?.to_slice()?;
262 buffer.par_extend(data.into_par_iter().copied());
263 }
264
265 Ok(buffer.into())
266 }
267}
268
269impl<A, T> ReadValue<Host, T> for Concat<A, T>
270where
271 A: Access<T>,
272 T: Number,
273{
274 fn read_value(&self, offset: usize) -> Result<T, Error> {
275 ReadValue::<Platform, T>::read_value(self, offset)
276 }
277}
278
279pub struct Dual<L, R, IT, OT> {
280 left: L,
281 right: R,
282 zip: fn(IT, IT) -> OT,
283}
284
285impl<L, R, T: Number> Dual<L, R, T, T> {
287 pub fn add(left: L, right: R) -> Self {
288 Self {
289 left,
290 right,
291 zip: T::add,
292 }
293 }
294
295 pub fn div(left: L, right: R) -> Self {
296 Self {
297 left,
298 right,
299 zip: T::div,
300 }
301 }
302
303 pub fn mul(left: L, right: R) -> Self {
304 Self {
305 left,
306 right,
307 zip: T::mul,
308 }
309 }
310
311 pub fn pow(left: L, right: R) -> Self {
312 Self {
313 left,
314 right,
315 zip: T::pow,
316 }
317 }
318
319 pub fn rem(left: L, right: R) -> Self
320 where
321 T: Real,
322 {
323 Self {
324 left,
325 right,
326 zip: T::pow,
327 }
328 }
329
330 pub fn sub(left: L, right: R) -> Self {
331 Self {
332 left,
333 right,
334 zip: T::sub,
335 }
336 }
337}
338
339impl<L, R, T: Float> Dual<L, R, T, T> {
341 pub fn log(left: L, right: R) -> Self {
342 Self {
343 left,
344 right,
345 zip: T::log,
346 }
347 }
348}
349
350impl<L, R, T: Number> Dual<L, R, T, u8> {
352 pub fn and(left: L, right: R) -> Self {
353 Self {
354 left,
355 right,
356 zip: |l, r| if l != T::ZERO && r != T::ZERO { 1 } else { 0 },
357 }
358 }
359
360 pub fn or(left: L, right: R) -> Self {
361 Self {
362 left,
363 right,
364 zip: |l, r| if l != T::ZERO || r != T::ZERO { 1 } else { 0 },
365 }
366 }
367
368 pub fn xor(left: L, right: R) -> Self {
369 Self {
370 left,
371 right,
372 zip: |l, r| {
373 if (l != T::ZERO) ^ (r != T::ZERO) {
374 1
375 } else {
376 0
377 }
378 },
379 }
380 }
381}
382
383impl<L, R, T: Number> Dual<L, R, T, u8> {
385 pub fn eq(left: L, right: R) -> Self {
386 Self {
387 left,
388 right,
389 zip: |l, r| if l == r { 1 } else { 0 },
390 }
391 }
392
393 pub fn ne(left: L, right: R) -> Self {
394 Self {
395 left,
396 right,
397 zip: |l, r| if l != r { 1 } else { 0 },
398 }
399 }
400}
401
402impl<L, R, T: Number + PartialOrd> Dual<L, R, T, u8> {
403 pub fn ge(left: L, right: R) -> Self {
404 Self {
405 left,
406 right,
407 zip: |l, r| if l >= r { 1 } else { 0 },
408 }
409 }
410
411 pub fn gt(left: L, right: R) -> Self {
412 Self {
413 left,
414 right,
415 zip: |l, r| if l > r { 1 } else { 0 },
416 }
417 }
418
419 pub fn le(left: L, right: R) -> Self {
420 Self {
421 left,
422 right,
423 zip: |l, r| if l <= r { 1 } else { 0 },
424 }
425 }
426
427 pub fn lt(left: L, right: R) -> Self {
428 Self {
429 left,
430 right,
431 zip: |l, r| if l < r { 1 } else { 0 },
432 }
433 }
434}
435
436impl<L, R, IT, OT> Op for Dual<L, R, IT, OT>
437where
438 L: Access<IT>,
439 R: Access<IT>,
440 IT: Number,
441 OT: Number,
442{
443 fn size(&self) -> usize {
444 self.left.size()
445 }
446}
447
448impl<L, R, IT, OT> Enqueue<Stack, OT> for Dual<L, R, IT, OT>
449where
450 L: Access<IT>,
451 R: Access<IT>,
452 IT: Number,
453 OT: Number,
454{
455 type Buffer = StackVec<OT>;
456
457 fn enqueue(&self) -> Result<Self::Buffer, Error> {
458 let left = self.left.read()?.to_slice()?;
459 let right = self.right.read()?.to_slice()?;
460 exec_dual(self.zip, left, right)
461 }
462}
463
464impl<L, R, IT, OT> Enqueue<Heap, OT> for Dual<L, R, IT, OT>
465where
466 L: Access<IT>,
467 R: Access<IT>,
468 IT: Number,
469 OT: Number,
470{
471 type Buffer = Vec<OT>;
472
473 fn enqueue(&self) -> Result<Self::Buffer, Error> {
474 let (left, right) = try_join_read(&self.left, &self.right)?;
475 exec_dual_parallel(self.zip, left, right)
476 }
477}
478
479impl<L, R, IT, OT> Enqueue<Host, OT> for Dual<L, R, IT, OT>
480where
481 L: Access<IT>,
482 R: Access<IT>,
483 IT: Number,
484 OT: Number,
485{
486 type Buffer = Buffer<OT>;
487
488 fn enqueue(&self) -> Result<Self::Buffer, Error> {
489 host_enqueue!(self, self.size() < VEC_MIN_SIZE, OT)
490 }
491}
492
493impl<L, R, IT, OT> ReadValue<Host, OT> for Dual<L, R, IT, OT>
494where
495 L: Access<IT>,
496 R: Access<IT>,
497 IT: Number,
498 OT: Number,
499{
500 fn read_value(&self, offset: usize) -> Result<OT, Error> {
501 try_join_value(&self.left, &self.right, offset).map(|(l, r)| (self.zip)(l, r))
502 }
503}
504
505pub struct Flip<A, T> {
506 access: A,
507 spec: FlipSpec,
508 dtype: PhantomData<T>,
509}
510
511impl<A, T> Flip<A, T> {
512 pub fn new(access: A, shape: Shape, axis: usize) -> Result<Self, Error> {
513 FlipSpec::new(shape, axis).map(|spec| Self {
514 access,
515 spec,
516 dtype: PhantomData,
517 })
518 }
519}
520
521impl<A, T> Op for Flip<A, T>
522where
523 A: Access<T>,
524 T: Number,
525{
526 fn size(&self) -> usize {
527 self.access.size()
528 }
529}
530
531impl<A, T> Enqueue<Heap, T> for Flip<A, T>
532where
533 A: Access<T>,
534 T: Number,
535{
536 type Buffer = Vec<T>;
537
538 fn enqueue(&self) -> Result<Self::Buffer, Error> {
539 (0..self.size())
540 .into_par_iter()
541 .map(|offset| self.read_value(offset))
542 .collect()
543 }
544}
545
546impl<A, T> Enqueue<Stack, T> for Flip<A, T>
547where
548 A: Access<T>,
549 T: Number,
550{
551 type Buffer = StackVec<T>;
552
553 fn enqueue(&self) -> Result<Self::Buffer, Error> {
554 (0..self.size())
555 .map(|offset| self.read_value(offset))
556 .collect()
557 }
558}
559
560impl<A, T> Enqueue<Host, T> for Flip<A, T>
561where
562 A: Access<T>,
563 T: Number,
564{
565 type Buffer = Buffer<T>;
566
567 fn enqueue(&self) -> Result<Self::Buffer, Error> {
568 host_enqueue!(self, self.size() < VEC_MIN_SIZE, T)
569 }
570}
571
572impl<A, T> ReadValue<Host, T> for Flip<A, T>
573where
574 A: Access<T>,
575 T: Number,
576{
577 fn read_value(&self, offset: usize) -> Result<T, Error> {
578 debug_assert!(offset < self.size());
579 let offset = self.spec.source_offset(offset);
580 self.access.read_value(offset)
581 }
582}
583
584pub struct Linear<T> {
585 start: T,
586 step: T,
587 size: usize,
588}
589
590impl<T> Linear<T> {
591 pub fn new(start: T, step: T, size: usize) -> Self {
592 Self { start, step, size }
593 }
594
595 #[inline]
596 fn value_at(&self, offset: usize) -> T
597 where
598 T: Number,
599 {
600 let offset = T::cast_from(ng::Number::from(offset as u64));
601 T::add(self.start, T::mul(offset, self.step))
602 }
603}
604
605impl<T: Send + Sync> Op for Linear<T> {
606 fn size(&self) -> usize {
607 self.size
608 }
609}
610
611impl<T: Number> Enqueue<Stack, T> for Linear<T> {
612 type Buffer = StackVec<T>;
613
614 fn enqueue(&self) -> Result<Self::Buffer, Error> {
615 let buffer = (0..self.size).map(|offset| self.value_at(offset)).collect();
616
617 Ok(buffer)
618 }
619}
620
621impl<T: Number> Enqueue<Heap, T> for Linear<T> {
622 type Buffer = Vec<T>;
623
624 fn enqueue(&self) -> Result<Self::Buffer, Error> {
625 let buffer = (0..self.size)
626 .into_par_iter()
627 .map(|offset| self.value_at(offset))
628 .collect();
629
630 Ok(buffer)
631 }
632}
633
634impl<T: Number> Enqueue<Host, T> for Linear<T> {
635 type Buffer = Buffer<T>;
636
637 fn enqueue(&self) -> Result<Self::Buffer, Error> {
638 host_enqueue!(self, self.size < VEC_MIN_SIZE, T)
639 }
640}
641
642impl<T: Number> ReadValue<Host, T> for Linear<T> {
643 fn read_value(&self, offset: usize) -> Result<T, Error> {
644 Ok(self.value_at(offset))
645 }
646}
647
648pub struct MatDiag<A, T> {
649 access: A,
650 dim: usize,
651 batch_size: usize,
652 dtype: PhantomData<T>,
653}
654
655impl<A, T> MatDiag<A, T> {
656 pub fn new(access: A, batch_size: usize, dim: usize) -> Self {
657 Self {
658 access,
659 dim,
660 batch_size,
661 dtype: PhantomData,
662 }
663 }
664}
665
666impl<A: Access<T>, T: Number> Op for MatDiag<A, T> {
667 fn size(&self) -> usize {
668 debug_assert_eq!(self.access.size(), self.batch_size * self.dim * self.dim);
669 self.batch_size * self.dim
670 }
671}
672
673impl<A: Access<T>, T: Number> Enqueue<Heap, T> for MatDiag<A, T> {
674 type Buffer = Vec<T>;
675
676 fn enqueue(&self) -> Result<Self::Buffer, Error> {
677 let input = self.access.read()?.to_slice()?;
678
679 let diagonals = input
680 .par_chunks_exact(self.dim * self.dim)
681 .flat_map(|matrix| {
682 matrix
683 .par_chunks_exact(self.dim)
684 .enumerate()
685 .map(|(i, row)| row[i])
686 })
687 .collect();
688
689 Ok(diagonals)
690 }
691}
692
693impl<A: Access<T>, T: Number> Enqueue<Stack, T> for MatDiag<A, T> {
694 type Buffer = StackVec<T>;
695
696 fn enqueue(&self) -> Result<Self::Buffer, Error> {
697 let input = self.access.read()?.to_slice()?;
698
699 let diagonals = input
700 .chunks_exact(self.dim * self.dim)
701 .flat_map(|matrix| {
702 matrix
703 .chunks_exact(self.dim)
704 .enumerate()
705 .map(|(i, row)| row[i])
706 })
707 .collect();
708
709 Ok(diagonals)
710 }
711}
712
713impl<A: Access<T>, T: Number> Enqueue<Host, T> for MatDiag<A, T> {
714 type Buffer = Buffer<T>;
715
716 fn enqueue(&self) -> Result<Self::Buffer, Error> {
717 host_enqueue!(self, self.size() < VEC_MIN_SIZE, T)
718 }
719}
720
721impl<A: Access<T>, T: Number> ReadValue<Host, T> for MatDiag<A, T> {
722 fn read_value(&self, offset: usize) -> Result<T, Error> {
723 let batch = offset / self.batch_size;
724 let i = offset % self.batch_size;
725 let source_offset = (batch * self.dim * self.dim) + (i * self.dim) + i;
726 self.access.read_value(source_offset)
727 }
728}
729
730pub struct MatMul<L, R, T> {
731 left: L,
732 right: R,
733 batch_size: usize,
734 dims: [usize; 3],
735 dtype: PhantomData<T>,
736}
737
738impl<L, R, T> MatMul<L, R, T> {
739 pub fn new(left: L, right: R, dims: [usize; 4]) -> Self {
740 let [batch_size, a, b, c] = dims;
741
742 Self {
743 left,
744 right,
745 batch_size,
746 dims: [a, b, c],
747 dtype: PhantomData,
748 }
749 }
750}
751
752impl<L, R, T> Op for MatMul<L, R, T>
753where
754 L: Send + Sync,
755 R: Send + Sync,
756 T: Send + Sync,
757{
758 fn size(&self) -> usize {
759 self.batch_size * self.dims[0] * self.dims[2]
760 }
761}
762
763impl<L, R, T> Enqueue<Stack, T> for MatMul<L, R, T>
764where
765 L: Access<T>,
766 R: Access<T>,
767 T: Number,
768{
769 type Buffer = StackVec<T>;
770
771 fn enqueue(&self) -> Result<Self::Buffer, Error> {
772 let left = self.left.read()?.to_slice()?;
773 let right = self.right.read()?.to_slice()?;
774
775 let [a, b, c] = self.dims;
776
777 let mut product = StackVec::with_capacity(self.batch_size * a * c);
778
779 for batch in 0..self.batch_size {
780 let l_start = batch * a * b;
781 let r_start = batch * b * c;
782
783 for x in 0..a {
784 for z in 0..c {
785 let mut sum = T::ZERO;
786
787 for y in 0..b {
788 let l_offset = l_start + (x * b) + y;
789 let r_offset = r_start + (y * c) + z;
790 sum = T::add(sum, T::mul(left[l_offset], right[r_offset]));
791 }
792
793 product.push(sum)
794 }
795 }
796 }
797
798 debug_assert_eq!(product.len(), self.size());
799
800 Ok(product)
801 }
802}
803
804impl<L, R, T> Enqueue<Heap, T> for MatMul<L, R, T>
805where
806 L: Access<T>,
807 R: Access<T>,
808 T: Number,
809{
810 type Buffer = Vec<T>;
811
812 fn enqueue(&self) -> Result<Self::Buffer, Error> {
813 let [a, b, c] = self.dims;
814
815 let (left, right) = try_join_read(&self.left, &self.right)?;
816
817 let right_size = b * c;
819 let right_matrices = right.par_chunks_exact(right_size).map(|right| {
820 let mut right_t = vec![T::ZERO; right_size];
821 transpose::transpose(right, &mut right_t[..], c, b);
822 right_t
823 });
824
825 let left_size = a * b;
826 let left_matrices = left.par_chunks_exact(left_size);
827
828 let output_size = a * c;
829 let mut output = Vec::<T>::with_capacity(self.batch_size * output_size);
830 let output_matrices = left_matrices
831 .zip(right_matrices)
832 .map(|(lm, rm)| {
833 let mut out = Vec::<T>::with_capacity(output_size);
834
835 let product = lm
836 .par_chunks_exact(b)
837 .map(|row| {
838 rm.par_chunks_exact(b).map(move |col| {
839 let col = col.par_chunks(8).map(|cc| cc.iter().copied());
841
842 row.par_chunks(8)
843 .zip(col)
844 .map(|(rc, cc)| {
845 rc.iter()
846 .copied()
847 .zip(cc)
848 .map(|(r, c)| T::mul(r, c))
849 .reduce(T::add)
850 .expect("sum")
851 })
852 .reduce(|| T::ZERO, T::add)
853 })
854 })
855 .flatten();
856
857 out.par_extend(product);
858 out
859 })
860 .flatten();
861
862 output.par_extend(output_matrices);
863
864 debug_assert_eq!(output.len(), self.batch_size * output_size);
865
866 Ok(output)
867 }
868}
869
870impl<L, R, T> Enqueue<Host, T> for MatMul<L, R, T>
871where
872 L: Access<T>,
873 R: Access<T>,
874 T: Number,
875{
876 type Buffer = Buffer<T>;
877
878 fn enqueue(&self) -> Result<Self::Buffer, Error> {
879 host_enqueue!(
880 self,
881 self.left.size() < VEC_MIN_SIZE && self.right.size() < VEC_MIN_SIZE,
882 T
883 )
884 }
885}
886
887impl<L, R, T> ReadValue<Host, T> for MatMul<L, R, T>
888where
889 L: Access<T>,
890 R: Access<T>,
891 T: Number,
892{
893 fn read_value(&self, _offset: usize) -> Result<T, Error> {
894 Err(Error::bounds(
895 "reading an individual value from a matrix multiplication is not implemented"
896 .to_string(),
897 ))
898 }
899}
900
901pub struct Scalar<A, IT, OT> {
902 access: A,
903 scalar: IT,
904 op: fn(IT, IT) -> OT,
905}
906
907impl<A, IT, OT> Scalar<A, IT, OT> {
908 fn new(access: A, scalar: IT, op: fn(IT, IT) -> OT) -> Self {
909 Self { access, scalar, op }
910 }
911}
912
913impl<A, T: Number> Scalar<A, T, T> {
914 pub fn add(access: A, scalar: T) -> Self {
915 Self::new(access, scalar, T::add)
916 }
917
918 pub fn div(access: A, scalar: T) -> Self {
919 Self::new(access, scalar, T::div)
920 }
921
922 pub fn mul(access: A, scalar: T) -> Self {
923 Self::new(access, scalar, T::mul)
924 }
925
926 pub fn pow(access: A, scalar: T) -> Self {
927 Self::new(access, scalar, T::pow)
928 }
929
930 pub fn rem(access: A, scalar: T) -> Self
931 where
932 T: Real,
933 {
934 Self::new(access, scalar, T::rem)
935 }
936
937 pub fn sub(access: A, scalar: T) -> Self {
938 Self::new(access, scalar, T::sub)
939 }
940}
941
942impl<A, T: Float> Scalar<A, T, T> {
943 pub fn log(access: A, scalar: T) -> Self {
944 Self::new(access, scalar, T::log)
945 }
946}
947
948impl<A, T> Scalar<A, T, u8>
949where
950 T: Number,
951{
952 pub fn and(access: A, scalar: T) -> Self {
953 Self::new(access, scalar, |l, r| {
954 if (l != T::ZERO) && (r != T::ZERO) {
955 1
956 } else {
957 0
958 }
959 })
960 }
961
962 pub fn or(access: A, scalar: T) -> Self {
963 Self::new(access, scalar, |l, r| {
964 if (l != T::ZERO) || (r != T::ZERO) {
965 1
966 } else {
967 0
968 }
969 })
970 }
971
972 pub fn xor(access: A, scalar: T) -> Self {
973 Self::new(access, scalar, |l, r| {
974 if (l != T::ZERO) ^ (r != T::ZERO) {
975 1
976 } else {
977 0
978 }
979 })
980 }
981
982 pub fn eq(access: A, scalar: T) -> Self {
983 Self::new(access, scalar, |l, r| if l == r { 1 } else { 0 })
984 }
985
986 pub fn ge(access: A, scalar: T) -> Self
987 where
988 T: Real,
989 {
990 Self::new(access, scalar, |l, r| if l >= r { 1 } else { 0 })
991 }
992
993 pub fn gt(access: A, scalar: T) -> Self
994 where
995 T: Real,
996 {
997 Self::new(access, scalar, |l, r| if l > r { 1 } else { 0 })
998 }
999
1000 pub fn le(access: A, scalar: T) -> Self
1001 where
1002 T: Real,
1003 {
1004 Self::new(access, scalar, |l, r| if l <= r { 1 } else { 0 })
1005 }
1006
1007 pub fn lt(access: A, scalar: T) -> Self
1008 where
1009 T: Real,
1010 {
1011 Self::new(access, scalar, |l, r| if l < r { 1 } else { 0 })
1012 }
1013
1014 pub fn ne(access: A, scalar: T) -> Self {
1015 Self::new(access, scalar, |l, r| if l != r { 1 } else { 0 })
1016 }
1017}
1018
1019impl<A, IT, OT> Op for Scalar<A, IT, OT>
1020where
1021 A: Access<IT>,
1022 IT: Number,
1023 OT: Number,
1024{
1025 fn size(&self) -> usize {
1026 self.access.size()
1027 }
1028}
1029
1030impl<A, IT, OT> Enqueue<Heap, OT> for Scalar<A, IT, OT>
1031where
1032 A: Access<IT>,
1033 IT: Number,
1034 OT: Number,
1035{
1036 type Buffer = Vec<OT>;
1037
1038 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1039 self.access
1040 .read()
1041 .and_then(|buf| buf.to_slice())
1042 .map(|slice| {
1043 slice
1044 .as_ref()
1045 .into_par_iter()
1046 .copied()
1047 .map(|l| (self.op)(l, self.scalar))
1048 .collect()
1049 })
1050 }
1051}
1052
1053impl<A, IT, OT> Enqueue<Stack, OT> for Scalar<A, IT, OT>
1054where
1055 A: Access<IT>,
1056 IT: Number,
1057 OT: Number,
1058{
1059 type Buffer = StackVec<OT>;
1060
1061 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1062 self.access
1063 .read()
1064 .and_then(|buf| buf.to_slice())
1065 .map(|slice| {
1066 slice
1067 .as_ref()
1068 .iter()
1069 .copied()
1070 .map(|l| (self.op)(l, self.scalar))
1071 .collect()
1072 })
1073 }
1074}
1075
1076impl<A, IT, OT> Enqueue<Host, OT> for Scalar<A, IT, OT>
1077where
1078 A: Access<IT>,
1079 IT: Number,
1080 OT: Number,
1081{
1082 type Buffer = Buffer<OT>;
1083
1084 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1085 host_enqueue!(self, self.size() < VEC_MIN_SIZE, OT)
1086 }
1087}
1088
1089impl<A, IT, OT> ReadValue<Host, OT> for Scalar<A, IT, OT>
1090where
1091 A: Access<IT>,
1092 IT: Number,
1093 OT: Number,
1094{
1095 fn read_value(&self, offset: usize) -> Result<OT, Error> {
1096 self.access
1097 .read_value(offset)
1098 .map(|n| (self.op)(n, self.scalar))
1099 }
1100}
1101
1102pub struct RandomNormal {
1103 size: usize,
1104}
1105
1106impl RandomNormal {
1107 pub fn new(size: usize) -> Self {
1108 Self { size }
1109 }
1110
1111 fn box_muller(u: [f32; 2]) -> [f32; 2] {
1112 let [u1, u2] = u;
1113 let r = (u1.ln() * -2.).sqrt();
1114 let theta = 2. * PI * u2;
1115 [r * theta.cos(), r * theta.sin()]
1116 }
1117}
1118
1119impl Op for RandomNormal {
1120 fn size(&self) -> usize {
1121 self.size
1122 }
1123}
1124
1125impl Enqueue<Heap, f32> for RandomNormal {
1126 type Buffer = Vec<f32>;
1127
1128 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1129 let mut rng = Rand::new();
1130
1131 let mut output = (0..self.size.div_ceil(2))
1132 .flat_map(|_| Self::box_muller([rng.gen(), rng.gen()]))
1133 .collect::<Vec<f32>>();
1134
1135 if output.len() > self.size {
1136 output.pop();
1137 }
1138
1139 debug_assert_eq!(output.len(), self.size);
1140
1141 Ok(output)
1142 }
1143}
1144
1145impl Enqueue<Stack, f32> for RandomNormal {
1146 type Buffer = StackVec<f32>;
1147
1148 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1149 let mut rng = Rand::new();
1150
1151 let mut output = iter::repeat_with(|| [rng.gen(), rng.gen()])
1152 .take(self.size.div_ceil(2))
1153 .flat_map(Self::box_muller)
1154 .collect::<StackVec<f32>>();
1155
1156 if output.len() > self.size {
1157 output.pop();
1158 }
1159
1160 debug_assert_eq!(output.len(), self.size);
1161
1162 Ok(output)
1163 }
1164}
1165
1166impl Enqueue<Host, f32> for RandomNormal {
1167 type Buffer = Buffer<f32>;
1168
1169 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1170 host_enqueue!(self, self.size < VEC_MIN_SIZE, f32)
1171 }
1172}
1173
1174impl ReadValue<Host, f32> for RandomNormal {
1175 fn read_value(&self, _offset: usize) -> Result<f32, Error> {
1176 Err(Error::bounds(
1177 "cannot calculate an individual value of a random normal distribution".to_string(),
1178 ))
1179 }
1180}
1181
1182pub struct RandomUniform {
1183 size: usize,
1184}
1185
1186impl RandomUniform {
1187 pub fn new(size: usize) -> Self {
1188 Self { size }
1189 }
1190}
1191
1192impl Op for RandomUniform {
1193 fn size(&self) -> usize {
1194 self.size
1195 }
1196}
1197
1198impl Enqueue<Heap, f32> for RandomUniform {
1199 type Buffer = Vec<f32>;
1200
1201 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1202 let mut rng = Rand::new();
1203 Ok((0..self.size).map(|_| rng.gen()).collect())
1204 }
1205}
1206
1207impl Enqueue<Stack, f32> for RandomUniform {
1208 type Buffer = StackVec<f32>;
1209
1210 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1211 let mut rng = Rand::new();
1212 Ok((0..self.size).map(|_| rng.gen()).collect())
1213 }
1214}
1215
1216impl Enqueue<Host, f32> for RandomUniform {
1217 type Buffer = Buffer<f32>;
1218
1219 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1220 host_enqueue!(self, self.size < VEC_MIN_SIZE, f32)
1221 }
1222}
1223
1224impl ReadValue<Host, f32> for RandomUniform {
1225 fn read_value(&self, _offset: usize) -> Result<f32, Error> {
1226 Ok(Rand::new().gen())
1227 }
1228}
1229
1230pub struct Reduce<A, T> {
1231 access: A,
1232 stride: usize,
1233 reduce: fn(T, T) -> T,
1234 id: T,
1235}
1236
1237impl<A, T> Reduce<A, T>
1238where
1239 T: Number,
1240{
1241 pub fn product(access: A, stride: usize) -> Self {
1242 Self {
1243 access,
1244 stride,
1245 reduce: T::mul,
1246 id: T::ONE,
1247 }
1248 }
1249
1250 pub fn sum(access: A, stride: usize) -> Self {
1251 Self {
1252 access,
1253 stride,
1254 reduce: T::add,
1255 id: T::ZERO,
1256 }
1257 }
1258}
1259
1260impl<A, T> Reduce<A, T>
1261where
1262 T: Real,
1263{
1264 pub fn max(access: A, stride: usize) -> Self {
1265 Self {
1266 access,
1267 stride,
1268 reduce: Real::max,
1269 id: T::MIN,
1270 }
1271 }
1272
1273 pub fn min(access: A, stride: usize) -> Self {
1274 Self {
1275 access,
1276 stride,
1277 reduce: Real::min,
1278 id: T::MAX,
1279 }
1280 }
1281}
1282
1283impl<A: Access<T>, T: Number> Op for Reduce<A, T> {
1284 fn size(&self) -> usize {
1285 debug_assert_eq!(self.access.size() % self.stride, 0);
1286 self.access.size() / self.stride
1287 }
1288}
1289
1290impl<A: Access<T>, T: Number> Enqueue<Heap, T> for Reduce<A, T> {
1291 type Buffer = Vec<T>;
1292
1293 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1294 self.access
1295 .read()
1296 .and_then(|buf| buf.to_slice())
1297 .map(|slice| {
1298 slice
1299 .chunks_exact(self.stride)
1300 .map(|chunk| {
1301 chunk
1302 .par_chunks(8)
1304 .map(|chunk| {
1305 chunk.iter().copied().reduce(self.reduce).expect("reduced")
1306 })
1307 .reduce(|| self.id, self.reduce)
1308 })
1309 .collect()
1310 })
1311 }
1312}
1313
1314impl<A: Access<T>, T: Number> Enqueue<Stack, T> for Reduce<A, T> {
1315 type Buffer = StackVec<T>;
1316
1317 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1318 self.access
1319 .read()
1320 .and_then(|buf| buf.to_slice())
1321 .map(|slice| {
1322 slice
1323 .chunks_exact(self.stride)
1324 .map(|chunk| chunk.iter().copied().reduce(self.reduce).expect("reduced"))
1325 .collect()
1326 })
1327 }
1328}
1329
1330impl<A: Access<T>, T: Number> Enqueue<Host, T> for Reduce<A, T> {
1331 type Buffer = Buffer<T>;
1332
1333 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1334 host_enqueue!(
1335 self,
1336 self.stride < VEC_MIN_SIZE && self.size() < VEC_MIN_SIZE,
1337 T
1338 )
1339 }
1340}
1341
1342impl<A: Access<T>, T: Number> ReadValue<Host, T> for Reduce<A, T> {
1343 fn read_value(&self, offset: usize) -> Result<T, Error> {
1344 let offset = offset * self.stride;
1345
1346 if offset < self.access.size() {
1347 (offset..(offset + self.stride))
1348 .into_par_iter()
1349 .map(|offset| self.access.read_value(offset))
1350 .try_reduce(|| self.id, |r, v| Ok((self.reduce)(r, v)))
1351 } else {
1352 Err(Error::bounds(format!(
1353 "invalid offset {offset} for a reduce op with size {}",
1354 self.size()
1355 )))
1356 }
1357 }
1358}
1359
1360pub struct Slice<A, T> {
1361 access: A,
1362 spec: SliceSpec,
1363 dtype: PhantomData<T>,
1364}
1365
1366impl<A, T> Slice<A, T> {
1367 pub fn new(access: A, shape: &[usize], range: Range) -> Self {
1368 let spec = SliceSpec::new(shape, range);
1369
1370 Self {
1371 access,
1372 spec,
1373 dtype: PhantomData,
1374 }
1375 }
1376}
1377
1378impl<A: Send + Sync, T: Copy + Send + Sync> Slice<A, T> {
1379 fn read(&self, source: &[T]) -> Result<StackVec<T>, Error> {
1380 let output = (0..self.size())
1381 .map(|offset_out| self.spec.source_offset(offset_out))
1382 .map(|offset_in| source[offset_in])
1383 .collect();
1384
1385 Ok(output)
1386 }
1387
1388 fn read_parallel(&self, source: &[T]) -> Result<Vec<T>, Error> {
1389 let output = (0..self.size())
1390 .into_par_iter()
1391 .map(|offset_out| self.spec.source_offset(offset_out))
1392 .map(|offset_in| source[offset_in])
1393 .collect();
1394
1395 Ok(output)
1396 }
1397}
1398
1399impl<A, T> Slice<A, T>
1400where
1401 T: Number,
1402 A: AccessMut<T>,
1403{
1404 fn overwrite<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
1405 if data.len() == self.size() {
1406 let data = data.to_slice()?;
1407
1408 for (offset, value) in data.iter().copied().enumerate() {
1409 let source_offset = self.spec.source_offset(offset);
1410 self.access.write_value_at(source_offset, value)?;
1411 }
1412
1413 Ok(())
1414 } else {
1415 Err(Error::bounds(format!(
1416 "cannot overwrite a slice of size {} with a buffer of size {}",
1417 self.size(),
1418 data.len(),
1419 )))
1420 }
1421 }
1422
1423 fn overwrite_value(&mut self, value: T) -> Result<(), Error> {
1424 for offset in 0..self.access.size() {
1425 let source_offset = self.spec.source_offset(offset);
1426 self.access.write_value_at(source_offset, value)?;
1427 }
1428
1429 Ok(())
1430 }
1431
1432 fn overwrite_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
1433 let source_offset = self.spec.source_offset(offset);
1434 self.access.write_value_at(source_offset, value)
1435 }
1436}
1437
1438impl<A: Send + Sync, T: Send + Sync> Op for Slice<A, T> {
1439 fn size(&self) -> usize {
1440 self.spec.size()
1441 }
1442}
1443
1444impl<A: Access<T>, T: Number> Enqueue<Heap, T> for Slice<A, T> {
1445 type Buffer = Vec<T>;
1446
1447 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1448 self.access
1449 .read()
1450 .and_then(|buf| buf.to_slice())
1451 .and_then(|buf| self.read_parallel(&buf))
1452 }
1453}
1454
1455impl<A: Access<T>, T: Number> Enqueue<Stack, T> for Slice<A, T> {
1456 type Buffer = StackVec<T>;
1457
1458 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1459 self.access
1460 .read()
1461 .and_then(|buf| buf.to_slice())
1462 .and_then(|buf| self.read(&buf))
1463 }
1464}
1465
1466impl<A: Access<T>, T: Number> Enqueue<Host, T> for Slice<A, T> {
1467 type Buffer = Buffer<T>;
1468
1469 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1470 host_enqueue!(self, self.size() < VEC_MIN_SIZE, T)
1471 }
1472}
1473
1474impl<A: Access<T>, T: Number> ReadValue<Host, T> for Slice<A, T> {
1475 fn read_value(&self, offset: usize) -> Result<T, Error> {
1476 let offset = self.spec.source_offset(offset);
1477 self.access.read_value(offset)
1478 }
1479}
1480
1481impl<A, T> crate::ops::Write<Heap, T> for Slice<A, T>
1482where
1483 T: Number,
1484 A: AccessMut<T>,
1485{
1486 fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
1487 self.overwrite(data)
1488 }
1489
1490 fn write_value(&mut self, value: T) -> Result<(), Error> {
1491 self.overwrite_value(value)
1492 }
1493
1494 fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
1495 self.overwrite_value_at(offset, value)
1496 }
1497}
1498
1499impl<A, T> crate::ops::Write<Stack, T> for Slice<A, T>
1500where
1501 T: Number,
1502 A: AccessMut<T>,
1503{
1504 fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
1505 self.overwrite(data)
1506 }
1507
1508 fn write_value(&mut self, value: T) -> Result<(), Error> {
1509 self.overwrite_value(value)
1510 }
1511
1512 fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
1513 self.overwrite_value_at(offset, value)
1514 }
1515}
1516
1517impl<A, T> crate::ops::Write<Host, T> for Slice<A, T>
1518where
1519 T: Number,
1520 A: AccessMut<T>,
1521{
1522 fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
1523 self.overwrite(data)
1524 }
1525
1526 fn write_value(&mut self, value: T) -> Result<(), Error> {
1527 self.overwrite_value(value)
1528 }
1529
1530 fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
1531 self.overwrite_value_at(offset, value)
1532 }
1533}
1534
1535pub struct Unary<A, IT, OT> {
1536 access: A,
1537 op: fn(IT) -> OT,
1538}
1539
1540impl<A: Access<T>, T: Float> Unary<A, T, T> {
1541 pub fn exp(access: A) -> Self {
1542 Self { access, op: T::exp }
1543 }
1544
1545 pub fn ln(access: A) -> Self {
1546 Self { access, op: T::ln }
1547 }
1548}
1549
1550impl<A: Access<T>, T: Number> Unary<A, T, T::Abs> {
1551 pub fn abs(access: A) -> Self {
1552 Self {
1553 access,
1554 op: Number::abs,
1555 }
1556 }
1557}
1558
1559impl<A: Access<T>, T: Real> Unary<A, T, T> {
1560 pub fn round(access: A) -> Self {
1561 Self {
1562 access,
1563 op: Real::round,
1564 }
1565 }
1566}
1567
1568impl<A: Access<T>, T: Float> Unary<A, T, T> {
1569 pub fn sin(access: A) -> Self {
1570 Self {
1571 access,
1572 op: |n| n.sin(),
1573 }
1574 }
1575
1576 pub fn asin(access: A) -> Self {
1577 Self {
1578 access,
1579 op: |n| n.asin(),
1580 }
1581 }
1582
1583 pub fn sinh(access: A) -> Self {
1584 Self {
1585 access,
1586 op: |n| n.sinh(),
1587 }
1588 }
1589
1590 pub fn cos(access: A) -> Self {
1591 Self {
1592 access,
1593 op: |n| n.cos(),
1594 }
1595 }
1596
1597 pub fn acos(access: A) -> Self {
1598 Self {
1599 access,
1600 op: |n| n.acos(),
1601 }
1602 }
1603
1604 pub fn cosh(access: A) -> Self {
1605 Self {
1606 access,
1607 op: |n| n.cosh(),
1608 }
1609 }
1610
1611 pub fn tan(access: A) -> Self {
1612 Self {
1613 access,
1614 op: |n| n.tan(),
1615 }
1616 }
1617
1618 pub fn atan(access: A) -> Self {
1619 Self {
1620 access,
1621 op: |n| n.atan(),
1622 }
1623 }
1624
1625 pub fn tanh(access: A) -> Self {
1626 Self {
1627 access,
1628 op: |n| n.tanh(),
1629 }
1630 }
1631}
1632
1633impl<A: Access<T>, T: Number> Unary<A, T, u8> {
1634 pub fn not(access: A) -> Self {
1635 Self {
1636 access,
1637 op: |n| if n == T::ZERO { 1 } else { 0 },
1638 }
1639 }
1640}
1641
1642impl<A: Access<T>, T: Float> Unary<A, T, u8> {
1643 pub fn inf(access: A) -> Self {
1644 Self {
1645 access,
1646 op: |n| if n.is_inf() { 1 } else { 0 },
1647 }
1648 }
1649
1650 pub fn nan(access: A) -> Self {
1651 Self {
1652 access,
1653 op: |n| if n.is_nan() { 1 } else { 0 },
1654 }
1655 }
1656}
1657
1658#[cfg(feature = "complex")]
1659impl<A, T> Unary<A, T, T>
1660where
1661 A: Access<T>,
1662 T: Complex,
1663{
1664 pub fn conj(access: A) -> Self {
1665 Self {
1666 access,
1667 op: |n| n.conj(),
1668 }
1669 }
1670}
1671
1672#[cfg(feature = "complex")]
1673impl<A, T> Unary<A, T, T::Real>
1674where
1675 A: Access<T>,
1676 T: Complex,
1677{
1678 pub fn angle(access: A) -> Self {
1679 Self {
1680 access,
1681 op: |n| n.angle(),
1682 }
1683 }
1684
1685 pub fn re(access: A) -> Self {
1686 Self {
1687 access,
1688 op: |n| n.re(),
1689 }
1690 }
1691
1692 pub fn im(access: A) -> Self {
1693 Self {
1694 access,
1695 op: |n| n.im(),
1696 }
1697 }
1698}
1699
1700impl<A, IT, OT> Op for Unary<A, IT, OT>
1701where
1702 A: Access<IT>,
1703 IT: Number,
1704 OT: Number,
1705{
1706 fn size(&self) -> usize {
1707 self.access.size()
1708 }
1709}
1710
1711impl<A, IT, OT> Enqueue<Heap, OT> for Unary<A, IT, OT>
1712where
1713 A: Access<IT>,
1714 IT: Number,
1715 OT: Number,
1716{
1717 type Buffer = Vec<OT>;
1718
1719 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1720 self.access
1721 .read()
1722 .and_then(|buf| buf.to_slice())
1723 .map(|input| input.into_par_iter().copied().map(self.op).collect())
1724 }
1725}
1726
1727impl<A, IT, OT> Enqueue<Stack, OT> for Unary<A, IT, OT>
1728where
1729 A: Access<IT>,
1730 IT: Number,
1731 OT: Number,
1732{
1733 type Buffer = StackVec<OT>;
1734
1735 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1736 self.access
1737 .read()
1738 .and_then(|buf| buf.to_slice())
1739 .map(|input| input.iter().copied().map(self.op).collect())
1740 }
1741}
1742
1743impl<A, IT, OT> Enqueue<Host, OT> for Unary<A, IT, OT>
1744where
1745 A: Access<IT>,
1746 IT: Number,
1747 OT: Number,
1748{
1749 type Buffer = Buffer<OT>;
1750
1751 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1752 host_enqueue!(self, self.size() < VEC_MIN_SIZE, OT)
1753 }
1754}
1755
1756impl<A, IT, OT> ReadValue<Host, OT> for Unary<A, IT, OT>
1757where
1758 A: Access<IT>,
1759 IT: Number,
1760 OT: Number,
1761{
1762 fn read_value(&self, offset: usize) -> Result<OT, Error> {
1763 self.access.read_value(offset).map(|n| (self.op)(n))
1764 }
1765}
1766
1767pub struct View<A, T> {
1768 access: A,
1769 spec: ViewSpec,
1770 dtype: PhantomData<T>,
1771}
1772
1773impl<A: Access<T>, T: Number> View<A, T> {
1774 pub fn broadcast(access: A, shape: Shape, broadcast: Shape) -> Self {
1775 let source_strides = strides_for(&shape, shape.len()).collect();
1776
1777 Self {
1778 access,
1779 spec: ViewSpec::new(broadcast, source_strides),
1780 dtype: PhantomData,
1781 }
1782 }
1783
1784 pub fn transpose(access: A, shape: Shape, axes: Axes) -> Self {
1785 let strides = strides_for(&shape, shape.len()).collect::<Strides>();
1786 let shape = axes.iter().copied().map(|x| shape[x]).collect::<Strides>();
1787 let source_strides = axes.into_iter().map(|x| strides[x]).collect::<Strides>();
1788
1789 Self {
1790 access,
1791 spec: ViewSpec::new(shape, source_strides),
1792 dtype: PhantomData,
1793 }
1794 }
1795}
1796
1797impl<A: Access<T>, T: Number> Op for View<A, T> {
1798 fn size(&self) -> usize {
1799 self.spec.size()
1800 }
1801}
1802
1803impl<A: Access<T>, T: Number> Enqueue<Stack, T> for View<A, T> {
1804 type Buffer = StackVec<T>;
1805
1806 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1807 let source = self.access.read().and_then(|source| source.to_slice())?;
1808
1809 let buffer = (0..self.spec.size())
1810 .map(|offset| self.spec.source_offset(offset))
1811 .map(|source_offset| source[source_offset])
1812 .collect();
1813
1814 Ok(buffer)
1815 }
1816}
1817
1818impl<A: Access<T>, T: Number> Enqueue<Heap, T> for View<A, T> {
1819 type Buffer = Vec<T>;
1820
1821 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1822 let source = self.access.read().and_then(|source| source.to_slice())?;
1823
1824 let buffer = (0..self.spec.size())
1825 .into_par_iter()
1826 .map(|offset| self.spec.source_offset(offset))
1827 .map(|source_offset| source[source_offset])
1828 .collect();
1829
1830 Ok(buffer)
1831 }
1832}
1833
1834impl<A: Access<T>, T: Number> Enqueue<Host, T> for View<A, T> {
1835 type Buffer = Buffer<T>;
1836
1837 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1838 host_enqueue!(self, self.size() < VEC_MIN_SIZE, T)
1839 }
1840}
1841
1842impl<A: Access<T>, T: Number> ReadValue<Host, T> for View<A, T> {
1843 fn read_value(&self, offset: usize) -> Result<T, Error> {
1844 self.access.read_value(self.spec.source_offset(offset))
1845 }
1846}
1847
1848fn exec_dual<IT: Number, OT: Number>(
1849 zip: fn(IT, IT) -> OT,
1850 left: SliceConverter<IT>,
1851 right: SliceConverter<IT>,
1852) -> Result<StackVec<OT>, Error> {
1853 let output = left
1854 .iter()
1855 .copied()
1856 .zip(right.iter().copied())
1857 .map(|(l, r)| (zip)(l, r))
1858 .collect();
1859
1860 Ok(output)
1861}
1862
1863fn exec_dual_parallel<IT: Number, OT: Number>(
1864 zip: fn(IT, IT) -> OT,
1865 left: SliceConverter<IT>,
1866 right: SliceConverter<IT>,
1867) -> Result<Vec<OT>, Error> {
1868 let output = left
1869 .into_par_iter()
1870 .copied()
1871 .zip(right.into_par_iter().copied())
1872 .map(|(l, r)| (zip)(l, r))
1873 .collect();
1874
1875 Ok(output)
1876}
1877
1878#[inline]
1879fn try_join_read<'a, L, R, T>(
1880 left: &'a L,
1881 right: &'a R,
1882) -> Result<(SliceConverter<'a, T>, SliceConverter<'a, T>), Error>
1883where
1884 L: Access<T>,
1885 R: Access<T>,
1886 T: Number,
1887{
1888 let (l, r) = join(
1889 || left.read().and_then(|buf| buf.to_slice()),
1890 || right.read().and_then(|buf| buf.to_slice()),
1891 );
1892
1893 Ok((l?, r?))
1894}
1895
1896#[inline]
1897fn try_join_value<'a, L, R, T>(left: &'a L, right: &'a R, offset: usize) -> Result<(T, T), Error>
1898where
1899 L: Access<T>,
1900 R: Access<T>,
1901 T: Number,
1902{
1903 let (l, r) = join(|| left.read_value(offset), || right.read_value(offset));
1904
1905 Ok((l?, r?))
1906}