1use std::marker::PhantomData;
4
5use crate::access::*;
6use crate::buffer::Buffer;
7#[cfg(feature = "opencl")]
8use crate::opencl;
9use crate::platform::{Platform, PlatformInstance};
10use crate::{
11 host, range_shape, strides_for, Axes, AxisRange, BufferConverter, Error, Float, Number, Range,
12 Real, Shape, Strides,
13};
14
15#[cfg(feature = "complex")]
16pub mod complex;
17
18macro_rules! op_dispatch {
19 ($this:expr, $op:ident, $call:expr) => {
20 match $this {
21 #[cfg(feature = "opencl")]
22 Self::CL($op) => $call,
23 Self::Host($op) => $call,
24 }
25 };
26}
27
28macro_rules! op_enqueue {
29 ($this:expr, $t:ty) => {
30 match $this {
31 #[cfg(feature = "opencl")]
32 Self::CL(op) => Enqueue::<opencl::OpenCL, $t>::enqueue(op).map(Buffer::CL),
33 Self::Host(op) => Enqueue::<host::Host, $t>::enqueue(op).map(Buffer::Host),
34 }
35 };
36}
37
38pub trait Op: Send + Sync {
39 fn size(&self) -> usize;
40}
41
42pub trait Enqueue<P: PlatformInstance, T: Number>: Op {
43 type Buffer: Into<BufferConverter<'static, T>>;
44
45 fn enqueue(&self) -> Result<Self::Buffer, Error>;
46}
47
48pub trait ReadValue<P: PlatformInstance, T: Number>: Op {
49 fn read_value(&self, offset: usize) -> Result<T, Error>;
50}
51
52pub trait ReadOp<P, T>: Enqueue<P, T> + ReadValue<P, T>
53where
54 P: PlatformInstance,
55 T: Number,
56{
57}
58
59impl<O, P, T> ReadOp<P, T> for O
60where
61 O: Enqueue<P, T> + ReadValue<P, T>,
62 P: PlatformInstance,
63 T: Number,
64{
65}
66
67pub trait Write<P: PlatformInstance, T: Number>: Enqueue<P, T> {
68 fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error>;
69
70 fn write_value(&mut self, value: T) -> Result<(), Error>;
71
72 fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error>;
73}
74
75pub trait ConstructConcat<A, T>: PlatformInstance
76where
77 A: Access<T>,
78 T: Number,
79{
80 type Op: ReadOp<Self, T>;
81
82 fn concat(self, data: Vec<A>) -> Result<AccessOp<Self::Op, Self>, Error>;
83}
84
85pub trait ConstructRange<T: Number>: PlatformInstance {
86 type Range: Enqueue<Self, T>;
87
88 fn range(self, start: T, stop: T, size: usize) -> Result<AccessOp<Self::Range, Self>, Error>;
89}
90
91pub trait ElementwiseAbs<A, T>: PlatformInstance
92where
93 A: Access<T>,
94 T: Number,
95{
96 type Op: ReadOp<Self, T::Abs>;
97
98 fn abs(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
99}
100
101pub trait ElementwiseBoolean<L, R, T>: PlatformInstance {
102 type Op: ReadOp<Self, u8>;
103
104 fn and(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
105
106 fn or(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
107
108 fn xor(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
109}
110
111pub trait ElementwiseBooleanScalar<A, T>: PlatformInstance {
112 type Op: ReadOp<Self, u8>;
113
114 fn and_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
115
116 fn or_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
117
118 fn xor_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
119}
120
121pub trait ElementwiseCast<A, IT, OT>: PlatformInstance
122where
123 A: Access<IT>,
124 IT: Number,
125 OT: Number,
126{
127 type Op: ReadOp<Self, OT>;
128
129 fn cast(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
130}
131
132pub trait ElementwiseCompare<L, R, T>: PlatformInstance {
133 type Op: ReadOp<Self, u8>;
134
135 fn eq(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
136
137 fn ge(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>
138 where
139 T: Real;
140
141 fn gt(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>
142 where
143 T: Real;
144
145 fn le(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>
146 where
147 T: Real;
148
149 fn lt(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>
150 where
151 T: Real;
152
153 fn ne(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
154}
155
156pub trait ElementwiseCompareScalar<A, T>: PlatformInstance {
157 type Op: ReadOp<Self, u8>;
158
159 fn eq_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
160
161 fn ge_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>
162 where
163 T: Real;
164
165 fn gt_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>
166 where
167 T: Real;
168
169 fn le_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>
170 where
171 T: Real;
172
173 fn lt_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>
174 where
175 T: Real;
176
177 fn ne_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
178}
179
180pub trait ElementwiseDual<L, R, T>: PlatformInstance
181where
182 L: Access<T>,
183 R: Access<T>,
184 T: Number,
185{
186 type Op: ReadOp<Self, T>;
187
188 fn add(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
189
190 fn div(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
191
192 fn log(self, arg: L, base: R) -> Result<AccessOp<Self::Op, Self>, Error>
193 where
194 T: Float;
195
196 fn mul(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
197
198 fn pow(self, arg: L, exp: R) -> Result<AccessOp<Self::Op, Self>, Error>;
199
200 fn rem(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>
201 where
202 T: Real;
203
204 fn sub(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
205}
206
207pub trait ElementwiseScalar<A, T>: PlatformInstance
208where
209 A: Access<T>,
210 T: Number,
211{
212 type Op: ReadOp<Self, T>;
213
214 fn add_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
215
216 fn div_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
217
218 fn log_scalar(self, arg: A, base: T) -> Result<AccessOp<Self::Op, Self>, Error>
219 where
220 T: Float;
221
222 fn mul_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
223
224 fn pow_scalar(self, arg: A, exp: T) -> Result<AccessOp<Self::Op, Self>, Error>;
225
226 fn rem_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>
227 where
228 T: Real;
229
230 fn sub_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
231}
232
233pub trait ElementwiseNumeric<A, T>: PlatformInstance
234where
235 A: Access<T>,
236 T: Number,
237{
238 type Op: ReadOp<Self, u8>;
239
240 fn is_inf(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
241
242 fn is_nan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
243}
244
245pub trait ElementwiseTrig<A, T>: PlatformInstance
246where
247 A: Access<T>,
248 T: Float,
249{
250 type Op: ReadOp<Self, T>;
251
252 fn sin(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
253
254 fn asin(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
255
256 fn sinh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
257
258 fn cos(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
259
260 fn acos(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
261
262 fn cosh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
263
264 fn tan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
265
266 fn atan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
267
268 fn tanh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
269}
270
271pub trait ElementwiseUnary<A, T>: PlatformInstance
272where
273 A: Access<T>,
274 T: Float,
275{
276 type Op: ReadOp<Self, T>;
277
278 fn exp(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
279
280 fn ln(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
281
282 fn round(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>
283 where
284 T: Real;
285}
286
287pub trait ElementwiseUnaryBoolean<A, T>: PlatformInstance
288where
289 A: Access<T>,
290 T: Number,
291{
292 type Op: ReadOp<Self, u8>;
293
294 fn not(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
295}
296
297pub trait GatherCond<A, L, R, T>: PlatformInstance
298where
299 A: Access<u8>,
300 L: Access<T>,
301 R: Access<T>,
302 T: Number,
303{
304 type Op: ReadOp<Self, T>;
305
306 fn cond(self, cond: A, then: L, or_else: R) -> Result<AccessOp<Self::Op, Self>, Error>;
307}
308
309pub trait LinAlgDual<L, R, T>: PlatformInstance
310where
311 L: Access<T>,
312 R: Access<T>,
313 T: Number,
314{
315 type Op: ReadOp<Self, T>;
316
317 fn matmul(self, left: L, right: R, dims: [usize; 4])
318 -> Result<AccessOp<Self::Op, Self>, Error>;
319}
320
321pub trait LinAlgUnary<A, T>: PlatformInstance
322where
323 A: Access<T>,
324 T: Number,
325{
326 type Op: ReadOp<Self, T>;
327
328 fn diag(
329 self,
330 access: A,
331 batch_size: usize,
332 dim: usize,
333 ) -> Result<AccessOp<Self::Op, Self>, Error>;
334}
335
336pub trait Random: PlatformInstance {
337 type Normal: Enqueue<Self, f32>;
338 type Uniform: Enqueue<Self, f32>;
339
340 fn random_normal(self, size: usize) -> Result<AccessOp<Self::Normal, Self>, Error>;
341
342 fn random_uniform(self, size: usize) -> Result<AccessOp<Self::Uniform, Self>, Error>;
343}
344
345pub trait ReduceAll<A, T>: PlatformInstance {
346 fn all(self, access: A) -> Result<bool, Error>;
347
348 fn any(self, access: A) -> Result<bool, Error>;
349
350 fn max(self, access: A) -> Result<T, Error>
351 where
352 T: Real;
353
354 fn min(self, access: A) -> Result<T, Error>
355 where
356 T: Real;
357
358 fn product(self, access: A) -> Result<T, Error>;
359
360 fn sum(self, access: A) -> Result<T, Error>;
361}
362
363pub trait ReduceAxes<A: Access<T>, T: Number>: PlatformInstance {
364 type Op: ReadOp<Self, T>;
365
366 fn max(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error>
367 where
368 T: Real;
369
370 fn min(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error>
371 where
372 T: Real;
373
374 fn product(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error>;
375
376 fn sum(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error>;
377}
378
379pub trait Transform<A: Access<T>, T: Number>: PlatformInstance {
380 type Broadcast: ReadOp<Self, T>;
381 type Flip: ReadOp<Self, T>;
382 type Slice: ReadOp<Self, T>;
383 type Transpose: ReadOp<Self, T>;
384
385 fn broadcast(
386 self,
387 access: A,
388 shape: Shape,
389 broadcast: Shape,
390 ) -> Result<AccessOp<Self::Broadcast, Self>, Error>;
391
392 fn flip(
393 self,
394 access: A,
395 shape: Shape,
396 axis: usize,
397 ) -> Result<AccessOp<Self::Flip, Self>, Error>;
398
399 fn slice(
400 self,
401 access: A,
402 shape: &[usize],
403 range: Range,
404 ) -> Result<AccessOp<Self::Slice, Self>, Error>;
405
406 fn transpose(
407 self,
408 access: A,
409 shape: Shape,
410 permutation: Axes,
411 ) -> Result<AccessOp<Self::Transpose, Self>, Error>;
412}
413
414pub enum Cast<A, IT, OT> {
415 #[cfg(feature = "opencl")]
416 CL(opencl::ops::Cast<A, IT, OT>),
417 Host(host::ops::Cast<A, IT, OT>),
418}
419
420impl<A: Access<IT>, IT: Number, OT: Number> Op for Cast<A, IT, OT> {
421 fn size(&self) -> usize {
422 op_dispatch!(self, op, op.size())
423 }
424}
425
426impl<A: Access<IT>, IT: Number, OT: Number> Enqueue<Platform, OT> for Cast<A, IT, OT> {
427 type Buffer = Buffer<OT>;
428
429 fn enqueue(&self) -> Result<Self::Buffer, Error> {
430 op_enqueue!(self, OT)
431 }
432}
433
434impl<A: Access<IT>, IT: Number, OT: Number> ReadValue<Platform, OT> for Cast<A, IT, OT> {
435 fn read_value(&self, offset: usize) -> Result<OT, Error> {
436 op_dispatch!(self, op, op.read_value(offset))
437 }
438}
439
440impl<A, IT, OT> From<host::ops::Cast<A, IT, OT>> for Cast<A, IT, OT> {
441 fn from(op: host::ops::Cast<A, IT, OT>) -> Cast<A, IT, OT> {
442 Self::Host(op)
443 }
444}
445
446#[cfg(feature = "opencl")]
447impl<A, IT, OT> From<opencl::ops::Cast<A, IT, OT>> for Cast<A, IT, OT> {
448 fn from(op: opencl::ops::Cast<A, IT, OT>) -> Cast<A, IT, OT> {
449 Self::CL(op)
450 }
451}
452
453pub struct Concat<A, T> {
454 data: Vec<A>,
455 dtype: PhantomData<T>,
456}
457
458impl<A, T> Concat<A, T> {
459 pub fn new(data: Vec<A>) -> Self {
460 Self {
461 data,
462 dtype: PhantomData,
463 }
464 }
465
466 pub(crate) fn data(&self) -> &[A] {
467 &self.data
468 }
469}
470
471impl<A, T> Op for Concat<A, T>
472where
473 A: Access<T>,
474 T: Number,
475{
476 fn size(&self) -> usize {
477 self.data.iter().map(|access| access.size()).sum()
478 }
479}
480
481impl<A, T> Enqueue<Platform, T> for Concat<A, T>
482where
483 A: Access<T>,
484 T: Number,
485{
486 type Buffer = Buffer<T>;
487
488 fn enqueue(&self) -> Result<Self::Buffer, Error> {
489 match Platform::select(self.size()) {
490 #[cfg(feature = "opencl")]
491 Platform::CL(_) => Enqueue::<opencl::OpenCL, T>::enqueue(self).map(Buffer::from),
492 Platform::Host(_) => Enqueue::<host::Host, T>::enqueue(self).map(Buffer::from),
493 }
494 }
495}
496
497impl<A, T> ReadValue<Platform, T> for Concat<A, T>
498where
499 A: Access<T>,
500 T: Number,
501{
502 fn read_value(&self, offset: usize) -> Result<T, Error> {
503 let mut start = 0;
504
505 for access in &self.data {
506 let end = start + access.size();
507 if offset < end {
508 return access.read_value(offset - start);
509 }
510 start = end;
511 }
512
513 Err(Error::bounds(format!(
514 "offset {} is out of bounds for a concatenation of size",
515 self.size()
516 )))
517 }
518}
519
520pub enum Cond<A, L, R, T> {
521 #[cfg(feature = "opencl")]
522 CL(opencl::ops::Cond<A, L, R, T>),
523 Host(host::ops::Cond<A, L, R, T>),
524}
525
526impl<A, L, R, T> Op for Cond<A, L, R, T>
527where
528 A: Access<u8>,
529 L: Access<T>,
530 R: Access<T>,
531 T: Number,
532{
533 fn size(&self) -> usize {
534 op_dispatch!(self, op, op.size())
535 }
536}
537
538impl<A, L, R, T> Enqueue<Platform, T> for Cond<A, L, R, T>
539where
540 A: Access<u8>,
541 L: Access<T>,
542 R: Access<T>,
543 T: Number,
544{
545 type Buffer = Buffer<T>;
546
547 fn enqueue(&self) -> Result<Self::Buffer, Error> {
548 op_enqueue!(self, T)
549 }
550}
551
552impl<A, L, R, T> ReadValue<Platform, T> for Cond<A, L, R, T>
553where
554 A: Access<u8>,
555 L: Access<T>,
556 R: Access<T>,
557 T: Number,
558{
559 fn read_value(&self, offset: usize) -> Result<T, Error> {
560 op_dispatch!(self, op, op.read_value(offset))
561 }
562}
563
564impl<A, L, R, T> From<host::ops::Cond<A, L, R, T>> for Cond<A, L, R, T> {
565 fn from(op: host::ops::Cond<A, L, R, T>) -> Self {
566 Self::Host(op)
567 }
568}
569
570#[cfg(feature = "opencl")]
571impl<A, L, R, T> From<opencl::ops::Cond<A, L, R, T>> for Cond<A, L, R, T> {
572 fn from(op: opencl::ops::Cond<A, L, R, T>) -> Self {
573 Self::CL(op)
574 }
575}
576
577pub enum Dual<L, R, IT, OT> {
578 #[cfg(feature = "opencl")]
579 CL(opencl::ops::Dual<L, R, IT, OT>),
580 Host(host::ops::Dual<L, R, IT, OT>),
581}
582
583impl<L, R, IT, OT> Op for Dual<L, R, IT, OT>
584where
585 L: Access<IT>,
586 R: Access<IT>,
587 IT: Number,
588 OT: Number,
589{
590 fn size(&self) -> usize {
591 op_dispatch!(self, op, op.size())
592 }
593}
594
595impl<L, R, IT, OT> Enqueue<Platform, OT> for Dual<L, R, IT, OT>
596where
597 L: Access<IT>,
598 R: Access<IT>,
599 IT: Number,
600 OT: Number,
601{
602 type Buffer = Buffer<OT>;
603
604 fn enqueue(&self) -> Result<Self::Buffer, Error> {
605 op_enqueue!(self, OT)
606 }
607}
608
609impl<L, R, IT, OT> ReadValue<Platform, OT> for Dual<L, R, IT, OT>
610where
611 L: Access<IT>,
612 R: Access<IT>,
613 IT: Number,
614 OT: Number,
615{
616 fn read_value(&self, offset: usize) -> Result<OT, Error> {
617 op_dispatch!(self, op, op.read_value(offset))
618 }
619}
620
621#[cfg(feature = "opencl")]
622impl<L, R, IT, OT> From<opencl::ops::Dual<L, R, IT, OT>> for Dual<L, R, IT, OT> {
623 fn from(op: opencl::ops::Dual<L, R, IT, OT>) -> Self {
624 Self::CL(op)
625 }
626}
627
628impl<L, R, IT, OT> From<host::ops::Dual<L, R, IT, OT>> for Dual<L, R, IT, OT> {
629 fn from(op: host::ops::Dual<L, R, IT, OT>) -> Self {
630 Self::Host(op)
631 }
632}
633
634pub enum Flip<A, T> {
635 #[cfg(feature = "opencl")]
636 CL(opencl::ops::Flip<A, T>),
637 Host(host::ops::Flip<A, T>),
638}
639
640impl<A: Access<T>, T: Number> Op for Flip<A, T> {
641 fn size(&self) -> usize {
642 op_dispatch!(self, op, op.size())
643 }
644}
645
646impl<A: Access<T>, T: Number> Enqueue<Platform, T> for Flip<A, T> {
647 type Buffer = Buffer<T>;
648
649 fn enqueue(&self) -> Result<Self::Buffer, Error> {
650 op_enqueue!(self, T)
651 }
652}
653
654impl<A: Access<T>, T: Number> ReadValue<Platform, T> for Flip<A, T> {
655 fn read_value(&self, offset: usize) -> Result<T, Error> {
656 op_dispatch!(self, op, op.read_value(offset))
657 }
658}
659
660#[cfg(feature = "opencl")]
661impl<A, T> From<opencl::ops::Flip<A, T>> for Flip<A, T> {
662 fn from(op: opencl::ops::Flip<A, T>) -> Self {
663 Self::CL(op)
664 }
665}
666
667impl<A, T> From<host::ops::Flip<A, T>> for Flip<A, T> {
668 fn from(op: host::ops::Flip<A, T>) -> Self {
669 Self::Host(op)
670 }
671}
672
673pub enum Linear<T> {
674 #[cfg(feature = "opencl")]
675 CL(opencl::ops::Linear<T>),
676 Host(host::ops::Linear<T>),
677}
678
679#[cfg(feature = "opencl")]
680impl<T> From<opencl::ops::Linear<T>> for Linear<T> {
681 fn from(op: opencl::ops::Linear<T>) -> Self {
682 Self::CL(op)
683 }
684}
685
686impl<T> From<host::ops::Linear<T>> for Linear<T> {
687 fn from(op: host::ops::Linear<T>) -> Self {
688 Self::Host(op)
689 }
690}
691
692impl<T: Send + Sync> Op for Linear<T> {
693 fn size(&self) -> usize {
694 op_dispatch!(self, op, op.size())
695 }
696}
697
698impl<T: Number> Enqueue<Platform, T> for Linear<T> {
699 type Buffer = Buffer<T>;
700
701 fn enqueue(&self) -> Result<Self::Buffer, Error> {
702 op_enqueue!(self, T)
703 }
704}
705
706impl<T: Number> ReadValue<Platform, T> for Linear<T> {
707 fn read_value(&self, offset: usize) -> Result<T, Error> {
708 op_dispatch!(self, op, op.read_value(offset))
709 }
710}
711
712pub enum MatDiag<A, T> {
713 #[cfg(feature = "opencl")]
714 CL(opencl::ops::MatDiag<A, T>),
715 Host(host::ops::MatDiag<A, T>),
716}
717
718impl<A: Access<T>, T: Number> Op for MatDiag<A, T> {
719 fn size(&self) -> usize {
720 op_dispatch!(self, op, op.size())
721 }
722}
723
724impl<A: Access<T>, T: Number> Enqueue<Platform, T> for MatDiag<A, T> {
725 type Buffer = Buffer<T>;
726
727 fn enqueue(&self) -> Result<Self::Buffer, Error> {
728 op_enqueue!(self, T)
729 }
730}
731
732impl<A: Access<T>, T: Number> ReadValue<Platform, T> for MatDiag<A, T> {
733 fn read_value(&self, offset: usize) -> Result<T, Error> {
734 op_dispatch!(self, op, op.read_value(offset))
735 }
736}
737
738impl<A, T> From<host::ops::MatDiag<A, T>> for MatDiag<A, T> {
739 fn from(op: host::ops::MatDiag<A, T>) -> Self {
740 Self::Host(op)
741 }
742}
743
744#[cfg(feature = "opencl")]
745impl<A, T> From<opencl::ops::MatDiag<A, T>> for MatDiag<A, T> {
746 fn from(op: opencl::ops::MatDiag<A, T>) -> Self {
747 Self::CL(op)
748 }
749}
750
751pub enum MatMul<L, R, T> {
752 #[cfg(feature = "opencl")]
753 CL(opencl::ops::MatMul<L, R, T>),
754 Host(host::ops::MatMul<L, R, T>),
755}
756
757impl<L, R, T> Op for MatMul<L, R, T>
758where
759 L: Access<T>,
760 R: Access<T>,
761 T: Number,
762{
763 fn size(&self) -> usize {
764 op_dispatch!(self, op, op.size())
765 }
766}
767
768impl<L, R, T> Enqueue<Platform, T> for MatMul<L, R, T>
769where
770 L: Access<T>,
771 R: Access<T>,
772 T: Number,
773{
774 type Buffer = Buffer<T>;
775
776 fn enqueue(&self) -> Result<Self::Buffer, Error> {
777 op_enqueue!(self, T)
778 }
779}
780
781impl<L, R, T> ReadValue<Platform, T> for MatMul<L, R, T>
782where
783 L: Access<T>,
784 R: Access<T>,
785 T: Number,
786{
787 fn read_value(&self, offset: usize) -> Result<T, Error> {
788 op_dispatch!(self, op, op.read_value(offset))
789 }
790}
791
792#[cfg(feature = "opencl")]
793impl<L, R, T> From<opencl::ops::MatMul<L, R, T>> for MatMul<L, R, T> {
794 fn from(op: opencl::ops::MatMul<L, R, T>) -> Self {
795 Self::CL(op)
796 }
797}
798
799impl<L, R, T> From<host::ops::MatMul<L, R, T>> for MatMul<L, R, T> {
800 fn from(op: host::ops::MatMul<L, R, T>) -> Self {
801 Self::Host(op)
802 }
803}
804
805pub enum RandomNormal {
806 #[cfg(feature = "opencl")]
807 CL(opencl::ops::RandomNormal),
808 Host(host::ops::RandomNormal),
809}
810
811#[cfg(feature = "opencl")]
812impl From<opencl::ops::RandomNormal> for RandomNormal {
813 fn from(op: opencl::ops::RandomNormal) -> Self {
814 Self::CL(op)
815 }
816}
817
818impl From<host::ops::RandomNormal> for RandomNormal {
819 fn from(op: host::ops::RandomNormal) -> Self {
820 Self::Host(op)
821 }
822}
823
824pub enum RandomUniform {
825 #[cfg(feature = "opencl")]
826 CL(opencl::ops::RandomUniform),
827 Host(host::ops::RandomUniform),
828}
829
830#[cfg(feature = "opencl")]
831impl From<opencl::ops::RandomUniform> for RandomUniform {
832 fn from(op: opencl::ops::RandomUniform) -> Self {
833 Self::CL(op)
834 }
835}
836
837impl From<host::ops::RandomUniform> for RandomUniform {
838 fn from(op: host::ops::RandomUniform) -> Self {
839 Self::Host(op)
840 }
841}
842
843macro_rules! impl_random {
844 ($t:ty) => {
845 impl Op for $t {
846 fn size(&self) -> usize {
847 op_dispatch!(self, op, op.size())
848 }
849 }
850
851 impl Enqueue<Platform, f32> for $t {
852 type Buffer = Buffer<f32>;
853
854 fn enqueue(&self) -> Result<Self::Buffer, Error> {
855 op_enqueue!(self, f32)
856 }
857 }
858
859 impl ReadValue<Platform, f32> for $t {
860 fn read_value(&self, offset: usize) -> Result<f32, Error> {
861 op_dispatch!(self, op, op.read_value(offset))
862 }
863 }
864 };
865}
866
867impl_random!(RandomNormal);
868impl_random!(RandomUniform);
869
870macro_rules! impl_unary {
871 ($op:ty, $t:ty) => {
872 impl<A: Access<T>, T: Number> Op for $op {
873 fn size(&self) -> usize {
874 op_dispatch!(self, op, op.size())
875 }
876 }
877
878 impl<A: Access<T>, T: Number> Enqueue<Platform, $t> for $op {
879 type Buffer = Buffer<$t>;
880
881 fn enqueue(&self) -> Result<Self::Buffer, Error> {
882 op_enqueue!(self, $t)
883 }
884 }
885
886 impl<A: Access<T>, T: Number> ReadValue<Platform, $t> for $op {
887 fn read_value(&self, offset: usize) -> Result<$t, Error> {
888 op_dispatch!(self, op, op.read_value(offset))
889 }
890 }
891 };
892}
893
894pub enum Reduce<A, T: Number> {
895 #[cfg(feature = "opencl")]
896 CL(opencl::ops::Reduce<A, T>),
897 Host(host::ops::Reduce<A, T>),
898}
899
900impl_unary!(Reduce<A, T>, T);
901
902impl<A, T: Number> From<host::ops::Reduce<A, T>> for Reduce<A, T> {
903 fn from(op: host::ops::Reduce<A, T>) -> Self {
904 Self::Host(op)
905 }
906}
907
908#[cfg(feature = "opencl")]
909impl<A, T: Number> From<opencl::ops::Reduce<A, T>> for Reduce<A, T> {
910 fn from(op: opencl::ops::Reduce<A, T>) -> Self {
911 Self::CL(op)
912 }
913}
914
915#[derive(Clone, Eq, PartialEq, Hash)]
916pub struct SliceSpec {
917 pub range: Range,
918 pub shape: Shape,
919 pub strides: Strides,
920 pub source_strides: Strides,
921}
922
923impl SliceSpec {
924 pub fn new(source_shape: &[usize], range: Range) -> Self {
925 debug_assert!(range.len() <= source_shape.len());
926
927 let shape = range_shape(source_shape, &range);
928 let strides = strides_for(&shape, shape.len()).collect();
929 let source_strides = strides_for(source_shape, source_shape.len()).collect();
930
931 Self {
932 range,
933 shape,
934 strides,
935 source_strides,
936 }
937 }
938
939 pub fn source_offset(&self, offset: usize) -> usize {
940 debug_assert!(!self.shape.is_empty());
941 debug_assert_eq!(self.shape.len(), self.strides.len());
942
943 let mut coord = self
944 .strides
945 .iter()
946 .copied()
947 .zip(&self.shape)
948 .map(|(stride, dim)| {
949 if stride == 0 {
950 0
951 } else {
952 (offset / stride) % dim
953 }
954 });
955
956 let mut offset = 0;
957 for (stride, bound) in self.source_strides.iter().zip(self.range.iter()) {
958 let i = match bound {
959 AxisRange::At(i) => *i,
960 AxisRange::In(start, stop, step) => {
961 let i = start + (coord.next().expect("i") * step);
962 debug_assert!(i < *stop);
963 i
964 }
965 AxisRange::Of(indices) => indices[coord.next().expect("i")],
966 };
967
968 offset += i * stride;
969 }
970
971 offset
972 }
973
974 pub fn size(&self) -> usize {
975 self.shape.iter().product()
976 }
977}
978
979pub enum Scalar<A, IT, OT> {
980 #[cfg(feature = "opencl")]
981 CL(opencl::ops::Scalar<A, IT, OT>),
982 Host(host::ops::Scalar<A, IT, OT>),
983}
984
985impl<A, IT, OT> Op for Scalar<A, IT, OT>
986where
987 A: Access<IT>,
988 IT: Number,
989 OT: Number,
990{
991 fn size(&self) -> usize {
992 op_dispatch!(self, op, op.size())
993 }
994}
995
996impl<A, IT, OT> Enqueue<Platform, OT> for Scalar<A, IT, OT>
997where
998 A: Access<IT>,
999 IT: Number,
1000 OT: Number,
1001{
1002 type Buffer = Buffer<OT>;
1003
1004 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1005 op_enqueue!(self, OT)
1006 }
1007}
1008
1009impl<A, IT, OT> ReadValue<Platform, OT> for Scalar<A, IT, OT>
1010where
1011 A: Access<IT>,
1012 IT: Number,
1013 OT: Number,
1014{
1015 fn read_value(&self, offset: usize) -> Result<OT, Error> {
1016 op_dispatch!(self, op, op.read_value(offset))
1017 }
1018}
1019
1020#[cfg(feature = "opencl")]
1021impl<A, IT, OT> From<opencl::ops::Scalar<A, IT, OT>> for Scalar<A, IT, OT> {
1022 fn from(op: opencl::ops::Scalar<A, IT, OT>) -> Self {
1023 Self::CL(op)
1024 }
1025}
1026
1027impl<A, IT, OT> From<host::ops::Scalar<A, IT, OT>> for Scalar<A, IT, OT> {
1028 fn from(op: host::ops::Scalar<A, IT, OT>) -> Self {
1029 Self::Host(op)
1030 }
1031}
1032
1033pub enum Slice<A, T> {
1034 #[cfg(feature = "opencl")]
1035 CL(opencl::ops::Slice<A, T>),
1036 Host(host::ops::Slice<A, T>),
1037}
1038
1039impl_unary!(Slice<A, T>, T);
1040
1041#[cfg(feature = "opencl")]
1042impl<A, T> Write<Platform, T> for Slice<A, T>
1043where
1044 A: AccessMut<T> + std::fmt::Debug,
1045 T: Number,
1046{
1047 fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
1048 match self {
1049 Self::CL(op) => Write::<opencl::OpenCL, T>::write(op, data),
1050 Self::Host(op) => Write::<host::Host, T>::write(op, data),
1051 }
1052 }
1053
1054 fn write_value(&mut self, value: T) -> Result<(), Error> {
1055 match self {
1056 Self::CL(op) => Write::<opencl::OpenCL, T>::write_value(op, value),
1057 Self::Host(op) => Write::<host::Host, T>::write_value(op, value),
1058 }
1059 }
1060
1061 fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
1062 match self {
1063 Self::CL(op) => Write::<opencl::OpenCL, T>::write_value_at(op, offset, value),
1064 Self::Host(op) => Write::<host::Host, T>::write_value_at(op, offset, value),
1065 }
1066 }
1067}
1068
1069#[cfg(not(feature = "opencl"))]
1070impl<A, T> Write<Platform, T> for Slice<A, T>
1071where
1072 T: Number,
1073 A: AccessMut<T>,
1074{
1075 fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
1076 match self {
1077 Self::Host(op) => Write::<host::Host, T>::write(op, data),
1078 }
1079 }
1080
1081 fn write_value(&mut self, value: T) -> Result<(), Error> {
1082 match self {
1083 Self::Host(op) => Write::<host::Host, T>::write_value(op, value),
1084 }
1085 }
1086
1087 fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
1088 match self {
1089 Self::Host(op) => Write::<host::Host, T>::write_value_at(op, offset, value),
1090 }
1091 }
1092}
1093
1094#[cfg(feature = "opencl")]
1095impl<A, T> From<opencl::ops::Slice<A, T>> for Slice<A, T> {
1096 fn from(op: opencl::ops::Slice<A, T>) -> Self {
1097 Self::CL(op)
1098 }
1099}
1100
1101impl<A, T> From<host::ops::Slice<A, T>> for Slice<A, T> {
1102 fn from(op: host::ops::Slice<A, T>) -> Self {
1103 Self::Host(op)
1104 }
1105}
1106
1107pub enum Unary<A, IT, OT> {
1108 #[cfg(feature = "opencl")]
1109 CL(opencl::ops::Unary<A, IT, OT>),
1110 Host(host::ops::Unary<A, IT, OT>),
1111}
1112
1113impl<A, IT, OT> Op for Unary<A, IT, OT>
1114where
1115 A: Access<IT>,
1116 IT: Number,
1117 OT: Number,
1118{
1119 fn size(&self) -> usize {
1120 op_dispatch!(self, op, op.size())
1121 }
1122}
1123
1124impl<A, IT, OT> Enqueue<Platform, OT> for Unary<A, IT, OT>
1125where
1126 A: Access<IT>,
1127 IT: Number,
1128 OT: Number,
1129{
1130 type Buffer = Buffer<OT>;
1131
1132 fn enqueue(&self) -> Result<Self::Buffer, Error> {
1133 op_enqueue!(self, OT)
1134 }
1135}
1136
1137impl<A, IT, OT> ReadValue<Platform, OT> for Unary<A, IT, OT>
1138where
1139 A: Access<IT>,
1140 IT: Number,
1141 OT: Number,
1142{
1143 fn read_value(&self, offset: usize) -> Result<OT, Error> {
1144 op_dispatch!(self, op, op.read_value(offset))
1145 }
1146}
1147
1148impl<A, IT, OT> From<host::ops::Unary<A, IT, OT>> for Unary<A, IT, OT> {
1149 fn from(op: host::ops::Unary<A, IT, OT>) -> Self {
1150 Self::Host(op)
1151 }
1152}
1153
1154#[cfg(feature = "opencl")]
1155impl<A, IT, OT> From<opencl::ops::Unary<A, IT, OT>> for Unary<A, IT, OT> {
1156 fn from(op: opencl::ops::Unary<A, IT, OT>) -> Self {
1157 Self::CL(op)
1158 }
1159}
1160
1161#[derive(Clone, Eq, PartialEq, Hash)]
1162pub struct FlipSpec {
1163 pub shape: Shape,
1164 pub strides: Strides,
1165 pub axis: usize,
1166}
1167
1168impl FlipSpec {
1169 pub fn new(shape: Shape, axis: usize) -> Result<Self, Error> {
1170 if axis < shape.len() {
1171 let strides = strides_for(&shape, shape.len()).collect();
1172
1173 Ok(Self {
1174 shape,
1175 strides,
1176 axis,
1177 })
1178 } else {
1179 Err(Error::bounds(format!("shape {shape:?} has no axis {axis}")))
1180 }
1181 }
1182
1183 pub fn source_offset(&self, offset: usize) -> usize {
1184 self.strides
1185 .iter()
1186 .copied()
1187 .zip(self.shape.iter().copied())
1188 .map(|(stride, dim)| {
1189 if stride == 0 {
1190 0
1191 } else {
1192 (offset / stride) % dim
1193 }
1194 }) .zip(self.strides.iter().copied())
1196 .enumerate()
1197 .map(|(x, (i, source_stride))| {
1198 let i = if x == self.axis {
1199 self.shape[x] - i - 1
1200 } else {
1201 i
1202 };
1203
1204 i * source_stride
1205 })
1206 .sum::<usize>()
1207 }
1208}
1209
1210#[derive(Clone, Eq, PartialEq, Hash)]
1211pub struct ViewSpec {
1212 pub shape: Shape,
1213 pub strides: Strides,
1214 pub source_strides: Strides,
1215}
1216
1217impl ViewSpec {
1218 pub fn new(shape: Shape, source_strides: Strides) -> Self {
1219 let strides = strides_for(&shape, shape.len()).collect();
1220
1221 Self {
1222 shape,
1223 strides,
1224 source_strides,
1225 }
1226 }
1227
1228 pub fn source_offset(&self, offset: usize) -> usize {
1229 debug_assert!(offset < self.size());
1230
1231 let source_offset = self
1232 .strides
1233 .iter()
1234 .copied()
1235 .zip(self.shape.iter().copied())
1236 .rev()
1237 .take(self.source_strides.len())
1238 .map(|(stride, dim)| {
1239 if stride == 0 {
1240 0
1241 } else {
1242 (offset / stride) % dim
1243 }
1244 }) .zip(self.source_strides.iter().rev().copied())
1246 .map(|(i, source_stride)| i * source_stride)
1247 .sum::<usize>();
1248
1249 source_offset
1250 }
1251
1252 pub fn size(&self) -> usize {
1253 self.shape.iter().product()
1254 }
1255}
1256
1257pub enum View<A, T> {
1258 #[cfg(feature = "opencl")]
1259 CL(opencl::ops::View<A, T>),
1260 Host(host::ops::View<A, T>),
1261}
1262
1263impl_unary!(View<A, T>, T);
1264
1265#[cfg(feature = "opencl")]
1266impl<A, T> From<opencl::ops::View<A, T>> for View<A, T> {
1267 fn from(op: opencl::ops::View<A, T>) -> Self {
1268 Self::CL(op)
1269 }
1270}
1271
1272impl<A, T> From<host::ops::View<A, T>> for View<A, T> {
1273 fn from(op: host::ops::View<A, T>) -> Self {
1274 Self::Host(op)
1275 }
1276}