ha_ndarray/
ops.rs

1//! Array operations
2
3use crate::access::*;
4use crate::buffer::Buffer;
5#[cfg(feature = "opencl")]
6use crate::opencl;
7use crate::platform::{Platform, PlatformInstance};
8use crate::{
9    host, range_shape, strides_for, Axes, AxisRange, BufferConverter, CType, Error, Range, Shape,
10    Strides,
11};
12
13macro_rules! op_dispatch {
14    ($this:expr, $op:ident, $call:expr) => {
15        match $this {
16            #[cfg(feature = "opencl")]
17            Self::CL($op) => $call,
18            Self::Host($op) => $call,
19        }
20    };
21}
22
23macro_rules! op_enqueue {
24    ($this:expr, $t:ty) => {
25        match $this {
26            #[cfg(feature = "opencl")]
27            Self::CL(op) => Enqueue::<opencl::OpenCL, $t>::enqueue(op).map(Buffer::CL),
28            Self::Host(op) => Enqueue::<host::Host, $t>::enqueue(op).map(Buffer::Host),
29        }
30    };
31}
32
33pub trait Op: Send + Sync {
34    fn size(&self) -> usize;
35}
36
37pub trait Enqueue<P: PlatformInstance, T: CType>: Op {
38    type Buffer: Into<BufferConverter<'static, T>>;
39
40    fn enqueue(&self) -> Result<Self::Buffer, Error>;
41}
42
43pub trait ReadValue<P: PlatformInstance, T: CType>: Op {
44    fn read_value(&self, offset: usize) -> Result<T, Error>;
45}
46
47pub trait ReadOp<P, T>: Enqueue<P, T> + ReadValue<P, T>
48where
49    P: PlatformInstance,
50    T: CType,
51{
52}
53
54impl<O, P, T> ReadOp<P, T> for O
55where
56    O: Enqueue<P, T> + ReadValue<P, T>,
57    P: PlatformInstance,
58    T: CType,
59{
60}
61
62pub trait Write<P: PlatformInstance, T: CType>: Enqueue<P, T> {
63    fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error>;
64
65    fn write_value(&mut self, value: T) -> Result<(), Error>;
66
67    fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error>;
68}
69
70pub trait Construct<T: CType>: PlatformInstance {
71    type Range: Enqueue<Self, T>;
72
73    fn range(self, start: T, stop: T, size: usize) -> Result<AccessOp<Self::Range, Self>, Error>;
74}
75
76pub trait ElementwiseBoolean<L, R, T>: PlatformInstance {
77    type Op: ReadOp<Self, u8>;
78
79    fn and(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
80
81    fn or(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
82
83    fn xor(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
84}
85
86pub trait ElementwiseBooleanScalar<A, T>: PlatformInstance {
87    type Op: ReadOp<Self, u8>;
88
89    fn and_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
90
91    fn or_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
92
93    fn xor_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
94}
95
96pub trait ElementwiseCast<A, IT, OT>: PlatformInstance
97where
98    A: Access<IT>,
99    IT: CType,
100    OT: CType,
101{
102    type Op: ReadOp<Self, OT>;
103
104    fn cast(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
105}
106
107pub trait ElementwiseCompare<L, R, T>: PlatformInstance {
108    type Op: ReadOp<Self, u8>;
109
110    fn eq(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
111
112    fn ge(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
113
114    fn gt(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
115
116    fn le(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
117
118    fn lt(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
119
120    fn ne(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
121}
122
123pub trait ElementwiseScalarCompare<A, T>: PlatformInstance {
124    type Op: ReadOp<Self, u8>;
125
126    fn eq_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
127
128    fn ge_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
129
130    fn gt_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
131
132    fn le_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
133
134    fn lt_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
135
136    fn ne_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
137}
138
139pub trait ElementwiseDual<L, R, T>: PlatformInstance
140where
141    L: Access<T>,
142    R: Access<T>,
143    T: CType,
144{
145    type Op: ReadOp<Self, T>;
146
147    fn add(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
148
149    fn div(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
150
151    fn log(self, arg: L, base: R) -> Result<AccessOp<Self::Op, Self>, Error>;
152
153    fn mul(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
154
155    fn pow(self, arg: L, exp: R) -> Result<AccessOp<Self::Op, Self>, Error>;
156
157    fn rem(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
158
159    fn sub(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
160}
161
162pub trait ElementwiseScalar<A, T>: PlatformInstance
163where
164    A: Access<T>,
165    T: CType,
166{
167    type Op: ReadOp<Self, T>;
168
169    fn add_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
170
171    fn div_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
172
173    fn log_scalar(self, arg: A, base: T) -> Result<AccessOp<Self::Op, Self>, Error>;
174
175    fn mul_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
176
177    fn pow_scalar(self, arg: A, exp: T) -> Result<AccessOp<Self::Op, Self>, Error>;
178
179    fn rem_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
180
181    fn sub_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
182}
183
184pub trait ElementwiseNumeric<A, T>: PlatformInstance
185where
186    A: Access<T>,
187    T: CType,
188{
189    type Op: ReadOp<Self, u8>;
190
191    fn is_inf(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
192
193    fn is_nan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
194}
195
196pub trait ElementwiseTrig<A, T>: PlatformInstance
197where
198    A: Access<T>,
199    T: CType,
200{
201    type Op: ReadOp<Self, T::Float>;
202
203    fn sin(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
204
205    fn asin(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
206
207    fn sinh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
208
209    fn cos(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
210
211    fn acos(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
212
213    fn cosh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
214
215    fn tan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
216
217    fn atan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
218
219    fn tanh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
220}
221
222pub trait ElementwiseUnary<A, T>: PlatformInstance
223where
224    A: Access<T>,
225    T: CType,
226{
227    type Op: ReadOp<Self, T>;
228
229    fn abs(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
230
231    fn exp(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
232
233    fn ln(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
234
235    fn round(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
236}
237
238pub trait ElementwiseUnaryBoolean<A, T>: PlatformInstance
239where
240    A: Access<T>,
241    T: CType,
242{
243    type Op: ReadOp<Self, u8>;
244
245    fn not(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
246}
247
248pub trait GatherCond<A, L, R, T>: PlatformInstance
249where
250    A: Access<u8>,
251    L: Access<T>,
252    R: Access<T>,
253    T: CType,
254{
255    type Op: ReadOp<Self, T>;
256
257    fn cond(self, cond: A, then: L, or_else: R) -> Result<AccessOp<Self::Op, Self>, Error>;
258}
259
260pub trait LinAlgDual<L, R, T>: PlatformInstance
261where
262    L: Access<T>,
263    R: Access<T>,
264    T: CType,
265{
266    type Op: ReadOp<Self, T>;
267
268    fn matmul(self, left: L, right: R, dims: [usize; 4])
269        -> Result<AccessOp<Self::Op, Self>, Error>;
270}
271
272pub trait LinAlgUnary<A, T>: PlatformInstance
273where
274    A: Access<T>,
275    T: CType,
276{
277    type Op: ReadOp<Self, T>;
278
279    fn diag(
280        self,
281        access: A,
282        batch_size: usize,
283        dim: usize,
284    ) -> Result<AccessOp<Self::Op, Self>, Error>;
285}
286
287pub trait Random: PlatformInstance {
288    type Normal: Enqueue<Self, f32>;
289    type Uniform: Enqueue<Self, f32>;
290
291    fn random_normal(self, size: usize) -> Result<AccessOp<Self::Normal, Self>, Error>;
292
293    fn random_uniform(self, size: usize) -> Result<AccessOp<Self::Uniform, Self>, Error>;
294}
295
296pub trait ReduceAll<A, T>: PlatformInstance {
297    fn all(self, access: A) -> Result<bool, Error>;
298
299    fn any(self, access: A) -> Result<bool, Error>;
300
301    fn max(self, access: A) -> Result<T, Error>;
302
303    fn min(self, access: A) -> Result<T, Error>;
304
305    fn product(self, access: A) -> Result<T, Error>;
306
307    fn sum(self, access: A) -> Result<T, Error>;
308}
309
310pub trait ReduceAxes<A: Access<T>, T: CType>: PlatformInstance {
311    type Op: ReadOp<Self, T>;
312
313    fn max(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error>;
314
315    fn min(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error>;
316
317    fn product(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error>;
318
319    fn sum(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error>;
320}
321
322pub trait Transform<A: Access<T>, T: CType>: PlatformInstance {
323    type Broadcast: ReadOp<Self, T>;
324    type Slice: ReadOp<Self, T>;
325    type Transpose: ReadOp<Self, T>;
326
327    fn broadcast(
328        self,
329        access: A,
330        shape: Shape,
331        broadcast: Shape,
332    ) -> Result<AccessOp<Self::Broadcast, Self>, Error>;
333
334    fn slice(
335        self,
336        access: A,
337        shape: &[usize],
338        range: Range,
339    ) -> Result<AccessOp<Self::Slice, Self>, Error>;
340
341    fn transpose(
342        self,
343        access: A,
344        shape: Shape,
345        permutation: Axes,
346    ) -> Result<AccessOp<Self::Transpose, Self>, Error>;
347}
348
349pub enum Cast<A, IT, OT> {
350    #[cfg(feature = "opencl")]
351    CL(opencl::ops::Cast<A, IT, OT>),
352    Host(host::ops::Cast<A, IT, OT>),
353}
354
355impl<A: Access<IT>, IT: CType, OT: CType> Op for Cast<A, IT, OT> {
356    fn size(&self) -> usize {
357        op_dispatch!(self, op, op.size())
358    }
359}
360
361impl<A: Access<IT>, IT: CType, OT: CType> Enqueue<Platform, OT> for Cast<A, IT, OT> {
362    type Buffer = Buffer<OT>;
363
364    fn enqueue(&self) -> Result<Self::Buffer, Error> {
365        op_enqueue!(self, OT)
366    }
367}
368
369impl<A: Access<IT>, IT: CType, OT: CType> ReadValue<Platform, OT> for Cast<A, IT, OT> {
370    fn read_value(&self, offset: usize) -> Result<OT, Error> {
371        op_dispatch!(self, op, op.read_value(offset))
372    }
373}
374
375impl<A, IT, OT> From<host::ops::Cast<A, IT, OT>> for Cast<A, IT, OT> {
376    fn from(op: host::ops::Cast<A, IT, OT>) -> Cast<A, IT, OT> {
377        Self::Host(op)
378    }
379}
380
381#[cfg(feature = "opencl")]
382impl<A, IT, OT> From<opencl::ops::Cast<A, IT, OT>> for Cast<A, IT, OT> {
383    fn from(op: opencl::ops::Cast<A, IT, OT>) -> Cast<A, IT, OT> {
384        Self::CL(op)
385    }
386}
387
388pub enum Cond<A, L, R, T> {
389    #[cfg(feature = "opencl")]
390    CL(opencl::ops::Cond<A, L, R, T>),
391    Host(host::ops::Cond<A, L, R, T>),
392}
393
394impl<A, L, R, T> Op for Cond<A, L, R, T>
395where
396    A: Access<u8>,
397    L: Access<T>,
398    R: Access<T>,
399    T: CType,
400{
401    fn size(&self) -> usize {
402        op_dispatch!(self, op, op.size())
403    }
404}
405
406impl<A, L, R, T> Enqueue<Platform, T> for Cond<A, L, R, T>
407where
408    A: Access<u8>,
409    L: Access<T>,
410    R: Access<T>,
411    T: CType,
412{
413    type Buffer = Buffer<T>;
414
415    fn enqueue(&self) -> Result<Self::Buffer, Error> {
416        op_enqueue!(self, T)
417    }
418}
419
420impl<A, L, R, T> ReadValue<Platform, T> for Cond<A, L, R, T>
421where
422    A: Access<u8>,
423    L: Access<T>,
424    R: Access<T>,
425    T: CType,
426{
427    fn read_value(&self, offset: usize) -> Result<T, Error> {
428        op_dispatch!(self, op, op.read_value(offset))
429    }
430}
431
432impl<A, L, R, T> From<host::ops::Cond<A, L, R, T>> for Cond<A, L, R, T> {
433    fn from(op: host::ops::Cond<A, L, R, T>) -> Self {
434        Self::Host(op)
435    }
436}
437
438#[cfg(feature = "opencl")]
439impl<A, L, R, T> From<opencl::ops::Cond<A, L, R, T>> for Cond<A, L, R, T> {
440    fn from(op: opencl::ops::Cond<A, L, R, T>) -> Self {
441        Self::CL(op)
442    }
443}
444
445pub enum Dual<L, R, IT, OT> {
446    #[cfg(feature = "opencl")]
447    CL(opencl::ops::Dual<L, R, IT, OT>),
448    Host(host::ops::Dual<L, R, IT, OT>),
449}
450
451impl<L, R, IT, OT> Op for Dual<L, R, IT, OT>
452where
453    L: Access<IT>,
454    R: Access<IT>,
455    IT: CType,
456    OT: CType,
457{
458    fn size(&self) -> usize {
459        op_dispatch!(self, op, op.size())
460    }
461}
462
463impl<L, R, IT, OT> Enqueue<Platform, OT> for Dual<L, R, IT, OT>
464where
465    L: Access<IT>,
466    R: Access<IT>,
467    IT: CType,
468    OT: CType,
469{
470    type Buffer = Buffer<OT>;
471
472    fn enqueue(&self) -> Result<Self::Buffer, Error> {
473        op_enqueue!(self, OT)
474    }
475}
476
477impl<L, R, IT, OT> ReadValue<Platform, OT> for Dual<L, R, IT, OT>
478where
479    L: Access<IT>,
480    R: Access<IT>,
481    IT: CType,
482    OT: CType,
483{
484    fn read_value(&self, offset: usize) -> Result<OT, Error> {
485        op_dispatch!(self, op, op.read_value(offset))
486    }
487}
488
489#[cfg(feature = "opencl")]
490impl<L, R, IT, OT> From<opencl::ops::Dual<L, R, IT, OT>> for Dual<L, R, IT, OT> {
491    fn from(op: opencl::ops::Dual<L, R, IT, OT>) -> Self {
492        Self::CL(op)
493    }
494}
495
496impl<L, R, IT, OT> From<host::ops::Dual<L, R, IT, OT>> for Dual<L, R, IT, OT> {
497    fn from(op: host::ops::Dual<L, R, IT, OT>) -> Self {
498        Self::Host(op)
499    }
500}
501
502pub enum Linear<T> {
503    #[cfg(feature = "opencl")]
504    CL(opencl::ops::Linear<T>),
505    Host(host::ops::Linear<T>),
506}
507
508#[cfg(feature = "opencl")]
509impl<T> From<opencl::ops::Linear<T>> for Linear<T> {
510    fn from(op: opencl::ops::Linear<T>) -> Self {
511        Self::CL(op)
512    }
513}
514
515impl<T> From<host::ops::Linear<T>> for Linear<T> {
516    fn from(op: host::ops::Linear<T>) -> Self {
517        Self::Host(op)
518    }
519}
520
521impl<T: Send + Sync> Op for Linear<T> {
522    fn size(&self) -> usize {
523        op_dispatch!(self, op, op.size())
524    }
525}
526
527impl<T: CType> Enqueue<Platform, T> for Linear<T> {
528    type Buffer = Buffer<T>;
529
530    fn enqueue(&self) -> Result<Self::Buffer, Error> {
531        op_enqueue!(self, T)
532    }
533}
534
535impl<T: CType> ReadValue<Platform, T> for Linear<T> {
536    fn read_value(&self, offset: usize) -> Result<T, Error> {
537        op_dispatch!(self, op, op.read_value(offset))
538    }
539}
540
541pub enum MatDiag<A, T> {
542    #[cfg(feature = "opencl")]
543    CL(opencl::ops::MatDiag<A, T>),
544    Host(host::ops::MatDiag<A, T>),
545}
546
547impl<A: Access<T>, T: CType> Op for MatDiag<A, T> {
548    fn size(&self) -> usize {
549        op_dispatch!(self, op, op.size())
550    }
551}
552
553impl<A: Access<T>, T: CType> Enqueue<Platform, T> for MatDiag<A, T> {
554    type Buffer = Buffer<T>;
555
556    fn enqueue(&self) -> Result<Self::Buffer, Error> {
557        op_enqueue!(self, T)
558    }
559}
560
561impl<A: Access<T>, T: CType> ReadValue<Platform, T> for MatDiag<A, T> {
562    fn read_value(&self, offset: usize) -> Result<T, Error> {
563        op_dispatch!(self, op, op.read_value(offset))
564    }
565}
566
567impl<A, T> From<host::ops::MatDiag<A, T>> for MatDiag<A, T> {
568    fn from(op: host::ops::MatDiag<A, T>) -> Self {
569        Self::Host(op)
570    }
571}
572
573#[cfg(feature = "opencl")]
574impl<A, T> From<opencl::ops::MatDiag<A, T>> for MatDiag<A, T> {
575    fn from(op: opencl::ops::MatDiag<A, T>) -> Self {
576        Self::CL(op)
577    }
578}
579
580pub enum MatMul<L, R, T> {
581    #[cfg(feature = "opencl")]
582    CL(opencl::ops::MatMul<L, R, T>),
583    Host(host::ops::MatMul<L, R, T>),
584}
585
586impl<L, R, T> Op for MatMul<L, R, T>
587where
588    L: Access<T>,
589    R: Access<T>,
590    T: CType,
591{
592    fn size(&self) -> usize {
593        op_dispatch!(self, op, op.size())
594    }
595}
596
597impl<L, R, T> Enqueue<Platform, T> for MatMul<L, R, T>
598where
599    L: Access<T>,
600    R: Access<T>,
601    T: CType,
602{
603    type Buffer = Buffer<T>;
604
605    fn enqueue(&self) -> Result<Self::Buffer, Error> {
606        op_enqueue!(self, T)
607    }
608}
609
610impl<L, R, T> ReadValue<Platform, T> for MatMul<L, R, T>
611where
612    L: Access<T>,
613    R: Access<T>,
614    T: CType,
615{
616    fn read_value(&self, offset: usize) -> Result<T, Error> {
617        op_dispatch!(self, op, op.read_value(offset))
618    }
619}
620
621#[cfg(feature = "opencl")]
622impl<L, R, T> From<opencl::ops::MatMul<L, R, T>> for MatMul<L, R, T> {
623    fn from(op: opencl::ops::MatMul<L, R, T>) -> Self {
624        Self::CL(op)
625    }
626}
627
628impl<L, R, T> From<host::ops::MatMul<L, R, T>> for MatMul<L, R, T> {
629    fn from(op: host::ops::MatMul<L, R, T>) -> Self {
630        Self::Host(op)
631    }
632}
633
634pub enum RandomNormal {
635    #[cfg(feature = "opencl")]
636    CL(opencl::ops::RandomNormal),
637    Host(host::ops::RandomNormal),
638}
639
640#[cfg(feature = "opencl")]
641impl From<opencl::ops::RandomNormal> for RandomNormal {
642    fn from(op: opencl::ops::RandomNormal) -> Self {
643        Self::CL(op)
644    }
645}
646
647impl From<host::ops::RandomNormal> for RandomNormal {
648    fn from(op: host::ops::RandomNormal) -> Self {
649        Self::Host(op)
650    }
651}
652
653pub enum RandomUniform {
654    #[cfg(feature = "opencl")]
655    CL(opencl::ops::RandomUniform),
656    Host(host::ops::RandomUniform),
657}
658
659#[cfg(feature = "opencl")]
660impl From<opencl::ops::RandomUniform> for RandomUniform {
661    fn from(op: opencl::ops::RandomUniform) -> Self {
662        Self::CL(op)
663    }
664}
665
666impl From<host::ops::RandomUniform> for RandomUniform {
667    fn from(op: host::ops::RandomUniform) -> Self {
668        Self::Host(op)
669    }
670}
671
672macro_rules! impl_random {
673    ($t:ty) => {
674        impl Op for $t {
675            fn size(&self) -> usize {
676                op_dispatch!(self, op, op.size())
677            }
678        }
679
680        impl Enqueue<Platform, f32> for $t {
681            type Buffer = Buffer<f32>;
682
683            fn enqueue(&self) -> Result<Self::Buffer, Error> {
684                op_enqueue!(self, f32)
685            }
686        }
687
688        impl ReadValue<Platform, f32> for $t {
689            fn read_value(&self, offset: usize) -> Result<f32, Error> {
690                op_dispatch!(self, op, op.read_value(offset))
691            }
692        }
693    };
694}
695
696impl_random!(RandomNormal);
697impl_random!(RandomUniform);
698
699macro_rules! impl_unary {
700    ($op:ty, $t:ty) => {
701        impl<A: Access<T>, T: CType> Op for $op {
702            fn size(&self) -> usize {
703                op_dispatch!(self, op, op.size())
704            }
705        }
706
707        impl<A: Access<T>, T: CType> Enqueue<Platform, $t> for $op {
708            type Buffer = Buffer<$t>;
709
710            fn enqueue(&self) -> Result<Self::Buffer, Error> {
711                op_enqueue!(self, $t)
712            }
713        }
714
715        impl<A: Access<T>, T: CType> ReadValue<Platform, $t> for $op {
716            fn read_value(&self, offset: usize) -> Result<$t, Error> {
717                op_dispatch!(self, op, op.read_value(offset))
718            }
719        }
720    };
721}
722
723pub enum Reduce<A, T: CType> {
724    #[cfg(feature = "opencl")]
725    CL(opencl::ops::Reduce<A, T>),
726    Host(host::ops::Reduce<A, T>),
727}
728
729impl_unary!(Reduce<A, T>, T);
730
731impl<A, T: CType> From<host::ops::Reduce<A, T>> for Reduce<A, T> {
732    fn from(op: host::ops::Reduce<A, T>) -> Self {
733        Self::Host(op)
734    }
735}
736
737#[cfg(feature = "opencl")]
738impl<A, T: CType> From<opencl::ops::Reduce<A, T>> for Reduce<A, T> {
739    fn from(op: opencl::ops::Reduce<A, T>) -> Self {
740        Self::CL(op)
741    }
742}
743
744#[derive(Clone, Eq, PartialEq, Hash)]
745pub struct SliceSpec {
746    pub range: Range,
747    pub shape: Shape,
748    pub strides: Strides,
749    pub source_strides: Strides,
750}
751
752impl SliceSpec {
753    pub fn new(source_shape: &[usize], range: Range) -> Self {
754        debug_assert!(range.len() <= source_shape.len());
755
756        let shape = range_shape(source_shape, &range);
757        let strides = strides_for(&shape, shape.len()).collect();
758        let source_strides = strides_for(source_shape, source_shape.len()).collect();
759
760        Self {
761            range,
762            shape,
763            strides,
764            source_strides,
765        }
766    }
767
768    pub fn source_offset(&self, offset: usize) -> usize {
769        debug_assert!(!self.shape.is_empty());
770        debug_assert_eq!(self.shape.len(), self.strides.len());
771
772        let mut coord = self
773            .strides
774            .iter()
775            .copied()
776            .zip(&self.shape)
777            .map(|(stride, dim)| {
778                if stride == 0 {
779                    0
780                } else {
781                    (offset / stride) % dim
782                }
783            });
784
785        let mut offset = 0;
786        for (stride, bound) in self.source_strides.iter().zip(self.range.iter()) {
787            let i = match bound {
788                AxisRange::At(i) => *i,
789                AxisRange::In(start, stop, step) => {
790                    let i = start + (coord.next().expect("i") * step);
791                    debug_assert!(i < *stop);
792                    i
793                }
794                AxisRange::Of(indices) => indices[coord.next().expect("i")],
795            };
796
797            offset += i * stride;
798        }
799
800        offset
801    }
802
803    pub fn size(&self) -> usize {
804        self.shape.iter().product()
805    }
806}
807
808pub enum Scalar<A, IT, OT> {
809    #[cfg(feature = "opencl")]
810    CL(opencl::ops::Scalar<A, IT, OT>),
811    Host(host::ops::Scalar<A, IT, OT>),
812}
813
814impl<A, IT, OT> Op for Scalar<A, IT, OT>
815where
816    A: Access<IT>,
817    IT: CType,
818    OT: CType,
819{
820    fn size(&self) -> usize {
821        op_dispatch!(self, op, op.size())
822    }
823}
824
825impl<A, IT, OT> Enqueue<Platform, OT> for Scalar<A, IT, OT>
826where
827    A: Access<IT>,
828    IT: CType,
829    OT: CType,
830{
831    type Buffer = Buffer<OT>;
832
833    fn enqueue(&self) -> Result<Self::Buffer, Error> {
834        op_enqueue!(self, OT)
835    }
836}
837
838impl<A, IT, OT> ReadValue<Platform, OT> for Scalar<A, IT, OT>
839where
840    A: Access<IT>,
841    IT: CType,
842    OT: CType,
843{
844    fn read_value(&self, offset: usize) -> Result<OT, Error> {
845        op_dispatch!(self, op, op.read_value(offset))
846    }
847}
848
849#[cfg(feature = "opencl")]
850impl<A, IT, OT> From<opencl::ops::Scalar<A, IT, OT>> for Scalar<A, IT, OT> {
851    fn from(op: opencl::ops::Scalar<A, IT, OT>) -> Self {
852        Self::CL(op)
853    }
854}
855
856impl<A, IT, OT> From<host::ops::Scalar<A, IT, OT>> for Scalar<A, IT, OT> {
857    fn from(op: host::ops::Scalar<A, IT, OT>) -> Self {
858        Self::Host(op)
859    }
860}
861
862pub enum Slice<A, T> {
863    #[cfg(feature = "opencl")]
864    CL(opencl::ops::Slice<A, T>),
865    Host(host::ops::Slice<A, T>),
866}
867
868impl_unary!(Slice<A, T>, T);
869
870#[cfg(feature = "opencl")]
871impl<A, T> Write<Platform, T> for Slice<A, T>
872where
873    A: AccessMut<T> + std::fmt::Debug,
874    T: CType,
875{
876    fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
877        match self {
878            Self::CL(op) => Write::<opencl::OpenCL, T>::write(op, data),
879            Self::Host(op) => Write::<host::Host, T>::write(op, data),
880        }
881    }
882
883    fn write_value(&mut self, value: T) -> Result<(), Error> {
884        match self {
885            Self::CL(op) => Write::<opencl::OpenCL, T>::write_value(op, value),
886            Self::Host(op) => Write::<host::Host, T>::write_value(op, value),
887        }
888    }
889
890    fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
891        match self {
892            Self::CL(op) => Write::<opencl::OpenCL, T>::write_value_at(op, offset, value),
893            Self::Host(op) => Write::<host::Host, T>::write_value_at(op, offset, value),
894        }
895    }
896}
897
898#[cfg(not(feature = "opencl"))]
899impl<A, T> Write<Platform, T> for Slice<A, T>
900where
901    T: CType,
902    A: AccessMut<T>,
903{
904    fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
905        match self {
906            Self::Host(op) => Write::<host::Host, T>::write(op, data),
907        }
908    }
909
910    fn write_value(&mut self, value: T) -> Result<(), Error> {
911        match self {
912            Self::Host(op) => Write::<host::Host, T>::write_value(op, value),
913        }
914    }
915
916    fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
917        match self {
918            Self::Host(op) => Write::<host::Host, T>::write_value_at(op, offset, value),
919        }
920    }
921}
922
923#[cfg(feature = "opencl")]
924impl<A, T> From<opencl::ops::Slice<A, T>> for Slice<A, T> {
925    fn from(op: opencl::ops::Slice<A, T>) -> Self {
926        Self::CL(op)
927    }
928}
929
930impl<A, T> From<host::ops::Slice<A, T>> for Slice<A, T> {
931    fn from(op: host::ops::Slice<A, T>) -> Self {
932        Self::Host(op)
933    }
934}
935
936pub enum Unary<A, IT, OT> {
937    #[cfg(feature = "opencl")]
938    CL(opencl::ops::Unary<A, IT, OT>),
939    Host(host::ops::Unary<A, IT, OT>),
940}
941
942impl<A, IT, OT> Op for Unary<A, IT, OT>
943where
944    A: Access<IT>,
945    IT: CType,
946    OT: CType,
947{
948    fn size(&self) -> usize {
949        op_dispatch!(self, op, op.size())
950    }
951}
952
953impl<A, IT, OT> Enqueue<Platform, OT> for Unary<A, IT, OT>
954where
955    A: Access<IT>,
956    IT: CType,
957    OT: CType,
958{
959    type Buffer = Buffer<OT>;
960
961    fn enqueue(&self) -> Result<Self::Buffer, Error> {
962        op_enqueue!(self, OT)
963    }
964}
965
966impl<A, IT, OT> ReadValue<Platform, OT> for Unary<A, IT, OT>
967where
968    A: Access<IT>,
969    IT: CType,
970    OT: CType,
971{
972    fn read_value(&self, offset: usize) -> Result<OT, Error> {
973        op_dispatch!(self, op, op.read_value(offset))
974    }
975}
976
977impl<A, IT, OT> From<host::ops::Unary<A, IT, OT>> for Unary<A, IT, OT> {
978    fn from(op: host::ops::Unary<A, IT, OT>) -> Self {
979        Self::Host(op)
980    }
981}
982
983#[cfg(feature = "opencl")]
984impl<A, IT, OT> From<opencl::ops::Unary<A, IT, OT>> for Unary<A, IT, OT> {
985    fn from(op: opencl::ops::Unary<A, IT, OT>) -> Self {
986        Self::CL(op)
987    }
988}
989
990#[derive(Clone, Eq, PartialEq, Hash)]
991pub struct ViewSpec {
992    pub shape: Shape,
993    pub strides: Strides,
994    pub source_strides: Strides,
995}
996
997impl ViewSpec {
998    pub fn new(shape: Shape, source_strides: Strides) -> Self {
999        let strides = strides_for(&shape, shape.len()).collect();
1000
1001        Self {
1002            shape,
1003            strides,
1004            source_strides,
1005        }
1006    }
1007
1008    pub fn source_offset(&self, offset: usize) -> usize {
1009        debug_assert!(offset < self.size());
1010
1011        let source_offset = self
1012            .strides
1013            .iter()
1014            .copied()
1015            .zip(self.shape.iter().copied())
1016            .rev()
1017            .take(self.source_strides.len())
1018            .map(|(stride, dim)| {
1019                if stride == 0 {
1020                    0
1021                } else {
1022                    (offset / stride) % dim
1023                }
1024            }) // coord
1025            .zip(self.source_strides.iter().rev().copied())
1026            .map(|(i, source_stride)| i * source_stride)
1027            .sum::<usize>();
1028
1029        source_offset
1030    }
1031
1032    pub fn size(&self) -> usize {
1033        self.shape.iter().product()
1034    }
1035}
1036
1037pub enum View<A, T> {
1038    #[cfg(feature = "opencl")]
1039    CL(opencl::ops::View<A, T>),
1040    Host(host::ops::View<A, T>),
1041}
1042
1043impl_unary!(View<A, T>, T);
1044
1045#[cfg(feature = "opencl")]
1046impl<A, T> From<opencl::ops::View<A, T>> for View<A, T> {
1047    fn from(op: opencl::ops::View<A, T>) -> Self {
1048        Self::CL(op)
1049    }
1050}
1051
1052impl<A, T> From<host::ops::View<A, T>> for View<A, T> {
1053    fn from(op: host::ops::View<A, T>) -> Self {
1054        Self::Host(op)
1055    }
1056}