Skip to main content

ha_ndarray/
platform.rs

1use std::fmt;
2
3use crate::access::{Access, AccessOp};
4use crate::buffer::{Buffer, BufferConverter, BufferInstance};
5#[cfg(feature = "opencl")]
6use crate::opencl;
7use crate::ops::*;
8use crate::{host, Axes, Error, Float, Number, Range, Real, Shape};
9
10/// A ha-ndarray platform
11pub trait PlatformInstance: PartialEq + Eq + Clone + Copy + Send + Sync + fmt::Debug {
12    /// Select a specific sub-platform based on data size.
13    fn select(size_hint: usize) -> Self;
14}
15
16/// Constructor for a new buffer filled with a single value
17pub trait Constant<T: Number>: PlatformInstance {
18    /// The type of buffer use by this platform
19    type Buffer: BufferInstance<T>;
20
21    /// Construct a new buffer filled with a single value.
22    fn constant(&self, value: T, size: usize) -> Result<Self::Buffer, Error>;
23}
24
25/// Converter to construct an owned, platform-specific buffer
26pub trait Convert<T: Number>: PlatformInstance {
27    /// The type of buffer use by this platform
28    type Buffer: BufferInstance<T>;
29
30    fn convert(&self, buffer: BufferConverter<T>) -> Result<Self::Buffer, Error>;
31}
32
33/// The global platform, responsible for delegating to specific hardware platforms
34#[derive(Debug, Copy, Clone, Eq, PartialEq)]
35pub enum Platform {
36    #[cfg(feature = "opencl")]
37    CL(opencl::OpenCL),
38    Host(host::Host),
39}
40
41#[cfg(feature = "opencl")]
42impl PlatformInstance for Platform {
43    fn select(size_hint: usize) -> Self {
44        if size_hint < opencl::GPU_MIN_SIZE {
45            Self::Host(host::Host::select(size_hint))
46        } else {
47            Self::CL(opencl::OpenCL)
48        }
49    }
50}
51
52#[cfg(not(feature = "opencl"))]
53impl PlatformInstance for Platform {
54    fn select(size_hint: usize) -> Self {
55        Self::Host(host::Host::select(size_hint))
56    }
57}
58
59#[cfg(feature = "opencl")]
60impl From<opencl::OpenCL> for Platform {
61    fn from(opencl: opencl::OpenCL) -> Self {
62        Self::CL(opencl)
63    }
64}
65
66impl From<host::Host> for Platform {
67    fn from(host: host::Host) -> Self {
68        Self::Host(host)
69    }
70}
71
72impl<T: Number> Convert<T> for Platform {
73    type Buffer = Buffer<T>;
74
75    fn convert<'a>(&self, buffer: BufferConverter<'a, T>) -> Result<Self::Buffer, Error> {
76        match self {
77            #[cfg(feature = "opencl")]
78            Self::CL(cl) => cl.convert(buffer).map(Buffer::CL),
79            Self::Host(host) => host.convert(buffer).map(Buffer::Host),
80        }
81    }
82}
83
84impl<T: Number> Constant<T> for Platform {
85    type Buffer = Buffer<T>;
86
87    fn constant(&self, value: T, size: usize) -> Result<Self::Buffer, Error> {
88        match self {
89            #[cfg(feature = "opencl")]
90            Self::CL(cl) => cl.constant(value, size).map(Buffer::CL),
91            Self::Host(host) => host.constant(value, size).map(Buffer::Host),
92        }
93    }
94}
95
96impl<A, T> ConstructConcat<A, T> for Platform
97where
98    A: Access<T>,
99    T: Number,
100{
101    type Op = Concat<A, T>;
102
103    fn concat(self, data: Vec<A>) -> Result<AccessOp<Self::Op, Self>, Error> {
104        match self {
105            #[cfg(feature = "opencl")]
106            Self::CL(cl) => cl.concat(data).map(AccessOp::wrap),
107            Self::Host(host) => host.concat(data).map(AccessOp::wrap),
108        }
109    }
110}
111
112impl<T: Number + PartialOrd> ConstructRange<T> for Platform {
113    type Range = Linear<T>;
114
115    fn range(self, start: T, stop: T, size: usize) -> Result<AccessOp<Self::Range, Self>, Error> {
116        match self {
117            #[cfg(feature = "opencl")]
118            Self::CL(cl) => cl.range(start, stop, size).map(AccessOp::wrap),
119            Self::Host(host) => host.range(start, stop, size).map(AccessOp::wrap),
120        }
121    }
122}
123
124impl<A, T> ElementwiseAbs<A, T> for Platform
125where
126    A: Access<T>,
127    T: Number,
128{
129    type Op = Unary<A, T, T::Abs>;
130
131    fn abs(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
132        match self {
133            #[cfg(feature = "opencl")]
134            Self::CL(cl) => cl.abs(access).map(AccessOp::wrap),
135            Self::Host(host) => host.abs(access).map(AccessOp::wrap),
136        }
137    }
138}
139
140impl<L, R, T> ElementwiseBoolean<L, R, T> for Platform
141where
142    L: Access<T>,
143    R: Access<T>,
144    T: Number,
145{
146    type Op = Dual<L, R, T, u8>;
147
148    fn and(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
149        match self {
150            #[cfg(feature = "opencl")]
151            Self::CL(cl) => cl.and(left, right).map(AccessOp::wrap),
152            Self::Host(host) => host.and(left, right).map(AccessOp::wrap),
153        }
154    }
155
156    fn or(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
157        match self {
158            #[cfg(feature = "opencl")]
159            Self::CL(cl) => cl.or(left, right).map(AccessOp::wrap),
160            Self::Host(host) => host.or(left, right).map(AccessOp::wrap),
161        }
162    }
163
164    fn xor(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
165        match self {
166            #[cfg(feature = "opencl")]
167            Self::CL(cl) => cl.xor(left, right).map(AccessOp::wrap),
168            Self::Host(host) => host.xor(left, right).map(AccessOp::wrap),
169        }
170    }
171}
172
173impl<A: Access<T>, T: Number> ElementwiseBooleanScalar<A, T> for Platform {
174    type Op = Scalar<A, T, u8>;
175
176    fn and_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
177        match self {
178            #[cfg(feature = "opencl")]
179            Self::CL(cl) => cl.and_scalar(left, right).map(AccessOp::wrap),
180            Self::Host(host) => host.and_scalar(left, right).map(AccessOp::wrap),
181        }
182    }
183
184    fn or_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
185        match self {
186            #[cfg(feature = "opencl")]
187            Self::CL(cl) => cl.or_scalar(left, right).map(AccessOp::wrap),
188            Self::Host(host) => host.or_scalar(left, right).map(AccessOp::wrap),
189        }
190    }
191
192    fn xor_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
193        match self {
194            #[cfg(feature = "opencl")]
195            Self::CL(cl) => cl.xor_scalar(left, right).map(AccessOp::wrap),
196            Self::Host(host) => host.xor_scalar(left, right).map(AccessOp::wrap),
197        }
198    }
199}
200
201impl<A: Access<IT>, IT: Number, OT: Number> ElementwiseCast<A, IT, OT> for Platform {
202    type Op = Cast<A, IT, OT>;
203
204    fn cast(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
205        match self {
206            #[cfg(feature = "opencl")]
207            Self::CL(cl) => cl.cast(access).map(AccessOp::wrap),
208            Self::Host(host) => host.cast(access).map(AccessOp::wrap),
209        }
210    }
211}
212
213impl<L, R, T> ElementwiseCompare<L, R, T> for Platform
214where
215    L: Access<T>,
216    R: Access<T>,
217    T: Number,
218{
219    type Op = Dual<L, R, T, u8>;
220
221    fn eq(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
222        match self {
223            #[cfg(feature = "opencl")]
224            Self::CL(cl) => cl.eq(left, right).map(AccessOp::wrap),
225            Self::Host(host) => host.eq(left, right).map(AccessOp::wrap),
226        }
227    }
228
229    fn ge(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>
230    where
231        T: Real,
232    {
233        match self {
234            #[cfg(feature = "opencl")]
235            Self::CL(cl) => cl.ge(left, right).map(AccessOp::wrap),
236            Self::Host(host) => host.ge(left, right).map(AccessOp::wrap),
237        }
238    }
239
240    fn gt(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>
241    where
242        T: Real,
243    {
244        match self {
245            #[cfg(feature = "opencl")]
246            Self::CL(cl) => cl.gt(left, right).map(AccessOp::wrap),
247            Self::Host(host) => host.gt(left, right).map(AccessOp::wrap),
248        }
249    }
250
251    fn le(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>
252    where
253        T: Real,
254    {
255        match self {
256            #[cfg(feature = "opencl")]
257            Self::CL(cl) => cl.le(left, right).map(AccessOp::wrap),
258            Self::Host(host) => host.le(left, right).map(AccessOp::wrap),
259        }
260    }
261
262    fn lt(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>
263    where
264        T: Real,
265    {
266        match self {
267            #[cfg(feature = "opencl")]
268            Self::CL(cl) => cl.lt(left, right).map(AccessOp::wrap),
269            Self::Host(host) => host.lt(left, right).map(AccessOp::wrap),
270        }
271    }
272
273    fn ne(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
274        match self {
275            #[cfg(feature = "opencl")]
276            Self::CL(cl) => cl.ne(left, right).map(AccessOp::wrap),
277            Self::Host(host) => host.ne(left, right).map(AccessOp::wrap),
278        }
279    }
280}
281
282impl<A: Access<T>, T: Number> ElementwiseCompareScalar<A, T> for Platform {
283    type Op = Scalar<A, T, u8>;
284
285    fn eq_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
286        match self {
287            #[cfg(feature = "opencl")]
288            Self::CL(cl) => cl.eq_scalar(left, right).map(AccessOp::wrap),
289            Self::Host(host) => host.eq_scalar(left, right).map(AccessOp::wrap),
290        }
291    }
292
293    fn ge_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>
294    where
295        T: Real,
296    {
297        match self {
298            #[cfg(feature = "opencl")]
299            Self::CL(cl) => cl.ge_scalar(left, right).map(AccessOp::wrap),
300            Self::Host(host) => host.ge_scalar(left, right).map(AccessOp::wrap),
301        }
302    }
303
304    fn gt_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>
305    where
306        T: Real,
307    {
308        match self {
309            #[cfg(feature = "opencl")]
310            Self::CL(cl) => cl.gt_scalar(left, right).map(AccessOp::wrap),
311            Self::Host(host) => host.gt_scalar(left, right).map(AccessOp::wrap),
312        }
313    }
314
315    fn le_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>
316    where
317        T: Real,
318    {
319        match self {
320            #[cfg(feature = "opencl")]
321            Self::CL(cl) => cl.le_scalar(left, right).map(AccessOp::wrap),
322            Self::Host(host) => host.le_scalar(left, right).map(AccessOp::wrap),
323        }
324    }
325
326    fn lt_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>
327    where
328        T: Real,
329    {
330        match self {
331            #[cfg(feature = "opencl")]
332            Self::CL(cl) => cl.lt_scalar(left, right).map(AccessOp::wrap),
333            Self::Host(host) => host.lt_scalar(left, right).map(AccessOp::wrap),
334        }
335    }
336
337    fn ne_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
338        match self {
339            #[cfg(feature = "opencl")]
340            Self::CL(cl) => cl.ne_scalar(left, right).map(AccessOp::wrap),
341            Self::Host(host) => host.ne_scalar(left, right).map(AccessOp::wrap),
342        }
343    }
344}
345
346#[cfg(feature = "complex")]
347impl<A, T> complex::ElementwiseUnaryComplex<A, T> for Platform
348where
349    A: Access<T>,
350    T: crate::Complex,
351{
352    type Real = Unary<A, T, T::Real>;
353    type Complex = Unary<A, T, T>;
354
355    fn angle(self, access: A) -> Result<AccessOp<Self::Real, Self>, Error> {
356        match self {
357            #[cfg(feature = "opencl")]
358            Self::CL(cl) => cl.angle(access).map(AccessOp::wrap),
359            Self::Host(host) => host.angle(access).map(AccessOp::wrap),
360        }
361    }
362
363    fn conj(self, access: A) -> Result<AccessOp<Self::Complex, Self>, Error> {
364        match self {
365            #[cfg(feature = "opencl")]
366            Self::CL(cl) => cl.conj(access).map(AccessOp::wrap),
367            Self::Host(host) => host.conj(access).map(AccessOp::wrap),
368        }
369    }
370
371    fn re(self, access: A) -> Result<AccessOp<Self::Real, Self>, Error> {
372        match self {
373            #[cfg(feature = "opencl")]
374            Self::CL(cl) => cl.re(access).map(AccessOp::wrap),
375            Self::Host(host) => host.re(access).map(AccessOp::wrap),
376        }
377    }
378
379    fn im(self, access: A) -> Result<AccessOp<Self::Real, Self>, Error> {
380        match self {
381            #[cfg(feature = "opencl")]
382            Self::CL(cl) => cl.im(access).map(AccessOp::wrap),
383            Self::Host(host) => host.im(access).map(AccessOp::wrap),
384        }
385    }
386}
387
388impl<L, R, T> ElementwiseDual<L, R, T> for Platform
389where
390    L: Access<T>,
391    R: Access<T>,
392    T: Number,
393{
394    type Op = Dual<L, R, T, T>;
395
396    fn add(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
397        match self {
398            #[cfg(feature = "opencl")]
399            Self::CL(cl) => cl.add(left, right).map(AccessOp::wrap),
400            Self::Host(host) => host.add(left, right).map(AccessOp::wrap),
401        }
402    }
403
404    fn div(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
405        match self {
406            #[cfg(feature = "opencl")]
407            Self::CL(cl) => cl.div(left, right).map(AccessOp::wrap),
408            Self::Host(host) => host.div(left, right).map(AccessOp::wrap),
409        }
410    }
411
412    fn log(self, arg: L, base: R) -> Result<AccessOp<Self::Op, Self>, Error>
413    where
414        T: Float,
415    {
416        match self {
417            #[cfg(feature = "opencl")]
418            Self::CL(cl) => cl.log(arg, base).map(AccessOp::wrap),
419            Self::Host(host) => host.log(arg, base).map(AccessOp::wrap),
420        }
421    }
422
423    fn mul(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
424        match self {
425            #[cfg(feature = "opencl")]
426            Self::CL(cl) => cl.mul(left, right).map(AccessOp::wrap),
427            Self::Host(host) => host.mul(left, right).map(AccessOp::wrap),
428        }
429    }
430
431    fn pow(self, arg: L, exp: R) -> Result<AccessOp<Self::Op, Self>, Error> {
432        match self {
433            #[cfg(feature = "opencl")]
434            Self::CL(cl) => cl.pow(arg, exp).map(AccessOp::wrap),
435            Self::Host(host) => host.pow(arg, exp).map(AccessOp::wrap),
436        }
437    }
438
439    fn rem(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error>
440    where
441        T: Real,
442    {
443        match self {
444            #[cfg(feature = "opencl")]
445            Self::CL(cl) => cl.rem(left, right).map(AccessOp::wrap),
446            Self::Host(host) => host.rem(left, right).map(AccessOp::wrap),
447        }
448    }
449
450    fn sub(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
451        match self {
452            #[cfg(feature = "opencl")]
453            Self::CL(cl) => cl.sub(left, right).map(AccessOp::wrap),
454            Self::Host(host) => host.sub(left, right).map(AccessOp::wrap),
455        }
456    }
457}
458
459impl<A: Access<T>, T: Number> ElementwiseScalar<A, T> for Platform {
460    type Op = Scalar<A, T, T>;
461
462    fn add_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
463        match self {
464            #[cfg(feature = "opencl")]
465            Self::CL(cl) => cl.add_scalar(left, right).map(AccessOp::wrap),
466            Self::Host(host) => host.add_scalar(left, right).map(AccessOp::wrap),
467        }
468    }
469
470    fn div_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
471        match self {
472            #[cfg(feature = "opencl")]
473            Self::CL(cl) => cl.div_scalar(left, right).map(AccessOp::wrap),
474            Self::Host(host) => host.div_scalar(left, right).map(AccessOp::wrap),
475        }
476    }
477
478    fn log_scalar(self, arg: A, base: T) -> Result<AccessOp<Self::Op, Self>, Error>
479    where
480        T: Float,
481    {
482        match self {
483            #[cfg(feature = "opencl")]
484            Self::CL(cl) => cl.log_scalar(arg, base).map(AccessOp::wrap),
485            Self::Host(host) => host.log_scalar(arg, base).map(AccessOp::wrap),
486        }
487    }
488
489    fn mul_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
490        match self {
491            #[cfg(feature = "opencl")]
492            Self::CL(cl) => cl.mul_scalar(left, right).map(AccessOp::wrap),
493            Self::Host(host) => host.mul_scalar(left, right).map(AccessOp::wrap),
494        }
495    }
496
497    fn pow_scalar(self, arg: A, exp: T) -> Result<AccessOp<Self::Op, Self>, Error> {
498        match self {
499            #[cfg(feature = "opencl")]
500            Self::CL(cl) => cl.pow_scalar(arg, exp).map(AccessOp::wrap),
501            Self::Host(host) => host.pow_scalar(arg, exp).map(AccessOp::wrap),
502        }
503    }
504
505    fn rem_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error>
506    where
507        T: Real,
508    {
509        match self {
510            #[cfg(feature = "opencl")]
511            Self::CL(cl) => cl.rem_scalar(left, right).map(AccessOp::wrap),
512            Self::Host(host) => host.rem_scalar(left, right).map(AccessOp::wrap),
513        }
514    }
515
516    fn sub_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
517        match self {
518            #[cfg(feature = "opencl")]
519            Self::CL(cl) => cl.sub_scalar(left, right).map(AccessOp::wrap),
520            Self::Host(host) => host.sub_scalar(left, right).map(AccessOp::wrap),
521        }
522    }
523}
524
525impl<A: Access<T>, T: Float> ElementwiseNumeric<A, T> for Platform {
526    type Op = Unary<A, T, u8>;
527
528    fn is_inf(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
529        match self {
530            #[cfg(feature = "opencl")]
531            Self::CL(cl) => cl.is_inf(access).map(AccessOp::wrap),
532            Self::Host(host) => host.is_inf(access).map(AccessOp::wrap),
533        }
534    }
535
536    fn is_nan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
537        match self {
538            #[cfg(feature = "opencl")]
539            Self::CL(cl) => cl.is_nan(access).map(AccessOp::wrap),
540            Self::Host(host) => host.is_nan(access).map(AccessOp::wrap),
541        }
542    }
543}
544
545impl<A: Access<T>, T: Float> ElementwiseTrig<A, T> for Platform {
546    type Op = Unary<A, T, T>;
547
548    fn sin(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
549        match self {
550            #[cfg(feature = "opencl")]
551            Self::CL(cl) => cl.sin(access).map(AccessOp::wrap),
552            Self::Host(host) => host.sin(access).map(AccessOp::wrap),
553        }
554    }
555
556    fn asin(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
557        match self {
558            #[cfg(feature = "opencl")]
559            Self::CL(cl) => cl.asin(access).map(AccessOp::wrap),
560            Self::Host(host) => host.asin(access).map(AccessOp::wrap),
561        }
562    }
563
564    fn sinh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
565        match self {
566            #[cfg(feature = "opencl")]
567            Self::CL(cl) => cl.sinh(access).map(AccessOp::wrap),
568            Self::Host(host) => host.sinh(access).map(AccessOp::wrap),
569        }
570    }
571
572    fn cos(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
573        match self {
574            #[cfg(feature = "opencl")]
575            Self::CL(cl) => cl.cos(access).map(AccessOp::wrap),
576            Self::Host(host) => host.cos(access).map(AccessOp::wrap),
577        }
578    }
579
580    fn acos(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
581        match self {
582            #[cfg(feature = "opencl")]
583            Self::CL(cl) => cl.acos(access).map(AccessOp::wrap),
584            Self::Host(host) => host.acos(access).map(AccessOp::wrap),
585        }
586    }
587
588    fn cosh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
589        match self {
590            #[cfg(feature = "opencl")]
591            Self::CL(cl) => cl.cosh(access).map(AccessOp::wrap),
592            Self::Host(host) => host.cosh(access).map(AccessOp::wrap),
593        }
594    }
595
596    fn tan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
597        match self {
598            #[cfg(feature = "opencl")]
599            Self::CL(cl) => cl.tan(access).map(AccessOp::wrap),
600            Self::Host(host) => host.tan(access).map(AccessOp::wrap),
601        }
602    }
603
604    fn atan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
605        match self {
606            #[cfg(feature = "opencl")]
607            Self::CL(cl) => cl.atan(access).map(AccessOp::wrap),
608            Self::Host(host) => host.atan(access).map(AccessOp::wrap),
609        }
610    }
611
612    fn tanh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
613        match self {
614            #[cfg(feature = "opencl")]
615            Self::CL(cl) => cl.tanh(access).map(AccessOp::wrap),
616            Self::Host(host) => host.tanh(access).map(AccessOp::wrap),
617        }
618    }
619}
620
621impl<A: Access<T>, T: Float> ElementwiseUnary<A, T> for Platform {
622    type Op = Unary<A, T, T>;
623
624    fn exp(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
625        match self {
626            #[cfg(feature = "opencl")]
627            Self::CL(cl) => cl.exp(access).map(AccessOp::wrap),
628            Self::Host(host) => host.exp(access).map(AccessOp::wrap),
629        }
630    }
631
632    fn ln(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
633        match self {
634            #[cfg(feature = "opencl")]
635            Self::CL(cl) => cl.ln(access).map(AccessOp::wrap),
636            Self::Host(host) => host.ln(access).map(AccessOp::wrap),
637        }
638    }
639
640    fn round(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error>
641    where
642        T: Real,
643    {
644        match self {
645            #[cfg(feature = "opencl")]
646            Self::CL(cl) => cl.round(access).map(AccessOp::wrap),
647            Self::Host(host) => host.round(access).map(AccessOp::wrap),
648        }
649    }
650}
651
652impl<A: Access<T>, T: Number> ElementwiseUnaryBoolean<A, T> for Platform {
653    type Op = Unary<A, T, u8>;
654
655    fn not(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
656        match self {
657            #[cfg(feature = "opencl")]
658            Self::CL(cl) => cl.not(access).map(AccessOp::wrap),
659            Self::Host(host) => host.not(access).map(AccessOp::wrap),
660        }
661    }
662}
663
664#[cfg(feature = "complex")]
665impl<A, T> complex::Fourier<A, num_complex::Complex<T>> for Platform
666where
667    A: Access<num_complex::Complex<T>>,
668    T: rustfft::FftNum,
669    num_complex::Complex<T>: crate::Complex,
670{
671    type Op = complex::FFT<A, num_complex::Complex<T>>;
672
673    fn fft(self, access: A, dim: usize) -> Result<AccessOp<Self::Op, Self>, Error> {
674        // TODO: support FFT on OpenCL
675        let host = match self {
676            #[cfg(feature = "opencl")]
677            Self::CL(_cl) => host::Host::select(access.size()),
678            Self::Host(host) => host,
679        };
680
681        host.fft(access, dim).map(AccessOp::wrap)
682    }
683
684    fn ifft(self, access: A, dim: usize) -> Result<AccessOp<Self::Op, Self>, Error> {
685        // TODO: support IFFT on OpenCL
686        let host = match self {
687            #[cfg(feature = "opencl")]
688            Self::CL(_cl) => host::Host::select(access.size()),
689            Self::Host(host) => host,
690        };
691
692        host.fft(access, dim).map(AccessOp::wrap)
693    }
694}
695
696impl<A, L, R, T> GatherCond<A, L, R, T> for Platform
697where
698    A: Access<u8>,
699    L: Access<T>,
700    R: Access<T>,
701    T: Number,
702{
703    type Op = Cond<A, L, R, T>;
704
705    fn cond(self, cond: A, then: L, or_else: R) -> Result<AccessOp<Self::Op, Self>, Error> {
706        match self {
707            #[cfg(feature = "opencl")]
708            Self::CL(cl) => cl.cond(cond, then, or_else).map(AccessOp::wrap),
709            Self::Host(host) => host.cond(cond, then, or_else).map(AccessOp::wrap),
710        }
711    }
712}
713
714impl<L, R, T> LinAlgDual<L, R, T> for Platform
715where
716    L: Access<T>,
717    R: Access<T>,
718    T: Number,
719{
720    type Op = MatMul<L, R, T>;
721
722    fn matmul(
723        self,
724        left: L,
725        right: R,
726        dims: [usize; 4],
727    ) -> Result<AccessOp<Self::Op, Self>, Error> {
728        match self {
729            #[cfg(feature = "opencl")]
730            Self::CL(cl) => cl.matmul(left, right, dims).map(AccessOp::wrap),
731            Self::Host(host) => host.matmul(left, right, dims).map(AccessOp::wrap),
732        }
733    }
734}
735
736impl<A: Access<T>, T: Number> LinAlgUnary<A, T> for Platform {
737    type Op = MatDiag<A, T>;
738
739    fn diag(
740        self,
741        access: A,
742        batch_size: usize,
743        dim: usize,
744    ) -> Result<AccessOp<Self::Op, Self>, Error> {
745        match self {
746            #[cfg(feature = "opencl")]
747            Self::CL(cl) => cl.diag(access, batch_size, dim).map(AccessOp::wrap),
748            Self::Host(host) => host.diag(access, batch_size, dim).map(AccessOp::wrap),
749        }
750    }
751}
752
753impl Random for Platform {
754    type Normal = RandomNormal;
755    type Uniform = RandomUniform;
756
757    fn random_normal(self, size: usize) -> Result<AccessOp<Self::Normal, Self>, Error> {
758        match self {
759            #[cfg(feature = "opencl")]
760            Self::CL(cl) => cl.random_normal(size).map(AccessOp::wrap),
761            Self::Host(host) => host.random_normal(size).map(AccessOp::wrap),
762        }
763    }
764
765    fn random_uniform(self, size: usize) -> Result<AccessOp<Self::Uniform, Self>, Error> {
766        match self {
767            #[cfg(feature = "opencl")]
768            Self::CL(cl) => cl.random_uniform(size).map(AccessOp::wrap),
769            Self::Host(host) => host.random_uniform(size).map(AccessOp::wrap),
770        }
771    }
772}
773
774impl<A, T> ReduceAll<A, T> for Platform
775where
776    A: Access<T>,
777    T: Number,
778{
779    fn all(self, access: A) -> Result<bool, Error> {
780        match self {
781            #[cfg(feature = "opencl")]
782            Self::CL(cl) => cl.all(access),
783            Self::Host(host) => host.all(access),
784        }
785    }
786
787    fn any(self, access: A) -> Result<bool, Error> {
788        match self {
789            #[cfg(feature = "opencl")]
790            Self::CL(cl) => cl.any(access),
791            Self::Host(host) => host.any(access),
792        }
793    }
794
795    fn max(self, access: A) -> Result<T, Error>
796    where
797        T: Real,
798    {
799        match self {
800            #[cfg(feature = "opencl")]
801            Self::CL(cl) => ReduceAll::max(cl, access),
802            Self::Host(host) => ReduceAll::max(host, access),
803        }
804    }
805
806    fn min(self, access: A) -> Result<T, Error>
807    where
808        T: Real,
809    {
810        match self {
811            #[cfg(feature = "opencl")]
812            Self::CL(cl) => ReduceAll::min(cl, access),
813            Self::Host(host) => ReduceAll::min(host, access),
814        }
815    }
816
817    fn product(self, access: A) -> Result<T, Error> {
818        match self {
819            #[cfg(feature = "opencl")]
820            Self::CL(cl) => ReduceAll::product(cl, access),
821            Self::Host(host) => ReduceAll::product(host, access),
822        }
823    }
824
825    fn sum(self, access: A) -> Result<T, Error> {
826        match self {
827            #[cfg(feature = "opencl")]
828            Self::CL(cl) => ReduceAll::sum(cl, access),
829            Self::Host(host) => ReduceAll::sum(host, access),
830        }
831    }
832}
833
834impl<A: Access<T>, T: Number> ReduceAxes<A, T> for Platform {
835    type Op = Reduce<A, T>;
836
837    fn max(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error>
838    where
839        T: Real,
840    {
841        match self {
842            #[cfg(feature = "opencl")]
843            Self::CL(cl) => ReduceAxes::max(cl, access, stride).map(AccessOp::wrap),
844            Self::Host(host) => ReduceAxes::max(host, access, stride).map(AccessOp::wrap),
845        }
846    }
847
848    fn min(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error>
849    where
850        T: Real,
851    {
852        match self {
853            #[cfg(feature = "opencl")]
854            Self::CL(cl) => ReduceAxes::min(cl, access, stride).map(AccessOp::wrap),
855            Self::Host(host) => ReduceAxes::min(host, access, stride).map(AccessOp::wrap),
856        }
857    }
858
859    fn product(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error> {
860        match self {
861            #[cfg(feature = "opencl")]
862            Self::CL(cl) => ReduceAxes::product(cl, access, stride).map(AccessOp::wrap),
863            Self::Host(host) => ReduceAxes::product(host, access, stride).map(AccessOp::wrap),
864        }
865    }
866
867    fn sum(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error> {
868        match self {
869            #[cfg(feature = "opencl")]
870            Self::CL(cl) => ReduceAxes::sum(cl, access, stride).map(AccessOp::wrap),
871            Self::Host(host) => ReduceAxes::sum(host, access, stride).map(AccessOp::wrap),
872        }
873    }
874}
875
876impl<A: Access<T>, T: Number> Transform<A, T> for Platform {
877    type Broadcast = View<A, T>;
878    type Flip = Flip<A, T>;
879    type Slice = Slice<A, T>;
880    type Transpose = View<A, T>;
881
882    fn broadcast(
883        self,
884        access: A,
885        shape: Shape,
886        broadcast: Shape,
887    ) -> Result<AccessOp<Self::Broadcast, Self>, Error> {
888        match self {
889            #[cfg(feature = "opencl")]
890            Self::CL(cl) => cl.broadcast(access, shape, broadcast).map(AccessOp::wrap),
891            Self::Host(host) => host.broadcast(access, shape, broadcast).map(AccessOp::wrap),
892        }
893    }
894
895    fn flip(
896        self,
897        access: A,
898        shape: Shape,
899        axis: usize,
900    ) -> Result<AccessOp<Self::Flip, Self>, Error> {
901        match self {
902            #[cfg(feature = "opencl")]
903            Self::CL(cl) => cl.flip(access, shape, axis).map(AccessOp::wrap),
904            Self::Host(host) => host.flip(access, shape, axis).map(AccessOp::wrap),
905        }
906    }
907
908    fn slice(
909        self,
910        access: A,
911        shape: &[usize],
912        range: Range,
913    ) -> Result<AccessOp<Self::Slice, Self>, Error> {
914        match self {
915            #[cfg(feature = "opencl")]
916            Self::CL(cl) => cl.slice(access, shape, range).map(AccessOp::wrap),
917            Self::Host(host) => host.slice(access, shape, range).map(AccessOp::wrap),
918        }
919    }
920
921    fn transpose(
922        self,
923        access: A,
924        shape: Shape,
925        permutation: Axes,
926    ) -> Result<AccessOp<Self::Transpose, Self>, Error> {
927        match self {
928            #[cfg(feature = "opencl")]
929            Self::CL(cl) => cl.transpose(access, shape, permutation).map(AccessOp::wrap),
930            Self::Host(host) => host
931                .transpose(access, shape, permutation)
932                .map(AccessOp::wrap),
933        }
934    }
935}