1use std::fmt;
2
3use crate::access::{Access, AccessOp};
4use crate::buffer::{Buffer, BufferConverter, BufferInstance};
5#[cfg(feature = "opencl")]
6use crate::opencl;
7use crate::ops::*;
8use crate::{host, Axes, CType, Error, Float, Range, Shape};
9
10pub trait PlatformInstance: PartialEq + Eq + Clone + Copy + Send + Sync + fmt::Debug {
12 fn select(size_hint: usize) -> Self;
14}
15
16pub trait Constant<T: CType>: PlatformInstance {
18 type Buffer: BufferInstance<T>;
20
21 fn constant(&self, value: T, size: usize) -> Result<Self::Buffer, Error>;
23}
24
25pub trait Convert<T: CType>: PlatformInstance {
27 type Buffer: BufferInstance<T>;
29
30 fn convert(&self, buffer: BufferConverter<T>) -> Result<Self::Buffer, Error>;
31}
32
33#[derive(Debug, Copy, Clone, Eq, PartialEq)]
35pub enum Platform {
36 #[cfg(feature = "opencl")]
37 CL(opencl::OpenCL),
38 Host(host::Host),
39}
40
41#[cfg(feature = "opencl")]
42impl PlatformInstance for Platform {
43 fn select(size_hint: usize) -> Self {
44 if size_hint < opencl::GPU_MIN_SIZE {
45 Self::Host(host::Host::select(size_hint))
46 } else {
47 Self::CL(opencl::OpenCL)
48 }
49 }
50}
51
52#[cfg(not(feature = "opencl"))]
53impl PlatformInstance for Platform {
54 fn select(size_hint: usize) -> Self {
55 Self::Host(host::Host::select(size_hint))
56 }
57}
58
59#[cfg(feature = "opencl")]
60impl From<opencl::OpenCL> for Platform {
61 fn from(opencl: opencl::OpenCL) -> Self {
62 Self::CL(opencl)
63 }
64}
65
66impl From<host::Host> for Platform {
67 fn from(host: host::Host) -> Self {
68 Self::Host(host)
69 }
70}
71
72impl<T: CType> Convert<T> for Platform {
73 type Buffer = Buffer<T>;
74
75 fn convert<'a>(&self, buffer: BufferConverter<'a, T>) -> Result<Self::Buffer, Error> {
76 match self {
77 #[cfg(feature = "opencl")]
78 Self::CL(cl) => cl.convert(buffer).map(Buffer::CL),
79 Self::Host(host) => host.convert(buffer).map(Buffer::Host),
80 }
81 }
82}
83
84#[cfg(not(feature = "opencl"))]
85impl<T: CType> Constant<T> for Platform {
86 type Buffer = Buffer<T>;
87
88 fn constant(&self, value: T, size: usize) -> Result<Self::Buffer, Error> {
89 match self {
90 Self::Host(host) => host.constant(value, size).map(Buffer::Host),
91 }
92 }
93}
94
95#[cfg(feature = "opencl")]
96impl<T: CType> Constant<T> for Platform {
97 type Buffer = Buffer<T>;
98
99 fn constant(&self, value: T, size: usize) -> Result<Self::Buffer, Error> {
100 match self {
101 Self::CL(cl) => cl.constant(value, size).map(Buffer::CL),
102 Self::Host(host) => host.constant(value, size).map(Buffer::Host),
103 }
104 }
105}
106
107#[cfg(not(feature = "opencl"))]
108impl<T: CType> Construct<T> for Platform {
109 type Range = Linear<T>;
110
111 fn range(self, start: T, stop: T, size: usize) -> Result<AccessOp<Self::Range, Self>, Error> {
112 match self {
113 Self::Host(host) => host.range(start, stop, size).map(AccessOp::wrap),
114 }
115 }
116}
117
118#[cfg(feature = "opencl")]
119impl<T: CType> Construct<T> for Platform {
120 type Range = Linear<T>;
121
122 fn range(self, start: T, stop: T, size: usize) -> Result<AccessOp<Self::Range, Self>, Error> {
123 match self {
124 Self::CL(cl) => cl.range(start, stop, size).map(AccessOp::wrap),
125 Self::Host(host) => host.range(start, stop, size).map(AccessOp::wrap),
126 }
127 }
128}
129
130#[cfg(not(feature = "opencl"))]
131impl<L, R, T> ElementwiseBoolean<L, R, T> for Platform
132where
133 L: Access<T>,
134 R: Access<T>,
135 T: CType,
136{
137 type Op = Dual<L, R, T, u8>;
138
139 fn and(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
140 match self {
141 Self::Host(host) => host.and(left, right).map(AccessOp::wrap),
142 }
143 }
144 fn or(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
145 match self {
146 Self::Host(host) => host.or(left, right).map(AccessOp::wrap),
147 }
148 }
149 fn xor(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
150 match self {
151 Self::Host(host) => host.xor(left, right).map(AccessOp::wrap),
152 }
153 }
154}
155
156#[cfg(feature = "opencl")]
157impl<L, R, T> ElementwiseBoolean<L, R, T> for Platform
158where
159 L: Access<T>,
160 R: Access<T>,
161 T: CType,
162{
163 type Op = Dual<L, R, T, u8>;
164
165 fn and(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
166 match self {
167 Self::CL(cl) => cl.and(left, right).map(AccessOp::wrap),
168 Self::Host(host) => host.and(left, right).map(AccessOp::wrap),
169 }
170 }
171
172 fn or(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
173 match self {
174 Self::CL(cl) => cl.or(left, right).map(AccessOp::wrap),
175 Self::Host(host) => host.or(left, right).map(AccessOp::wrap),
176 }
177 }
178
179 fn xor(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
180 match self {
181 Self::CL(cl) => cl.xor(left, right).map(AccessOp::wrap),
182 Self::Host(host) => host.xor(left, right).map(AccessOp::wrap),
183 }
184 }
185}
186
187#[cfg(not(feature = "opencl"))]
188impl<A: Access<T>, T: CType> ElementwiseBooleanScalar<A, T> for Platform {
189 type Op = Scalar<A, T, u8>;
190
191 fn and_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
192 match self {
193 Self::Host(host) => host.and_scalar(left, right).map(AccessOp::wrap),
194 }
195 }
196 fn or_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
197 match self {
198 Self::Host(host) => host.or_scalar(left, right).map(AccessOp::wrap),
199 }
200 }
201 fn xor_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
202 match self {
203 Self::Host(host) => host.xor_scalar(left, right).map(AccessOp::wrap),
204 }
205 }
206}
207
208#[cfg(feature = "opencl")]
209impl<A: Access<T>, T: CType> ElementwiseBooleanScalar<A, T> for Platform {
210 type Op = Scalar<A, T, u8>;
211
212 fn and_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
213 match self {
214 Self::CL(cl) => cl.and_scalar(left, right).map(AccessOp::wrap),
215 Self::Host(host) => host.and_scalar(left, right).map(AccessOp::wrap),
216 }
217 }
218
219 fn or_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
220 match self {
221 Self::CL(cl) => cl.or_scalar(left, right).map(AccessOp::wrap),
222 Self::Host(host) => host.or_scalar(left, right).map(AccessOp::wrap),
223 }
224 }
225
226 fn xor_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
227 match self {
228 Self::CL(cl) => cl.xor_scalar(left, right).map(AccessOp::wrap),
229 Self::Host(host) => host.xor_scalar(left, right).map(AccessOp::wrap),
230 }
231 }
232}
233
234#[cfg(not(feature = "opencl"))]
235impl<A: Access<IT>, IT: CType, OT: CType> ElementwiseCast<A, IT, OT> for Platform {
236 type Op = Cast<A, IT, OT>;
237
238 fn cast(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
239 match self {
240 Self::Host(host) => host.cast(access).map(AccessOp::wrap),
241 }
242 }
243}
244
245#[cfg(feature = "opencl")]
246impl<A: Access<IT>, IT: CType, OT: CType> ElementwiseCast<A, IT, OT> for Platform {
247 type Op = Cast<A, IT, OT>;
248
249 fn cast(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
250 match self {
251 Self::CL(cl) => cl.cast(access).map(AccessOp::wrap),
252 Self::Host(host) => host.cast(access).map(AccessOp::wrap),
253 }
254 }
255}
256
257#[cfg(not(feature = "opencl"))]
258impl<L, R, T> ElementwiseCompare<L, R, T> for Platform
259where
260 L: Access<T>,
261 R: Access<T>,
262 T: CType,
263{
264 type Op = Dual<L, R, T, u8>;
265
266 fn eq(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
267 match self {
268 Self::Host(host) => host.eq(left, right).map(AccessOp::wrap),
269 }
270 }
271
272 fn ge(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
273 match self {
274 Self::Host(host) => host.ge(left, right).map(AccessOp::wrap),
275 }
276 }
277
278 fn gt(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
279 match self {
280 Self::Host(host) => host.gt(left, right).map(AccessOp::wrap),
281 }
282 }
283
284 fn le(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
285 match self {
286 Self::Host(host) => host.le(left, right).map(AccessOp::wrap),
287 }
288 }
289
290 fn lt(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
291 match self {
292 Self::Host(host) => host.lt(left, right).map(AccessOp::wrap),
293 }
294 }
295
296 fn ne(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
297 match self {
298 Self::Host(host) => host.ne(left, right).map(AccessOp::wrap),
299 }
300 }
301}
302
303#[cfg(feature = "opencl")]
304impl<L, R, T> ElementwiseCompare<L, R, T> for Platform
305where
306 L: Access<T>,
307 R: Access<T>,
308 T: CType,
309{
310 type Op = Dual<L, R, T, u8>;
311
312 fn eq(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
313 match self {
314 Self::CL(cl) => cl.eq(left, right).map(AccessOp::wrap),
315 Self::Host(host) => host.eq(left, right).map(AccessOp::wrap),
316 }
317 }
318
319 fn ge(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
320 match self {
321 Self::CL(cl) => cl.ge(left, right).map(AccessOp::wrap),
322 Self::Host(host) => host.ge(left, right).map(AccessOp::wrap),
323 }
324 }
325
326 fn gt(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
327 match self {
328 Self::CL(cl) => cl.gt(left, right).map(AccessOp::wrap),
329 Self::Host(host) => host.gt(left, right).map(AccessOp::wrap),
330 }
331 }
332
333 fn le(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
334 match self {
335 Self::CL(cl) => cl.le(left, right).map(AccessOp::wrap),
336 Self::Host(host) => host.le(left, right).map(AccessOp::wrap),
337 }
338 }
339
340 fn lt(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
341 match self {
342 Self::CL(cl) => cl.lt(left, right).map(AccessOp::wrap),
343 Self::Host(host) => host.lt(left, right).map(AccessOp::wrap),
344 }
345 }
346
347 fn ne(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
348 match self {
349 Self::CL(cl) => cl.ne(left, right).map(AccessOp::wrap),
350 Self::Host(host) => host.ne(left, right).map(AccessOp::wrap),
351 }
352 }
353}
354
355#[cfg(not(feature = "opencl"))]
356impl<A: Access<T>, T: CType> ElementwiseScalarCompare<A, T> for Platform {
357 type Op = Scalar<A, T, u8>;
358
359 fn eq_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
360 match self {
361 Self::Host(host) => host.eq_scalar(left, right).map(AccessOp::wrap),
362 }
363 }
364
365 fn ge_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
366 match self {
367 Self::Host(host) => host.ge_scalar(left, right).map(AccessOp::wrap),
368 }
369 }
370
371 fn gt_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
372 match self {
373 Self::Host(host) => host.gt_scalar(left, right).map(AccessOp::wrap),
374 }
375 }
376
377 fn le_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
378 match self {
379 Self::Host(host) => host.le_scalar(left, right).map(AccessOp::wrap),
380 }
381 }
382
383 fn lt_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
384 match self {
385 Self::Host(host) => host.lt_scalar(left, right).map(AccessOp::wrap),
386 }
387 }
388
389 fn ne_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
390 match self {
391 Self::Host(host) => host.ne_scalar(left, right).map(AccessOp::wrap),
392 }
393 }
394}
395
396#[cfg(feature = "opencl")]
397impl<A: Access<T>, T: CType> ElementwiseScalarCompare<A, T> for Platform {
398 type Op = Scalar<A, T, u8>;
399
400 fn eq_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
401 match self {
402 Self::CL(cl) => cl.eq_scalar(left, right).map(AccessOp::wrap),
403 Self::Host(host) => host.eq_scalar(left, right).map(AccessOp::wrap),
404 }
405 }
406
407 fn ge_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
408 match self {
409 Self::CL(cl) => cl.ge_scalar(left, right).map(AccessOp::wrap),
410 Self::Host(host) => host.ge_scalar(left, right).map(AccessOp::wrap),
411 }
412 }
413
414 fn gt_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
415 match self {
416 Self::CL(cl) => cl.gt_scalar(left, right).map(AccessOp::wrap),
417 Self::Host(host) => host.gt_scalar(left, right).map(AccessOp::wrap),
418 }
419 }
420
421 fn le_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
422 match self {
423 Self::CL(cl) => cl.le_scalar(left, right).map(AccessOp::wrap),
424 Self::Host(host) => host.le_scalar(left, right).map(AccessOp::wrap),
425 }
426 }
427
428 fn lt_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
429 match self {
430 Self::CL(cl) => cl.lt_scalar(left, right).map(AccessOp::wrap),
431 Self::Host(host) => host.lt_scalar(left, right).map(AccessOp::wrap),
432 }
433 }
434
435 fn ne_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
436 match self {
437 Self::CL(cl) => cl.ne_scalar(left, right).map(AccessOp::wrap),
438 Self::Host(host) => host.ne_scalar(left, right).map(AccessOp::wrap),
439 }
440 }
441}
442
443#[cfg(not(feature = "opencl"))]
444impl<L, R, T> ElementwiseDual<L, R, T> for Platform
445where
446 L: Access<T>,
447 R: Access<T>,
448 T: CType,
449{
450 type Op = Dual<L, R, T, T>;
451
452 fn add(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
453 match self {
454 Self::Host(host) => host.add(left, right).map(AccessOp::wrap),
455 }
456 }
457
458 fn div(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
459 match self {
460 Self::Host(host) => host.div(left, right).map(AccessOp::wrap),
461 }
462 }
463
464 fn log(self, arg: L, base: R) -> Result<AccessOp<Self::Op, Self>, Error> {
465 match self {
466 Self::Host(host) => host.log(arg, base).map(AccessOp::wrap),
467 }
468 }
469
470 fn mul(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
471 match self {
472 Self::Host(host) => host.mul(left, right).map(AccessOp::wrap),
473 }
474 }
475
476 fn pow(self, arg: L, exp: R) -> Result<AccessOp<Self::Op, Self>, Error> {
477 match self {
478 Self::Host(host) => host.pow(arg, exp).map(AccessOp::wrap),
479 }
480 }
481
482 fn rem(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
483 match self {
484 Self::Host(host) => host.rem(left, right).map(AccessOp::wrap),
485 }
486 }
487
488 fn sub(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
489 match self {
490 Self::Host(host) => host.sub(left, right).map(AccessOp::wrap),
491 }
492 }
493}
494
495#[cfg(feature = "opencl")]
496impl<L, R, T> ElementwiseDual<L, R, T> for Platform
497where
498 L: Access<T>,
499 R: Access<T>,
500 T: CType,
501{
502 type Op = Dual<L, R, T, T>;
503
504 fn add(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
505 match self {
506 Self::CL(cl) => cl.add(left, right).map(AccessOp::wrap),
507 Self::Host(host) => host.add(left, right).map(AccessOp::wrap),
508 }
509 }
510
511 fn div(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
512 match self {
513 Self::CL(cl) => cl.div(left, right).map(AccessOp::wrap),
514 Self::Host(host) => host.div(left, right).map(AccessOp::wrap),
515 }
516 }
517
518 fn log(self, arg: L, base: R) -> Result<AccessOp<Self::Op, Self>, Error> {
519 match self {
520 Self::CL(cl) => cl.log(arg, base).map(AccessOp::wrap),
521 Self::Host(host) => host.log(arg, base).map(AccessOp::wrap),
522 }
523 }
524
525 fn mul(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
526 match self {
527 Self::CL(cl) => cl.mul(left, right).map(AccessOp::wrap),
528 Self::Host(host) => host.mul(left, right).map(AccessOp::wrap),
529 }
530 }
531
532 fn pow(self, arg: L, exp: R) -> Result<AccessOp<Self::Op, Self>, Error> {
533 match self {
534 Self::CL(cl) => cl.pow(arg, exp).map(AccessOp::wrap),
535 Self::Host(host) => host.pow(arg, exp).map(AccessOp::wrap),
536 }
537 }
538
539 fn rem(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
540 match self {
541 Self::CL(cl) => cl.rem(left, right).map(AccessOp::wrap),
542 Self::Host(host) => host.rem(left, right).map(AccessOp::wrap),
543 }
544 }
545
546 fn sub(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
547 match self {
548 Self::CL(cl) => cl.sub(left, right).map(AccessOp::wrap),
549 Self::Host(host) => host.sub(left, right).map(AccessOp::wrap),
550 }
551 }
552}
553
554#[cfg(not(feature = "opencl"))]
555impl<A: Access<T>, T: CType> ElementwiseScalar<A, T> for Platform {
556 type Op = Scalar<A, T, T>;
557
558 fn add_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
559 match self {
560 Self::Host(host) => host.add_scalar(left, right).map(AccessOp::wrap),
561 }
562 }
563
564 fn div_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
565 match self {
566 Self::Host(host) => host.div_scalar(left, right).map(AccessOp::wrap),
567 }
568 }
569
570 fn log_scalar(self, arg: A, base: T) -> Result<AccessOp<Self::Op, Self>, Error> {
571 match self {
572 Self::Host(host) => host.log_scalar(arg, base).map(AccessOp::wrap),
573 }
574 }
575
576 fn mul_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
577 match self {
578 Self::Host(host) => host.mul_scalar(left, right).map(AccessOp::wrap),
579 }
580 }
581
582 fn pow_scalar(self, arg: A, exp: T) -> Result<AccessOp<Self::Op, Self>, Error> {
583 match self {
584 Self::Host(host) => host.pow_scalar(arg, exp).map(AccessOp::wrap),
585 }
586 }
587
588 fn rem_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
589 match self {
590 Self::Host(host) => host.rem_scalar(left, right).map(AccessOp::wrap),
591 }
592 }
593
594 fn sub_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
595 match self {
596 Self::Host(host) => host.sub_scalar(left, right).map(AccessOp::wrap),
597 }
598 }
599}
600
601#[cfg(feature = "opencl")]
602impl<A: Access<T>, T: CType> ElementwiseScalar<A, T> for Platform {
603 type Op = Scalar<A, T, T>;
604
605 fn add_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
606 match self {
607 Self::CL(cl) => cl.add_scalar(left, right).map(AccessOp::wrap),
608 Self::Host(host) => host.add_scalar(left, right).map(AccessOp::wrap),
609 }
610 }
611
612 fn div_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
613 match self {
614 Self::CL(cl) => cl.div_scalar(left, right).map(AccessOp::wrap),
615 Self::Host(host) => host.div_scalar(left, right).map(AccessOp::wrap),
616 }
617 }
618
619 fn log_scalar(self, arg: A, base: T) -> Result<AccessOp<Self::Op, Self>, Error> {
620 match self {
621 Self::CL(cl) => cl.log_scalar(arg, base).map(AccessOp::wrap),
622 Self::Host(host) => host.log_scalar(arg, base).map(AccessOp::wrap),
623 }
624 }
625
626 fn mul_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
627 match self {
628 Self::CL(cl) => cl.mul_scalar(left, right).map(AccessOp::wrap),
629 Self::Host(host) => host.mul_scalar(left, right).map(AccessOp::wrap),
630 }
631 }
632
633 fn pow_scalar(self, arg: A, exp: T) -> Result<AccessOp<Self::Op, Self>, Error> {
634 match self {
635 Self::CL(cl) => cl.pow_scalar(arg, exp).map(AccessOp::wrap),
636 Self::Host(host) => host.pow_scalar(arg, exp).map(AccessOp::wrap),
637 }
638 }
639
640 fn rem_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
641 match self {
642 Self::CL(cl) => cl.rem_scalar(left, right).map(AccessOp::wrap),
643 Self::Host(host) => host.rem_scalar(left, right).map(AccessOp::wrap),
644 }
645 }
646
647 fn sub_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
648 match self {
649 Self::CL(cl) => cl.sub_scalar(left, right).map(AccessOp::wrap),
650 Self::Host(host) => host.sub_scalar(left, right).map(AccessOp::wrap),
651 }
652 }
653}
654
655#[cfg(not(feature = "opencl"))]
656impl<A: Access<T>, T: Float> ElementwiseNumeric<A, T> for Platform {
657 type Op = Unary<A, T, u8>;
658
659 fn is_inf(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
660 match self {
661 Self::Host(host) => host.is_inf(access).map(AccessOp::wrap),
662 }
663 }
664
665 fn is_nan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
666 match self {
667 Self::Host(host) => host.is_nan(access).map(AccessOp::wrap),
668 }
669 }
670}
671
672#[cfg(feature = "opencl")]
673impl<A: Access<T>, T: Float> ElementwiseNumeric<A, T> for Platform {
674 type Op = Unary<A, T, u8>;
675
676 fn is_inf(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
677 match self {
678 Self::CL(cl) => cl.is_inf(access).map(AccessOp::wrap),
679 Self::Host(host) => host.is_inf(access).map(AccessOp::wrap),
680 }
681 }
682
683 fn is_nan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
684 match self {
685 Self::CL(cl) => cl.is_nan(access).map(AccessOp::wrap),
686 Self::Host(host) => host.is_nan(access).map(AccessOp::wrap),
687 }
688 }
689}
690
691#[cfg(not(feature = "opencl"))]
692impl<A: Access<T>, T: CType> ElementwiseTrig<A, T> for Platform {
693 type Op = Unary<A, T, T::Float>;
694
695 fn sin(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
696 match self {
697 Self::Host(host) => host.sin(access).map(AccessOp::wrap),
698 }
699 }
700
701 fn asin(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
702 match self {
703 Self::Host(host) => host.asin(access).map(AccessOp::wrap),
704 }
705 }
706
707 fn sinh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
708 match self {
709 Self::Host(host) => host.sinh(access).map(AccessOp::wrap),
710 }
711 }
712
713 fn cos(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
714 match self {
715 Self::Host(host) => host.cos(access).map(AccessOp::wrap),
716 }
717 }
718
719 fn acos(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
720 match self {
721 Self::Host(host) => host.acos(access).map(AccessOp::wrap),
722 }
723 }
724
725 fn cosh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
726 match self {
727 Self::Host(host) => host.cosh(access).map(AccessOp::wrap),
728 }
729 }
730
731 fn tan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
732 match self {
733 Self::Host(host) => host.tan(access).map(AccessOp::wrap),
734 }
735 }
736
737 fn atan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
738 match self {
739 Self::Host(host) => host.atan(access).map(AccessOp::wrap),
740 }
741 }
742
743 fn tanh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
744 match self {
745 Self::Host(host) => host.tanh(access).map(AccessOp::wrap),
746 }
747 }
748}
749
750#[cfg(feature = "opencl")]
751impl<A: Access<T>, T: CType> ElementwiseTrig<A, T> for Platform {
752 type Op = Unary<A, T, T::Float>;
753
754 fn sin(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
755 match self {
756 Self::CL(cl) => cl.sin(access).map(AccessOp::wrap),
757 Self::Host(host) => host.sin(access).map(AccessOp::wrap),
758 }
759 }
760
761 fn asin(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
762 match self {
763 Self::CL(cl) => cl.asin(access).map(AccessOp::wrap),
764 Self::Host(host) => host.asin(access).map(AccessOp::wrap),
765 }
766 }
767
768 fn sinh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
769 match self {
770 Self::CL(cl) => cl.sinh(access).map(AccessOp::wrap),
771 Self::Host(host) => host.sinh(access).map(AccessOp::wrap),
772 }
773 }
774
775 fn cos(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
776 match self {
777 Self::CL(cl) => cl.cos(access).map(AccessOp::wrap),
778 Self::Host(host) => host.cos(access).map(AccessOp::wrap),
779 }
780 }
781
782 fn acos(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
783 match self {
784 Self::CL(cl) => cl.acos(access).map(AccessOp::wrap),
785 Self::Host(host) => host.acos(access).map(AccessOp::wrap),
786 }
787 }
788
789 fn cosh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
790 match self {
791 Self::CL(cl) => cl.cosh(access).map(AccessOp::wrap),
792 Self::Host(host) => host.cosh(access).map(AccessOp::wrap),
793 }
794 }
795
796 fn tan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
797 match self {
798 Self::CL(cl) => cl.tan(access).map(AccessOp::wrap),
799 Self::Host(host) => host.tan(access).map(AccessOp::wrap),
800 }
801 }
802
803 fn atan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
804 match self {
805 Self::CL(cl) => cl.atan(access).map(AccessOp::wrap),
806 Self::Host(host) => host.atan(access).map(AccessOp::wrap),
807 }
808 }
809
810 fn tanh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
811 match self {
812 Self::CL(cl) => cl.tanh(access).map(AccessOp::wrap),
813 Self::Host(host) => host.tanh(access).map(AccessOp::wrap),
814 }
815 }
816}
817
818#[cfg(not(feature = "opencl"))]
819impl<A: Access<T>, T: CType> ElementwiseUnary<A, T> for Platform {
820 type Op = Unary<A, T, T>;
821
822 fn abs(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
823 match self {
824 Self::Host(host) => host.abs(access).map(AccessOp::wrap),
825 }
826 }
827
828 fn exp(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
829 match self {
830 Self::Host(host) => host.exp(access).map(AccessOp::wrap),
831 }
832 }
833
834 fn ln(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
835 match self {
836 Self::Host(host) => host.ln(access).map(AccessOp::wrap),
837 }
838 }
839
840 fn round(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
841 match self {
842 Self::Host(host) => host.round(access).map(AccessOp::wrap),
843 }
844 }
845}
846
847#[cfg(feature = "opencl")]
848impl<A: Access<T>, T: CType> ElementwiseUnary<A, T> for Platform {
849 type Op = Unary<A, T, T>;
850
851 fn abs(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
852 match self {
853 Self::CL(cl) => cl.abs(access).map(AccessOp::wrap),
854 Self::Host(host) => host.abs(access).map(AccessOp::wrap),
855 }
856 }
857
858 fn exp(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
859 match self {
860 Self::CL(cl) => cl.exp(access).map(AccessOp::wrap),
861 Self::Host(host) => host.exp(access).map(AccessOp::wrap),
862 }
863 }
864
865 fn ln(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
866 match self {
867 Self::CL(cl) => cl.ln(access).map(AccessOp::wrap),
868 Self::Host(host) => host.ln(access).map(AccessOp::wrap),
869 }
870 }
871
872 fn round(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
873 match self {
874 Self::CL(cl) => cl.round(access).map(AccessOp::wrap),
875 Self::Host(host) => host.round(access).map(AccessOp::wrap),
876 }
877 }
878}
879
880#[cfg(not(feature = "opencl"))]
881impl<A: Access<T>, T: CType> ElementwiseUnaryBoolean<A, T> for Platform {
882 type Op = Unary<A, T, u8>;
883
884 fn not(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
885 match self {
886 Self::Host(host) => host.not(access).map(AccessOp::wrap),
887 }
888 }
889}
890
891#[cfg(feature = "opencl")]
892impl<A: Access<T>, T: CType> ElementwiseUnaryBoolean<A, T> for Platform {
893 type Op = Unary<A, T, u8>;
894
895 fn not(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
896 match self {
897 Self::CL(cl) => cl.not(access).map(AccessOp::wrap),
898 Self::Host(host) => host.not(access).map(AccessOp::wrap),
899 }
900 }
901}
902
903#[cfg(not(feature = "opencl"))]
904impl<A, L, R, T> GatherCond<A, L, R, T> for Platform
905where
906 A: Access<u8>,
907 L: Access<T>,
908 R: Access<T>,
909 T: CType,
910{
911 type Op = Cond<A, L, R, T>;
912
913 fn cond(self, cond: A, then: L, or_else: R) -> Result<AccessOp<Self::Op, Self>, Error> {
914 match self {
915 Self::Host(host) => host.cond(cond, then, or_else).map(AccessOp::wrap),
916 }
917 }
918}
919
920#[cfg(feature = "opencl")]
921impl<A, L, R, T> GatherCond<A, L, R, T> for Platform
922where
923 A: Access<u8>,
924 L: Access<T>,
925 R: Access<T>,
926 T: CType,
927{
928 type Op = Cond<A, L, R, T>;
929
930 fn cond(self, cond: A, then: L, or_else: R) -> Result<AccessOp<Self::Op, Self>, Error> {
931 match self {
932 Self::CL(cl) => cl.cond(cond, then, or_else).map(AccessOp::wrap),
933 Self::Host(host) => host.cond(cond, then, or_else).map(AccessOp::wrap),
934 }
935 }
936}
937
938#[cfg(not(feature = "opencl"))]
939impl<L, R, T> LinAlgDual<L, R, T> for Platform
940where
941 L: Access<T>,
942 R: Access<T>,
943 T: CType,
944{
945 type Op = MatMul<L, R, T>;
946
947 fn matmul(
948 self,
949 left: L,
950 right: R,
951 dims: [usize; 4],
952 ) -> Result<AccessOp<Self::Op, Self>, Error> {
953 match self {
954 Self::Host(host) => host.matmul(left, right, dims).map(AccessOp::wrap),
955 }
956 }
957}
958
959#[cfg(feature = "opencl")]
960impl<L, R, T> LinAlgDual<L, R, T> for Platform
961where
962 L: Access<T>,
963 R: Access<T>,
964 T: CType,
965{
966 type Op = MatMul<L, R, T>;
967
968 fn matmul(
969 self,
970 left: L,
971 right: R,
972 dims: [usize; 4],
973 ) -> Result<AccessOp<Self::Op, Self>, Error> {
974 match self {
975 Self::CL(cl) => cl.matmul(left, right, dims).map(AccessOp::wrap),
976 Self::Host(host) => host.matmul(left, right, dims).map(AccessOp::wrap),
977 }
978 }
979}
980
981#[cfg(not(feature = "opencl"))]
982impl<A: Access<T>, T: CType> LinAlgUnary<A, T> for Platform {
983 type Op = MatDiag<A, T>;
984
985 fn diag(
986 self,
987 access: A,
988 batch_size: usize,
989 dim: usize,
990 ) -> Result<AccessOp<Self::Op, Self>, Error> {
991 match self {
992 Self::Host(host) => host.diag(access, batch_size, dim).map(AccessOp::wrap),
993 }
994 }
995}
996
997#[cfg(feature = "opencl")]
998impl<A: Access<T>, T: CType> LinAlgUnary<A, T> for Platform {
999 type Op = MatDiag<A, T>;
1000
1001 fn diag(
1002 self,
1003 access: A,
1004 batch_size: usize,
1005 dim: usize,
1006 ) -> Result<AccessOp<Self::Op, Self>, Error> {
1007 match self {
1008 Self::CL(cl) => cl.diag(access, batch_size, dim).map(AccessOp::wrap),
1009 Self::Host(host) => host.diag(access, batch_size, dim).map(AccessOp::wrap),
1010 }
1011 }
1012}
1013
1014#[cfg(not(feature = "opencl"))]
1015impl Random for Platform {
1016 type Normal = RandomNormal;
1017 type Uniform = RandomUniform;
1018
1019 fn random_normal(self, size: usize) -> Result<AccessOp<Self::Normal, Self>, Error> {
1020 match self {
1021 Self::Host(host) => host.random_normal(size).map(AccessOp::wrap),
1022 }
1023 }
1024
1025 fn random_uniform(self, size: usize) -> Result<AccessOp<Self::Uniform, Self>, Error> {
1026 match self {
1027 Self::Host(host) => host.random_uniform(size).map(AccessOp::wrap),
1028 }
1029 }
1030}
1031
1032#[cfg(feature = "opencl")]
1033impl Random for Platform {
1034 type Normal = RandomNormal;
1035 type Uniform = RandomUniform;
1036
1037 fn random_normal(self, size: usize) -> Result<AccessOp<Self::Normal, Self>, Error> {
1038 match self {
1039 Self::CL(cl) => cl.random_normal(size).map(AccessOp::wrap),
1040 Self::Host(host) => host.random_normal(size).map(AccessOp::wrap),
1041 }
1042 }
1043
1044 fn random_uniform(self, size: usize) -> Result<AccessOp<Self::Uniform, Self>, Error> {
1045 match self {
1046 Self::CL(cl) => cl.random_uniform(size).map(AccessOp::wrap),
1047 Self::Host(host) => host.random_uniform(size).map(AccessOp::wrap),
1048 }
1049 }
1050}
1051
1052#[cfg(not(feature = "opencl"))]
1053impl<A: Access<T>, T: CType> ReduceAll<A, T> for Platform {
1054 fn all(self, access: A) -> Result<bool, Error> {
1055 match self {
1056 Self::Host(host) => host.all(access),
1057 }
1058 }
1059
1060 fn any(self, access: A) -> Result<bool, Error> {
1061 match self {
1062 Self::Host(host) => host.any(access),
1063 }
1064 }
1065
1066 fn max(self, access: A) -> Result<T, Error> {
1067 match self {
1068 Self::Host(host) => ReduceAll::max(host, access),
1069 }
1070 }
1071
1072 fn min(self, access: A) -> Result<T, Error> {
1073 match self {
1074 Self::Host(host) => ReduceAll::min(host, access),
1075 }
1076 }
1077
1078 fn product(self, access: A) -> Result<T, Error> {
1079 match self {
1080 Self::Host(host) => ReduceAll::product(host, access),
1081 }
1082 }
1083
1084 fn sum(self, access: A) -> Result<T, Error> {
1085 match self {
1086 Self::Host(host) => ReduceAll::sum(host, access),
1087 }
1088 }
1089}
1090
1091#[cfg(feature = "opencl")]
1092impl<A, T> ReduceAll<A, T> for Platform
1093where
1094 A: Access<T>,
1095 T: CType,
1096{
1097 fn all(self, access: A) -> Result<bool, Error> {
1098 match self {
1099 Self::CL(cl) => cl.all(access),
1100 Self::Host(host) => host.all(access),
1101 }
1102 }
1103
1104 fn any(self, access: A) -> Result<bool, Error> {
1105 match self {
1106 Self::CL(cl) => cl.any(access),
1107 Self::Host(host) => host.any(access),
1108 }
1109 }
1110
1111 fn max(self, access: A) -> Result<T, Error> {
1112 match self {
1113 Self::CL(cl) => ReduceAll::max(cl, access),
1114 Self::Host(host) => ReduceAll::max(host, access),
1115 }
1116 }
1117
1118 fn min(self, access: A) -> Result<T, Error> {
1119 match self {
1120 Self::CL(cl) => ReduceAll::min(cl, access),
1121 Self::Host(host) => ReduceAll::min(host, access),
1122 }
1123 }
1124
1125 fn product(self, access: A) -> Result<T, Error> {
1126 match self {
1127 Self::CL(cl) => ReduceAll::product(cl, access),
1128 Self::Host(host) => ReduceAll::product(host, access),
1129 }
1130 }
1131
1132 fn sum(self, access: A) -> Result<T, Error> {
1133 match self {
1134 Self::CL(cl) => ReduceAll::sum(cl, access),
1135 Self::Host(host) => ReduceAll::sum(host, access),
1136 }
1137 }
1138}
1139
1140#[cfg(not(feature = "opencl"))]
1141impl<A: Access<T>, T: CType> ReduceAxes<A, T> for Platform {
1142 type Op = Reduce<A, T>;
1143
1144 fn max(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error> {
1145 match self {
1146 Self::Host(host) => ReduceAxes::max(host, access, stride).map(AccessOp::wrap),
1147 }
1148 }
1149
1150 fn min(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error> {
1151 match self {
1152 Self::Host(host) => ReduceAxes::min(host, access, stride).map(AccessOp::wrap),
1153 }
1154 }
1155
1156 fn product(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error> {
1157 match self {
1158 Self::Host(host) => ReduceAxes::product(host, access, stride).map(AccessOp::wrap),
1159 }
1160 }
1161
1162 fn sum(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error> {
1163 match self {
1164 Self::Host(host) => ReduceAxes::sum(host, access, stride).map(AccessOp::wrap),
1165 }
1166 }
1167}
1168
1169#[cfg(feature = "opencl")]
1170impl<A: Access<T>, T: CType> ReduceAxes<A, T> for Platform {
1171 type Op = Reduce<A, T>;
1172
1173 fn max(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error> {
1174 match self {
1175 Self::CL(cl) => ReduceAxes::max(cl, access, stride).map(AccessOp::wrap),
1176 Self::Host(host) => ReduceAxes::max(host, access, stride).map(AccessOp::wrap),
1177 }
1178 }
1179
1180 fn min(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error> {
1181 match self {
1182 Self::CL(cl) => ReduceAxes::min(cl, access, stride).map(AccessOp::wrap),
1183 Self::Host(host) => ReduceAxes::min(host, access, stride).map(AccessOp::wrap),
1184 }
1185 }
1186
1187 fn product(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error> {
1188 match self {
1189 Self::CL(cl) => ReduceAxes::product(cl, access, stride).map(AccessOp::wrap),
1190 Self::Host(host) => ReduceAxes::product(host, access, stride).map(AccessOp::wrap),
1191 }
1192 }
1193
1194 fn sum(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error> {
1195 match self {
1196 Self::CL(cl) => ReduceAxes::sum(cl, access, stride).map(AccessOp::wrap),
1197 Self::Host(host) => ReduceAxes::sum(host, access, stride).map(AccessOp::wrap),
1198 }
1199 }
1200}
1201
1202impl<A: Access<T>, T: CType> Transform<A, T> for Platform {
1203 type Broadcast = View<A, T>;
1204 type Slice = Slice<A, T>;
1205 type Transpose = View<A, T>;
1206
1207 fn broadcast(
1208 self,
1209 access: A,
1210 shape: Shape,
1211 broadcast: Shape,
1212 ) -> Result<AccessOp<Self::Broadcast, Self>, Error> {
1213 match self {
1214 #[cfg(feature = "opencl")]
1215 Self::CL(cl) => cl.broadcast(access, shape, broadcast).map(AccessOp::wrap),
1216 Self::Host(host) => host.broadcast(access, shape, broadcast).map(AccessOp::wrap),
1217 }
1218 }
1219
1220 fn slice(
1221 self,
1222 access: A,
1223 shape: &[usize],
1224 range: Range,
1225 ) -> Result<AccessOp<Self::Slice, Self>, Error> {
1226 match self {
1227 #[cfg(feature = "opencl")]
1228 Self::CL(cl) => cl.slice(access, shape, range).map(AccessOp::wrap),
1229 Self::Host(host) => host.slice(access, shape, range).map(AccessOp::wrap),
1230 }
1231 }
1232
1233 fn transpose(
1234 self,
1235 access: A,
1236 shape: Shape,
1237 permutation: Axes,
1238 ) -> Result<AccessOp<Self::Transpose, Self>, Error> {
1239 match self {
1240 #[cfg(feature = "opencl")]
1241 Self::CL(cl) => cl.transpose(access, shape, permutation).map(AccessOp::wrap),
1242 Self::Host(host) => host
1243 .transpose(access, shape, permutation)
1244 .map(AccessOp::wrap),
1245 }
1246 }
1247}