Skip to main content

ha_ndarray/ops/
mod.rs

1//! Array operations
2
3use std::marker::PhantomData;
4
5use crate::access::*;
6use crate::buffer::Buffer;
7#[cfg(feature = "opencl")]
8use crate::opencl;
9use crate::platform::{Platform, PlatformInstance};
10use crate::{
11    host, range_shape, strides_for, Axes, AxisRange, BufferConverter, Error, Float, Number, Range,
12    Real, Shape, Strides,
13};
14
15#[cfg(feature = "complex")]
16pub mod complex;
17
18macro_rules! op_dispatch {
19    ($this:expr, $op:ident, $call:expr) => {
20        match $this {
21            #[cfg(feature = "opencl")]
22            Self::CL($op) => $call,
23            Self::Host($op) => $call,
24        }
25    };
26}
27
28macro_rules! op_enqueue {
29    ($this:expr, $t:ty) => {
30        match $this {
31            #[cfg(feature = "opencl")]
32            Self::CL(op) => Enqueue::<opencl::OpenCL, $t>::enqueue(op).map(Buffer::CL),
33            Self::Host(op) => Enqueue::<host::Host, $t>::enqueue(op).map(Buffer::Host),
34        }
35    };
36}
37
38pub trait Op: Send + Sync {
39    fn size(&self) -> usize;
40}
41
42pub trait Enqueue<P: PlatformInstance, T: Number>: Op {
43    type Buffer: Into<BufferConverter<'static, T>>;
44
45    fn enqueue(&self) -> Result<Self::Buffer, Error>;
46}
47
48pub trait ReadValue<P: PlatformInstance, T: Number>: Op {
49    fn read_value(&self, offset: usize) -> Result<T, Error>;
50}
51
52pub trait ReadOp<P, T>: Enqueue<P, T> + ReadValue<P, T>
53where
54    P: PlatformInstance,
55    T: Number,
56{
57}
58
59impl<O, P, T> ReadOp<P, T> for O
60where
61    O: Enqueue<P, T> + ReadValue<P, T>,
62    P: PlatformInstance,
63    T: Number,
64{
65}
66
67pub trait Write<P: PlatformInstance, T: Number>: Enqueue<P, T> {
68    fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error>;
69
70    fn write_value(&mut self, value: T) -> Result<(), Error>;
71
72    fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error>;
73}
74
75pub trait ConstructConcat<A, T>: PlatformInstance
76where
77    A: Access<T>,
78    T: Number,
79{
80    type Op: ReadOp<Self, T>;
81
82    fn concat(self, data: Vec<A>) -> Result<AccessOp<Self::Op, Self>, Error>;
83}
84
85pub trait ConstructRange<T: Number>: PlatformInstance {
86    type Range: Enqueue<Self, T>;
87
88    fn range(self, start: T, stop: T, size: usize) -> Result<AccessOp<Self::Range, Self>, Error>;
89}
90
91pub trait ElementwiseAbs<A, T>: PlatformInstance
92where
93    A: Access<T>,
94    T: Number,
95{
96    type Op: ReadOp<Self, T::Abs>;
97
98    fn abs(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
99}
100
101pub trait ElementwiseBoolean<L, R, T>: PlatformInstance {
102    type Op: ReadOp<Self, u8>;
103
104    fn and(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
105
106    fn or(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
107
108    fn xor(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
109}
110
111pub trait ElementwiseBooleanScalar<A, T>: PlatformInstance {
112    type Op: ReadOp<Self, u8>;
113
114    fn and_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
115
116    fn or_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
117
118    fn xor_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
119}
120
121pub trait ElementwiseCast<A, IT, OT>: PlatformInstance
122where
123    A: Access<IT>,
124    IT: Number,
125    OT: Number,
126{
127    type Op: ReadOp<Self, OT>;
128
129    fn cast(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
130}
131
132pub trait ElementwiseCompare<L, R, T>: PlatformInstance {
133    type Op: ReadOp<Self, u8>;
134
135    fn eq(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
136
137    fn ge(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>
138    where
139        T: Real;
140
141    fn gt(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>
142    where
143        T: Real;
144
145    fn le(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>
146    where
147        T: Real;
148
149    fn lt(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>
150    where
151        T: Real;
152
153    fn ne(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
154}
155
156pub trait ElementwiseCompareScalar<A, T>: PlatformInstance {
157    type Op: ReadOp<Self, u8>;
158
159    fn eq_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
160
161    fn ge_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>
162    where
163        T: Real;
164
165    fn gt_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>
166    where
167        T: Real;
168
169    fn le_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>
170    where
171        T: Real;
172
173    fn lt_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>
174    where
175        T: Real;
176
177    fn ne_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
178}
179
180pub trait ElementwiseDual<L, R, T>: PlatformInstance
181where
182    L: Access<T>,
183    R: Access<T>,
184    T: Number,
185{
186    type Op: ReadOp<Self, T>;
187
188    fn add(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
189
190    fn div(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
191
192    fn log(self, arg: L, base: R) -> Result<AccessOp<Self::Op, Self>, Error>
193    where
194        T: Float;
195
196    fn mul(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
197
198    fn pow(self, arg: L, exp: R) -> Result<AccessOp<Self::Op, Self>, Error>;
199
200    fn rem(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>
201    where
202        T: Real;
203
204    fn sub(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
205}
206
207pub trait ElementwiseScalar<A, T>: PlatformInstance
208where
209    A: Access<T>,
210    T: Number,
211{
212    type Op: ReadOp<Self, T>;
213
214    fn add_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
215
216    fn div_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
217
218    fn log_scalar(self, arg: A, base: T) -> Result<AccessOp<Self::Op, Self>, Error>
219    where
220        T: Float;
221
222    fn mul_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
223
224    fn pow_scalar(self, arg: A, exp: T) -> Result<AccessOp<Self::Op, Self>, Error>;
225
226    fn rem_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>
227    where
228        T: Real;
229
230    fn sub_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
231}
232
233pub trait ElementwiseNumeric<A, T>: PlatformInstance
234where
235    A: Access<T>,
236    T: Number,
237{
238    type Op: ReadOp<Self, u8>;
239
240    fn is_inf(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
241
242    fn is_nan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
243}
244
245pub trait ElementwiseTrig<A, T>: PlatformInstance
246where
247    A: Access<T>,
248    T: Float,
249{
250    type Op: ReadOp<Self, T>;
251
252    fn sin(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
253
254    fn asin(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
255
256    fn sinh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
257
258    fn cos(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
259
260    fn acos(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
261
262    fn cosh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
263
264    fn tan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
265
266    fn atan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
267
268    fn tanh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
269}
270
271pub trait ElementwiseUnary<A, T>: PlatformInstance
272where
273    A: Access<T>,
274    T: Float,
275{
276    type Op: ReadOp<Self, T>;
277
278    fn exp(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
279
280    fn ln(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
281
282    fn round(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>
283    where
284        T: Real;
285}
286
287pub trait ElementwiseUnaryBoolean<A, T>: PlatformInstance
288where
289    A: Access<T>,
290    T: Number,
291{
292    type Op: ReadOp<Self, u8>;
293
294    fn not(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
295}
296
297pub trait GatherCond<A, L, R, T>: PlatformInstance
298where
299    A: Access<u8>,
300    L: Access<T>,
301    R: Access<T>,
302    T: Number,
303{
304    type Op: ReadOp<Self, T>;
305
306    fn cond(self, cond: A, then: L, or_else: R) -> Result<AccessOp<Self::Op, Self>, Error>;
307}
308
309pub trait LinAlgDual<L, R, T>: PlatformInstance
310where
311    L: Access<T>,
312    R: Access<T>,
313    T: Number,
314{
315    type Op: ReadOp<Self, T>;
316
317    fn matmul(self, left: L, right: R, dims: [usize; 4])
318        -> Result<AccessOp<Self::Op, Self>, Error>;
319}
320
321pub trait LinAlgUnary<A, T>: PlatformInstance
322where
323    A: Access<T>,
324    T: Number,
325{
326    type Op: ReadOp<Self, T>;
327
328    fn diag(
329        self,
330        access: A,
331        batch_size: usize,
332        dim: usize,
333    ) -> Result<AccessOp<Self::Op, Self>, Error>;
334}
335
336pub trait Random: PlatformInstance {
337    type Normal: Enqueue<Self, f32>;
338    type Uniform: Enqueue<Self, f32>;
339
340    fn random_normal(self, size: usize) -> Result<AccessOp<Self::Normal, Self>, Error>;
341
342    fn random_uniform(self, size: usize) -> Result<AccessOp<Self::Uniform, Self>, Error>;
343}
344
345pub trait ReduceAll<A, T>: PlatformInstance {
346    fn all(self, access: A) -> Result<bool, Error>;
347
348    fn any(self, access: A) -> Result<bool, Error>;
349
350    fn max(self, access: A) -> Result<T, Error>
351    where
352        T: Real;
353
354    fn min(self, access: A) -> Result<T, Error>
355    where
356        T: Real;
357
358    fn product(self, access: A) -> Result<T, Error>;
359
360    fn sum(self, access: A) -> Result<T, Error>;
361}
362
363pub trait ReduceAxes<A: Access<T>, T: Number>: PlatformInstance {
364    type Op: ReadOp<Self, T>;
365
366    fn max(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error>
367    where
368        T: Real;
369
370    fn min(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error>
371    where
372        T: Real;
373
374    fn product(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error>;
375
376    fn sum(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error>;
377}
378
379pub trait Transform<A: Access<T>, T: Number>: PlatformInstance {
380    type Broadcast: ReadOp<Self, T>;
381    type Flip: ReadOp<Self, T>;
382    type Slice: ReadOp<Self, T>;
383    type Transpose: ReadOp<Self, T>;
384
385    fn broadcast(
386        self,
387        access: A,
388        shape: Shape,
389        broadcast: Shape,
390    ) -> Result<AccessOp<Self::Broadcast, Self>, Error>;
391
392    fn flip(
393        self,
394        access: A,
395        shape: Shape,
396        axis: usize,
397    ) -> Result<AccessOp<Self::Flip, Self>, Error>;
398
399    fn slice(
400        self,
401        access: A,
402        shape: &[usize],
403        range: Range,
404    ) -> Result<AccessOp<Self::Slice, Self>, Error>;
405
406    fn transpose(
407        self,
408        access: A,
409        shape: Shape,
410        permutation: Axes,
411    ) -> Result<AccessOp<Self::Transpose, Self>, Error>;
412}
413
414pub enum Cast<A, IT, OT> {
415    #[cfg(feature = "opencl")]
416    CL(opencl::ops::Cast<A, IT, OT>),
417    Host(host::ops::Cast<A, IT, OT>),
418}
419
420impl<A: Access<IT>, IT: Number, OT: Number> Op for Cast<A, IT, OT> {
421    fn size(&self) -> usize {
422        op_dispatch!(self, op, op.size())
423    }
424}
425
426impl<A: Access<IT>, IT: Number, OT: Number> Enqueue<Platform, OT> for Cast<A, IT, OT> {
427    type Buffer = Buffer<OT>;
428
429    fn enqueue(&self) -> Result<Self::Buffer, Error> {
430        op_enqueue!(self, OT)
431    }
432}
433
434impl<A: Access<IT>, IT: Number, OT: Number> ReadValue<Platform, OT> for Cast<A, IT, OT> {
435    fn read_value(&self, offset: usize) -> Result<OT, Error> {
436        op_dispatch!(self, op, op.read_value(offset))
437    }
438}
439
440impl<A, IT, OT> From<host::ops::Cast<A, IT, OT>> for Cast<A, IT, OT> {
441    fn from(op: host::ops::Cast<A, IT, OT>) -> Cast<A, IT, OT> {
442        Self::Host(op)
443    }
444}
445
446#[cfg(feature = "opencl")]
447impl<A, IT, OT> From<opencl::ops::Cast<A, IT, OT>> for Cast<A, IT, OT> {
448    fn from(op: opencl::ops::Cast<A, IT, OT>) -> Cast<A, IT, OT> {
449        Self::CL(op)
450    }
451}
452
453pub struct Concat<A, T> {
454    data: Vec<A>,
455    dtype: PhantomData<T>,
456}
457
458impl<A, T> Concat<A, T> {
459    pub fn new(data: Vec<A>) -> Self {
460        Self {
461            data,
462            dtype: PhantomData,
463        }
464    }
465
466    pub(crate) fn data(&self) -> &[A] {
467        &self.data
468    }
469}
470
471impl<A, T> Op for Concat<A, T>
472where
473    A: Access<T>,
474    T: Number,
475{
476    fn size(&self) -> usize {
477        self.data.iter().map(|access| access.size()).sum()
478    }
479}
480
481impl<A, T> Enqueue<Platform, T> for Concat<A, T>
482where
483    A: Access<T>,
484    T: Number,
485{
486    type Buffer = Buffer<T>;
487
488    fn enqueue(&self) -> Result<Self::Buffer, Error> {
489        match Platform::select(self.size()) {
490            #[cfg(feature = "opencl")]
491            Platform::CL(_) => Enqueue::<opencl::OpenCL, T>::enqueue(self).map(Buffer::from),
492            Platform::Host(_) => Enqueue::<host::Host, T>::enqueue(self).map(Buffer::from),
493        }
494    }
495}
496
497impl<A, T> ReadValue<Platform, T> for Concat<A, T>
498where
499    A: Access<T>,
500    T: Number,
501{
502    fn read_value(&self, offset: usize) -> Result<T, Error> {
503        let mut start = 0;
504
505        for access in &self.data {
506            let end = start + access.size();
507            if offset < end {
508                return access.read_value(offset - start);
509            }
510            start = end;
511        }
512
513        Err(Error::bounds(format!(
514            "offset {} is out of bounds for a concatenation of size",
515            self.size()
516        )))
517    }
518}
519
520pub enum Cond<A, L, R, T> {
521    #[cfg(feature = "opencl")]
522    CL(opencl::ops::Cond<A, L, R, T>),
523    Host(host::ops::Cond<A, L, R, T>),
524}
525
526impl<A, L, R, T> Op for Cond<A, L, R, T>
527where
528    A: Access<u8>,
529    L: Access<T>,
530    R: Access<T>,
531    T: Number,
532{
533    fn size(&self) -> usize {
534        op_dispatch!(self, op, op.size())
535    }
536}
537
538impl<A, L, R, T> Enqueue<Platform, T> for Cond<A, L, R, T>
539where
540    A: Access<u8>,
541    L: Access<T>,
542    R: Access<T>,
543    T: Number,
544{
545    type Buffer = Buffer<T>;
546
547    fn enqueue(&self) -> Result<Self::Buffer, Error> {
548        op_enqueue!(self, T)
549    }
550}
551
552impl<A, L, R, T> ReadValue<Platform, T> for Cond<A, L, R, T>
553where
554    A: Access<u8>,
555    L: Access<T>,
556    R: Access<T>,
557    T: Number,
558{
559    fn read_value(&self, offset: usize) -> Result<T, Error> {
560        op_dispatch!(self, op, op.read_value(offset))
561    }
562}
563
564impl<A, L, R, T> From<host::ops::Cond<A, L, R, T>> for Cond<A, L, R, T> {
565    fn from(op: host::ops::Cond<A, L, R, T>) -> Self {
566        Self::Host(op)
567    }
568}
569
570#[cfg(feature = "opencl")]
571impl<A, L, R, T> From<opencl::ops::Cond<A, L, R, T>> for Cond<A, L, R, T> {
572    fn from(op: opencl::ops::Cond<A, L, R, T>) -> Self {
573        Self::CL(op)
574    }
575}
576
577pub enum Dual<L, R, IT, OT> {
578    #[cfg(feature = "opencl")]
579    CL(opencl::ops::Dual<L, R, IT, OT>),
580    Host(host::ops::Dual<L, R, IT, OT>),
581}
582
583impl<L, R, IT, OT> Op for Dual<L, R, IT, OT>
584where
585    L: Access<IT>,
586    R: Access<IT>,
587    IT: Number,
588    OT: Number,
589{
590    fn size(&self) -> usize {
591        op_dispatch!(self, op, op.size())
592    }
593}
594
595impl<L, R, IT, OT> Enqueue<Platform, OT> for Dual<L, R, IT, OT>
596where
597    L: Access<IT>,
598    R: Access<IT>,
599    IT: Number,
600    OT: Number,
601{
602    type Buffer = Buffer<OT>;
603
604    fn enqueue(&self) -> Result<Self::Buffer, Error> {
605        op_enqueue!(self, OT)
606    }
607}
608
609impl<L, R, IT, OT> ReadValue<Platform, OT> for Dual<L, R, IT, OT>
610where
611    L: Access<IT>,
612    R: Access<IT>,
613    IT: Number,
614    OT: Number,
615{
616    fn read_value(&self, offset: usize) -> Result<OT, Error> {
617        op_dispatch!(self, op, op.read_value(offset))
618    }
619}
620
621#[cfg(feature = "opencl")]
622impl<L, R, IT, OT> From<opencl::ops::Dual<L, R, IT, OT>> for Dual<L, R, IT, OT> {
623    fn from(op: opencl::ops::Dual<L, R, IT, OT>) -> Self {
624        Self::CL(op)
625    }
626}
627
628impl<L, R, IT, OT> From<host::ops::Dual<L, R, IT, OT>> for Dual<L, R, IT, OT> {
629    fn from(op: host::ops::Dual<L, R, IT, OT>) -> Self {
630        Self::Host(op)
631    }
632}
633
634pub enum Flip<A, T> {
635    #[cfg(feature = "opencl")]
636    CL(opencl::ops::Flip<A, T>),
637    Host(host::ops::Flip<A, T>),
638}
639
640impl<A: Access<T>, T: Number> Op for Flip<A, T> {
641    fn size(&self) -> usize {
642        op_dispatch!(self, op, op.size())
643    }
644}
645
646impl<A: Access<T>, T: Number> Enqueue<Platform, T> for Flip<A, T> {
647    type Buffer = Buffer<T>;
648
649    fn enqueue(&self) -> Result<Self::Buffer, Error> {
650        op_enqueue!(self, T)
651    }
652}
653
654impl<A: Access<T>, T: Number> ReadValue<Platform, T> for Flip<A, T> {
655    fn read_value(&self, offset: usize) -> Result<T, Error> {
656        op_dispatch!(self, op, op.read_value(offset))
657    }
658}
659
660#[cfg(feature = "opencl")]
661impl<A, T> From<opencl::ops::Flip<A, T>> for Flip<A, T> {
662    fn from(op: opencl::ops::Flip<A, T>) -> Self {
663        Self::CL(op)
664    }
665}
666
667impl<A, T> From<host::ops::Flip<A, T>> for Flip<A, T> {
668    fn from(op: host::ops::Flip<A, T>) -> Self {
669        Self::Host(op)
670    }
671}
672
673pub enum Linear<T> {
674    #[cfg(feature = "opencl")]
675    CL(opencl::ops::Linear<T>),
676    Host(host::ops::Linear<T>),
677}
678
679#[cfg(feature = "opencl")]
680impl<T> From<opencl::ops::Linear<T>> for Linear<T> {
681    fn from(op: opencl::ops::Linear<T>) -> Self {
682        Self::CL(op)
683    }
684}
685
686impl<T> From<host::ops::Linear<T>> for Linear<T> {
687    fn from(op: host::ops::Linear<T>) -> Self {
688        Self::Host(op)
689    }
690}
691
692impl<T: Send + Sync> Op for Linear<T> {
693    fn size(&self) -> usize {
694        op_dispatch!(self, op, op.size())
695    }
696}
697
698impl<T: Number> Enqueue<Platform, T> for Linear<T> {
699    type Buffer = Buffer<T>;
700
701    fn enqueue(&self) -> Result<Self::Buffer, Error> {
702        op_enqueue!(self, T)
703    }
704}
705
706impl<T: Number> ReadValue<Platform, T> for Linear<T> {
707    fn read_value(&self, offset: usize) -> Result<T, Error> {
708        op_dispatch!(self, op, op.read_value(offset))
709    }
710}
711
712pub enum MatDiag<A, T> {
713    #[cfg(feature = "opencl")]
714    CL(opencl::ops::MatDiag<A, T>),
715    Host(host::ops::MatDiag<A, T>),
716}
717
718impl<A: Access<T>, T: Number> Op for MatDiag<A, T> {
719    fn size(&self) -> usize {
720        op_dispatch!(self, op, op.size())
721    }
722}
723
724impl<A: Access<T>, T: Number> Enqueue<Platform, T> for MatDiag<A, T> {
725    type Buffer = Buffer<T>;
726
727    fn enqueue(&self) -> Result<Self::Buffer, Error> {
728        op_enqueue!(self, T)
729    }
730}
731
732impl<A: Access<T>, T: Number> ReadValue<Platform, T> for MatDiag<A, T> {
733    fn read_value(&self, offset: usize) -> Result<T, Error> {
734        op_dispatch!(self, op, op.read_value(offset))
735    }
736}
737
738impl<A, T> From<host::ops::MatDiag<A, T>> for MatDiag<A, T> {
739    fn from(op: host::ops::MatDiag<A, T>) -> Self {
740        Self::Host(op)
741    }
742}
743
744#[cfg(feature = "opencl")]
745impl<A, T> From<opencl::ops::MatDiag<A, T>> for MatDiag<A, T> {
746    fn from(op: opencl::ops::MatDiag<A, T>) -> Self {
747        Self::CL(op)
748    }
749}
750
751pub enum MatMul<L, R, T> {
752    #[cfg(feature = "opencl")]
753    CL(opencl::ops::MatMul<L, R, T>),
754    Host(host::ops::MatMul<L, R, T>),
755}
756
757impl<L, R, T> Op for MatMul<L, R, T>
758where
759    L: Access<T>,
760    R: Access<T>,
761    T: Number,
762{
763    fn size(&self) -> usize {
764        op_dispatch!(self, op, op.size())
765    }
766}
767
768impl<L, R, T> Enqueue<Platform, T> for MatMul<L, R, T>
769where
770    L: Access<T>,
771    R: Access<T>,
772    T: Number,
773{
774    type Buffer = Buffer<T>;
775
776    fn enqueue(&self) -> Result<Self::Buffer, Error> {
777        op_enqueue!(self, T)
778    }
779}
780
781impl<L, R, T> ReadValue<Platform, T> for MatMul<L, R, T>
782where
783    L: Access<T>,
784    R: Access<T>,
785    T: Number,
786{
787    fn read_value(&self, offset: usize) -> Result<T, Error> {
788        op_dispatch!(self, op, op.read_value(offset))
789    }
790}
791
792#[cfg(feature = "opencl")]
793impl<L, R, T> From<opencl::ops::MatMul<L, R, T>> for MatMul<L, R, T> {
794    fn from(op: opencl::ops::MatMul<L, R, T>) -> Self {
795        Self::CL(op)
796    }
797}
798
799impl<L, R, T> From<host::ops::MatMul<L, R, T>> for MatMul<L, R, T> {
800    fn from(op: host::ops::MatMul<L, R, T>) -> Self {
801        Self::Host(op)
802    }
803}
804
805pub enum RandomNormal {
806    #[cfg(feature = "opencl")]
807    CL(opencl::ops::RandomNormal),
808    Host(host::ops::RandomNormal),
809}
810
811#[cfg(feature = "opencl")]
812impl From<opencl::ops::RandomNormal> for RandomNormal {
813    fn from(op: opencl::ops::RandomNormal) -> Self {
814        Self::CL(op)
815    }
816}
817
818impl From<host::ops::RandomNormal> for RandomNormal {
819    fn from(op: host::ops::RandomNormal) -> Self {
820        Self::Host(op)
821    }
822}
823
824pub enum RandomUniform {
825    #[cfg(feature = "opencl")]
826    CL(opencl::ops::RandomUniform),
827    Host(host::ops::RandomUniform),
828}
829
830#[cfg(feature = "opencl")]
831impl From<opencl::ops::RandomUniform> for RandomUniform {
832    fn from(op: opencl::ops::RandomUniform) -> Self {
833        Self::CL(op)
834    }
835}
836
837impl From<host::ops::RandomUniform> for RandomUniform {
838    fn from(op: host::ops::RandomUniform) -> Self {
839        Self::Host(op)
840    }
841}
842
843macro_rules! impl_random {
844    ($t:ty) => {
845        impl Op for $t {
846            fn size(&self) -> usize {
847                op_dispatch!(self, op, op.size())
848            }
849        }
850
851        impl Enqueue<Platform, f32> for $t {
852            type Buffer = Buffer<f32>;
853
854            fn enqueue(&self) -> Result<Self::Buffer, Error> {
855                op_enqueue!(self, f32)
856            }
857        }
858
859        impl ReadValue<Platform, f32> for $t {
860            fn read_value(&self, offset: usize) -> Result<f32, Error> {
861                op_dispatch!(self, op, op.read_value(offset))
862            }
863        }
864    };
865}
866
867impl_random!(RandomNormal);
868impl_random!(RandomUniform);
869
870macro_rules! impl_unary {
871    ($op:ty, $t:ty) => {
872        impl<A: Access<T>, T: Number> Op for $op {
873            fn size(&self) -> usize {
874                op_dispatch!(self, op, op.size())
875            }
876        }
877
878        impl<A: Access<T>, T: Number> Enqueue<Platform, $t> for $op {
879            type Buffer = Buffer<$t>;
880
881            fn enqueue(&self) -> Result<Self::Buffer, Error> {
882                op_enqueue!(self, $t)
883            }
884        }
885
886        impl<A: Access<T>, T: Number> ReadValue<Platform, $t> for $op {
887            fn read_value(&self, offset: usize) -> Result<$t, Error> {
888                op_dispatch!(self, op, op.read_value(offset))
889            }
890        }
891    };
892}
893
894pub enum Reduce<A, T: Number> {
895    #[cfg(feature = "opencl")]
896    CL(opencl::ops::Reduce<A, T>),
897    Host(host::ops::Reduce<A, T>),
898}
899
900impl_unary!(Reduce<A, T>, T);
901
902impl<A, T: Number> From<host::ops::Reduce<A, T>> for Reduce<A, T> {
903    fn from(op: host::ops::Reduce<A, T>) -> Self {
904        Self::Host(op)
905    }
906}
907
908#[cfg(feature = "opencl")]
909impl<A, T: Number> From<opencl::ops::Reduce<A, T>> for Reduce<A, T> {
910    fn from(op: opencl::ops::Reduce<A, T>) -> Self {
911        Self::CL(op)
912    }
913}
914
915#[derive(Clone, Eq, PartialEq, Hash)]
916pub struct SliceSpec {
917    pub range: Range,
918    pub shape: Shape,
919    pub strides: Strides,
920    pub source_strides: Strides,
921}
922
923impl SliceSpec {
924    pub fn new(source_shape: &[usize], range: Range) -> Self {
925        debug_assert!(range.len() <= source_shape.len());
926
927        let shape = range_shape(source_shape, &range);
928        let strides = strides_for(&shape, shape.len()).collect();
929        let source_strides = strides_for(source_shape, source_shape.len()).collect();
930
931        Self {
932            range,
933            shape,
934            strides,
935            source_strides,
936        }
937    }
938
939    pub fn source_offset(&self, offset: usize) -> usize {
940        debug_assert!(!self.shape.is_empty());
941        debug_assert_eq!(self.shape.len(), self.strides.len());
942
943        let mut coord = self
944            .strides
945            .iter()
946            .copied()
947            .zip(&self.shape)
948            .map(|(stride, dim)| {
949                if stride == 0 {
950                    0
951                } else {
952                    (offset / stride) % dim
953                }
954            });
955
956        let mut offset = 0;
957        for (stride, bound) in self.source_strides.iter().zip(self.range.iter()) {
958            let i = match bound {
959                AxisRange::At(i) => *i,
960                AxisRange::In(start, stop, step) => {
961                    let i = start + (coord.next().expect("i") * step);
962                    debug_assert!(i < *stop);
963                    i
964                }
965                AxisRange::Of(indices) => indices[coord.next().expect("i")],
966            };
967
968            offset += i * stride;
969        }
970
971        offset
972    }
973
974    pub fn size(&self) -> usize {
975        self.shape.iter().product()
976    }
977}
978
979pub enum Scalar<A, IT, OT> {
980    #[cfg(feature = "opencl")]
981    CL(opencl::ops::Scalar<A, IT, OT>),
982    Host(host::ops::Scalar<A, IT, OT>),
983}
984
985impl<A, IT, OT> Op for Scalar<A, IT, OT>
986where
987    A: Access<IT>,
988    IT: Number,
989    OT: Number,
990{
991    fn size(&self) -> usize {
992        op_dispatch!(self, op, op.size())
993    }
994}
995
996impl<A, IT, OT> Enqueue<Platform, OT> for Scalar<A, IT, OT>
997where
998    A: Access<IT>,
999    IT: Number,
1000    OT: Number,
1001{
1002    type Buffer = Buffer<OT>;
1003
1004    fn enqueue(&self) -> Result<Self::Buffer, Error> {
1005        op_enqueue!(self, OT)
1006    }
1007}
1008
1009impl<A, IT, OT> ReadValue<Platform, OT> for Scalar<A, IT, OT>
1010where
1011    A: Access<IT>,
1012    IT: Number,
1013    OT: Number,
1014{
1015    fn read_value(&self, offset: usize) -> Result<OT, Error> {
1016        op_dispatch!(self, op, op.read_value(offset))
1017    }
1018}
1019
1020#[cfg(feature = "opencl")]
1021impl<A, IT, OT> From<opencl::ops::Scalar<A, IT, OT>> for Scalar<A, IT, OT> {
1022    fn from(op: opencl::ops::Scalar<A, IT, OT>) -> Self {
1023        Self::CL(op)
1024    }
1025}
1026
1027impl<A, IT, OT> From<host::ops::Scalar<A, IT, OT>> for Scalar<A, IT, OT> {
1028    fn from(op: host::ops::Scalar<A, IT, OT>) -> Self {
1029        Self::Host(op)
1030    }
1031}
1032
1033pub enum Slice<A, T> {
1034    #[cfg(feature = "opencl")]
1035    CL(opencl::ops::Slice<A, T>),
1036    Host(host::ops::Slice<A, T>),
1037}
1038
1039impl_unary!(Slice<A, T>, T);
1040
1041#[cfg(feature = "opencl")]
1042impl<A, T> Write<Platform, T> for Slice<A, T>
1043where
1044    A: AccessMut<T> + std::fmt::Debug,
1045    T: Number,
1046{
1047    fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
1048        match self {
1049            Self::CL(op) => Write::<opencl::OpenCL, T>::write(op, data),
1050            Self::Host(op) => Write::<host::Host, T>::write(op, data),
1051        }
1052    }
1053
1054    fn write_value(&mut self, value: T) -> Result<(), Error> {
1055        match self {
1056            Self::CL(op) => Write::<opencl::OpenCL, T>::write_value(op, value),
1057            Self::Host(op) => Write::<host::Host, T>::write_value(op, value),
1058        }
1059    }
1060
1061    fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
1062        match self {
1063            Self::CL(op) => Write::<opencl::OpenCL, T>::write_value_at(op, offset, value),
1064            Self::Host(op) => Write::<host::Host, T>::write_value_at(op, offset, value),
1065        }
1066    }
1067}
1068
1069#[cfg(not(feature = "opencl"))]
1070impl<A, T> Write<Platform, T> for Slice<A, T>
1071where
1072    T: Number,
1073    A: AccessMut<T>,
1074{
1075    fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
1076        match self {
1077            Self::Host(op) => Write::<host::Host, T>::write(op, data),
1078        }
1079    }
1080
1081    fn write_value(&mut self, value: T) -> Result<(), Error> {
1082        match self {
1083            Self::Host(op) => Write::<host::Host, T>::write_value(op, value),
1084        }
1085    }
1086
1087    fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
1088        match self {
1089            Self::Host(op) => Write::<host::Host, T>::write_value_at(op, offset, value),
1090        }
1091    }
1092}
1093
1094#[cfg(feature = "opencl")]
1095impl<A, T> From<opencl::ops::Slice<A, T>> for Slice<A, T> {
1096    fn from(op: opencl::ops::Slice<A, T>) -> Self {
1097        Self::CL(op)
1098    }
1099}
1100
1101impl<A, T> From<host::ops::Slice<A, T>> for Slice<A, T> {
1102    fn from(op: host::ops::Slice<A, T>) -> Self {
1103        Self::Host(op)
1104    }
1105}
1106
1107pub enum Unary<A, IT, OT> {
1108    #[cfg(feature = "opencl")]
1109    CL(opencl::ops::Unary<A, IT, OT>),
1110    Host(host::ops::Unary<A, IT, OT>),
1111}
1112
1113impl<A, IT, OT> Op for Unary<A, IT, OT>
1114where
1115    A: Access<IT>,
1116    IT: Number,
1117    OT: Number,
1118{
1119    fn size(&self) -> usize {
1120        op_dispatch!(self, op, op.size())
1121    }
1122}
1123
1124impl<A, IT, OT> Enqueue<Platform, OT> for Unary<A, IT, OT>
1125where
1126    A: Access<IT>,
1127    IT: Number,
1128    OT: Number,
1129{
1130    type Buffer = Buffer<OT>;
1131
1132    fn enqueue(&self) -> Result<Self::Buffer, Error> {
1133        op_enqueue!(self, OT)
1134    }
1135}
1136
1137impl<A, IT, OT> ReadValue<Platform, OT> for Unary<A, IT, OT>
1138where
1139    A: Access<IT>,
1140    IT: Number,
1141    OT: Number,
1142{
1143    fn read_value(&self, offset: usize) -> Result<OT, Error> {
1144        op_dispatch!(self, op, op.read_value(offset))
1145    }
1146}
1147
1148impl<A, IT, OT> From<host::ops::Unary<A, IT, OT>> for Unary<A, IT, OT> {
1149    fn from(op: host::ops::Unary<A, IT, OT>) -> Self {
1150        Self::Host(op)
1151    }
1152}
1153
1154#[cfg(feature = "opencl")]
1155impl<A, IT, OT> From<opencl::ops::Unary<A, IT, OT>> for Unary<A, IT, OT> {
1156    fn from(op: opencl::ops::Unary<A, IT, OT>) -> Self {
1157        Self::CL(op)
1158    }
1159}
1160
1161#[derive(Clone, Eq, PartialEq, Hash)]
1162pub struct FlipSpec {
1163    pub shape: Shape,
1164    pub strides: Strides,
1165    pub axis: usize,
1166}
1167
1168impl FlipSpec {
1169    pub fn new(shape: Shape, axis: usize) -> Result<Self, Error> {
1170        if axis < shape.len() {
1171            let strides = strides_for(&shape, shape.len()).collect();
1172
1173            Ok(Self {
1174                shape,
1175                strides,
1176                axis,
1177            })
1178        } else {
1179            Err(Error::bounds(format!("shape {shape:?} has no axis {axis}")))
1180        }
1181    }
1182
1183    pub fn source_offset(&self, offset: usize) -> usize {
1184        self.strides
1185            .iter()
1186            .copied()
1187            .zip(self.shape.iter().copied())
1188            .map(|(stride, dim)| {
1189                if stride == 0 {
1190                    0
1191                } else {
1192                    (offset / stride) % dim
1193                }
1194            }) // coord
1195            .zip(self.strides.iter().copied())
1196            .enumerate()
1197            .map(|(x, (i, source_stride))| {
1198                let i = if x == self.axis {
1199                    self.shape[x] - i - 1
1200                } else {
1201                    i
1202                };
1203
1204                i * source_stride
1205            })
1206            .sum::<usize>()
1207    }
1208}
1209
1210#[derive(Clone, Eq, PartialEq, Hash)]
1211pub struct ViewSpec {
1212    pub shape: Shape,
1213    pub strides: Strides,
1214    pub source_strides: Strides,
1215}
1216
1217impl ViewSpec {
1218    pub fn new(shape: Shape, source_strides: Strides) -> Self {
1219        let strides = strides_for(&shape, shape.len()).collect();
1220
1221        Self {
1222            shape,
1223            strides,
1224            source_strides,
1225        }
1226    }
1227
1228    pub fn source_offset(&self, offset: usize) -> usize {
1229        debug_assert!(offset < self.size());
1230
1231        let source_offset = self
1232            .strides
1233            .iter()
1234            .copied()
1235            .zip(self.shape.iter().copied())
1236            .rev()
1237            .take(self.source_strides.len())
1238            .map(|(stride, dim)| {
1239                if stride == 0 {
1240                    0
1241                } else {
1242                    (offset / stride) % dim
1243                }
1244            }) // coord
1245            .zip(self.source_strides.iter().rev().copied())
1246            .map(|(i, source_stride)| i * source_stride)
1247            .sum::<usize>();
1248
1249        source_offset
1250    }
1251
1252    pub fn size(&self) -> usize {
1253        self.shape.iter().product()
1254    }
1255}
1256
1257pub enum View<A, T> {
1258    #[cfg(feature = "opencl")]
1259    CL(opencl::ops::View<A, T>),
1260    Host(host::ops::View<A, T>),
1261}
1262
1263impl_unary!(View<A, T>, T);
1264
1265#[cfg(feature = "opencl")]
1266impl<A, T> From<opencl::ops::View<A, T>> for View<A, T> {
1267    fn from(op: opencl::ops::View<A, T>) -> Self {
1268        Self::CL(op)
1269    }
1270}
1271
1272impl<A, T> From<host::ops::View<A, T>> for View<A, T> {
1273    fn from(op: host::ops::View<A, T>) -> Self {
1274        Self::Host(op)
1275    }
1276}