Skip to main content

ha_ndarray/host/ops/
mod.rs

1use std::f32::consts::PI;
2use std::iter;
3use std::marker::PhantomData;
4
5use frand::Rand;
6use number_general as ng;
7use rayon::join;
8use rayon::prelude::*;
9
10use crate::access::Access;
11use crate::ops::{Concat, Enqueue, FlipSpec, Op, ReadValue, SliceSpec, ViewSpec};
12#[cfg(feature = "complex")]
13use crate::Complex;
14use crate::{
15    strides_for, AccessMut, Axes, BufferConverter, Error, Float, Number, Platform, Range, Real,
16    Shape, Strides,
17};
18
19use super::buffer::Buffer;
20use super::platform::{Heap, Host, Stack};
21use super::{SliceConverter, StackVec, VEC_MIN_SIZE};
22
23#[cfg(feature = "complex")]
24pub mod complex;
25
26macro_rules! host_enqueue {
27    ($this:expr, $cond:expr, $t:ty) => {
28        if $cond {
29            Enqueue::<Stack, $t>::enqueue($this).map(Buffer::Stack)
30        } else {
31            Enqueue::<Heap, $t>::enqueue($this).map(Buffer::Heap)
32        }
33    };
34}
35
36pub struct Cast<A, IT, OT> {
37    access: A,
38    dtype: PhantomData<(IT, OT)>,
39}
40
41impl<A, IT, OT> Cast<A, IT, OT> {
42    pub fn new(access: A) -> Self {
43        Self {
44            access,
45            dtype: PhantomData,
46        }
47    }
48}
49
50impl<A: Access<IT>, IT: Number, OT: Number> Op for Cast<A, IT, OT> {
51    fn size(&self) -> usize {
52        self.access.size()
53    }
54}
55
56impl<A: Access<IT>, IT: Number, OT: Number> Enqueue<Heap, OT> for Cast<A, IT, OT> {
57    type Buffer = Vec<OT>;
58
59    fn enqueue(&self) -> Result<Self::Buffer, Error> {
60        self.access
61            .read()
62            .and_then(|buf| buf.to_slice())
63            .map(|slice| {
64                slice
65                    .into_par_iter()
66                    .copied()
67                    .map(|n| n.into())
68                    .map(OT::cast_from)
69                    .collect()
70            })
71    }
72}
73
74impl<A: Access<IT>, IT: Number, OT: Number> Enqueue<Stack, OT> for Cast<A, IT, OT> {
75    type Buffer = StackVec<OT>;
76
77    fn enqueue(&self) -> Result<Self::Buffer, Error> {
78        self.access
79            .read()
80            .and_then(|buf| buf.to_slice())
81            .map(|slice| {
82                slice
83                    .iter()
84                    .copied()
85                    .map(|n| n.into())
86                    .map(OT::cast_from)
87                    .collect()
88            })
89    }
90}
91
92impl<A: Access<IT>, IT: Number, OT: Number> Enqueue<Host, OT> for Cast<A, IT, OT> {
93    type Buffer = Buffer<OT>;
94
95    fn enqueue(&self) -> Result<Self::Buffer, Error> {
96        host_enqueue!(self, self.size() < VEC_MIN_SIZE, OT)
97    }
98}
99
100impl<A: Access<IT>, IT: Number, OT: Number> ReadValue<Host, OT> for Cast<A, IT, OT> {
101    fn read_value(&self, offset: usize) -> Result<OT, Error> {
102        self.access
103            .read_value(offset)
104            .map(|n| n.into())
105            .map(OT::cast_from)
106    }
107}
108
109pub struct Cond<A, L, R, T> {
110    cond: A,
111    then: L,
112    or_else: R,
113    dtype: PhantomData<T>,
114}
115
116impl<A, L, R, T> Cond<A, L, R, T> {
117    pub fn new(cond: A, then: L, or_else: R) -> Self {
118        Self {
119            cond,
120            then,
121            or_else,
122            dtype: PhantomData,
123        }
124    }
125}
126
127impl<A, L, R, T> Op for Cond<A, L, R, T>
128where
129    A: Access<u8>,
130    L: Access<T>,
131    R: Access<T>,
132    T: Number,
133{
134    fn size(&self) -> usize {
135        debug_assert_eq!(self.cond.size(), self.then.size());
136        debug_assert_eq!(self.cond.size(), self.or_else.size());
137        self.cond.size()
138    }
139}
140
141impl<A, L, R, T> Enqueue<Stack, T> for Cond<A, L, R, T>
142where
143    A: Access<u8>,
144    L: Access<T>,
145    R: Access<T>,
146    T: Number,
147{
148    type Buffer = StackVec<T>;
149
150    fn enqueue(&self) -> Result<Self::Buffer, Error> {
151        let cond = self.cond.read()?.to_slice()?;
152        let then = self.then.read()?.to_slice()?;
153        let or_else = self.or_else.read()?.to_slice()?;
154
155        let output = cond
156            .iter()
157            .copied()
158            .zip(then.iter().copied().zip(or_else.iter().copied()))
159            .map(|(cond, (then, or_else))| if cond != 0 { then } else { or_else })
160            .collect();
161
162        Ok(output)
163    }
164}
165
166impl<A, L, R, T> Enqueue<Heap, T> for Cond<A, L, R, T>
167where
168    A: Access<u8>,
169    L: Access<T>,
170    R: Access<T>,
171    T: Number,
172{
173    type Buffer = Vec<T>;
174
175    fn enqueue(&self) -> Result<Self::Buffer, Error> {
176        let (cond, (then, or_else)) = join(
177            || self.cond.read().and_then(|buf| buf.to_slice()),
178            || {
179                join(
180                    || self.then.read().and_then(|buf| buf.to_slice()),
181                    || self.or_else.read().and_then(|buf| buf.to_slice()),
182                )
183            },
184        );
185
186        let (cond, (then, or_else)) = (cond?, (then?, or_else?));
187
188        let output = cond
189            .into_par_iter()
190            .copied()
191            .zip(then.into_par_iter().zip(or_else.into_par_iter()))
192            .map(
193                |(cond, (then, or_else))| {
194                    if cond != 0 {
195                        then
196                    } else {
197                        or_else
198                    }
199                },
200            )
201            .copied()
202            .collect();
203
204        Ok(output)
205    }
206}
207
208impl<A, L, R, T> Enqueue<Host, T> for Cond<A, L, R, T>
209where
210    A: Access<u8>,
211    L: Access<T>,
212    R: Access<T>,
213    T: Number,
214{
215    type Buffer = Buffer<T>;
216
217    fn enqueue(&self) -> Result<Self::Buffer, Error> {
218        host_enqueue!(self, self.size() < VEC_MIN_SIZE, T)
219    }
220}
221
222impl<A, L, R, T> ReadValue<Host, T> for Cond<A, L, R, T>
223where
224    A: Access<u8>,
225    L: Access<T>,
226    R: Access<T>,
227    T: Number,
228{
229    fn read_value(&self, offset: usize) -> Result<T, Error> {
230        let (cond, (then, or_else)) = join(
231            || self.cond.read_value(offset),
232            || {
233                join(
234                    || self.then.read_value(offset),
235                    || self.or_else.read_value(offset),
236                )
237            },
238        );
239
240        let (cond, (then, or_else)) = (cond?, (then?, or_else?));
241
242        if cond != 0 {
243            Ok(then)
244        } else {
245            Ok(or_else)
246        }
247    }
248}
249
250impl<A, T> Enqueue<Host, T> for Concat<A, T>
251where
252    A: Access<T>,
253    T: Number,
254{
255    type Buffer = Buffer<T>;
256
257    fn enqueue(&self) -> Result<Self::Buffer, Error> {
258        let mut buffer = Vec::with_capacity(self.size());
259
260        for access in self.data() {
261            let data = access.read()?.to_slice()?;
262            buffer.par_extend(data.into_par_iter().copied());
263        }
264
265        Ok(buffer.into())
266    }
267}
268
269impl<A, T> ReadValue<Host, T> for Concat<A, T>
270where
271    A: Access<T>,
272    T: Number,
273{
274    fn read_value(&self, offset: usize) -> Result<T, Error> {
275        ReadValue::<Platform, T>::read_value(self, offset)
276    }
277}
278
279pub struct Dual<L, R, IT, OT> {
280    left: L,
281    right: R,
282    zip: fn(IT, IT) -> OT,
283}
284
285// arithmetic
286impl<L, R, T: Number> Dual<L, R, T, T> {
287    pub fn add(left: L, right: R) -> Self {
288        Self {
289            left,
290            right,
291            zip: T::add,
292        }
293    }
294
295    pub fn div(left: L, right: R) -> Self {
296        Self {
297            left,
298            right,
299            zip: T::div,
300        }
301    }
302
303    pub fn mul(left: L, right: R) -> Self {
304        Self {
305            left,
306            right,
307            zip: T::mul,
308        }
309    }
310
311    pub fn pow(left: L, right: R) -> Self {
312        Self {
313            left,
314            right,
315            zip: T::pow,
316        }
317    }
318
319    pub fn rem(left: L, right: R) -> Self
320    where
321        T: Real,
322    {
323        Self {
324            left,
325            right,
326            zip: T::pow,
327        }
328    }
329
330    pub fn sub(left: L, right: R) -> Self {
331        Self {
332            left,
333            right,
334            zip: T::sub,
335        }
336    }
337}
338
339// floating-point arithmetic
340impl<L, R, T: Float> Dual<L, R, T, T> {
341    pub fn log(left: L, right: R) -> Self {
342        Self {
343            left,
344            right,
345            zip: T::log,
346        }
347    }
348}
349
350// boolean operations
351impl<L, R, T: Number> Dual<L, R, T, u8> {
352    pub fn and(left: L, right: R) -> Self {
353        Self {
354            left,
355            right,
356            zip: |l, r| if l != T::ZERO && r != T::ZERO { 1 } else { 0 },
357        }
358    }
359
360    pub fn or(left: L, right: R) -> Self {
361        Self {
362            left,
363            right,
364            zip: |l, r| if l != T::ZERO || r != T::ZERO { 1 } else { 0 },
365        }
366    }
367
368    pub fn xor(left: L, right: R) -> Self {
369        Self {
370            left,
371            right,
372            zip: |l, r| {
373                if (l != T::ZERO) ^ (r != T::ZERO) {
374                    1
375                } else {
376                    0
377                }
378            },
379        }
380    }
381}
382
383// comparison
384impl<L, R, T: Number> Dual<L, R, T, u8> {
385    pub fn eq(left: L, right: R) -> Self {
386        Self {
387            left,
388            right,
389            zip: |l, r| if l == r { 1 } else { 0 },
390        }
391    }
392
393    pub fn ne(left: L, right: R) -> Self {
394        Self {
395            left,
396            right,
397            zip: |l, r| if l != r { 1 } else { 0 },
398        }
399    }
400}
401
402impl<L, R, T: Number + PartialOrd> Dual<L, R, T, u8> {
403    pub fn ge(left: L, right: R) -> Self {
404        Self {
405            left,
406            right,
407            zip: |l, r| if l >= r { 1 } else { 0 },
408        }
409    }
410
411    pub fn gt(left: L, right: R) -> Self {
412        Self {
413            left,
414            right,
415            zip: |l, r| if l > r { 1 } else { 0 },
416        }
417    }
418
419    pub fn le(left: L, right: R) -> Self {
420        Self {
421            left,
422            right,
423            zip: |l, r| if l <= r { 1 } else { 0 },
424        }
425    }
426
427    pub fn lt(left: L, right: R) -> Self {
428        Self {
429            left,
430            right,
431            zip: |l, r| if l < r { 1 } else { 0 },
432        }
433    }
434}
435
436impl<L, R, IT, OT> Op for Dual<L, R, IT, OT>
437where
438    L: Access<IT>,
439    R: Access<IT>,
440    IT: Number,
441    OT: Number,
442{
443    fn size(&self) -> usize {
444        self.left.size()
445    }
446}
447
448impl<L, R, IT, OT> Enqueue<Stack, OT> for Dual<L, R, IT, OT>
449where
450    L: Access<IT>,
451    R: Access<IT>,
452    IT: Number,
453    OT: Number,
454{
455    type Buffer = StackVec<OT>;
456
457    fn enqueue(&self) -> Result<Self::Buffer, Error> {
458        let left = self.left.read()?.to_slice()?;
459        let right = self.right.read()?.to_slice()?;
460        exec_dual(self.zip, left, right)
461    }
462}
463
464impl<L, R, IT, OT> Enqueue<Heap, OT> for Dual<L, R, IT, OT>
465where
466    L: Access<IT>,
467    R: Access<IT>,
468    IT: Number,
469    OT: Number,
470{
471    type Buffer = Vec<OT>;
472
473    fn enqueue(&self) -> Result<Self::Buffer, Error> {
474        let (left, right) = try_join_read(&self.left, &self.right)?;
475        exec_dual_parallel(self.zip, left, right)
476    }
477}
478
479impl<L, R, IT, OT> Enqueue<Host, OT> for Dual<L, R, IT, OT>
480where
481    L: Access<IT>,
482    R: Access<IT>,
483    IT: Number,
484    OT: Number,
485{
486    type Buffer = Buffer<OT>;
487
488    fn enqueue(&self) -> Result<Self::Buffer, Error> {
489        host_enqueue!(self, self.size() < VEC_MIN_SIZE, OT)
490    }
491}
492
493impl<L, R, IT, OT> ReadValue<Host, OT> for Dual<L, R, IT, OT>
494where
495    L: Access<IT>,
496    R: Access<IT>,
497    IT: Number,
498    OT: Number,
499{
500    fn read_value(&self, offset: usize) -> Result<OT, Error> {
501        try_join_value(&self.left, &self.right, offset).map(|(l, r)| (self.zip)(l, r))
502    }
503}
504
505pub struct Flip<A, T> {
506    access: A,
507    spec: FlipSpec,
508    dtype: PhantomData<T>,
509}
510
511impl<A, T> Flip<A, T> {
512    pub fn new(access: A, shape: Shape, axis: usize) -> Result<Self, Error> {
513        FlipSpec::new(shape, axis).map(|spec| Self {
514            access,
515            spec,
516            dtype: PhantomData,
517        })
518    }
519}
520
521impl<A, T> Op for Flip<A, T>
522where
523    A: Access<T>,
524    T: Number,
525{
526    fn size(&self) -> usize {
527        self.access.size()
528    }
529}
530
531impl<A, T> Enqueue<Heap, T> for Flip<A, T>
532where
533    A: Access<T>,
534    T: Number,
535{
536    type Buffer = Vec<T>;
537
538    fn enqueue(&self) -> Result<Self::Buffer, Error> {
539        (0..self.size())
540            .into_par_iter()
541            .map(|offset| self.read_value(offset))
542            .collect()
543    }
544}
545
546impl<A, T> Enqueue<Stack, T> for Flip<A, T>
547where
548    A: Access<T>,
549    T: Number,
550{
551    type Buffer = StackVec<T>;
552
553    fn enqueue(&self) -> Result<Self::Buffer, Error> {
554        (0..self.size())
555            .map(|offset| self.read_value(offset))
556            .collect()
557    }
558}
559
560impl<A, T> Enqueue<Host, T> for Flip<A, T>
561where
562    A: Access<T>,
563    T: Number,
564{
565    type Buffer = Buffer<T>;
566
567    fn enqueue(&self) -> Result<Self::Buffer, Error> {
568        host_enqueue!(self, self.size() < VEC_MIN_SIZE, T)
569    }
570}
571
572impl<A, T> ReadValue<Host, T> for Flip<A, T>
573where
574    A: Access<T>,
575    T: Number,
576{
577    fn read_value(&self, offset: usize) -> Result<T, Error> {
578        debug_assert!(offset < self.size());
579        let offset = self.spec.source_offset(offset);
580        self.access.read_value(offset)
581    }
582}
583
584pub struct Linear<T> {
585    start: T,
586    step: T,
587    size: usize,
588}
589
590impl<T> Linear<T> {
591    pub fn new(start: T, step: T, size: usize) -> Self {
592        Self { start, step, size }
593    }
594
595    #[inline]
596    fn value_at(&self, offset: usize) -> T
597    where
598        T: Number,
599    {
600        let offset = T::cast_from(ng::Number::from(offset as u64));
601        T::add(self.start, T::mul(offset, self.step))
602    }
603}
604
605impl<T: Send + Sync> Op for Linear<T> {
606    fn size(&self) -> usize {
607        self.size
608    }
609}
610
611impl<T: Number> Enqueue<Stack, T> for Linear<T> {
612    type Buffer = StackVec<T>;
613
614    fn enqueue(&self) -> Result<Self::Buffer, Error> {
615        let buffer = (0..self.size).map(|offset| self.value_at(offset)).collect();
616
617        Ok(buffer)
618    }
619}
620
621impl<T: Number> Enqueue<Heap, T> for Linear<T> {
622    type Buffer = Vec<T>;
623
624    fn enqueue(&self) -> Result<Self::Buffer, Error> {
625        let buffer = (0..self.size)
626            .into_par_iter()
627            .map(|offset| self.value_at(offset))
628            .collect();
629
630        Ok(buffer)
631    }
632}
633
634impl<T: Number> Enqueue<Host, T> for Linear<T> {
635    type Buffer = Buffer<T>;
636
637    fn enqueue(&self) -> Result<Self::Buffer, Error> {
638        host_enqueue!(self, self.size < VEC_MIN_SIZE, T)
639    }
640}
641
642impl<T: Number> ReadValue<Host, T> for Linear<T> {
643    fn read_value(&self, offset: usize) -> Result<T, Error> {
644        Ok(self.value_at(offset))
645    }
646}
647
648pub struct MatDiag<A, T> {
649    access: A,
650    dim: usize,
651    batch_size: usize,
652    dtype: PhantomData<T>,
653}
654
655impl<A, T> MatDiag<A, T> {
656    pub fn new(access: A, batch_size: usize, dim: usize) -> Self {
657        Self {
658            access,
659            dim,
660            batch_size,
661            dtype: PhantomData,
662        }
663    }
664}
665
666impl<A: Access<T>, T: Number> Op for MatDiag<A, T> {
667    fn size(&self) -> usize {
668        debug_assert_eq!(self.access.size(), self.batch_size * self.dim * self.dim);
669        self.batch_size * self.dim
670    }
671}
672
673impl<A: Access<T>, T: Number> Enqueue<Heap, T> for MatDiag<A, T> {
674    type Buffer = Vec<T>;
675
676    fn enqueue(&self) -> Result<Self::Buffer, Error> {
677        let input = self.access.read()?.to_slice()?;
678
679        let diagonals = input
680            .par_chunks_exact(self.dim * self.dim)
681            .flat_map(|matrix| {
682                matrix
683                    .par_chunks_exact(self.dim)
684                    .enumerate()
685                    .map(|(i, row)| row[i])
686            })
687            .collect();
688
689        Ok(diagonals)
690    }
691}
692
693impl<A: Access<T>, T: Number> Enqueue<Stack, T> for MatDiag<A, T> {
694    type Buffer = StackVec<T>;
695
696    fn enqueue(&self) -> Result<Self::Buffer, Error> {
697        let input = self.access.read()?.to_slice()?;
698
699        let diagonals = input
700            .chunks_exact(self.dim * self.dim)
701            .flat_map(|matrix| {
702                matrix
703                    .chunks_exact(self.dim)
704                    .enumerate()
705                    .map(|(i, row)| row[i])
706            })
707            .collect();
708
709        Ok(diagonals)
710    }
711}
712
713impl<A: Access<T>, T: Number> Enqueue<Host, T> for MatDiag<A, T> {
714    type Buffer = Buffer<T>;
715
716    fn enqueue(&self) -> Result<Self::Buffer, Error> {
717        host_enqueue!(self, self.size() < VEC_MIN_SIZE, T)
718    }
719}
720
721impl<A: Access<T>, T: Number> ReadValue<Host, T> for MatDiag<A, T> {
722    fn read_value(&self, offset: usize) -> Result<T, Error> {
723        let batch = offset / self.batch_size;
724        let i = offset % self.batch_size;
725        let source_offset = (batch * self.dim * self.dim) + (i * self.dim) + i;
726        self.access.read_value(source_offset)
727    }
728}
729
730pub struct MatMul<L, R, T> {
731    left: L,
732    right: R,
733    batch_size: usize,
734    dims: [usize; 3],
735    dtype: PhantomData<T>,
736}
737
738impl<L, R, T> MatMul<L, R, T> {
739    pub fn new(left: L, right: R, dims: [usize; 4]) -> Self {
740        let [batch_size, a, b, c] = dims;
741
742        Self {
743            left,
744            right,
745            batch_size,
746            dims: [a, b, c],
747            dtype: PhantomData,
748        }
749    }
750}
751
752impl<L, R, T> Op for MatMul<L, R, T>
753where
754    L: Send + Sync,
755    R: Send + Sync,
756    T: Send + Sync,
757{
758    fn size(&self) -> usize {
759        self.batch_size * self.dims[0] * self.dims[2]
760    }
761}
762
763impl<L, R, T> Enqueue<Stack, T> for MatMul<L, R, T>
764where
765    L: Access<T>,
766    R: Access<T>,
767    T: Number,
768{
769    type Buffer = StackVec<T>;
770
771    fn enqueue(&self) -> Result<Self::Buffer, Error> {
772        let left = self.left.read()?.to_slice()?;
773        let right = self.right.read()?.to_slice()?;
774
775        let [a, b, c] = self.dims;
776
777        let mut product = StackVec::with_capacity(self.batch_size * a * c);
778
779        for batch in 0..self.batch_size {
780            let l_start = batch * a * b;
781            let r_start = batch * b * c;
782
783            for x in 0..a {
784                for z in 0..c {
785                    let mut sum = T::ZERO;
786
787                    for y in 0..b {
788                        let l_offset = l_start + (x * b) + y;
789                        let r_offset = r_start + (y * c) + z;
790                        sum = T::add(sum, T::mul(left[l_offset], right[r_offset]));
791                    }
792
793                    product.push(sum)
794                }
795            }
796        }
797
798        debug_assert_eq!(product.len(), self.size());
799
800        Ok(product)
801    }
802}
803
804impl<L, R, T> Enqueue<Heap, T> for MatMul<L, R, T>
805where
806    L: Access<T>,
807    R: Access<T>,
808    T: Number,
809{
810    type Buffer = Vec<T>;
811
812    fn enqueue(&self) -> Result<Self::Buffer, Error> {
813        let [a, b, c] = self.dims;
814
815        let (left, right) = try_join_read(&self.left, &self.right)?;
816
817        // transpose the right matrices
818        let right_size = b * c;
819        let right_matrices = right.par_chunks_exact(right_size).map(|right| {
820            let mut right_t = vec![T::ZERO; right_size];
821            transpose::transpose(right, &mut right_t[..], c, b);
822            right_t
823        });
824
825        let left_size = a * b;
826        let left_matrices = left.par_chunks_exact(left_size);
827
828        let output_size = a * c;
829        let mut output = Vec::<T>::with_capacity(self.batch_size * output_size);
830        let output_matrices = left_matrices
831            .zip(right_matrices)
832            .map(|(lm, rm)| {
833                let mut out = Vec::<T>::with_capacity(output_size);
834
835                let product = lm
836                    .par_chunks_exact(b)
837                    .map(|row| {
838                        rm.par_chunks_exact(b).map(move |col| {
839                            // chunk the dot product to encourage the compiler to vectorize
840                            let col = col.par_chunks(8).map(|cc| cc.iter().copied());
841
842                            row.par_chunks(8)
843                                .zip(col)
844                                .map(|(rc, cc)| {
845                                    rc.iter()
846                                        .copied()
847                                        .zip(cc)
848                                        .map(|(r, c)| T::mul(r, c))
849                                        .reduce(T::add)
850                                        .expect("sum")
851                                })
852                                .reduce(|| T::ZERO, T::add)
853                        })
854                    })
855                    .flatten();
856
857                out.par_extend(product);
858                out
859            })
860            .flatten();
861
862        output.par_extend(output_matrices);
863
864        debug_assert_eq!(output.len(), self.batch_size * output_size);
865
866        Ok(output)
867    }
868}
869
870impl<L, R, T> Enqueue<Host, T> for MatMul<L, R, T>
871where
872    L: Access<T>,
873    R: Access<T>,
874    T: Number,
875{
876    type Buffer = Buffer<T>;
877
878    fn enqueue(&self) -> Result<Self::Buffer, Error> {
879        host_enqueue!(
880            self,
881            self.left.size() < VEC_MIN_SIZE && self.right.size() < VEC_MIN_SIZE,
882            T
883        )
884    }
885}
886
887impl<L, R, T> ReadValue<Host, T> for MatMul<L, R, T>
888where
889    L: Access<T>,
890    R: Access<T>,
891    T: Number,
892{
893    fn read_value(&self, _offset: usize) -> Result<T, Error> {
894        Err(Error::bounds(
895            "reading an individual value from a matrix multiplication is not implemented"
896                .to_string(),
897        ))
898    }
899}
900
901pub struct Scalar<A, IT, OT> {
902    access: A,
903    scalar: IT,
904    op: fn(IT, IT) -> OT,
905}
906
907impl<A, IT, OT> Scalar<A, IT, OT> {
908    fn new(access: A, scalar: IT, op: fn(IT, IT) -> OT) -> Self {
909        Self { access, scalar, op }
910    }
911}
912
913impl<A, T: Number> Scalar<A, T, T> {
914    pub fn add(access: A, scalar: T) -> Self {
915        Self::new(access, scalar, T::add)
916    }
917
918    pub fn div(access: A, scalar: T) -> Self {
919        Self::new(access, scalar, T::div)
920    }
921
922    pub fn mul(access: A, scalar: T) -> Self {
923        Self::new(access, scalar, T::mul)
924    }
925
926    pub fn pow(access: A, scalar: T) -> Self {
927        Self::new(access, scalar, T::pow)
928    }
929
930    pub fn rem(access: A, scalar: T) -> Self
931    where
932        T: Real,
933    {
934        Self::new(access, scalar, T::rem)
935    }
936
937    pub fn sub(access: A, scalar: T) -> Self {
938        Self::new(access, scalar, T::sub)
939    }
940}
941
942impl<A, T: Float> Scalar<A, T, T> {
943    pub fn log(access: A, scalar: T) -> Self {
944        Self::new(access, scalar, T::log)
945    }
946}
947
948impl<A, T> Scalar<A, T, u8>
949where
950    T: Number,
951{
952    pub fn and(access: A, scalar: T) -> Self {
953        Self::new(access, scalar, |l, r| {
954            if (l != T::ZERO) && (r != T::ZERO) {
955                1
956            } else {
957                0
958            }
959        })
960    }
961
962    pub fn or(access: A, scalar: T) -> Self {
963        Self::new(access, scalar, |l, r| {
964            if (l != T::ZERO) || (r != T::ZERO) {
965                1
966            } else {
967                0
968            }
969        })
970    }
971
972    pub fn xor(access: A, scalar: T) -> Self {
973        Self::new(access, scalar, |l, r| {
974            if (l != T::ZERO) ^ (r != T::ZERO) {
975                1
976            } else {
977                0
978            }
979        })
980    }
981
982    pub fn eq(access: A, scalar: T) -> Self {
983        Self::new(access, scalar, |l, r| if l == r { 1 } else { 0 })
984    }
985
986    pub fn ge(access: A, scalar: T) -> Self
987    where
988        T: Real,
989    {
990        Self::new(access, scalar, |l, r| if l >= r { 1 } else { 0 })
991    }
992
993    pub fn gt(access: A, scalar: T) -> Self
994    where
995        T: Real,
996    {
997        Self::new(access, scalar, |l, r| if l > r { 1 } else { 0 })
998    }
999
1000    pub fn le(access: A, scalar: T) -> Self
1001    where
1002        T: Real,
1003    {
1004        Self::new(access, scalar, |l, r| if l <= r { 1 } else { 0 })
1005    }
1006
1007    pub fn lt(access: A, scalar: T) -> Self
1008    where
1009        T: Real,
1010    {
1011        Self::new(access, scalar, |l, r| if l < r { 1 } else { 0 })
1012    }
1013
1014    pub fn ne(access: A, scalar: T) -> Self {
1015        Self::new(access, scalar, |l, r| if l != r { 1 } else { 0 })
1016    }
1017}
1018
1019impl<A, IT, OT> Op for Scalar<A, IT, OT>
1020where
1021    A: Access<IT>,
1022    IT: Number,
1023    OT: Number,
1024{
1025    fn size(&self) -> usize {
1026        self.access.size()
1027    }
1028}
1029
1030impl<A, IT, OT> Enqueue<Heap, OT> for Scalar<A, IT, OT>
1031where
1032    A: Access<IT>,
1033    IT: Number,
1034    OT: Number,
1035{
1036    type Buffer = Vec<OT>;
1037
1038    fn enqueue(&self) -> Result<Self::Buffer, Error> {
1039        self.access
1040            .read()
1041            .and_then(|buf| buf.to_slice())
1042            .map(|slice| {
1043                slice
1044                    .as_ref()
1045                    .into_par_iter()
1046                    .copied()
1047                    .map(|l| (self.op)(l, self.scalar))
1048                    .collect()
1049            })
1050    }
1051}
1052
1053impl<A, IT, OT> Enqueue<Stack, OT> for Scalar<A, IT, OT>
1054where
1055    A: Access<IT>,
1056    IT: Number,
1057    OT: Number,
1058{
1059    type Buffer = StackVec<OT>;
1060
1061    fn enqueue(&self) -> Result<Self::Buffer, Error> {
1062        self.access
1063            .read()
1064            .and_then(|buf| buf.to_slice())
1065            .map(|slice| {
1066                slice
1067                    .as_ref()
1068                    .iter()
1069                    .copied()
1070                    .map(|l| (self.op)(l, self.scalar))
1071                    .collect()
1072            })
1073    }
1074}
1075
1076impl<A, IT, OT> Enqueue<Host, OT> for Scalar<A, IT, OT>
1077where
1078    A: Access<IT>,
1079    IT: Number,
1080    OT: Number,
1081{
1082    type Buffer = Buffer<OT>;
1083
1084    fn enqueue(&self) -> Result<Self::Buffer, Error> {
1085        host_enqueue!(self, self.size() < VEC_MIN_SIZE, OT)
1086    }
1087}
1088
1089impl<A, IT, OT> ReadValue<Host, OT> for Scalar<A, IT, OT>
1090where
1091    A: Access<IT>,
1092    IT: Number,
1093    OT: Number,
1094{
1095    fn read_value(&self, offset: usize) -> Result<OT, Error> {
1096        self.access
1097            .read_value(offset)
1098            .map(|n| (self.op)(n, self.scalar))
1099    }
1100}
1101
1102pub struct RandomNormal {
1103    size: usize,
1104}
1105
1106impl RandomNormal {
1107    pub fn new(size: usize) -> Self {
1108        Self { size }
1109    }
1110
1111    fn box_muller(u: [f32; 2]) -> [f32; 2] {
1112        let [u1, u2] = u;
1113        let r = (u1.ln() * -2.).sqrt();
1114        let theta = 2. * PI * u2;
1115        [r * theta.cos(), r * theta.sin()]
1116    }
1117}
1118
1119impl Op for RandomNormal {
1120    fn size(&self) -> usize {
1121        self.size
1122    }
1123}
1124
1125impl Enqueue<Heap, f32> for RandomNormal {
1126    type Buffer = Vec<f32>;
1127
1128    fn enqueue(&self) -> Result<Self::Buffer, Error> {
1129        let mut rng = Rand::new();
1130
1131        let mut output = (0..self.size.div_ceil(2))
1132            .flat_map(|_| Self::box_muller([rng.gen(), rng.gen()]))
1133            .collect::<Vec<f32>>();
1134
1135        if output.len() > self.size {
1136            output.pop();
1137        }
1138
1139        debug_assert_eq!(output.len(), self.size);
1140
1141        Ok(output)
1142    }
1143}
1144
1145impl Enqueue<Stack, f32> for RandomNormal {
1146    type Buffer = StackVec<f32>;
1147
1148    fn enqueue(&self) -> Result<Self::Buffer, Error> {
1149        let mut rng = Rand::new();
1150
1151        let mut output = iter::repeat_with(|| [rng.gen(), rng.gen()])
1152            .take(self.size.div_ceil(2))
1153            .flat_map(Self::box_muller)
1154            .collect::<StackVec<f32>>();
1155
1156        if output.len() > self.size {
1157            output.pop();
1158        }
1159
1160        debug_assert_eq!(output.len(), self.size);
1161
1162        Ok(output)
1163    }
1164}
1165
1166impl Enqueue<Host, f32> for RandomNormal {
1167    type Buffer = Buffer<f32>;
1168
1169    fn enqueue(&self) -> Result<Self::Buffer, Error> {
1170        host_enqueue!(self, self.size < VEC_MIN_SIZE, f32)
1171    }
1172}
1173
1174impl ReadValue<Host, f32> for RandomNormal {
1175    fn read_value(&self, _offset: usize) -> Result<f32, Error> {
1176        Err(Error::bounds(
1177            "cannot calculate an individual value of a random normal distribution".to_string(),
1178        ))
1179    }
1180}
1181
1182pub struct RandomUniform {
1183    size: usize,
1184}
1185
1186impl RandomUniform {
1187    pub fn new(size: usize) -> Self {
1188        Self { size }
1189    }
1190}
1191
1192impl Op for RandomUniform {
1193    fn size(&self) -> usize {
1194        self.size
1195    }
1196}
1197
1198impl Enqueue<Heap, f32> for RandomUniform {
1199    type Buffer = Vec<f32>;
1200
1201    fn enqueue(&self) -> Result<Self::Buffer, Error> {
1202        let mut rng = Rand::new();
1203        Ok((0..self.size).map(|_| rng.gen()).collect())
1204    }
1205}
1206
1207impl Enqueue<Stack, f32> for RandomUniform {
1208    type Buffer = StackVec<f32>;
1209
1210    fn enqueue(&self) -> Result<Self::Buffer, Error> {
1211        let mut rng = Rand::new();
1212        Ok((0..self.size).map(|_| rng.gen()).collect())
1213    }
1214}
1215
1216impl Enqueue<Host, f32> for RandomUniform {
1217    type Buffer = Buffer<f32>;
1218
1219    fn enqueue(&self) -> Result<Self::Buffer, Error> {
1220        host_enqueue!(self, self.size < VEC_MIN_SIZE, f32)
1221    }
1222}
1223
1224impl ReadValue<Host, f32> for RandomUniform {
1225    fn read_value(&self, _offset: usize) -> Result<f32, Error> {
1226        Ok(Rand::new().gen())
1227    }
1228}
1229
1230pub struct Reduce<A, T> {
1231    access: A,
1232    stride: usize,
1233    reduce: fn(T, T) -> T,
1234    id: T,
1235}
1236
1237impl<A, T> Reduce<A, T>
1238where
1239    T: Number,
1240{
1241    pub fn product(access: A, stride: usize) -> Self {
1242        Self {
1243            access,
1244            stride,
1245            reduce: T::mul,
1246            id: T::ONE,
1247        }
1248    }
1249
1250    pub fn sum(access: A, stride: usize) -> Self {
1251        Self {
1252            access,
1253            stride,
1254            reduce: T::add,
1255            id: T::ZERO,
1256        }
1257    }
1258}
1259
1260impl<A, T> Reduce<A, T>
1261where
1262    T: Real,
1263{
1264    pub fn max(access: A, stride: usize) -> Self {
1265        Self {
1266            access,
1267            stride,
1268            reduce: Real::max,
1269            id: T::MIN,
1270        }
1271    }
1272
1273    pub fn min(access: A, stride: usize) -> Self {
1274        Self {
1275            access,
1276            stride,
1277            reduce: Real::min,
1278            id: T::MAX,
1279        }
1280    }
1281}
1282
1283impl<A: Access<T>, T: Number> Op for Reduce<A, T> {
1284    fn size(&self) -> usize {
1285        debug_assert_eq!(self.access.size() % self.stride, 0);
1286        self.access.size() / self.stride
1287    }
1288}
1289
1290impl<A: Access<T>, T: Number> Enqueue<Heap, T> for Reduce<A, T> {
1291    type Buffer = Vec<T>;
1292
1293    fn enqueue(&self) -> Result<Self::Buffer, Error> {
1294        self.access
1295            .read()
1296            .and_then(|buf| buf.to_slice())
1297            .map(|slice| {
1298                slice
1299                    .chunks_exact(self.stride)
1300                    .map(|chunk| {
1301                        chunk
1302                            // encourage the compiler to vectorize
1303                            .par_chunks(8)
1304                            .map(|chunk| {
1305                                chunk.iter().copied().reduce(self.reduce).expect("reduced")
1306                            })
1307                            .reduce(|| self.id, self.reduce)
1308                    })
1309                    .collect()
1310            })
1311    }
1312}
1313
1314impl<A: Access<T>, T: Number> Enqueue<Stack, T> for Reduce<A, T> {
1315    type Buffer = StackVec<T>;
1316
1317    fn enqueue(&self) -> Result<Self::Buffer, Error> {
1318        self.access
1319            .read()
1320            .and_then(|buf| buf.to_slice())
1321            .map(|slice| {
1322                slice
1323                    .chunks_exact(self.stride)
1324                    .map(|chunk| chunk.iter().copied().reduce(self.reduce).expect("reduced"))
1325                    .collect()
1326            })
1327    }
1328}
1329
1330impl<A: Access<T>, T: Number> Enqueue<Host, T> for Reduce<A, T> {
1331    type Buffer = Buffer<T>;
1332
1333    fn enqueue(&self) -> Result<Self::Buffer, Error> {
1334        host_enqueue!(
1335            self,
1336            self.stride < VEC_MIN_SIZE && self.size() < VEC_MIN_SIZE,
1337            T
1338        )
1339    }
1340}
1341
1342impl<A: Access<T>, T: Number> ReadValue<Host, T> for Reduce<A, T> {
1343    fn read_value(&self, offset: usize) -> Result<T, Error> {
1344        let offset = offset * self.stride;
1345
1346        if offset < self.access.size() {
1347            (offset..(offset + self.stride))
1348                .into_par_iter()
1349                .map(|offset| self.access.read_value(offset))
1350                .try_reduce(|| self.id, |r, v| Ok((self.reduce)(r, v)))
1351        } else {
1352            Err(Error::bounds(format!(
1353                "invalid offset {offset} for a reduce op with size {}",
1354                self.size()
1355            )))
1356        }
1357    }
1358}
1359
1360pub struct Slice<A, T> {
1361    access: A,
1362    spec: SliceSpec,
1363    dtype: PhantomData<T>,
1364}
1365
1366impl<A, T> Slice<A, T> {
1367    pub fn new(access: A, shape: &[usize], range: Range) -> Self {
1368        let spec = SliceSpec::new(shape, range);
1369
1370        Self {
1371            access,
1372            spec,
1373            dtype: PhantomData,
1374        }
1375    }
1376}
1377
1378impl<A: Send + Sync, T: Copy + Send + Sync> Slice<A, T> {
1379    fn read(&self, source: &[T]) -> Result<StackVec<T>, Error> {
1380        let output = (0..self.size())
1381            .map(|offset_out| self.spec.source_offset(offset_out))
1382            .map(|offset_in| source[offset_in])
1383            .collect();
1384
1385        Ok(output)
1386    }
1387
1388    fn read_parallel(&self, source: &[T]) -> Result<Vec<T>, Error> {
1389        let output = (0..self.size())
1390            .into_par_iter()
1391            .map(|offset_out| self.spec.source_offset(offset_out))
1392            .map(|offset_in| source[offset_in])
1393            .collect();
1394
1395        Ok(output)
1396    }
1397}
1398
1399impl<A, T> Slice<A, T>
1400where
1401    T: Number,
1402    A: AccessMut<T>,
1403{
1404    fn overwrite<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
1405        if data.len() == self.size() {
1406            let data = data.to_slice()?;
1407
1408            for (offset, value) in data.iter().copied().enumerate() {
1409                let source_offset = self.spec.source_offset(offset);
1410                self.access.write_value_at(source_offset, value)?;
1411            }
1412
1413            Ok(())
1414        } else {
1415            Err(Error::bounds(format!(
1416                "cannot overwrite a slice of size {} with a buffer of size {}",
1417                self.size(),
1418                data.len(),
1419            )))
1420        }
1421    }
1422
1423    fn overwrite_value(&mut self, value: T) -> Result<(), Error> {
1424        for offset in 0..self.access.size() {
1425            let source_offset = self.spec.source_offset(offset);
1426            self.access.write_value_at(source_offset, value)?;
1427        }
1428
1429        Ok(())
1430    }
1431
1432    fn overwrite_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
1433        let source_offset = self.spec.source_offset(offset);
1434        self.access.write_value_at(source_offset, value)
1435    }
1436}
1437
1438impl<A: Send + Sync, T: Send + Sync> Op for Slice<A, T> {
1439    fn size(&self) -> usize {
1440        self.spec.size()
1441    }
1442}
1443
1444impl<A: Access<T>, T: Number> Enqueue<Heap, T> for Slice<A, T> {
1445    type Buffer = Vec<T>;
1446
1447    fn enqueue(&self) -> Result<Self::Buffer, Error> {
1448        self.access
1449            .read()
1450            .and_then(|buf| buf.to_slice())
1451            .and_then(|buf| self.read_parallel(&buf))
1452    }
1453}
1454
1455impl<A: Access<T>, T: Number> Enqueue<Stack, T> for Slice<A, T> {
1456    type Buffer = StackVec<T>;
1457
1458    fn enqueue(&self) -> Result<Self::Buffer, Error> {
1459        self.access
1460            .read()
1461            .and_then(|buf| buf.to_slice())
1462            .and_then(|buf| self.read(&buf))
1463    }
1464}
1465
1466impl<A: Access<T>, T: Number> Enqueue<Host, T> for Slice<A, T> {
1467    type Buffer = Buffer<T>;
1468
1469    fn enqueue(&self) -> Result<Self::Buffer, Error> {
1470        host_enqueue!(self, self.size() < VEC_MIN_SIZE, T)
1471    }
1472}
1473
1474impl<A: Access<T>, T: Number> ReadValue<Host, T> for Slice<A, T> {
1475    fn read_value(&self, offset: usize) -> Result<T, Error> {
1476        let offset = self.spec.source_offset(offset);
1477        self.access.read_value(offset)
1478    }
1479}
1480
1481impl<A, T> crate::ops::Write<Heap, T> for Slice<A, T>
1482where
1483    T: Number,
1484    A: AccessMut<T>,
1485{
1486    fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
1487        self.overwrite(data)
1488    }
1489
1490    fn write_value(&mut self, value: T) -> Result<(), Error> {
1491        self.overwrite_value(value)
1492    }
1493
1494    fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
1495        self.overwrite_value_at(offset, value)
1496    }
1497}
1498
1499impl<A, T> crate::ops::Write<Stack, T> for Slice<A, T>
1500where
1501    T: Number,
1502    A: AccessMut<T>,
1503{
1504    fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
1505        self.overwrite(data)
1506    }
1507
1508    fn write_value(&mut self, value: T) -> Result<(), Error> {
1509        self.overwrite_value(value)
1510    }
1511
1512    fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
1513        self.overwrite_value_at(offset, value)
1514    }
1515}
1516
1517impl<A, T> crate::ops::Write<Host, T> for Slice<A, T>
1518where
1519    T: Number,
1520    A: AccessMut<T>,
1521{
1522    fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
1523        self.overwrite(data)
1524    }
1525
1526    fn write_value(&mut self, value: T) -> Result<(), Error> {
1527        self.overwrite_value(value)
1528    }
1529
1530    fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
1531        self.overwrite_value_at(offset, value)
1532    }
1533}
1534
1535pub struct Unary<A, IT, OT> {
1536    access: A,
1537    op: fn(IT) -> OT,
1538}
1539
1540impl<A: Access<T>, T: Float> Unary<A, T, T> {
1541    pub fn exp(access: A) -> Self {
1542        Self { access, op: T::exp }
1543    }
1544
1545    pub fn ln(access: A) -> Self {
1546        Self { access, op: T::ln }
1547    }
1548}
1549
1550impl<A: Access<T>, T: Number> Unary<A, T, T::Abs> {
1551    pub fn abs(access: A) -> Self {
1552        Self {
1553            access,
1554            op: Number::abs,
1555        }
1556    }
1557}
1558
1559impl<A: Access<T>, T: Real> Unary<A, T, T> {
1560    pub fn round(access: A) -> Self {
1561        Self {
1562            access,
1563            op: Real::round,
1564        }
1565    }
1566}
1567
1568impl<A: Access<T>, T: Float> Unary<A, T, T> {
1569    pub fn sin(access: A) -> Self {
1570        Self {
1571            access,
1572            op: |n| n.sin(),
1573        }
1574    }
1575
1576    pub fn asin(access: A) -> Self {
1577        Self {
1578            access,
1579            op: |n| n.asin(),
1580        }
1581    }
1582
1583    pub fn sinh(access: A) -> Self {
1584        Self {
1585            access,
1586            op: |n| n.sinh(),
1587        }
1588    }
1589
1590    pub fn cos(access: A) -> Self {
1591        Self {
1592            access,
1593            op: |n| n.cos(),
1594        }
1595    }
1596
1597    pub fn acos(access: A) -> Self {
1598        Self {
1599            access,
1600            op: |n| n.acos(),
1601        }
1602    }
1603
1604    pub fn cosh(access: A) -> Self {
1605        Self {
1606            access,
1607            op: |n| n.cosh(),
1608        }
1609    }
1610
1611    pub fn tan(access: A) -> Self {
1612        Self {
1613            access,
1614            op: |n| n.tan(),
1615        }
1616    }
1617
1618    pub fn atan(access: A) -> Self {
1619        Self {
1620            access,
1621            op: |n| n.atan(),
1622        }
1623    }
1624
1625    pub fn tanh(access: A) -> Self {
1626        Self {
1627            access,
1628            op: |n| n.tanh(),
1629        }
1630    }
1631}
1632
1633impl<A: Access<T>, T: Number> Unary<A, T, u8> {
1634    pub fn not(access: A) -> Self {
1635        Self {
1636            access,
1637            op: |n| if n == T::ZERO { 1 } else { 0 },
1638        }
1639    }
1640}
1641
1642impl<A: Access<T>, T: Float> Unary<A, T, u8> {
1643    pub fn inf(access: A) -> Self {
1644        Self {
1645            access,
1646            op: |n| if n.is_inf() { 1 } else { 0 },
1647        }
1648    }
1649
1650    pub fn nan(access: A) -> Self {
1651        Self {
1652            access,
1653            op: |n| if n.is_nan() { 1 } else { 0 },
1654        }
1655    }
1656}
1657
1658#[cfg(feature = "complex")]
1659impl<A, T> Unary<A, T, T>
1660where
1661    A: Access<T>,
1662    T: Complex,
1663{
1664    pub fn conj(access: A) -> Self {
1665        Self {
1666            access,
1667            op: |n| n.conj(),
1668        }
1669    }
1670}
1671
1672#[cfg(feature = "complex")]
1673impl<A, T> Unary<A, T, T::Real>
1674where
1675    A: Access<T>,
1676    T: Complex,
1677{
1678    pub fn angle(access: A) -> Self {
1679        Self {
1680            access,
1681            op: |n| n.angle(),
1682        }
1683    }
1684
1685    pub fn re(access: A) -> Self {
1686        Self {
1687            access,
1688            op: |n| n.re(),
1689        }
1690    }
1691
1692    pub fn im(access: A) -> Self {
1693        Self {
1694            access,
1695            op: |n| n.im(),
1696        }
1697    }
1698}
1699
1700impl<A, IT, OT> Op for Unary<A, IT, OT>
1701where
1702    A: Access<IT>,
1703    IT: Number,
1704    OT: Number,
1705{
1706    fn size(&self) -> usize {
1707        self.access.size()
1708    }
1709}
1710
1711impl<A, IT, OT> Enqueue<Heap, OT> for Unary<A, IT, OT>
1712where
1713    A: Access<IT>,
1714    IT: Number,
1715    OT: Number,
1716{
1717    type Buffer = Vec<OT>;
1718
1719    fn enqueue(&self) -> Result<Self::Buffer, Error> {
1720        self.access
1721            .read()
1722            .and_then(|buf| buf.to_slice())
1723            .map(|input| input.into_par_iter().copied().map(self.op).collect())
1724    }
1725}
1726
1727impl<A, IT, OT> Enqueue<Stack, OT> for Unary<A, IT, OT>
1728where
1729    A: Access<IT>,
1730    IT: Number,
1731    OT: Number,
1732{
1733    type Buffer = StackVec<OT>;
1734
1735    fn enqueue(&self) -> Result<Self::Buffer, Error> {
1736        self.access
1737            .read()
1738            .and_then(|buf| buf.to_slice())
1739            .map(|input| input.iter().copied().map(self.op).collect())
1740    }
1741}
1742
1743impl<A, IT, OT> Enqueue<Host, OT> for Unary<A, IT, OT>
1744where
1745    A: Access<IT>,
1746    IT: Number,
1747    OT: Number,
1748{
1749    type Buffer = Buffer<OT>;
1750
1751    fn enqueue(&self) -> Result<Self::Buffer, Error> {
1752        host_enqueue!(self, self.size() < VEC_MIN_SIZE, OT)
1753    }
1754}
1755
1756impl<A, IT, OT> ReadValue<Host, OT> for Unary<A, IT, OT>
1757where
1758    A: Access<IT>,
1759    IT: Number,
1760    OT: Number,
1761{
1762    fn read_value(&self, offset: usize) -> Result<OT, Error> {
1763        self.access.read_value(offset).map(|n| (self.op)(n))
1764    }
1765}
1766
1767pub struct View<A, T> {
1768    access: A,
1769    spec: ViewSpec,
1770    dtype: PhantomData<T>,
1771}
1772
1773impl<A: Access<T>, T: Number> View<A, T> {
1774    pub fn broadcast(access: A, shape: Shape, broadcast: Shape) -> Self {
1775        let source_strides = strides_for(&shape, shape.len()).collect();
1776
1777        Self {
1778            access,
1779            spec: ViewSpec::new(broadcast, source_strides),
1780            dtype: PhantomData,
1781        }
1782    }
1783
1784    pub fn transpose(access: A, shape: Shape, axes: Axes) -> Self {
1785        let strides = strides_for(&shape, shape.len()).collect::<Strides>();
1786        let shape = axes.iter().copied().map(|x| shape[x]).collect::<Strides>();
1787        let source_strides = axes.into_iter().map(|x| strides[x]).collect::<Strides>();
1788
1789        Self {
1790            access,
1791            spec: ViewSpec::new(shape, source_strides),
1792            dtype: PhantomData,
1793        }
1794    }
1795}
1796
1797impl<A: Access<T>, T: Number> Op for View<A, T> {
1798    fn size(&self) -> usize {
1799        self.spec.size()
1800    }
1801}
1802
1803impl<A: Access<T>, T: Number> Enqueue<Stack, T> for View<A, T> {
1804    type Buffer = StackVec<T>;
1805
1806    fn enqueue(&self) -> Result<Self::Buffer, Error> {
1807        let source = self.access.read().and_then(|source| source.to_slice())?;
1808
1809        let buffer = (0..self.spec.size())
1810            .map(|offset| self.spec.source_offset(offset))
1811            .map(|source_offset| source[source_offset])
1812            .collect();
1813
1814        Ok(buffer)
1815    }
1816}
1817
1818impl<A: Access<T>, T: Number> Enqueue<Heap, T> for View<A, T> {
1819    type Buffer = Vec<T>;
1820
1821    fn enqueue(&self) -> Result<Self::Buffer, Error> {
1822        let source = self.access.read().and_then(|source| source.to_slice())?;
1823
1824        let buffer = (0..self.spec.size())
1825            .into_par_iter()
1826            .map(|offset| self.spec.source_offset(offset))
1827            .map(|source_offset| source[source_offset])
1828            .collect();
1829
1830        Ok(buffer)
1831    }
1832}
1833
1834impl<A: Access<T>, T: Number> Enqueue<Host, T> for View<A, T> {
1835    type Buffer = Buffer<T>;
1836
1837    fn enqueue(&self) -> Result<Self::Buffer, Error> {
1838        host_enqueue!(self, self.size() < VEC_MIN_SIZE, T)
1839    }
1840}
1841
1842impl<A: Access<T>, T: Number> ReadValue<Host, T> for View<A, T> {
1843    fn read_value(&self, offset: usize) -> Result<T, Error> {
1844        self.access.read_value(self.spec.source_offset(offset))
1845    }
1846}
1847
1848fn exec_dual<IT: Number, OT: Number>(
1849    zip: fn(IT, IT) -> OT,
1850    left: SliceConverter<IT>,
1851    right: SliceConverter<IT>,
1852) -> Result<StackVec<OT>, Error> {
1853    let output = left
1854        .iter()
1855        .copied()
1856        .zip(right.iter().copied())
1857        .map(|(l, r)| (zip)(l, r))
1858        .collect();
1859
1860    Ok(output)
1861}
1862
1863fn exec_dual_parallel<IT: Number, OT: Number>(
1864    zip: fn(IT, IT) -> OT,
1865    left: SliceConverter<IT>,
1866    right: SliceConverter<IT>,
1867) -> Result<Vec<OT>, Error> {
1868    let output = left
1869        .into_par_iter()
1870        .copied()
1871        .zip(right.into_par_iter().copied())
1872        .map(|(l, r)| (zip)(l, r))
1873        .collect();
1874
1875    Ok(output)
1876}
1877
1878#[inline]
1879fn try_join_read<'a, L, R, T>(
1880    left: &'a L,
1881    right: &'a R,
1882) -> Result<(SliceConverter<'a, T>, SliceConverter<'a, T>), Error>
1883where
1884    L: Access<T>,
1885    R: Access<T>,
1886    T: Number,
1887{
1888    let (l, r) = join(
1889        || left.read().and_then(|buf| buf.to_slice()),
1890        || right.read().and_then(|buf| buf.to_slice()),
1891    );
1892
1893    Ok((l?, r?))
1894}
1895
1896#[inline]
1897fn try_join_value<'a, L, R, T>(left: &'a L, right: &'a R, offset: usize) -> Result<(T, T), Error>
1898where
1899    L: Access<T>,
1900    R: Access<T>,
1901    T: Number,
1902{
1903    let (l, r) = join(|| left.read_value(offset), || right.read_value(offset));
1904
1905    Ok((l?, r?))
1906}