ha_ndarray/
platform.rs

1use std::fmt;
2
3use crate::access::{Access, AccessOp};
4use crate::buffer::{Buffer, BufferConverter, BufferInstance};
5#[cfg(feature = "opencl")]
6use crate::opencl;
7use crate::ops::*;
8use crate::{host, Axes, CType, Error, Float, Range, Shape};
9
10/// A ha-ndarray platform
11pub trait PlatformInstance: PartialEq + Eq + Clone + Copy + Send + Sync + fmt::Debug {
12    /// Select a specific sub-platform based on data size.
13    fn select(size_hint: usize) -> Self;
14}
15
16/// Constructor for a new buffer filled with a single value
17pub trait Constant<T: CType>: PlatformInstance {
18    /// The type of buffer use by this platform
19    type Buffer: BufferInstance<T>;
20
21    /// Construct a new buffer filled with a single value.
22    fn constant(&self, value: T, size: usize) -> Result<Self::Buffer, Error>;
23}
24
25/// Converter to construct an owned, platform-specific buffer
26pub trait Convert<T: CType>: PlatformInstance {
27    /// The type of buffer use by this platform
28    type Buffer: BufferInstance<T>;
29
30    fn convert(&self, buffer: BufferConverter<T>) -> Result<Self::Buffer, Error>;
31}
32
33/// The global platform, responsible for delegating to specific hardware platforms
34#[derive(Debug, Copy, Clone, Eq, PartialEq)]
35pub enum Platform {
36    #[cfg(feature = "opencl")]
37    CL(opencl::OpenCL),
38    Host(host::Host),
39}
40
41#[cfg(feature = "opencl")]
42impl PlatformInstance for Platform {
43    fn select(size_hint: usize) -> Self {
44        if size_hint < opencl::GPU_MIN_SIZE {
45            Self::Host(host::Host::select(size_hint))
46        } else {
47            Self::CL(opencl::OpenCL)
48        }
49    }
50}
51
52#[cfg(not(feature = "opencl"))]
53impl PlatformInstance for Platform {
54    fn select(size_hint: usize) -> Self {
55        Self::Host(host::Host::select(size_hint))
56    }
57}
58
59#[cfg(feature = "opencl")]
60impl From<opencl::OpenCL> for Platform {
61    fn from(opencl: opencl::OpenCL) -> Self {
62        Self::CL(opencl)
63    }
64}
65
66impl From<host::Host> for Platform {
67    fn from(host: host::Host) -> Self {
68        Self::Host(host)
69    }
70}
71
72impl<T: CType> Convert<T> for Platform {
73    type Buffer = Buffer<T>;
74
75    fn convert<'a>(&self, buffer: BufferConverter<'a, T>) -> Result<Self::Buffer, Error> {
76        match self {
77            #[cfg(feature = "opencl")]
78            Self::CL(cl) => cl.convert(buffer).map(Buffer::CL),
79            Self::Host(host) => host.convert(buffer).map(Buffer::Host),
80        }
81    }
82}
83
84#[cfg(not(feature = "opencl"))]
85impl<T: CType> Constant<T> for Platform {
86    type Buffer = Buffer<T>;
87
88    fn constant(&self, value: T, size: usize) -> Result<Self::Buffer, Error> {
89        match self {
90            Self::Host(host) => host.constant(value, size).map(Buffer::Host),
91        }
92    }
93}
94
95#[cfg(feature = "opencl")]
96impl<T: CType> Constant<T> for Platform {
97    type Buffer = Buffer<T>;
98
99    fn constant(&self, value: T, size: usize) -> Result<Self::Buffer, Error> {
100        match self {
101            Self::CL(cl) => cl.constant(value, size).map(Buffer::CL),
102            Self::Host(host) => host.constant(value, size).map(Buffer::Host),
103        }
104    }
105}
106
107#[cfg(not(feature = "opencl"))]
108impl<T: CType> Construct<T> for Platform {
109    type Range = Linear<T>;
110
111    fn range(self, start: T, stop: T, size: usize) -> Result<AccessOp<Self::Range, Self>, Error> {
112        match self {
113            Self::Host(host) => host.range(start, stop, size).map(AccessOp::wrap),
114        }
115    }
116}
117
118#[cfg(feature = "opencl")]
119impl<T: CType> Construct<T> for Platform {
120    type Range = Linear<T>;
121
122    fn range(self, start: T, stop: T, size: usize) -> Result<AccessOp<Self::Range, Self>, Error> {
123        match self {
124            Self::CL(cl) => cl.range(start, stop, size).map(AccessOp::wrap),
125            Self::Host(host) => host.range(start, stop, size).map(AccessOp::wrap),
126        }
127    }
128}
129
130#[cfg(not(feature = "opencl"))]
131impl<L, R, T> ElementwiseBoolean<L, R, T> for Platform
132where
133    L: Access<T>,
134    R: Access<T>,
135    T: CType,
136{
137    type Op = Dual<L, R, T, u8>;
138
139    fn and(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
140        match self {
141            Self::Host(host) => host.and(left, right).map(AccessOp::wrap),
142        }
143    }
144    fn or(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
145        match self {
146            Self::Host(host) => host.or(left, right).map(AccessOp::wrap),
147        }
148    }
149    fn xor(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
150        match self {
151            Self::Host(host) => host.xor(left, right).map(AccessOp::wrap),
152        }
153    }
154}
155
156#[cfg(feature = "opencl")]
157impl<L, R, T> ElementwiseBoolean<L, R, T> for Platform
158where
159    L: Access<T>,
160    R: Access<T>,
161    T: CType,
162{
163    type Op = Dual<L, R, T, u8>;
164
165    fn and(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
166        match self {
167            Self::CL(cl) => cl.and(left, right).map(AccessOp::wrap),
168            Self::Host(host) => host.and(left, right).map(AccessOp::wrap),
169        }
170    }
171
172    fn or(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
173        match self {
174            Self::CL(cl) => cl.or(left, right).map(AccessOp::wrap),
175            Self::Host(host) => host.or(left, right).map(AccessOp::wrap),
176        }
177    }
178
179    fn xor(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
180        match self {
181            Self::CL(cl) => cl.xor(left, right).map(AccessOp::wrap),
182            Self::Host(host) => host.xor(left, right).map(AccessOp::wrap),
183        }
184    }
185}
186
187#[cfg(not(feature = "opencl"))]
188impl<A: Access<T>, T: CType> ElementwiseBooleanScalar<A, T> for Platform {
189    type Op = Scalar<A, T, u8>;
190
191    fn and_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
192        match self {
193            Self::Host(host) => host.and_scalar(left, right).map(AccessOp::wrap),
194        }
195    }
196    fn or_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
197        match self {
198            Self::Host(host) => host.or_scalar(left, right).map(AccessOp::wrap),
199        }
200    }
201    fn xor_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
202        match self {
203            Self::Host(host) => host.xor_scalar(left, right).map(AccessOp::wrap),
204        }
205    }
206}
207
208#[cfg(feature = "opencl")]
209impl<A: Access<T>, T: CType> ElementwiseBooleanScalar<A, T> for Platform {
210    type Op = Scalar<A, T, u8>;
211
212    fn and_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
213        match self {
214            Self::CL(cl) => cl.and_scalar(left, right).map(AccessOp::wrap),
215            Self::Host(host) => host.and_scalar(left, right).map(AccessOp::wrap),
216        }
217    }
218
219    fn or_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
220        match self {
221            Self::CL(cl) => cl.or_scalar(left, right).map(AccessOp::wrap),
222            Self::Host(host) => host.or_scalar(left, right).map(AccessOp::wrap),
223        }
224    }
225
226    fn xor_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
227        match self {
228            Self::CL(cl) => cl.xor_scalar(left, right).map(AccessOp::wrap),
229            Self::Host(host) => host.xor_scalar(left, right).map(AccessOp::wrap),
230        }
231    }
232}
233
234#[cfg(not(feature = "opencl"))]
235impl<A: Access<IT>, IT: CType, OT: CType> ElementwiseCast<A, IT, OT> for Platform {
236    type Op = Cast<A, IT, OT>;
237
238    fn cast(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
239        match self {
240            Self::Host(host) => host.cast(access).map(AccessOp::wrap),
241        }
242    }
243}
244
245#[cfg(feature = "opencl")]
246impl<A: Access<IT>, IT: CType, OT: CType> ElementwiseCast<A, IT, OT> for Platform {
247    type Op = Cast<A, IT, OT>;
248
249    fn cast(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
250        match self {
251            Self::CL(cl) => cl.cast(access).map(AccessOp::wrap),
252            Self::Host(host) => host.cast(access).map(AccessOp::wrap),
253        }
254    }
255}
256
257#[cfg(not(feature = "opencl"))]
258impl<L, R, T> ElementwiseCompare<L, R, T> for Platform
259where
260    L: Access<T>,
261    R: Access<T>,
262    T: CType,
263{
264    type Op = Dual<L, R, T, u8>;
265
266    fn eq(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
267        match self {
268            Self::Host(host) => host.eq(left, right).map(AccessOp::wrap),
269        }
270    }
271
272    fn ge(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
273        match self {
274            Self::Host(host) => host.ge(left, right).map(AccessOp::wrap),
275        }
276    }
277
278    fn gt(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
279        match self {
280            Self::Host(host) => host.gt(left, right).map(AccessOp::wrap),
281        }
282    }
283
284    fn le(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
285        match self {
286            Self::Host(host) => host.le(left, right).map(AccessOp::wrap),
287        }
288    }
289
290    fn lt(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
291        match self {
292            Self::Host(host) => host.lt(left, right).map(AccessOp::wrap),
293        }
294    }
295
296    fn ne(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
297        match self {
298            Self::Host(host) => host.ne(left, right).map(AccessOp::wrap),
299        }
300    }
301}
302
303#[cfg(feature = "opencl")]
304impl<L, R, T> ElementwiseCompare<L, R, T> for Platform
305where
306    L: Access<T>,
307    R: Access<T>,
308    T: CType,
309{
310    type Op = Dual<L, R, T, u8>;
311
312    fn eq(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
313        match self {
314            Self::CL(cl) => cl.eq(left, right).map(AccessOp::wrap),
315            Self::Host(host) => host.eq(left, right).map(AccessOp::wrap),
316        }
317    }
318
319    fn ge(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
320        match self {
321            Self::CL(cl) => cl.ge(left, right).map(AccessOp::wrap),
322            Self::Host(host) => host.ge(left, right).map(AccessOp::wrap),
323        }
324    }
325
326    fn gt(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
327        match self {
328            Self::CL(cl) => cl.gt(left, right).map(AccessOp::wrap),
329            Self::Host(host) => host.gt(left, right).map(AccessOp::wrap),
330        }
331    }
332
333    fn le(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
334        match self {
335            Self::CL(cl) => cl.le(left, right).map(AccessOp::wrap),
336            Self::Host(host) => host.le(left, right).map(AccessOp::wrap),
337        }
338    }
339
340    fn lt(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
341        match self {
342            Self::CL(cl) => cl.lt(left, right).map(AccessOp::wrap),
343            Self::Host(host) => host.lt(left, right).map(AccessOp::wrap),
344        }
345    }
346
347    fn ne(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
348        match self {
349            Self::CL(cl) => cl.ne(left, right).map(AccessOp::wrap),
350            Self::Host(host) => host.ne(left, right).map(AccessOp::wrap),
351        }
352    }
353}
354
355#[cfg(not(feature = "opencl"))]
356impl<A: Access<T>, T: CType> ElementwiseScalarCompare<A, T> for Platform {
357    type Op = Scalar<A, T, u8>;
358
359    fn eq_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
360        match self {
361            Self::Host(host) => host.eq_scalar(left, right).map(AccessOp::wrap),
362        }
363    }
364
365    fn ge_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
366        match self {
367            Self::Host(host) => host.ge_scalar(left, right).map(AccessOp::wrap),
368        }
369    }
370
371    fn gt_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
372        match self {
373            Self::Host(host) => host.gt_scalar(left, right).map(AccessOp::wrap),
374        }
375    }
376
377    fn le_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
378        match self {
379            Self::Host(host) => host.le_scalar(left, right).map(AccessOp::wrap),
380        }
381    }
382
383    fn lt_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
384        match self {
385            Self::Host(host) => host.lt_scalar(left, right).map(AccessOp::wrap),
386        }
387    }
388
389    fn ne_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
390        match self {
391            Self::Host(host) => host.ne_scalar(left, right).map(AccessOp::wrap),
392        }
393    }
394}
395
396#[cfg(feature = "opencl")]
397impl<A: Access<T>, T: CType> ElementwiseScalarCompare<A, T> for Platform {
398    type Op = Scalar<A, T, u8>;
399
400    fn eq_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
401        match self {
402            Self::CL(cl) => cl.eq_scalar(left, right).map(AccessOp::wrap),
403            Self::Host(host) => host.eq_scalar(left, right).map(AccessOp::wrap),
404        }
405    }
406
407    fn ge_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
408        match self {
409            Self::CL(cl) => cl.ge_scalar(left, right).map(AccessOp::wrap),
410            Self::Host(host) => host.ge_scalar(left, right).map(AccessOp::wrap),
411        }
412    }
413
414    fn gt_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
415        match self {
416            Self::CL(cl) => cl.gt_scalar(left, right).map(AccessOp::wrap),
417            Self::Host(host) => host.gt_scalar(left, right).map(AccessOp::wrap),
418        }
419    }
420
421    fn le_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
422        match self {
423            Self::CL(cl) => cl.le_scalar(left, right).map(AccessOp::wrap),
424            Self::Host(host) => host.le_scalar(left, right).map(AccessOp::wrap),
425        }
426    }
427
428    fn lt_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
429        match self {
430            Self::CL(cl) => cl.lt_scalar(left, right).map(AccessOp::wrap),
431            Self::Host(host) => host.lt_scalar(left, right).map(AccessOp::wrap),
432        }
433    }
434
435    fn ne_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
436        match self {
437            Self::CL(cl) => cl.ne_scalar(left, right).map(AccessOp::wrap),
438            Self::Host(host) => host.ne_scalar(left, right).map(AccessOp::wrap),
439        }
440    }
441}
442
443#[cfg(not(feature = "opencl"))]
444impl<L, R, T> ElementwiseDual<L, R, T> for Platform
445where
446    L: Access<T>,
447    R: Access<T>,
448    T: CType,
449{
450    type Op = Dual<L, R, T, T>;
451
452    fn add(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
453        match self {
454            Self::Host(host) => host.add(left, right).map(AccessOp::wrap),
455        }
456    }
457
458    fn div(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
459        match self {
460            Self::Host(host) => host.div(left, right).map(AccessOp::wrap),
461        }
462    }
463
464    fn log(self, arg: L, base: R) -> Result<AccessOp<Self::Op, Self>, Error> {
465        match self {
466            Self::Host(host) => host.log(arg, base).map(AccessOp::wrap),
467        }
468    }
469
470    fn mul(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
471        match self {
472            Self::Host(host) => host.mul(left, right).map(AccessOp::wrap),
473        }
474    }
475
476    fn pow(self, arg: L, exp: R) -> Result<AccessOp<Self::Op, Self>, Error> {
477        match self {
478            Self::Host(host) => host.pow(arg, exp).map(AccessOp::wrap),
479        }
480    }
481
482    fn rem(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
483        match self {
484            Self::Host(host) => host.rem(left, right).map(AccessOp::wrap),
485        }
486    }
487
488    fn sub(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
489        match self {
490            Self::Host(host) => host.sub(left, right).map(AccessOp::wrap),
491        }
492    }
493}
494
495#[cfg(feature = "opencl")]
496impl<L, R, T> ElementwiseDual<L, R, T> for Platform
497where
498    L: Access<T>,
499    R: Access<T>,
500    T: CType,
501{
502    type Op = Dual<L, R, T, T>;
503
504    fn add(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
505        match self {
506            Self::CL(cl) => cl.add(left, right).map(AccessOp::wrap),
507            Self::Host(host) => host.add(left, right).map(AccessOp::wrap),
508        }
509    }
510
511    fn div(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
512        match self {
513            Self::CL(cl) => cl.div(left, right).map(AccessOp::wrap),
514            Self::Host(host) => host.div(left, right).map(AccessOp::wrap),
515        }
516    }
517
518    fn log(self, arg: L, base: R) -> Result<AccessOp<Self::Op, Self>, Error> {
519        match self {
520            Self::CL(cl) => cl.log(arg, base).map(AccessOp::wrap),
521            Self::Host(host) => host.log(arg, base).map(AccessOp::wrap),
522        }
523    }
524
525    fn mul(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
526        match self {
527            Self::CL(cl) => cl.mul(left, right).map(AccessOp::wrap),
528            Self::Host(host) => host.mul(left, right).map(AccessOp::wrap),
529        }
530    }
531
532    fn pow(self, arg: L, exp: R) -> Result<AccessOp<Self::Op, Self>, Error> {
533        match self {
534            Self::CL(cl) => cl.pow(arg, exp).map(AccessOp::wrap),
535            Self::Host(host) => host.pow(arg, exp).map(AccessOp::wrap),
536        }
537    }
538
539    fn rem(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
540        match self {
541            Self::CL(cl) => cl.rem(left, right).map(AccessOp::wrap),
542            Self::Host(host) => host.rem(left, right).map(AccessOp::wrap),
543        }
544    }
545
546    fn sub(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
547        match self {
548            Self::CL(cl) => cl.sub(left, right).map(AccessOp::wrap),
549            Self::Host(host) => host.sub(left, right).map(AccessOp::wrap),
550        }
551    }
552}
553
554#[cfg(not(feature = "opencl"))]
555impl<A: Access<T>, T: CType> ElementwiseScalar<A, T> for Platform {
556    type Op = Scalar<A, T, T>;
557
558    fn add_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
559        match self {
560            Self::Host(host) => host.add_scalar(left, right).map(AccessOp::wrap),
561        }
562    }
563
564    fn div_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
565        match self {
566            Self::Host(host) => host.div_scalar(left, right).map(AccessOp::wrap),
567        }
568    }
569
570    fn log_scalar(self, arg: A, base: T) -> Result<AccessOp<Self::Op, Self>, Error> {
571        match self {
572            Self::Host(host) => host.log_scalar(arg, base).map(AccessOp::wrap),
573        }
574    }
575
576    fn mul_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
577        match self {
578            Self::Host(host) => host.mul_scalar(left, right).map(AccessOp::wrap),
579        }
580    }
581
582    fn pow_scalar(self, arg: A, exp: T) -> Result<AccessOp<Self::Op, Self>, Error> {
583        match self {
584            Self::Host(host) => host.pow_scalar(arg, exp).map(AccessOp::wrap),
585        }
586    }
587
588    fn rem_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
589        match self {
590            Self::Host(host) => host.rem_scalar(left, right).map(AccessOp::wrap),
591        }
592    }
593
594    fn sub_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
595        match self {
596            Self::Host(host) => host.sub_scalar(left, right).map(AccessOp::wrap),
597        }
598    }
599}
600
601#[cfg(feature = "opencl")]
602impl<A: Access<T>, T: CType> ElementwiseScalar<A, T> for Platform {
603    type Op = Scalar<A, T, T>;
604
605    fn add_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
606        match self {
607            Self::CL(cl) => cl.add_scalar(left, right).map(AccessOp::wrap),
608            Self::Host(host) => host.add_scalar(left, right).map(AccessOp::wrap),
609        }
610    }
611
612    fn div_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
613        match self {
614            Self::CL(cl) => cl.div_scalar(left, right).map(AccessOp::wrap),
615            Self::Host(host) => host.div_scalar(left, right).map(AccessOp::wrap),
616        }
617    }
618
619    fn log_scalar(self, arg: A, base: T) -> Result<AccessOp<Self::Op, Self>, Error> {
620        match self {
621            Self::CL(cl) => cl.log_scalar(arg, base).map(AccessOp::wrap),
622            Self::Host(host) => host.log_scalar(arg, base).map(AccessOp::wrap),
623        }
624    }
625
626    fn mul_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
627        match self {
628            Self::CL(cl) => cl.mul_scalar(left, right).map(AccessOp::wrap),
629            Self::Host(host) => host.mul_scalar(left, right).map(AccessOp::wrap),
630        }
631    }
632
633    fn pow_scalar(self, arg: A, exp: T) -> Result<AccessOp<Self::Op, Self>, Error> {
634        match self {
635            Self::CL(cl) => cl.pow_scalar(arg, exp).map(AccessOp::wrap),
636            Self::Host(host) => host.pow_scalar(arg, exp).map(AccessOp::wrap),
637        }
638    }
639
640    fn rem_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
641        match self {
642            Self::CL(cl) => cl.rem_scalar(left, right).map(AccessOp::wrap),
643            Self::Host(host) => host.rem_scalar(left, right).map(AccessOp::wrap),
644        }
645    }
646
647    fn sub_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
648        match self {
649            Self::CL(cl) => cl.sub_scalar(left, right).map(AccessOp::wrap),
650            Self::Host(host) => host.sub_scalar(left, right).map(AccessOp::wrap),
651        }
652    }
653}
654
655#[cfg(not(feature = "opencl"))]
656impl<A: Access<T>, T: Float> ElementwiseNumeric<A, T> for Platform {
657    type Op = Unary<A, T, u8>;
658
659    fn is_inf(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
660        match self {
661            Self::Host(host) => host.is_inf(access).map(AccessOp::wrap),
662        }
663    }
664
665    fn is_nan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
666        match self {
667            Self::Host(host) => host.is_nan(access).map(AccessOp::wrap),
668        }
669    }
670}
671
672#[cfg(feature = "opencl")]
673impl<A: Access<T>, T: Float> ElementwiseNumeric<A, T> for Platform {
674    type Op = Unary<A, T, u8>;
675
676    fn is_inf(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
677        match self {
678            Self::CL(cl) => cl.is_inf(access).map(AccessOp::wrap),
679            Self::Host(host) => host.is_inf(access).map(AccessOp::wrap),
680        }
681    }
682
683    fn is_nan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
684        match self {
685            Self::CL(cl) => cl.is_nan(access).map(AccessOp::wrap),
686            Self::Host(host) => host.is_nan(access).map(AccessOp::wrap),
687        }
688    }
689}
690
691#[cfg(not(feature = "opencl"))]
692impl<A: Access<T>, T: CType> ElementwiseTrig<A, T> for Platform {
693    type Op = Unary<A, T, T::Float>;
694
695    fn sin(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
696        match self {
697            Self::Host(host) => host.sin(access).map(AccessOp::wrap),
698        }
699    }
700
701    fn asin(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
702        match self {
703            Self::Host(host) => host.asin(access).map(AccessOp::wrap),
704        }
705    }
706
707    fn sinh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
708        match self {
709            Self::Host(host) => host.sinh(access).map(AccessOp::wrap),
710        }
711    }
712
713    fn cos(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
714        match self {
715            Self::Host(host) => host.cos(access).map(AccessOp::wrap),
716        }
717    }
718
719    fn acos(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
720        match self {
721            Self::Host(host) => host.acos(access).map(AccessOp::wrap),
722        }
723    }
724
725    fn cosh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
726        match self {
727            Self::Host(host) => host.cosh(access).map(AccessOp::wrap),
728        }
729    }
730
731    fn tan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
732        match self {
733            Self::Host(host) => host.tan(access).map(AccessOp::wrap),
734        }
735    }
736
737    fn atan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
738        match self {
739            Self::Host(host) => host.atan(access).map(AccessOp::wrap),
740        }
741    }
742
743    fn tanh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
744        match self {
745            Self::Host(host) => host.tanh(access).map(AccessOp::wrap),
746        }
747    }
748}
749
750#[cfg(feature = "opencl")]
751impl<A: Access<T>, T: CType> ElementwiseTrig<A, T> for Platform {
752    type Op = Unary<A, T, T::Float>;
753
754    fn sin(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
755        match self {
756            Self::CL(cl) => cl.sin(access).map(AccessOp::wrap),
757            Self::Host(host) => host.sin(access).map(AccessOp::wrap),
758        }
759    }
760
761    fn asin(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
762        match self {
763            Self::CL(cl) => cl.asin(access).map(AccessOp::wrap),
764            Self::Host(host) => host.asin(access).map(AccessOp::wrap),
765        }
766    }
767
768    fn sinh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
769        match self {
770            Self::CL(cl) => cl.sinh(access).map(AccessOp::wrap),
771            Self::Host(host) => host.sinh(access).map(AccessOp::wrap),
772        }
773    }
774
775    fn cos(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
776        match self {
777            Self::CL(cl) => cl.cos(access).map(AccessOp::wrap),
778            Self::Host(host) => host.cos(access).map(AccessOp::wrap),
779        }
780    }
781
782    fn acos(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
783        match self {
784            Self::CL(cl) => cl.acos(access).map(AccessOp::wrap),
785            Self::Host(host) => host.acos(access).map(AccessOp::wrap),
786        }
787    }
788
789    fn cosh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
790        match self {
791            Self::CL(cl) => cl.cosh(access).map(AccessOp::wrap),
792            Self::Host(host) => host.cosh(access).map(AccessOp::wrap),
793        }
794    }
795
796    fn tan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
797        match self {
798            Self::CL(cl) => cl.tan(access).map(AccessOp::wrap),
799            Self::Host(host) => host.tan(access).map(AccessOp::wrap),
800        }
801    }
802
803    fn atan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
804        match self {
805            Self::CL(cl) => cl.atan(access).map(AccessOp::wrap),
806            Self::Host(host) => host.atan(access).map(AccessOp::wrap),
807        }
808    }
809
810    fn tanh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
811        match self {
812            Self::CL(cl) => cl.tanh(access).map(AccessOp::wrap),
813            Self::Host(host) => host.tanh(access).map(AccessOp::wrap),
814        }
815    }
816}
817
818#[cfg(not(feature = "opencl"))]
819impl<A: Access<T>, T: CType> ElementwiseUnary<A, T> for Platform {
820    type Op = Unary<A, T, T>;
821
822    fn abs(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
823        match self {
824            Self::Host(host) => host.abs(access).map(AccessOp::wrap),
825        }
826    }
827
828    fn exp(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
829        match self {
830            Self::Host(host) => host.exp(access).map(AccessOp::wrap),
831        }
832    }
833
834    fn ln(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
835        match self {
836            Self::Host(host) => host.ln(access).map(AccessOp::wrap),
837        }
838    }
839
840    fn round(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
841        match self {
842            Self::Host(host) => host.round(access).map(AccessOp::wrap),
843        }
844    }
845}
846
847#[cfg(feature = "opencl")]
848impl<A: Access<T>, T: CType> ElementwiseUnary<A, T> for Platform {
849    type Op = Unary<A, T, T>;
850
851    fn abs(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
852        match self {
853            Self::CL(cl) => cl.abs(access).map(AccessOp::wrap),
854            Self::Host(host) => host.abs(access).map(AccessOp::wrap),
855        }
856    }
857
858    fn exp(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
859        match self {
860            Self::CL(cl) => cl.exp(access).map(AccessOp::wrap),
861            Self::Host(host) => host.exp(access).map(AccessOp::wrap),
862        }
863    }
864
865    fn ln(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
866        match self {
867            Self::CL(cl) => cl.ln(access).map(AccessOp::wrap),
868            Self::Host(host) => host.ln(access).map(AccessOp::wrap),
869        }
870    }
871
872    fn round(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
873        match self {
874            Self::CL(cl) => cl.round(access).map(AccessOp::wrap),
875            Self::Host(host) => host.round(access).map(AccessOp::wrap),
876        }
877    }
878}
879
880#[cfg(not(feature = "opencl"))]
881impl<A: Access<T>, T: CType> ElementwiseUnaryBoolean<A, T> for Platform {
882    type Op = Unary<A, T, u8>;
883
884    fn not(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
885        match self {
886            Self::Host(host) => host.not(access).map(AccessOp::wrap),
887        }
888    }
889}
890
891#[cfg(feature = "opencl")]
892impl<A: Access<T>, T: CType> ElementwiseUnaryBoolean<A, T> for Platform {
893    type Op = Unary<A, T, u8>;
894
895    fn not(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
896        match self {
897            Self::CL(cl) => cl.not(access).map(AccessOp::wrap),
898            Self::Host(host) => host.not(access).map(AccessOp::wrap),
899        }
900    }
901}
902
903#[cfg(not(feature = "opencl"))]
904impl<A, L, R, T> GatherCond<A, L, R, T> for Platform
905where
906    A: Access<u8>,
907    L: Access<T>,
908    R: Access<T>,
909    T: CType,
910{
911    type Op = Cond<A, L, R, T>;
912
913    fn cond(self, cond: A, then: L, or_else: R) -> Result<AccessOp<Self::Op, Self>, Error> {
914        match self {
915            Self::Host(host) => host.cond(cond, then, or_else).map(AccessOp::wrap),
916        }
917    }
918}
919
920#[cfg(feature = "opencl")]
921impl<A, L, R, T> GatherCond<A, L, R, T> for Platform
922where
923    A: Access<u8>,
924    L: Access<T>,
925    R: Access<T>,
926    T: CType,
927{
928    type Op = Cond<A, L, R, T>;
929
930    fn cond(self, cond: A, then: L, or_else: R) -> Result<AccessOp<Self::Op, Self>, Error> {
931        match self {
932            Self::CL(cl) => cl.cond(cond, then, or_else).map(AccessOp::wrap),
933            Self::Host(host) => host.cond(cond, then, or_else).map(AccessOp::wrap),
934        }
935    }
936}
937
938#[cfg(not(feature = "opencl"))]
939impl<L, R, T> LinAlgDual<L, R, T> for Platform
940where
941    L: Access<T>,
942    R: Access<T>,
943    T: CType,
944{
945    type Op = MatMul<L, R, T>;
946
947    fn matmul(
948        self,
949        left: L,
950        right: R,
951        dims: [usize; 4],
952    ) -> Result<AccessOp<Self::Op, Self>, Error> {
953        match self {
954            Self::Host(host) => host.matmul(left, right, dims).map(AccessOp::wrap),
955        }
956    }
957}
958
959#[cfg(feature = "opencl")]
960impl<L, R, T> LinAlgDual<L, R, T> for Platform
961where
962    L: Access<T>,
963    R: Access<T>,
964    T: CType,
965{
966    type Op = MatMul<L, R, T>;
967
968    fn matmul(
969        self,
970        left: L,
971        right: R,
972        dims: [usize; 4],
973    ) -> Result<AccessOp<Self::Op, Self>, Error> {
974        match self {
975            Self::CL(cl) => cl.matmul(left, right, dims).map(AccessOp::wrap),
976            Self::Host(host) => host.matmul(left, right, dims).map(AccessOp::wrap),
977        }
978    }
979}
980
981#[cfg(not(feature = "opencl"))]
982impl<A: Access<T>, T: CType> LinAlgUnary<A, T> for Platform {
983    type Op = MatDiag<A, T>;
984
985    fn diag(
986        self,
987        access: A,
988        batch_size: usize,
989        dim: usize,
990    ) -> Result<AccessOp<Self::Op, Self>, Error> {
991        match self {
992            Self::Host(host) => host.diag(access, batch_size, dim).map(AccessOp::wrap),
993        }
994    }
995}
996
997#[cfg(feature = "opencl")]
998impl<A: Access<T>, T: CType> LinAlgUnary<A, T> for Platform {
999    type Op = MatDiag<A, T>;
1000
1001    fn diag(
1002        self,
1003        access: A,
1004        batch_size: usize,
1005        dim: usize,
1006    ) -> Result<AccessOp<Self::Op, Self>, Error> {
1007        match self {
1008            Self::CL(cl) => cl.diag(access, batch_size, dim).map(AccessOp::wrap),
1009            Self::Host(host) => host.diag(access, batch_size, dim).map(AccessOp::wrap),
1010        }
1011    }
1012}
1013
1014#[cfg(not(feature = "opencl"))]
1015impl Random for Platform {
1016    type Normal = RandomNormal;
1017    type Uniform = RandomUniform;
1018
1019    fn random_normal(self, size: usize) -> Result<AccessOp<Self::Normal, Self>, Error> {
1020        match self {
1021            Self::Host(host) => host.random_normal(size).map(AccessOp::wrap),
1022        }
1023    }
1024
1025    fn random_uniform(self, size: usize) -> Result<AccessOp<Self::Uniform, Self>, Error> {
1026        match self {
1027            Self::Host(host) => host.random_uniform(size).map(AccessOp::wrap),
1028        }
1029    }
1030}
1031
1032#[cfg(feature = "opencl")]
1033impl Random for Platform {
1034    type Normal = RandomNormal;
1035    type Uniform = RandomUniform;
1036
1037    fn random_normal(self, size: usize) -> Result<AccessOp<Self::Normal, Self>, Error> {
1038        match self {
1039            Self::CL(cl) => cl.random_normal(size).map(AccessOp::wrap),
1040            Self::Host(host) => host.random_normal(size).map(AccessOp::wrap),
1041        }
1042    }
1043
1044    fn random_uniform(self, size: usize) -> Result<AccessOp<Self::Uniform, Self>, Error> {
1045        match self {
1046            Self::CL(cl) => cl.random_uniform(size).map(AccessOp::wrap),
1047            Self::Host(host) => host.random_uniform(size).map(AccessOp::wrap),
1048        }
1049    }
1050}
1051
1052#[cfg(not(feature = "opencl"))]
1053impl<A: Access<T>, T: CType> ReduceAll<A, T> for Platform {
1054    fn all(self, access: A) -> Result<bool, Error> {
1055        match self {
1056            Self::Host(host) => host.all(access),
1057        }
1058    }
1059
1060    fn any(self, access: A) -> Result<bool, Error> {
1061        match self {
1062            Self::Host(host) => host.any(access),
1063        }
1064    }
1065
1066    fn max(self, access: A) -> Result<T, Error> {
1067        match self {
1068            Self::Host(host) => ReduceAll::max(host, access),
1069        }
1070    }
1071
1072    fn min(self, access: A) -> Result<T, Error> {
1073        match self {
1074            Self::Host(host) => ReduceAll::min(host, access),
1075        }
1076    }
1077
1078    fn product(self, access: A) -> Result<T, Error> {
1079        match self {
1080            Self::Host(host) => ReduceAll::product(host, access),
1081        }
1082    }
1083
1084    fn sum(self, access: A) -> Result<T, Error> {
1085        match self {
1086            Self::Host(host) => ReduceAll::sum(host, access),
1087        }
1088    }
1089}
1090
1091#[cfg(feature = "opencl")]
1092impl<A, T> ReduceAll<A, T> for Platform
1093where
1094    A: Access<T>,
1095    T: CType,
1096{
1097    fn all(self, access: A) -> Result<bool, Error> {
1098        match self {
1099            Self::CL(cl) => cl.all(access),
1100            Self::Host(host) => host.all(access),
1101        }
1102    }
1103
1104    fn any(self, access: A) -> Result<bool, Error> {
1105        match self {
1106            Self::CL(cl) => cl.any(access),
1107            Self::Host(host) => host.any(access),
1108        }
1109    }
1110
1111    fn max(self, access: A) -> Result<T, Error> {
1112        match self {
1113            Self::CL(cl) => ReduceAll::max(cl, access),
1114            Self::Host(host) => ReduceAll::max(host, access),
1115        }
1116    }
1117
1118    fn min(self, access: A) -> Result<T, Error> {
1119        match self {
1120            Self::CL(cl) => ReduceAll::min(cl, access),
1121            Self::Host(host) => ReduceAll::min(host, access),
1122        }
1123    }
1124
1125    fn product(self, access: A) -> Result<T, Error> {
1126        match self {
1127            Self::CL(cl) => ReduceAll::product(cl, access),
1128            Self::Host(host) => ReduceAll::product(host, access),
1129        }
1130    }
1131
1132    fn sum(self, access: A) -> Result<T, Error> {
1133        match self {
1134            Self::CL(cl) => ReduceAll::sum(cl, access),
1135            Self::Host(host) => ReduceAll::sum(host, access),
1136        }
1137    }
1138}
1139
1140#[cfg(not(feature = "opencl"))]
1141impl<A: Access<T>, T: CType> ReduceAxes<A, T> for Platform {
1142    type Op = Reduce<A, T>;
1143
1144    fn max(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error> {
1145        match self {
1146            Self::Host(host) => ReduceAxes::max(host, access, stride).map(AccessOp::wrap),
1147        }
1148    }
1149
1150    fn min(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error> {
1151        match self {
1152            Self::Host(host) => ReduceAxes::min(host, access, stride).map(AccessOp::wrap),
1153        }
1154    }
1155
1156    fn product(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error> {
1157        match self {
1158            Self::Host(host) => ReduceAxes::product(host, access, stride).map(AccessOp::wrap),
1159        }
1160    }
1161
1162    fn sum(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error> {
1163        match self {
1164            Self::Host(host) => ReduceAxes::sum(host, access, stride).map(AccessOp::wrap),
1165        }
1166    }
1167}
1168
1169#[cfg(feature = "opencl")]
1170impl<A: Access<T>, T: CType> ReduceAxes<A, T> for Platform {
1171    type Op = Reduce<A, T>;
1172
1173    fn max(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error> {
1174        match self {
1175            Self::CL(cl) => ReduceAxes::max(cl, access, stride).map(AccessOp::wrap),
1176            Self::Host(host) => ReduceAxes::max(host, access, stride).map(AccessOp::wrap),
1177        }
1178    }
1179
1180    fn min(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error> {
1181        match self {
1182            Self::CL(cl) => ReduceAxes::min(cl, access, stride).map(AccessOp::wrap),
1183            Self::Host(host) => ReduceAxes::min(host, access, stride).map(AccessOp::wrap),
1184        }
1185    }
1186
1187    fn product(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error> {
1188        match self {
1189            Self::CL(cl) => ReduceAxes::product(cl, access, stride).map(AccessOp::wrap),
1190            Self::Host(host) => ReduceAxes::product(host, access, stride).map(AccessOp::wrap),
1191        }
1192    }
1193
1194    fn sum(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error> {
1195        match self {
1196            Self::CL(cl) => ReduceAxes::sum(cl, access, stride).map(AccessOp::wrap),
1197            Self::Host(host) => ReduceAxes::sum(host, access, stride).map(AccessOp::wrap),
1198        }
1199    }
1200}
1201
1202impl<A: Access<T>, T: CType> Transform<A, T> for Platform {
1203    type Broadcast = View<A, T>;
1204    type Slice = Slice<A, T>;
1205    type Transpose = View<A, T>;
1206
1207    fn broadcast(
1208        self,
1209        access: A,
1210        shape: Shape,
1211        broadcast: Shape,
1212    ) -> Result<AccessOp<Self::Broadcast, Self>, Error> {
1213        match self {
1214            #[cfg(feature = "opencl")]
1215            Self::CL(cl) => cl.broadcast(access, shape, broadcast).map(AccessOp::wrap),
1216            Self::Host(host) => host.broadcast(access, shape, broadcast).map(AccessOp::wrap),
1217        }
1218    }
1219
1220    fn slice(
1221        self,
1222        access: A,
1223        shape: &[usize],
1224        range: Range,
1225    ) -> Result<AccessOp<Self::Slice, Self>, Error> {
1226        match self {
1227            #[cfg(feature = "opencl")]
1228            Self::CL(cl) => cl.slice(access, shape, range).map(AccessOp::wrap),
1229            Self::Host(host) => host.slice(access, shape, range).map(AccessOp::wrap),
1230        }
1231    }
1232
1233    fn transpose(
1234        self,
1235        access: A,
1236        shape: Shape,
1237        permutation: Axes,
1238    ) -> Result<AccessOp<Self::Transpose, Self>, Error> {
1239        match self {
1240            #[cfg(feature = "opencl")]
1241            Self::CL(cl) => cl.transpose(access, shape, permutation).map(AccessOp::wrap),
1242            Self::Host(host) => host
1243                .transpose(access, shape, permutation)
1244                .map(AccessOp::wrap),
1245        }
1246    }
1247}