1use crate::access::*;
4use crate::buffer::Buffer;
5#[cfg(feature = "opencl")]
6use crate::opencl;
7use crate::platform::{Platform, PlatformInstance};
8use crate::{
9 host, range_shape, strides_for, Axes, AxisRange, BufferConverter, CType, Error, Range, Shape,
10 Strides,
11};
12
13macro_rules! op_dispatch {
14 ($this:expr, $op:ident, $call:expr) => {
15 match $this {
16 #[cfg(feature = "opencl")]
17 Self::CL($op) => $call,
18 Self::Host($op) => $call,
19 }
20 };
21}
22
23macro_rules! op_enqueue {
24 ($this:expr, $t:ty) => {
25 match $this {
26 #[cfg(feature = "opencl")]
27 Self::CL(op) => Enqueue::<opencl::OpenCL, $t>::enqueue(op).map(Buffer::CL),
28 Self::Host(op) => Enqueue::<host::Host, $t>::enqueue(op).map(Buffer::Host),
29 }
30 };
31}
32
33pub trait Op: Send + Sync {
34 fn size(&self) -> usize;
35}
36
37pub trait Enqueue<P: PlatformInstance, T: CType>: Op {
38 type Buffer: Into<BufferConverter<'static, T>>;
39
40 fn enqueue(&self) -> Result<Self::Buffer, Error>;
41}
42
43pub trait ReadValue<P: PlatformInstance, T: CType>: Op {
44 fn read_value(&self, offset: usize) -> Result<T, Error>;
45}
46
47pub trait ReadOp<P, T>: Enqueue<P, T> + ReadValue<P, T>
48where
49 P: PlatformInstance,
50 T: CType,
51{
52}
53
54impl<O, P, T> ReadOp<P, T> for O
55where
56 O: Enqueue<P, T> + ReadValue<P, T>,
57 P: PlatformInstance,
58 T: CType,
59{
60}
61
62pub trait Write<P: PlatformInstance, T: CType>: Enqueue<P, T> {
63 fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error>;
64
65 fn write_value(&mut self, value: T) -> Result<(), Error>;
66
67 fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error>;
68}
69
70pub trait Construct<T: CType>: PlatformInstance {
71 type Range: Enqueue<Self, T>;
72
73 fn range(self, start: T, stop: T, size: usize) -> Result<AccessOp<Self::Range, Self>, Error>;
74}
75
76pub trait ElementwiseBoolean<L, R, T>: PlatformInstance {
77 type Op: ReadOp<Self, u8>;
78
79 fn and(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
80
81 fn or(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
82
83 fn xor(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
84}
85
86pub trait ElementwiseBooleanScalar<A, T>: PlatformInstance {
87 type Op: ReadOp<Self, u8>;
88
89 fn and_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
90
91 fn or_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
92
93 fn xor_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
94}
95
96pub trait ElementwiseCast<A, IT, OT>: PlatformInstance
97where
98 A: Access<IT>,
99 IT: CType,
100 OT: CType,
101{
102 type Op: ReadOp<Self, OT>;
103
104 fn cast(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
105}
106
107pub trait ElementwiseCompare<L, R, T>: PlatformInstance {
108 type Op: ReadOp<Self, u8>;
109
110 fn eq(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
111
112 fn ge(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
113
114 fn gt(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
115
116 fn le(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
117
118 fn lt(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
119
120 fn ne(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
121}
122
123pub trait ElementwiseScalarCompare<A, T>: PlatformInstance {
124 type Op: ReadOp<Self, u8>;
125
126 fn eq_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
127
128 fn ge_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
129
130 fn gt_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
131
132 fn le_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
133
134 fn lt_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
135
136 fn ne_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
137}
138
139pub trait ElementwiseDual<L, R, T>: PlatformInstance
140where
141 L: Access<T>,
142 R: Access<T>,
143 T: CType,
144{
145 type Op: ReadOp<Self, T>;
146
147 fn add(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
148
149 fn div(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
150
151 fn log(self, arg: L, base: R) -> Result<AccessOp<Self::Op, Self>, Error>;
152
153 fn mul(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
154
155 fn pow(self, arg: L, exp: R) -> Result<AccessOp<Self::Op, Self>, Error>;
156
157 fn rem(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
158
159 fn sub(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>;
160}
161
162pub trait ElementwiseScalar<A, T>: PlatformInstance
163where
164 A: Access<T>,
165 T: CType,
166{
167 type Op: ReadOp<Self, T>;
168
169 fn add_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
170
171 fn div_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
172
173 fn log_scalar(self, arg: A, base: T) -> Result<AccessOp<Self::Op, Self>, Error>;
174
175 fn mul_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
176
177 fn pow_scalar(self, arg: A, exp: T) -> Result<AccessOp<Self::Op, Self>, Error>;
178
179 fn rem_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
180
181 fn sub_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>;
182}
183
184pub trait ElementwiseNumeric<A, T>: PlatformInstance
185where
186 A: Access<T>,
187 T: CType,
188{
189 type Op: ReadOp<Self, u8>;
190
191 fn is_inf(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
192
193 fn is_nan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
194}
195
196pub trait ElementwiseTrig<A, T>: PlatformInstance
197where
198 A: Access<T>,
199 T: CType,
200{
201 type Op: ReadOp<Self, T::Float>;
202
203 fn sin(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
204
205 fn asin(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
206
207 fn sinh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
208
209 fn cos(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
210
211 fn acos(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
212
213 fn cosh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
214
215 fn tan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
216
217 fn atan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
218
219 fn tanh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
220}
221
222pub trait ElementwiseUnary<A, T>: PlatformInstance
223where
224 A: Access<T>,
225 T: CType,
226{
227 type Op: ReadOp<Self, T>;
228
229 fn abs(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
230
231 fn exp(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
232
233 fn ln(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
234
235 fn round(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
236}
237
238pub trait ElementwiseUnaryBoolean<A, T>: PlatformInstance
239where
240 A: Access<T>,
241 T: CType,
242{
243 type Op: ReadOp<Self, u8>;
244
245 fn not(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>;
246}
247
248pub trait GatherCond<A, L, R, T>: PlatformInstance
249where
250 A: Access<u8>,
251 L: Access<T>,
252 R: Access<T>,
253 T: CType,
254{
255 type Op: ReadOp<Self, T>;
256
257 fn cond(self, cond: A, then: L, or_else: R) -> Result<AccessOp<Self::Op, Self>, Error>;
258}
259
260pub trait LinAlgDual<L, R, T>: PlatformInstance
261where
262 L: Access<T>,
263 R: Access<T>,
264 T: CType,
265{
266 type Op: ReadOp<Self, T>;
267
268 fn matmul(self, left: L, right: R, dims: [usize; 4])
269 -> Result<AccessOp<Self::Op, Self>, Error>;
270}
271
272pub trait LinAlgUnary<A, T>: PlatformInstance
273where
274 A: Access<T>,
275 T: CType,
276{
277 type Op: ReadOp<Self, T>;
278
279 fn diag(
280 self,
281 access: A,
282 batch_size: usize,
283 dim: usize,
284 ) -> Result<AccessOp<Self::Op, Self>, Error>;
285}
286
287pub trait Random: PlatformInstance {
288 type Normal: Enqueue<Self, f32>;
289 type Uniform: Enqueue<Self, f32>;
290
291 fn random_normal(self, size: usize) -> Result<AccessOp<Self::Normal, Self>, Error>;
292
293 fn random_uniform(self, size: usize) -> Result<AccessOp<Self::Uniform, Self>, Error>;
294}
295
296pub trait ReduceAll<A, T>: PlatformInstance {
297 fn all(self, access: A) -> Result<bool, Error>;
298
299 fn any(self, access: A) -> Result<bool, Error>;
300
301 fn max(self, access: A) -> Result<T, Error>;
302
303 fn min(self, access: A) -> Result<T, Error>;
304
305 fn product(self, access: A) -> Result<T, Error>;
306
307 fn sum(self, access: A) -> Result<T, Error>;
308}
309
310pub trait ReduceAxes<A: Access<T>, T: CType>: PlatformInstance {
311 type Op: ReadOp<Self, T>;
312
313 fn max(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error>;
314
315 fn min(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error>;
316
317 fn product(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error>;
318
319 fn sum(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error>;
320}
321
322pub trait Transform<A: Access<T>, T: CType>: PlatformInstance {
323 type Broadcast: ReadOp<Self, T>;
324 type Slice: ReadOp<Self, T>;
325 type Transpose: ReadOp<Self, T>;
326
327 fn broadcast(
328 self,
329 access: A,
330 shape: Shape,
331 broadcast: Shape,
332 ) -> Result<AccessOp<Self::Broadcast, Self>, Error>;
333
334 fn slice(
335 self,
336 access: A,
337 shape: &[usize],
338 range: Range,
339 ) -> Result<AccessOp<Self::Slice, Self>, Error>;
340
341 fn transpose(
342 self,
343 access: A,
344 shape: Shape,
345 permutation: Axes,
346 ) -> Result<AccessOp<Self::Transpose, Self>, Error>;
347}
348
349pub enum Cast<A, IT, OT> {
350 #[cfg(feature = "opencl")]
351 CL(opencl::ops::Cast<A, IT, OT>),
352 Host(host::ops::Cast<A, IT, OT>),
353}
354
355impl<A: Access<IT>, IT: CType, OT: CType> Op for Cast<A, IT, OT> {
356 fn size(&self) -> usize {
357 op_dispatch!(self, op, op.size())
358 }
359}
360
361impl<A: Access<IT>, IT: CType, OT: CType> Enqueue<Platform, OT> for Cast<A, IT, OT> {
362 type Buffer = Buffer<OT>;
363
364 fn enqueue(&self) -> Result<Self::Buffer, Error> {
365 op_enqueue!(self, OT)
366 }
367}
368
369impl<A: Access<IT>, IT: CType, OT: CType> ReadValue<Platform, OT> for Cast<A, IT, OT> {
370 fn read_value(&self, offset: usize) -> Result<OT, Error> {
371 op_dispatch!(self, op, op.read_value(offset))
372 }
373}
374
375impl<A, IT, OT> From<host::ops::Cast<A, IT, OT>> for Cast<A, IT, OT> {
376 fn from(op: host::ops::Cast<A, IT, OT>) -> Cast<A, IT, OT> {
377 Self::Host(op)
378 }
379}
380
381#[cfg(feature = "opencl")]
382impl<A, IT, OT> From<opencl::ops::Cast<A, IT, OT>> for Cast<A, IT, OT> {
383 fn from(op: opencl::ops::Cast<A, IT, OT>) -> Cast<A, IT, OT> {
384 Self::CL(op)
385 }
386}
387
388pub enum Cond<A, L, R, T> {
389 #[cfg(feature = "opencl")]
390 CL(opencl::ops::Cond<A, L, R, T>),
391 Host(host::ops::Cond<A, L, R, T>),
392}
393
394impl<A, L, R, T> Op for Cond<A, L, R, T>
395where
396 A: Access<u8>,
397 L: Access<T>,
398 R: Access<T>,
399 T: CType,
400{
401 fn size(&self) -> usize {
402 op_dispatch!(self, op, op.size())
403 }
404}
405
406impl<A, L, R, T> Enqueue<Platform, T> for Cond<A, L, R, T>
407where
408 A: Access<u8>,
409 L: Access<T>,
410 R: Access<T>,
411 T: CType,
412{
413 type Buffer = Buffer<T>;
414
415 fn enqueue(&self) -> Result<Self::Buffer, Error> {
416 op_enqueue!(self, T)
417 }
418}
419
420impl<A, L, R, T> ReadValue<Platform, T> for Cond<A, L, R, T>
421where
422 A: Access<u8>,
423 L: Access<T>,
424 R: Access<T>,
425 T: CType,
426{
427 fn read_value(&self, offset: usize) -> Result<T, Error> {
428 op_dispatch!(self, op, op.read_value(offset))
429 }
430}
431
432impl<A, L, R, T> From<host::ops::Cond<A, L, R, T>> for Cond<A, L, R, T> {
433 fn from(op: host::ops::Cond<A, L, R, T>) -> Self {
434 Self::Host(op)
435 }
436}
437
438#[cfg(feature = "opencl")]
439impl<A, L, R, T> From<opencl::ops::Cond<A, L, R, T>> for Cond<A, L, R, T> {
440 fn from(op: opencl::ops::Cond<A, L, R, T>) -> Self {
441 Self::CL(op)
442 }
443}
444
445pub enum Dual<L, R, IT, OT> {
446 #[cfg(feature = "opencl")]
447 CL(opencl::ops::Dual<L, R, IT, OT>),
448 Host(host::ops::Dual<L, R, IT, OT>),
449}
450
451impl<L, R, IT, OT> Op for Dual<L, R, IT, OT>
452where
453 L: Access<IT>,
454 R: Access<IT>,
455 IT: CType,
456 OT: CType,
457{
458 fn size(&self) -> usize {
459 op_dispatch!(self, op, op.size())
460 }
461}
462
463impl<L, R, IT, OT> Enqueue<Platform, OT> for Dual<L, R, IT, OT>
464where
465 L: Access<IT>,
466 R: Access<IT>,
467 IT: CType,
468 OT: CType,
469{
470 type Buffer = Buffer<OT>;
471
472 fn enqueue(&self) -> Result<Self::Buffer, Error> {
473 op_enqueue!(self, OT)
474 }
475}
476
477impl<L, R, IT, OT> ReadValue<Platform, OT> for Dual<L, R, IT, OT>
478where
479 L: Access<IT>,
480 R: Access<IT>,
481 IT: CType,
482 OT: CType,
483{
484 fn read_value(&self, offset: usize) -> Result<OT, Error> {
485 op_dispatch!(self, op, op.read_value(offset))
486 }
487}
488
489#[cfg(feature = "opencl")]
490impl<L, R, IT, OT> From<opencl::ops::Dual<L, R, IT, OT>> for Dual<L, R, IT, OT> {
491 fn from(op: opencl::ops::Dual<L, R, IT, OT>) -> Self {
492 Self::CL(op)
493 }
494}
495
496impl<L, R, IT, OT> From<host::ops::Dual<L, R, IT, OT>> for Dual<L, R, IT, OT> {
497 fn from(op: host::ops::Dual<L, R, IT, OT>) -> Self {
498 Self::Host(op)
499 }
500}
501
502pub enum Linear<T> {
503 #[cfg(feature = "opencl")]
504 CL(opencl::ops::Linear<T>),
505 Host(host::ops::Linear<T>),
506}
507
508#[cfg(feature = "opencl")]
509impl<T> From<opencl::ops::Linear<T>> for Linear<T> {
510 fn from(op: opencl::ops::Linear<T>) -> Self {
511 Self::CL(op)
512 }
513}
514
515impl<T> From<host::ops::Linear<T>> for Linear<T> {
516 fn from(op: host::ops::Linear<T>) -> Self {
517 Self::Host(op)
518 }
519}
520
521impl<T: Send + Sync> Op for Linear<T> {
522 fn size(&self) -> usize {
523 op_dispatch!(self, op, op.size())
524 }
525}
526
527impl<T: CType> Enqueue<Platform, T> for Linear<T> {
528 type Buffer = Buffer<T>;
529
530 fn enqueue(&self) -> Result<Self::Buffer, Error> {
531 op_enqueue!(self, T)
532 }
533}
534
535impl<T: CType> ReadValue<Platform, T> for Linear<T> {
536 fn read_value(&self, offset: usize) -> Result<T, Error> {
537 op_dispatch!(self, op, op.read_value(offset))
538 }
539}
540
541pub enum MatDiag<A, T> {
542 #[cfg(feature = "opencl")]
543 CL(opencl::ops::MatDiag<A, T>),
544 Host(host::ops::MatDiag<A, T>),
545}
546
547impl<A: Access<T>, T: CType> Op for MatDiag<A, T> {
548 fn size(&self) -> usize {
549 op_dispatch!(self, op, op.size())
550 }
551}
552
553impl<A: Access<T>, T: CType> Enqueue<Platform, T> for MatDiag<A, T> {
554 type Buffer = Buffer<T>;
555
556 fn enqueue(&self) -> Result<Self::Buffer, Error> {
557 op_enqueue!(self, T)
558 }
559}
560
561impl<A: Access<T>, T: CType> ReadValue<Platform, T> for MatDiag<A, T> {
562 fn read_value(&self, offset: usize) -> Result<T, Error> {
563 op_dispatch!(self, op, op.read_value(offset))
564 }
565}
566
567impl<A, T> From<host::ops::MatDiag<A, T>> for MatDiag<A, T> {
568 fn from(op: host::ops::MatDiag<A, T>) -> Self {
569 Self::Host(op)
570 }
571}
572
573#[cfg(feature = "opencl")]
574impl<A, T> From<opencl::ops::MatDiag<A, T>> for MatDiag<A, T> {
575 fn from(op: opencl::ops::MatDiag<A, T>) -> Self {
576 Self::CL(op)
577 }
578}
579
580pub enum MatMul<L, R, T> {
581 #[cfg(feature = "opencl")]
582 CL(opencl::ops::MatMul<L, R, T>),
583 Host(host::ops::MatMul<L, R, T>),
584}
585
586impl<L, R, T> Op for MatMul<L, R, T>
587where
588 L: Access<T>,
589 R: Access<T>,
590 T: CType,
591{
592 fn size(&self) -> usize {
593 op_dispatch!(self, op, op.size())
594 }
595}
596
597impl<L, R, T> Enqueue<Platform, T> for MatMul<L, R, T>
598where
599 L: Access<T>,
600 R: Access<T>,
601 T: CType,
602{
603 type Buffer = Buffer<T>;
604
605 fn enqueue(&self) -> Result<Self::Buffer, Error> {
606 op_enqueue!(self, T)
607 }
608}
609
610impl<L, R, T> ReadValue<Platform, T> for MatMul<L, R, T>
611where
612 L: Access<T>,
613 R: Access<T>,
614 T: CType,
615{
616 fn read_value(&self, offset: usize) -> Result<T, Error> {
617 op_dispatch!(self, op, op.read_value(offset))
618 }
619}
620
621#[cfg(feature = "opencl")]
622impl<L, R, T> From<opencl::ops::MatMul<L, R, T>> for MatMul<L, R, T> {
623 fn from(op: opencl::ops::MatMul<L, R, T>) -> Self {
624 Self::CL(op)
625 }
626}
627
628impl<L, R, T> From<host::ops::MatMul<L, R, T>> for MatMul<L, R, T> {
629 fn from(op: host::ops::MatMul<L, R, T>) -> Self {
630 Self::Host(op)
631 }
632}
633
634pub enum RandomNormal {
635 #[cfg(feature = "opencl")]
636 CL(opencl::ops::RandomNormal),
637 Host(host::ops::RandomNormal),
638}
639
640#[cfg(feature = "opencl")]
641impl From<opencl::ops::RandomNormal> for RandomNormal {
642 fn from(op: opencl::ops::RandomNormal) -> Self {
643 Self::CL(op)
644 }
645}
646
647impl From<host::ops::RandomNormal> for RandomNormal {
648 fn from(op: host::ops::RandomNormal) -> Self {
649 Self::Host(op)
650 }
651}
652
653pub enum RandomUniform {
654 #[cfg(feature = "opencl")]
655 CL(opencl::ops::RandomUniform),
656 Host(host::ops::RandomUniform),
657}
658
659#[cfg(feature = "opencl")]
660impl From<opencl::ops::RandomUniform> for RandomUniform {
661 fn from(op: opencl::ops::RandomUniform) -> Self {
662 Self::CL(op)
663 }
664}
665
666impl From<host::ops::RandomUniform> for RandomUniform {
667 fn from(op: host::ops::RandomUniform) -> Self {
668 Self::Host(op)
669 }
670}
671
672macro_rules! impl_random {
673 ($t:ty) => {
674 impl Op for $t {
675 fn size(&self) -> usize {
676 op_dispatch!(self, op, op.size())
677 }
678 }
679
680 impl Enqueue<Platform, f32> for $t {
681 type Buffer = Buffer<f32>;
682
683 fn enqueue(&self) -> Result<Self::Buffer, Error> {
684 op_enqueue!(self, f32)
685 }
686 }
687
688 impl ReadValue<Platform, f32> for $t {
689 fn read_value(&self, offset: usize) -> Result<f32, Error> {
690 op_dispatch!(self, op, op.read_value(offset))
691 }
692 }
693 };
694}
695
696impl_random!(RandomNormal);
697impl_random!(RandomUniform);
698
699macro_rules! impl_unary {
700 ($op:ty, $t:ty) => {
701 impl<A: Access<T>, T: CType> Op for $op {
702 fn size(&self) -> usize {
703 op_dispatch!(self, op, op.size())
704 }
705 }
706
707 impl<A: Access<T>, T: CType> Enqueue<Platform, $t> for $op {
708 type Buffer = Buffer<$t>;
709
710 fn enqueue(&self) -> Result<Self::Buffer, Error> {
711 op_enqueue!(self, $t)
712 }
713 }
714
715 impl<A: Access<T>, T: CType> ReadValue<Platform, $t> for $op {
716 fn read_value(&self, offset: usize) -> Result<$t, Error> {
717 op_dispatch!(self, op, op.read_value(offset))
718 }
719 }
720 };
721}
722
723pub enum Reduce<A, T: CType> {
724 #[cfg(feature = "opencl")]
725 CL(opencl::ops::Reduce<A, T>),
726 Host(host::ops::Reduce<A, T>),
727}
728
729impl_unary!(Reduce<A, T>, T);
730
731impl<A, T: CType> From<host::ops::Reduce<A, T>> for Reduce<A, T> {
732 fn from(op: host::ops::Reduce<A, T>) -> Self {
733 Self::Host(op)
734 }
735}
736
737#[cfg(feature = "opencl")]
738impl<A, T: CType> From<opencl::ops::Reduce<A, T>> for Reduce<A, T> {
739 fn from(op: opencl::ops::Reduce<A, T>) -> Self {
740 Self::CL(op)
741 }
742}
743
744#[derive(Clone, Eq, PartialEq, Hash)]
745pub struct SliceSpec {
746 pub range: Range,
747 pub shape: Shape,
748 pub strides: Strides,
749 pub source_strides: Strides,
750}
751
752impl SliceSpec {
753 pub fn new(source_shape: &[usize], range: Range) -> Self {
754 debug_assert!(range.len() <= source_shape.len());
755
756 let shape = range_shape(source_shape, &range);
757 let strides = strides_for(&shape, shape.len()).collect();
758 let source_strides = strides_for(source_shape, source_shape.len()).collect();
759
760 Self {
761 range,
762 shape,
763 strides,
764 source_strides,
765 }
766 }
767
768 pub fn source_offset(&self, offset: usize) -> usize {
769 debug_assert!(!self.shape.is_empty());
770 debug_assert_eq!(self.shape.len(), self.strides.len());
771
772 let mut coord = self
773 .strides
774 .iter()
775 .copied()
776 .zip(&self.shape)
777 .map(|(stride, dim)| {
778 if stride == 0 {
779 0
780 } else {
781 (offset / stride) % dim
782 }
783 });
784
785 let mut offset = 0;
786 for (stride, bound) in self.source_strides.iter().zip(self.range.iter()) {
787 let i = match bound {
788 AxisRange::At(i) => *i,
789 AxisRange::In(start, stop, step) => {
790 let i = start + (coord.next().expect("i") * step);
791 debug_assert!(i < *stop);
792 i
793 }
794 AxisRange::Of(indices) => indices[coord.next().expect("i")],
795 };
796
797 offset += i * stride;
798 }
799
800 offset
801 }
802
803 pub fn size(&self) -> usize {
804 self.shape.iter().product()
805 }
806}
807
808pub enum Scalar<A, IT, OT> {
809 #[cfg(feature = "opencl")]
810 CL(opencl::ops::Scalar<A, IT, OT>),
811 Host(host::ops::Scalar<A, IT, OT>),
812}
813
814impl<A, IT, OT> Op for Scalar<A, IT, OT>
815where
816 A: Access<IT>,
817 IT: CType,
818 OT: CType,
819{
820 fn size(&self) -> usize {
821 op_dispatch!(self, op, op.size())
822 }
823}
824
825impl<A, IT, OT> Enqueue<Platform, OT> for Scalar<A, IT, OT>
826where
827 A: Access<IT>,
828 IT: CType,
829 OT: CType,
830{
831 type Buffer = Buffer<OT>;
832
833 fn enqueue(&self) -> Result<Self::Buffer, Error> {
834 op_enqueue!(self, OT)
835 }
836}
837
838impl<A, IT, OT> ReadValue<Platform, OT> for Scalar<A, IT, OT>
839where
840 A: Access<IT>,
841 IT: CType,
842 OT: CType,
843{
844 fn read_value(&self, offset: usize) -> Result<OT, Error> {
845 op_dispatch!(self, op, op.read_value(offset))
846 }
847}
848
849#[cfg(feature = "opencl")]
850impl<A, IT, OT> From<opencl::ops::Scalar<A, IT, OT>> for Scalar<A, IT, OT> {
851 fn from(op: opencl::ops::Scalar<A, IT, OT>) -> Self {
852 Self::CL(op)
853 }
854}
855
856impl<A, IT, OT> From<host::ops::Scalar<A, IT, OT>> for Scalar<A, IT, OT> {
857 fn from(op: host::ops::Scalar<A, IT, OT>) -> Self {
858 Self::Host(op)
859 }
860}
861
862pub enum Slice<A, T> {
863 #[cfg(feature = "opencl")]
864 CL(opencl::ops::Slice<A, T>),
865 Host(host::ops::Slice<A, T>),
866}
867
868impl_unary!(Slice<A, T>, T);
869
870#[cfg(feature = "opencl")]
871impl<A, T> Write<Platform, T> for Slice<A, T>
872where
873 A: AccessMut<T> + std::fmt::Debug,
874 T: CType,
875{
876 fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
877 match self {
878 Self::CL(op) => Write::<opencl::OpenCL, T>::write(op, data),
879 Self::Host(op) => Write::<host::Host, T>::write(op, data),
880 }
881 }
882
883 fn write_value(&mut self, value: T) -> Result<(), Error> {
884 match self {
885 Self::CL(op) => Write::<opencl::OpenCL, T>::write_value(op, value),
886 Self::Host(op) => Write::<host::Host, T>::write_value(op, value),
887 }
888 }
889
890 fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
891 match self {
892 Self::CL(op) => Write::<opencl::OpenCL, T>::write_value_at(op, offset, value),
893 Self::Host(op) => Write::<host::Host, T>::write_value_at(op, offset, value),
894 }
895 }
896}
897
898#[cfg(not(feature = "opencl"))]
899impl<A, T> Write<Platform, T> for Slice<A, T>
900where
901 T: CType,
902 A: AccessMut<T>,
903{
904 fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
905 match self {
906 Self::Host(op) => Write::<host::Host, T>::write(op, data),
907 }
908 }
909
910 fn write_value(&mut self, value: T) -> Result<(), Error> {
911 match self {
912 Self::Host(op) => Write::<host::Host, T>::write_value(op, value),
913 }
914 }
915
916 fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
917 match self {
918 Self::Host(op) => Write::<host::Host, T>::write_value_at(op, offset, value),
919 }
920 }
921}
922
923#[cfg(feature = "opencl")]
924impl<A, T> From<opencl::ops::Slice<A, T>> for Slice<A, T> {
925 fn from(op: opencl::ops::Slice<A, T>) -> Self {
926 Self::CL(op)
927 }
928}
929
930impl<A, T> From<host::ops::Slice<A, T>> for Slice<A, T> {
931 fn from(op: host::ops::Slice<A, T>) -> Self {
932 Self::Host(op)
933 }
934}
935
936pub enum Unary<A, IT, OT> {
937 #[cfg(feature = "opencl")]
938 CL(opencl::ops::Unary<A, IT, OT>),
939 Host(host::ops::Unary<A, IT, OT>),
940}
941
942impl<A, IT, OT> Op for Unary<A, IT, OT>
943where
944 A: Access<IT>,
945 IT: CType,
946 OT: CType,
947{
948 fn size(&self) -> usize {
949 op_dispatch!(self, op, op.size())
950 }
951}
952
953impl<A, IT, OT> Enqueue<Platform, OT> for Unary<A, IT, OT>
954where
955 A: Access<IT>,
956 IT: CType,
957 OT: CType,
958{
959 type Buffer = Buffer<OT>;
960
961 fn enqueue(&self) -> Result<Self::Buffer, Error> {
962 op_enqueue!(self, OT)
963 }
964}
965
966impl<A, IT, OT> ReadValue<Platform, OT> for Unary<A, IT, OT>
967where
968 A: Access<IT>,
969 IT: CType,
970 OT: CType,
971{
972 fn read_value(&self, offset: usize) -> Result<OT, Error> {
973 op_dispatch!(self, op, op.read_value(offset))
974 }
975}
976
977impl<A, IT, OT> From<host::ops::Unary<A, IT, OT>> for Unary<A, IT, OT> {
978 fn from(op: host::ops::Unary<A, IT, OT>) -> Self {
979 Self::Host(op)
980 }
981}
982
983#[cfg(feature = "opencl")]
984impl<A, IT, OT> From<opencl::ops::Unary<A, IT, OT>> for Unary<A, IT, OT> {
985 fn from(op: opencl::ops::Unary<A, IT, OT>) -> Self {
986 Self::CL(op)
987 }
988}
989
990#[derive(Clone, Eq, PartialEq, Hash)]
991pub struct ViewSpec {
992 pub shape: Shape,
993 pub strides: Strides,
994 pub source_strides: Strides,
995}
996
997impl ViewSpec {
998 pub fn new(shape: Shape, source_strides: Strides) -> Self {
999 let strides = strides_for(&shape, shape.len()).collect();
1000
1001 Self {
1002 shape,
1003 strides,
1004 source_strides,
1005 }
1006 }
1007
1008 pub fn source_offset(&self, offset: usize) -> usize {
1009 debug_assert!(offset < self.size());
1010
1011 let source_offset = self
1012 .strides
1013 .iter()
1014 .copied()
1015 .zip(self.shape.iter().copied())
1016 .rev()
1017 .take(self.source_strides.len())
1018 .map(|(stride, dim)| {
1019 if stride == 0 {
1020 0
1021 } else {
1022 (offset / stride) % dim
1023 }
1024 }) .zip(self.source_strides.iter().rev().copied())
1026 .map(|(i, source_stride)| i * source_stride)
1027 .sum::<usize>();
1028
1029 source_offset
1030 }
1031
1032 pub fn size(&self) -> usize {
1033 self.shape.iter().product()
1034 }
1035}
1036
1037pub enum View<A, T> {
1038 #[cfg(feature = "opencl")]
1039 CL(opencl::ops::View<A, T>),
1040 Host(host::ops::View<A, T>),
1041}
1042
1043impl_unary!(View<A, T>, T);
1044
1045#[cfg(feature = "opencl")]
1046impl<A, T> From<opencl::ops::View<A, T>> for View<A, T> {
1047 fn from(op: opencl::ops::View<A, T>) -> Self {
1048 Self::CL(op)
1049 }
1050}
1051
1052impl<A, T> From<host::ops::View<A, T>> for View<A, T> {
1053 fn from(op: host::ops::View<A, T>) -> Self {
1054 Self::Host(op)
1055 }
1056}