1use std::fmt;
2
3use crate::access::{Access, AccessOp};
4use crate::buffer::{Buffer, BufferConverter, BufferInstance};
5#[cfg(feature = "opencl")]
6use crate::opencl;
7use crate::ops::*;
8use crate::{host, Axes, Error, Float, Number, Range, Real, Shape};
9
10pub trait PlatformInstance: PartialEq + Eq + Clone + Copy + Send + Sync + fmt::Debug {
12 fn select(size_hint: usize) -> Self;
14}
15
16pub trait Constant<T: Number>: PlatformInstance {
18 type Buffer: BufferInstance<T>;
20
21 fn constant(&self, value: T, size: usize) -> Result<Self::Buffer, Error>;
23}
24
25pub trait Convert<T: Number>: PlatformInstance {
27 type Buffer: BufferInstance<T>;
29
30 fn convert(&self, buffer: BufferConverter<T>) -> Result<Self::Buffer, Error>;
31}
32
33#[derive(Debug, Copy, Clone, Eq, PartialEq)]
35pub enum Platform {
36 #[cfg(feature = "opencl")]
37 CL(opencl::OpenCL),
38 Host(host::Host),
39}
40
41#[cfg(feature = "opencl")]
42impl PlatformInstance for Platform {
43 fn select(size_hint: usize) -> Self {
44 if size_hint < opencl::GPU_MIN_SIZE {
45 Self::Host(host::Host::select(size_hint))
46 } else {
47 Self::CL(opencl::OpenCL)
48 }
49 }
50}
51
52#[cfg(not(feature = "opencl"))]
53impl PlatformInstance for Platform {
54 fn select(size_hint: usize) -> Self {
55 Self::Host(host::Host::select(size_hint))
56 }
57}
58
59#[cfg(feature = "opencl")]
60impl From<opencl::OpenCL> for Platform {
61 fn from(opencl: opencl::OpenCL) -> Self {
62 Self::CL(opencl)
63 }
64}
65
66impl From<host::Host> for Platform {
67 fn from(host: host::Host) -> Self {
68 Self::Host(host)
69 }
70}
71
72impl<T: Number> Convert<T> for Platform {
73 type Buffer = Buffer<T>;
74
75 fn convert<'a>(&self, buffer: BufferConverter<'a, T>) -> Result<Self::Buffer, Error> {
76 match self {
77 #[cfg(feature = "opencl")]
78 Self::CL(cl) => cl.convert(buffer).map(Buffer::CL),
79 Self::Host(host) => host.convert(buffer).map(Buffer::Host),
80 }
81 }
82}
83
84impl<T: Number> Constant<T> for Platform {
85 type Buffer = Buffer<T>;
86
87 fn constant(&self, value: T, size: usize) -> Result<Self::Buffer, Error> {
88 match self {
89 #[cfg(feature = "opencl")]
90 Self::CL(cl) => cl.constant(value, size).map(Buffer::CL),
91 Self::Host(host) => host.constant(value, size).map(Buffer::Host),
92 }
93 }
94}
95
96impl<A, T> ConstructConcat<A, T> for Platform
97where
98 A: Access<T>,
99 T: Number,
100{
101 type Op = Concat<A, T>;
102
103 fn concat(self, data: Vec<A>) -> Result<AccessOp<Self::Op, Self>, Error> {
104 match self {
105 #[cfg(feature = "opencl")]
106 Self::CL(cl) => cl.concat(data).map(AccessOp::wrap),
107 Self::Host(host) => host.concat(data).map(AccessOp::wrap),
108 }
109 }
110}
111
112impl<T: Number + PartialOrd> ConstructRange<T> for Platform {
113 type Range = Linear<T>;
114
115 fn range(self, start: T, stop: T, size: usize) -> Result<AccessOp<Self::Range, Self>, Error> {
116 match self {
117 #[cfg(feature = "opencl")]
118 Self::CL(cl) => cl.range(start, stop, size).map(AccessOp::wrap),
119 Self::Host(host) => host.range(start, stop, size).map(AccessOp::wrap),
120 }
121 }
122}
123
124impl<A, T> ElementwiseAbs<A, T> for Platform
125where
126 A: Access<T>,
127 T: Number,
128{
129 type Op = Unary<A, T, T::Abs>;
130
131 fn abs(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
132 match self {
133 #[cfg(feature = "opencl")]
134 Self::CL(cl) => cl.abs(access).map(AccessOp::wrap),
135 Self::Host(host) => host.abs(access).map(AccessOp::wrap),
136 }
137 }
138}
139
140impl<L, R, T> ElementwiseBoolean<L, R, T> for Platform
141where
142 L: Access<T>,
143 R: Access<T>,
144 T: Number,
145{
146 type Op = Dual<L, R, T, u8>;
147
148 fn and(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
149 match self {
150 #[cfg(feature = "opencl")]
151 Self::CL(cl) => cl.and(left, right).map(AccessOp::wrap),
152 Self::Host(host) => host.and(left, right).map(AccessOp::wrap),
153 }
154 }
155
156 fn or(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
157 match self {
158 #[cfg(feature = "opencl")]
159 Self::CL(cl) => cl.or(left, right).map(AccessOp::wrap),
160 Self::Host(host) => host.or(left, right).map(AccessOp::wrap),
161 }
162 }
163
164 fn xor(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
165 match self {
166 #[cfg(feature = "opencl")]
167 Self::CL(cl) => cl.xor(left, right).map(AccessOp::wrap),
168 Self::Host(host) => host.xor(left, right).map(AccessOp::wrap),
169 }
170 }
171}
172
173impl<A: Access<T>, T: Number> ElementwiseBooleanScalar<A, T> for Platform {
174 type Op = Scalar<A, T, u8>;
175
176 fn and_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
177 match self {
178 #[cfg(feature = "opencl")]
179 Self::CL(cl) => cl.and_scalar(left, right).map(AccessOp::wrap),
180 Self::Host(host) => host.and_scalar(left, right).map(AccessOp::wrap),
181 }
182 }
183
184 fn or_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
185 match self {
186 #[cfg(feature = "opencl")]
187 Self::CL(cl) => cl.or_scalar(left, right).map(AccessOp::wrap),
188 Self::Host(host) => host.or_scalar(left, right).map(AccessOp::wrap),
189 }
190 }
191
192 fn xor_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
193 match self {
194 #[cfg(feature = "opencl")]
195 Self::CL(cl) => cl.xor_scalar(left, right).map(AccessOp::wrap),
196 Self::Host(host) => host.xor_scalar(left, right).map(AccessOp::wrap),
197 }
198 }
199}
200
201impl<A: Access<IT>, IT: Number, OT: Number> ElementwiseCast<A, IT, OT> for Platform {
202 type Op = Cast<A, IT, OT>;
203
204 fn cast(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
205 match self {
206 #[cfg(feature = "opencl")]
207 Self::CL(cl) => cl.cast(access).map(AccessOp::wrap),
208 Self::Host(host) => host.cast(access).map(AccessOp::wrap),
209 }
210 }
211}
212
213impl<L, R, T> ElementwiseCompare<L, R, T> for Platform
214where
215 L: Access<T>,
216 R: Access<T>,
217 T: Number,
218{
219 type Op = Dual<L, R, T, u8>;
220
221 fn eq(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
222 match self {
223 #[cfg(feature = "opencl")]
224 Self::CL(cl) => cl.eq(left, right).map(AccessOp::wrap),
225 Self::Host(host) => host.eq(left, right).map(AccessOp::wrap),
226 }
227 }
228
229 fn ge(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>
230 where
231 T: Real,
232 {
233 match self {
234 #[cfg(feature = "opencl")]
235 Self::CL(cl) => cl.ge(left, right).map(AccessOp::wrap),
236 Self::Host(host) => host.ge(left, right).map(AccessOp::wrap),
237 }
238 }
239
240 fn gt(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>
241 where
242 T: Real,
243 {
244 match self {
245 #[cfg(feature = "opencl")]
246 Self::CL(cl) => cl.gt(left, right).map(AccessOp::wrap),
247 Self::Host(host) => host.gt(left, right).map(AccessOp::wrap),
248 }
249 }
250
251 fn le(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>
252 where
253 T: Real,
254 {
255 match self {
256 #[cfg(feature = "opencl")]
257 Self::CL(cl) => cl.le(left, right).map(AccessOp::wrap),
258 Self::Host(host) => host.le(left, right).map(AccessOp::wrap),
259 }
260 }
261
262 fn lt(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>
263 where
264 T: Real,
265 {
266 match self {
267 #[cfg(feature = "opencl")]
268 Self::CL(cl) => cl.lt(left, right).map(AccessOp::wrap),
269 Self::Host(host) => host.lt(left, right).map(AccessOp::wrap),
270 }
271 }
272
273 fn ne(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
274 match self {
275 #[cfg(feature = "opencl")]
276 Self::CL(cl) => cl.ne(left, right).map(AccessOp::wrap),
277 Self::Host(host) => host.ne(left, right).map(AccessOp::wrap),
278 }
279 }
280}
281
282impl<A: Access<T>, T: Number> ElementwiseCompareScalar<A, T> for Platform {
283 type Op = Scalar<A, T, u8>;
284
285 fn eq_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
286 match self {
287 #[cfg(feature = "opencl")]
288 Self::CL(cl) => cl.eq_scalar(left, right).map(AccessOp::wrap),
289 Self::Host(host) => host.eq_scalar(left, right).map(AccessOp::wrap),
290 }
291 }
292
293 fn ge_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>
294 where
295 T: Real,
296 {
297 match self {
298 #[cfg(feature = "opencl")]
299 Self::CL(cl) => cl.ge_scalar(left, right).map(AccessOp::wrap),
300 Self::Host(host) => host.ge_scalar(left, right).map(AccessOp::wrap),
301 }
302 }
303
304 fn gt_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>
305 where
306 T: Real,
307 {
308 match self {
309 #[cfg(feature = "opencl")]
310 Self::CL(cl) => cl.gt_scalar(left, right).map(AccessOp::wrap),
311 Self::Host(host) => host.gt_scalar(left, right).map(AccessOp::wrap),
312 }
313 }
314
315 fn le_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>
316 where
317 T: Real,
318 {
319 match self {
320 #[cfg(feature = "opencl")]
321 Self::CL(cl) => cl.le_scalar(left, right).map(AccessOp::wrap),
322 Self::Host(host) => host.le_scalar(left, right).map(AccessOp::wrap),
323 }
324 }
325
326 fn lt_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>
327 where
328 T: Real,
329 {
330 match self {
331 #[cfg(feature = "opencl")]
332 Self::CL(cl) => cl.lt_scalar(left, right).map(AccessOp::wrap),
333 Self::Host(host) => host.lt_scalar(left, right).map(AccessOp::wrap),
334 }
335 }
336
337 fn ne_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
338 match self {
339 #[cfg(feature = "opencl")]
340 Self::CL(cl) => cl.ne_scalar(left, right).map(AccessOp::wrap),
341 Self::Host(host) => host.ne_scalar(left, right).map(AccessOp::wrap),
342 }
343 }
344}
345
346#[cfg(feature = "complex")]
347impl<A, T> complex::ElementwiseUnaryComplex<A, T> for Platform
348where
349 A: Access<T>,
350 T: crate::Complex,
351{
352 type Real = Unary<A, T, T::Real>;
353 type Complex = Unary<A, T, T>;
354
355 fn angle(self, access: A) -> Result<AccessOp<Self::Real, Self>, Error> {
356 match self {
357 #[cfg(feature = "opencl")]
358 Self::CL(cl) => cl.angle(access).map(AccessOp::wrap),
359 Self::Host(host) => host.angle(access).map(AccessOp::wrap),
360 }
361 }
362
363 fn conj(self, access: A) -> Result<AccessOp<Self::Complex, Self>, Error> {
364 match self {
365 #[cfg(feature = "opencl")]
366 Self::CL(cl) => cl.conj(access).map(AccessOp::wrap),
367 Self::Host(host) => host.conj(access).map(AccessOp::wrap),
368 }
369 }
370
371 fn re(self, access: A) -> Result<AccessOp<Self::Real, Self>, Error> {
372 match self {
373 #[cfg(feature = "opencl")]
374 Self::CL(cl) => cl.re(access).map(AccessOp::wrap),
375 Self::Host(host) => host.re(access).map(AccessOp::wrap),
376 }
377 }
378
379 fn im(self, access: A) -> Result<AccessOp<Self::Real, Self>, Error> {
380 match self {
381 #[cfg(feature = "opencl")]
382 Self::CL(cl) => cl.im(access).map(AccessOp::wrap),
383 Self::Host(host) => host.im(access).map(AccessOp::wrap),
384 }
385 }
386}
387
388impl<L, R, T> ElementwiseDual<L, R, T> for Platform
389where
390 L: Access<T>,
391 R: Access<T>,
392 T: Number,
393{
394 type Op = Dual<L, R, T, T>;
395
396 fn add(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
397 match self {
398 #[cfg(feature = "opencl")]
399 Self::CL(cl) => cl.add(left, right).map(AccessOp::wrap),
400 Self::Host(host) => host.add(left, right).map(AccessOp::wrap),
401 }
402 }
403
404 fn div(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
405 match self {
406 #[cfg(feature = "opencl")]
407 Self::CL(cl) => cl.div(left, right).map(AccessOp::wrap),
408 Self::Host(host) => host.div(left, right).map(AccessOp::wrap),
409 }
410 }
411
412 fn log(self, arg: L, base: R) -> Result<AccessOp<Self::Op, Self>, Error>
413 where
414 T: Float,
415 {
416 match self {
417 #[cfg(feature = "opencl")]
418 Self::CL(cl) => cl.log(arg, base).map(AccessOp::wrap),
419 Self::Host(host) => host.log(arg, base).map(AccessOp::wrap),
420 }
421 }
422
423 fn mul(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
424 match self {
425 #[cfg(feature = "opencl")]
426 Self::CL(cl) => cl.mul(left, right).map(AccessOp::wrap),
427 Self::Host(host) => host.mul(left, right).map(AccessOp::wrap),
428 }
429 }
430
431 fn pow(self, arg: L, exp: R) -> Result<AccessOp<Self::Op, Self>, Error> {
432 match self {
433 #[cfg(feature = "opencl")]
434 Self::CL(cl) => cl.pow(arg, exp).map(AccessOp::wrap),
435 Self::Host(host) => host.pow(arg, exp).map(AccessOp::wrap),
436 }
437 }
438
439 fn rem(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>
440 where
441 T: Real,
442 {
443 match self {
444 #[cfg(feature = "opencl")]
445 Self::CL(cl) => cl.rem(left, right).map(AccessOp::wrap),
446 Self::Host(host) => host.rem(left, right).map(AccessOp::wrap),
447 }
448 }
449
450 fn sub(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
451 match self {
452 #[cfg(feature = "opencl")]
453 Self::CL(cl) => cl.sub(left, right).map(AccessOp::wrap),
454 Self::Host(host) => host.sub(left, right).map(AccessOp::wrap),
455 }
456 }
457}
458
459impl<A: Access<T>, T: Number> ElementwiseScalar<A, T> for Platform {
460 type Op = Scalar<A, T, T>;
461
462 fn add_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
463 match self {
464 #[cfg(feature = "opencl")]
465 Self::CL(cl) => cl.add_scalar(left, right).map(AccessOp::wrap),
466 Self::Host(host) => host.add_scalar(left, right).map(AccessOp::wrap),
467 }
468 }
469
470 fn div_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
471 match self {
472 #[cfg(feature = "opencl")]
473 Self::CL(cl) => cl.div_scalar(left, right).map(AccessOp::wrap),
474 Self::Host(host) => host.div_scalar(left, right).map(AccessOp::wrap),
475 }
476 }
477
478 fn log_scalar(self, arg: A, base: T) -> Result<AccessOp<Self::Op, Self>, Error>
479 where
480 T: Float,
481 {
482 match self {
483 #[cfg(feature = "opencl")]
484 Self::CL(cl) => cl.log_scalar(arg, base).map(AccessOp::wrap),
485 Self::Host(host) => host.log_scalar(arg, base).map(AccessOp::wrap),
486 }
487 }
488
489 fn mul_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
490 match self {
491 #[cfg(feature = "opencl")]
492 Self::CL(cl) => cl.mul_scalar(left, right).map(AccessOp::wrap),
493 Self::Host(host) => host.mul_scalar(left, right).map(AccessOp::wrap),
494 }
495 }
496
497 fn pow_scalar(self, arg: A, exp: T) -> Result<AccessOp<Self::Op, Self>, Error> {
498 match self {
499 #[cfg(feature = "opencl")]
500 Self::CL(cl) => cl.pow_scalar(arg, exp).map(AccessOp::wrap),
501 Self::Host(host) => host.pow_scalar(arg, exp).map(AccessOp::wrap),
502 }
503 }
504
505 fn rem_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>
506 where
507 T: Real,
508 {
509 match self {
510 #[cfg(feature = "opencl")]
511 Self::CL(cl) => cl.rem_scalar(left, right).map(AccessOp::wrap),
512 Self::Host(host) => host.rem_scalar(left, right).map(AccessOp::wrap),
513 }
514 }
515
516 fn sub_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
517 match self {
518 #[cfg(feature = "opencl")]
519 Self::CL(cl) => cl.sub_scalar(left, right).map(AccessOp::wrap),
520 Self::Host(host) => host.sub_scalar(left, right).map(AccessOp::wrap),
521 }
522 }
523}
524
525impl<A: Access<T>, T: Float> ElementwiseNumeric<A, T> for Platform {
526 type Op = Unary<A, T, u8>;
527
528 fn is_inf(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
529 match self {
530 #[cfg(feature = "opencl")]
531 Self::CL(cl) => cl.is_inf(access).map(AccessOp::wrap),
532 Self::Host(host) => host.is_inf(access).map(AccessOp::wrap),
533 }
534 }
535
536 fn is_nan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
537 match self {
538 #[cfg(feature = "opencl")]
539 Self::CL(cl) => cl.is_nan(access).map(AccessOp::wrap),
540 Self::Host(host) => host.is_nan(access).map(AccessOp::wrap),
541 }
542 }
543}
544
545impl<A: Access<T>, T: Float> ElementwiseTrig<A, T> for Platform {
546 type Op = Unary<A, T, T>;
547
548 fn sin(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
549 match self {
550 #[cfg(feature = "opencl")]
551 Self::CL(cl) => cl.sin(access).map(AccessOp::wrap),
552 Self::Host(host) => host.sin(access).map(AccessOp::wrap),
553 }
554 }
555
556 fn asin(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
557 match self {
558 #[cfg(feature = "opencl")]
559 Self::CL(cl) => cl.asin(access).map(AccessOp::wrap),
560 Self::Host(host) => host.asin(access).map(AccessOp::wrap),
561 }
562 }
563
564 fn sinh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
565 match self {
566 #[cfg(feature = "opencl")]
567 Self::CL(cl) => cl.sinh(access).map(AccessOp::wrap),
568 Self::Host(host) => host.sinh(access).map(AccessOp::wrap),
569 }
570 }
571
572 fn cos(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
573 match self {
574 #[cfg(feature = "opencl")]
575 Self::CL(cl) => cl.cos(access).map(AccessOp::wrap),
576 Self::Host(host) => host.cos(access).map(AccessOp::wrap),
577 }
578 }
579
580 fn acos(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
581 match self {
582 #[cfg(feature = "opencl")]
583 Self::CL(cl) => cl.acos(access).map(AccessOp::wrap),
584 Self::Host(host) => host.acos(access).map(AccessOp::wrap),
585 }
586 }
587
588 fn cosh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
589 match self {
590 #[cfg(feature = "opencl")]
591 Self::CL(cl) => cl.cosh(access).map(AccessOp::wrap),
592 Self::Host(host) => host.cosh(access).map(AccessOp::wrap),
593 }
594 }
595
596 fn tan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
597 match self {
598 #[cfg(feature = "opencl")]
599 Self::CL(cl) => cl.tan(access).map(AccessOp::wrap),
600 Self::Host(host) => host.tan(access).map(AccessOp::wrap),
601 }
602 }
603
604 fn atan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
605 match self {
606 #[cfg(feature = "opencl")]
607 Self::CL(cl) => cl.atan(access).map(AccessOp::wrap),
608 Self::Host(host) => host.atan(access).map(AccessOp::wrap),
609 }
610 }
611
612 fn tanh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
613 match self {
614 #[cfg(feature = "opencl")]
615 Self::CL(cl) => cl.tanh(access).map(AccessOp::wrap),
616 Self::Host(host) => host.tanh(access).map(AccessOp::wrap),
617 }
618 }
619}
620
621impl<A: Access<T>, T: Float> ElementwiseUnary<A, T> for Platform {
622 type Op = Unary<A, T, T>;
623
624 fn exp(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
625 match self {
626 #[cfg(feature = "opencl")]
627 Self::CL(cl) => cl.exp(access).map(AccessOp::wrap),
628 Self::Host(host) => host.exp(access).map(AccessOp::wrap),
629 }
630 }
631
632 fn ln(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
633 match self {
634 #[cfg(feature = "opencl")]
635 Self::CL(cl) => cl.ln(access).map(AccessOp::wrap),
636 Self::Host(host) => host.ln(access).map(AccessOp::wrap),
637 }
638 }
639
640 fn round(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>
641 where
642 T: Real,
643 {
644 match self {
645 #[cfg(feature = "opencl")]
646 Self::CL(cl) => cl.round(access).map(AccessOp::wrap),
647 Self::Host(host) => host.round(access).map(AccessOp::wrap),
648 }
649 }
650}
651
652impl<A: Access<T>, T: Number> ElementwiseUnaryBoolean<A, T> for Platform {
653 type Op = Unary<A, T, u8>;
654
655 fn not(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
656 match self {
657 #[cfg(feature = "opencl")]
658 Self::CL(cl) => cl.not(access).map(AccessOp::wrap),
659 Self::Host(host) => host.not(access).map(AccessOp::wrap),
660 }
661 }
662}
663
664#[cfg(feature = "complex")]
665impl<A, T> complex::Fourier<A, num_complex::Complex<T>> for Platform
666where
667 A: Access<num_complex::Complex<T>>,
668 T: rustfft::FftNum,
669 num_complex::Complex<T>: crate::Complex,
670{
671 type Op = complex::FFT<A, num_complex::Complex<T>>;
672
673 fn fft(self, access: A, dim: usize) -> Result<AccessOp<Self::Op, Self>, Error> {
674 let host = match self {
676 #[cfg(feature = "opencl")]
677 Self::CL(_cl) => host::Host::select(access.size()),
678 Self::Host(host) => host,
679 };
680
681 host.fft(access, dim).map(AccessOp::wrap)
682 }
683
684 fn ifft(self, access: A, dim: usize) -> Result<AccessOp<Self::Op, Self>, Error> {
685 let host = match self {
687 #[cfg(feature = "opencl")]
688 Self::CL(_cl) => host::Host::select(access.size()),
689 Self::Host(host) => host,
690 };
691
692 host.fft(access, dim).map(AccessOp::wrap)
693 }
694}
695
696impl<A, L, R, T> GatherCond<A, L, R, T> for Platform
697where
698 A: Access<u8>,
699 L: Access<T>,
700 R: Access<T>,
701 T: Number,
702{
703 type Op = Cond<A, L, R, T>;
704
705 fn cond(self, cond: A, then: L, or_else: R) -> Result<AccessOp<Self::Op, Self>, Error> {
706 match self {
707 #[cfg(feature = "opencl")]
708 Self::CL(cl) => cl.cond(cond, then, or_else).map(AccessOp::wrap),
709 Self::Host(host) => host.cond(cond, then, or_else).map(AccessOp::wrap),
710 }
711 }
712}
713
714impl<L, R, T> LinAlgDual<L, R, T> for Platform
715where
716 L: Access<T>,
717 R: Access<T>,
718 T: Number,
719{
720 type Op = MatMul<L, R, T>;
721
722 fn matmul(
723 self,
724 left: L,
725 right: R,
726 dims: [usize; 4],
727 ) -> Result<AccessOp<Self::Op, Self>, Error> {
728 match self {
729 #[cfg(feature = "opencl")]
730 Self::CL(cl) => cl.matmul(left, right, dims).map(AccessOp::wrap),
731 Self::Host(host) => host.matmul(left, right, dims).map(AccessOp::wrap),
732 }
733 }
734}
735
736impl<A: Access<T>, T: Number> LinAlgUnary<A, T> for Platform {
737 type Op = MatDiag<A, T>;
738
739 fn diag(
740 self,
741 access: A,
742 batch_size: usize,
743 dim: usize,
744 ) -> Result<AccessOp<Self::Op, Self>, Error> {
745 match self {
746 #[cfg(feature = "opencl")]
747 Self::CL(cl) => cl.diag(access, batch_size, dim).map(AccessOp::wrap),
748 Self::Host(host) => host.diag(access, batch_size, dim).map(AccessOp::wrap),
749 }
750 }
751}
752
753impl Random for Platform {
754 type Normal = RandomNormal;
755 type Uniform = RandomUniform;
756
757 fn random_normal(self, size: usize) -> Result<AccessOp<Self::Normal, Self>, Error> {
758 match self {
759 #[cfg(feature = "opencl")]
760 Self::CL(cl) => cl.random_normal(size).map(AccessOp::wrap),
761 Self::Host(host) => host.random_normal(size).map(AccessOp::wrap),
762 }
763 }
764
765 fn random_uniform(self, size: usize) -> Result<AccessOp<Self::Uniform, Self>, Error> {
766 match self {
767 #[cfg(feature = "opencl")]
768 Self::CL(cl) => cl.random_uniform(size).map(AccessOp::wrap),
769 Self::Host(host) => host.random_uniform(size).map(AccessOp::wrap),
770 }
771 }
772}
773
774impl<A, T> ReduceAll<A, T> for Platform
775where
776 A: Access<T>,
777 T: Number,
778{
779 fn all(self, access: A) -> Result<bool, Error> {
780 match self {
781 #[cfg(feature = "opencl")]
782 Self::CL(cl) => cl.all(access),
783 Self::Host(host) => host.all(access),
784 }
785 }
786
787 fn any(self, access: A) -> Result<bool, Error> {
788 match self {
789 #[cfg(feature = "opencl")]
790 Self::CL(cl) => cl.any(access),
791 Self::Host(host) => host.any(access),
792 }
793 }
794
795 fn max(self, access: A) -> Result<T, Error>
796 where
797 T: Real,
798 {
799 match self {
800 #[cfg(feature = "opencl")]
801 Self::CL(cl) => ReduceAll::max(cl, access),
802 Self::Host(host) => ReduceAll::max(host, access),
803 }
804 }
805
806 fn min(self, access: A) -> Result<T, Error>
807 where
808 T: Real,
809 {
810 match self {
811 #[cfg(feature = "opencl")]
812 Self::CL(cl) => ReduceAll::min(cl, access),
813 Self::Host(host) => ReduceAll::min(host, access),
814 }
815 }
816
817 fn product(self, access: A) -> Result<T, Error> {
818 match self {
819 #[cfg(feature = "opencl")]
820 Self::CL(cl) => ReduceAll::product(cl, access),
821 Self::Host(host) => ReduceAll::product(host, access),
822 }
823 }
824
825 fn sum(self, access: A) -> Result<T, Error> {
826 match self {
827 #[cfg(feature = "opencl")]
828 Self::CL(cl) => ReduceAll::sum(cl, access),
829 Self::Host(host) => ReduceAll::sum(host, access),
830 }
831 }
832}
833
834impl<A: Access<T>, T: Number> ReduceAxes<A, T> for Platform {
835 type Op = Reduce<A, T>;
836
837 fn max(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error>
838 where
839 T: Real,
840 {
841 match self {
842 #[cfg(feature = "opencl")]
843 Self::CL(cl) => ReduceAxes::max(cl, access, stride).map(AccessOp::wrap),
844 Self::Host(host) => ReduceAxes::max(host, access, stride).map(AccessOp::wrap),
845 }
846 }
847
848 fn min(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error>
849 where
850 T: Real,
851 {
852 match self {
853 #[cfg(feature = "opencl")]
854 Self::CL(cl) => ReduceAxes::min(cl, access, stride).map(AccessOp::wrap),
855 Self::Host(host) => ReduceAxes::min(host, access, stride).map(AccessOp::wrap),
856 }
857 }
858
859 fn product(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error> {
860 match self {
861 #[cfg(feature = "opencl")]
862 Self::CL(cl) => ReduceAxes::product(cl, access, stride).map(AccessOp::wrap),
863 Self::Host(host) => ReduceAxes::product(host, access, stride).map(AccessOp::wrap),
864 }
865 }
866
867 fn sum(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error> {
868 match self {
869 #[cfg(feature = "opencl")]
870 Self::CL(cl) => ReduceAxes::sum(cl, access, stride).map(AccessOp::wrap),
871 Self::Host(host) => ReduceAxes::sum(host, access, stride).map(AccessOp::wrap),
872 }
873 }
874}
875
876impl<A: Access<T>, T: Number> Transform<A, T> for Platform {
877 type Broadcast = View<A, T>;
878 type Flip = Flip<A, T>;
879 type Slice = Slice<A, T>;
880 type Transpose = View<A, T>;
881
882 fn broadcast(
883 self,
884 access: A,
885 shape: Shape,
886 broadcast: Shape,
887 ) -> Result<AccessOp<Self::Broadcast, Self>, Error> {
888 match self {
889 #[cfg(feature = "opencl")]
890 Self::CL(cl) => cl.broadcast(access, shape, broadcast).map(AccessOp::wrap),
891 Self::Host(host) => host.broadcast(access, shape, broadcast).map(AccessOp::wrap),
892 }
893 }
894
895 fn flip(
896 self,
897 access: A,
898 shape: Shape,
899 axis: usize,
900 ) -> Result<AccessOp<Self::Flip, Self>, Error> {
901 match self {
902 #[cfg(feature = "opencl")]
903 Self::CL(cl) => cl.flip(access, shape, axis).map(AccessOp::wrap),
904 Self::Host(host) => host.flip(access, shape, axis).map(AccessOp::wrap),
905 }
906 }
907
908 fn slice(
909 self,
910 access: A,
911 shape: &[usize],
912 range: Range,
913 ) -> Result<AccessOp<Self::Slice, Self>, Error> {
914 match self {
915 #[cfg(feature = "opencl")]
916 Self::CL(cl) => cl.slice(access, shape, range).map(AccessOp::wrap),
917 Self::Host(host) => host.slice(access, shape, range).map(AccessOp::wrap),
918 }
919 }
920
921 fn transpose(
922 self,
923 access: A,
924 shape: Shape,
925 permutation: Axes,
926 ) -> Result<AccessOp<Self::Transpose, Self>, Error> {
927 match self {
928 #[cfg(feature = "opencl")]
929 Self::CL(cl) => cl.transpose(access, shape, permutation).map(AccessOp::wrap),
930 Self::Host(host) => host
931 .transpose(access, shape, permutation)
932 .map(AccessOp::wrap),
933 }
934 }
935}