ha_ndarray/host/
platform.rs

1use rayon::prelude::*;
2
3use crate::access::{Access, AccessOp};
4use crate::buffer::BufferConverter;
5use crate::host::StackVec;
6use crate::ops::{
7    Construct, ElementwiseBoolean, ElementwiseBooleanScalar, ElementwiseCast, ElementwiseCompare,
8    ElementwiseDual, ElementwiseNumeric, ElementwiseScalar, ElementwiseScalarCompare,
9    ElementwiseTrig, ElementwiseUnary, ElementwiseUnaryBoolean, GatherCond, LinAlgDual,
10    LinAlgUnary, Random, ReduceAll, ReduceAxes, Transform,
11};
12use crate::platform::{Convert, PlatformInstance};
13use crate::{stackvec, Axes, CType, Constant, Error, Float, Range, Shape};
14
15use super::buffer::Buffer;
16use super::ops::*;
17
18pub const VEC_MIN_SIZE: usize = 64;
19
20#[derive(Debug, Copy, Clone, Eq, PartialEq)]
21pub struct Stack;
22
23impl PlatformInstance for Stack {
24    fn select(_size_hint: usize) -> Self {
25        Self
26    }
27}
28
29impl<T: CType> Constant<T> for Stack {
30    type Buffer = StackVec<T>;
31
32    fn constant(&self, value: T, size: usize) -> Result<Self::Buffer, Error> {
33        Ok(stackvec![value; size])
34    }
35}
36
37impl<T: CType> Convert<T> for Stack {
38    type Buffer = StackVec<T>;
39
40    fn convert(&self, buffer: BufferConverter<T>) -> Result<Self::Buffer, Error> {
41        buffer.to_slice().map(|buf| buf.into_stackvec())
42    }
43}
44
45impl<A, T> ReduceAll<A, T> for Stack
46where
47    A: Access<T>,
48    T: CType,
49{
50    fn all(self, access: A) -> Result<bool, Error> {
51        access
52            .read()
53            .and_then(|buf| buf.to_slice())
54            .map(|slice| slice.iter().copied().all(|n| n != T::ZERO))
55    }
56
57    fn any(self, access: A) -> Result<bool, Error> {
58        access
59            .read()
60            .and_then(|buf| buf.to_slice())
61            .map(|slice| slice.iter().copied().any(|n| n != T::ZERO))
62    }
63
64    fn max(self, access: A) -> Result<T, Error> {
65        access
66            .read()
67            .and_then(|buf| buf.to_slice())
68            .map(|slice| slice.iter().copied().reduce(T::max).expect("max"))
69    }
70
71    fn min(self, access: A) -> Result<T, Error> {
72        access
73            .read()
74            .and_then(|buf| buf.to_slice())
75            .map(|slice| slice.iter().copied().reduce(T::min).expect("min"))
76    }
77
78    fn product(self, access: A) -> Result<T, Error> {
79        access
80            .read()
81            .and_then(|buf| buf.to_slice())
82            .map(|slice| slice.iter().copied().reduce(T::mul).expect("product"))
83    }
84
85    fn sum(self, access: A) -> Result<T, Error> {
86        access
87            .read()
88            .and_then(|buf| buf.to_slice())
89            .map(|slice| slice.iter().copied().reduce(T::add).expect("sum"))
90    }
91}
92
93#[derive(Debug, Copy, Clone, Eq, PartialEq)]
94pub struct Heap;
95
96impl PlatformInstance for Heap {
97    fn select(_size_hint: usize) -> Self {
98        Self
99    }
100}
101
102impl<T: CType> Constant<T> for Heap {
103    type Buffer = Vec<T>;
104
105    fn constant(&self, value: T, size: usize) -> Result<Self::Buffer, Error> {
106        Ok(vec![value; size])
107    }
108}
109
110impl<T: CType> Convert<T> for Heap {
111    type Buffer = Vec<T>;
112
113    fn convert(&self, buffer: BufferConverter<T>) -> Result<Self::Buffer, Error> {
114        buffer.to_slice().map(|buf| buf.into_vec())
115    }
116}
117
118impl<A, T> ReduceAll<A, T> for Heap
119where
120    A: Access<T>,
121    T: CType,
122{
123    fn all(self, access: A) -> Result<bool, Error> {
124        access
125            .read()
126            .and_then(|buf| buf.to_slice())
127            .map(|slice| slice.into_par_iter().copied().all(|n| n != T::ZERO))
128    }
129
130    fn any(self, access: A) -> Result<bool, Error> {
131        access
132            .read()
133            .and_then(|buf| buf.to_slice())
134            .map(|slice| slice.into_par_iter().copied().any(|n| n != T::ZERO))
135    }
136
137    fn max(self, access: A) -> Result<T, Error> {
138        access
139            .read()
140            .and_then(|buf| buf.to_slice())
141            .map(|slice| slice.into_par_iter().copied().reduce(|| T::MIN, T::max))
142    }
143
144    fn min(self, access: A) -> Result<T, Error> {
145        access
146            .read()
147            .and_then(|buf| buf.to_slice())
148            .map(|slice| slice.into_par_iter().copied().reduce(|| T::MAX, T::min))
149    }
150
151    fn product(self, access: A) -> Result<T, Error> {
152        access
153            .read()
154            .and_then(|buf| buf.to_slice())
155            .map(|slice| slice.into_par_iter().copied().reduce(|| T::ONE, T::mul))
156    }
157
158    fn sum(self, access: A) -> Result<T, Error> {
159        access
160            .read()
161            .and_then(|buf| buf.to_slice())
162            .map(|slice| slice.into_par_iter().copied().reduce(|| T::ZERO, T::add))
163    }
164}
165
166#[derive(Debug, Copy, Clone, Eq, PartialEq)]
167pub enum Host {
168    Stack(Stack),
169    Heap(Heap),
170}
171
172impl PlatformInstance for Host {
173    fn select(size_hint: usize) -> Self {
174        if size_hint < VEC_MIN_SIZE {
175            Self::Stack(Stack)
176        } else {
177            Self::Heap(Heap)
178        }
179    }
180}
181
182impl<T: CType> Constant<T> for Host {
183    type Buffer = Buffer<T>;
184
185    fn constant(&self, value: T, size: usize) -> Result<Self::Buffer, Error> {
186        match self {
187            Self::Heap(heap) => heap.constant(value, size).map(Buffer::Heap),
188            Self::Stack(stack) => stack.constant(value, size).map(Buffer::Stack),
189        }
190    }
191}
192
193impl<T: CType> Convert<T> for Host {
194    type Buffer = Buffer<T>;
195
196    fn convert(&self, buffer: BufferConverter<T>) -> Result<Self::Buffer, Error> {
197        match self {
198            Self::Heap(heap) => heap.convert(buffer).map(Buffer::Heap),
199            Self::Stack(stack) => stack.convert(buffer).map(Buffer::Stack),
200        }
201    }
202}
203
204impl From<Heap> for Host {
205    fn from(heap: Heap) -> Self {
206        Self::Heap(heap)
207    }
208}
209
210impl From<Stack> for Host {
211    fn from(stack: Stack) -> Self {
212        Self::Stack(stack)
213    }
214}
215
216impl<T: CType> Construct<T> for Host {
217    type Range = Linear<T>;
218
219    fn range(self, start: T, stop: T, size: usize) -> Result<AccessOp<Self::Range, Self>, Error> {
220        if start <= stop {
221            let step = T::sub(stop, start).to_f64() / size as f64;
222            Ok(Linear::new(start, step, size).into())
223        } else {
224            Err(Error::Bounds(format!("invalid range: [{start}, {stop})")))
225        }
226    }
227}
228
229impl<A: Access<IT>, IT: CType, OT: CType> ElementwiseCast<A, IT, OT> for Host {
230    type Op = Cast<A, IT, OT>;
231
232    fn cast(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
233        Ok(Cast::new(access).into())
234    }
235}
236
237impl<A, L, R, T> GatherCond<A, L, R, T> for Host
238where
239    A: Access<u8>,
240    L: Access<T>,
241    R: Access<T>,
242    T: CType,
243{
244    type Op = Cond<A, L, R, T>;
245
246    fn cond(self, cond: A, then: L, or_else: R) -> Result<AccessOp<Self::Op, Self>, Error> {
247        Ok(Cond::new(cond, then, or_else).into())
248    }
249}
250
251impl<L, R, T> ElementwiseBoolean<L, R, T> for Host
252where
253    L: Access<T>,
254    R: Access<T>,
255    T: CType,
256{
257    type Op = Dual<L, R, T, u8>;
258
259    fn and(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
260        Ok(Dual::and(left, right).into())
261    }
262
263    fn or(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
264        Ok(Dual::or(left, right).into())
265    }
266
267    fn xor(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
268        Ok(Dual::xor(left, right).into())
269    }
270}
271
272impl<A: Access<T>, T: CType> ElementwiseBooleanScalar<A, T> for Host {
273    type Op = Scalar<A, T, u8>;
274
275    fn and_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
276        Ok(Scalar::and(left, right).into())
277    }
278
279    fn or_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
280        Ok(Scalar::or(left, right).into())
281    }
282
283    fn xor_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
284        Ok(Scalar::xor(left, right).into())
285    }
286}
287
288impl<L, R, T> ElementwiseCompare<L, R, T> for Host
289where
290    L: Access<T>,
291    R: Access<T>,
292    T: CType,
293{
294    type Op = Dual<L, R, T, u8>;
295
296    fn eq(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
297        Ok(Dual::eq(left, right).into())
298    }
299
300    fn ge(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
301        Ok(Dual::ge(left, right).into())
302    }
303
304    fn gt(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
305        Ok(Dual::gt(left, right).into())
306    }
307
308    fn le(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
309        Ok(Dual::le(left, right).into())
310    }
311
312    fn lt(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
313        Ok(Dual::lt(left, right).into())
314    }
315
316    fn ne(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
317        Ok(Dual::ne(left, right).into())
318    }
319}
320
321impl<A: Access<T>, T: CType> ElementwiseScalarCompare<A, T> for Host {
322    type Op = Scalar<A, T, u8>;
323
324    fn eq_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
325        Ok(Scalar::eq(left, right).into())
326    }
327
328    fn ge_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
329        Ok(Scalar::ge(left, right).into())
330    }
331
332    fn gt_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
333        Ok(Scalar::gt(left, right).into())
334    }
335
336    fn le_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
337        Ok(Scalar::le(left, right).into())
338    }
339
340    fn lt_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
341        Ok(Scalar::lt(left, right).into())
342    }
343
344    fn ne_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
345        Ok(Scalar::ne(left, right).into())
346    }
347}
348
349impl<L, R, T> ElementwiseDual<L, R, T> for Host
350where
351    L: Access<T>,
352    R: Access<T>,
353    T: CType,
354{
355    type Op = Dual<L, R, T, T>;
356
357    fn add(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
358        Ok(Dual::add(left, right).into())
359    }
360
361    fn div(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
362        Ok(Dual::div(left, right).into())
363    }
364
365    fn log(self, arg: L, base: R) -> Result<AccessOp<Self::Op, Self>, Error> {
366        Ok(Dual::log(arg, base).into())
367    }
368
369    fn mul(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
370        Ok(Dual::mul(left, right).into())
371    }
372
373    fn pow(self, arg: L, exp: R) -> Result<AccessOp<Self::Op, Self>, Error> {
374        Ok(Dual::pow(arg, exp).into())
375    }
376
377    fn rem(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
378        Ok(Dual::rem(left, right).into())
379    }
380
381    fn sub(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
382        Ok(Dual::sub(left, right).into())
383    }
384}
385
386impl<A: Access<T>, T: CType> ElementwiseScalar<A, T> for Host {
387    type Op = Scalar<A, T, T>;
388
389    fn add_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
390        Ok(Scalar::add(left, right).into())
391    }
392
393    fn div_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
394        Ok(Scalar::div(left, right).into())
395    }
396
397    fn log_scalar(self, arg: A, base: T) -> Result<AccessOp<Self::Op, Self>, Error> {
398        Ok(Scalar::log(arg, base).into())
399    }
400
401    fn mul_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
402        Ok(Scalar::mul(left, right).into())
403    }
404
405    fn pow_scalar(self, arg: A, exp: T) -> Result<AccessOp<Self::Op, Self>, Error> {
406        Ok(Scalar::pow(arg, exp).into())
407    }
408
409    fn rem_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
410        Ok(Scalar::rem(left, right).into())
411    }
412
413    fn sub_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
414        Ok(Scalar::sub(left, right).into())
415    }
416}
417
418impl<A: Access<T>, T: Float> ElementwiseNumeric<A, T> for Host {
419    type Op = Unary<A, T, u8>;
420
421    fn is_inf(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
422        Ok(Unary::inf(access).into())
423    }
424
425    fn is_nan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
426        Ok(Unary::nan(access).into())
427    }
428}
429
430impl<A: Access<T>, T: CType> ElementwiseTrig<A, T> for Host {
431    type Op = Unary<A, T, T::Float>;
432
433    fn sin(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
434        Ok(Unary::sin(access).into())
435    }
436
437    fn asin(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
438        Ok(Unary::asin(access).into())
439    }
440
441    fn sinh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
442        Ok(Unary::sinh(access).into())
443    }
444
445    fn cos(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
446        Ok(Unary::cos(access).into())
447    }
448
449    fn acos(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
450        Ok(Unary::acos(access).into())
451    }
452
453    fn cosh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
454        Ok(Unary::cosh(access).into())
455    }
456
457    fn tan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
458        Ok(Unary::tan(access).into())
459    }
460
461    fn atan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
462        Ok(Unary::atan(access).into())
463    }
464
465    fn tanh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
466        Ok(Unary::tanh(access).into())
467    }
468}
469
470impl<A: Access<T>, T: CType> ElementwiseUnary<A, T> for Host {
471    type Op = Unary<A, T, T>;
472
473    fn abs(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
474        Ok(Unary::abs(access).into())
475    }
476
477    fn exp(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
478        Ok(Unary::exp(access).into())
479    }
480
481    fn ln(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
482        Ok(Unary::ln(access).into())
483    }
484
485    fn round(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
486        Ok(Unary::round(access).into())
487    }
488}
489
490impl<A: Access<T>, T: CType> ElementwiseUnaryBoolean<A, T> for Host {
491    type Op = Unary<A, T, u8>;
492
493    fn not(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
494        Ok(Unary::not(access).into())
495    }
496}
497
498impl<L, R, T> LinAlgDual<L, R, T> for Host
499where
500    L: Access<T>,
501    R: Access<T>,
502    T: CType,
503{
504    type Op = MatMul<L, R, T>;
505
506    fn matmul(
507        self,
508        left: L,
509        right: R,
510        dims: [usize; 4],
511    ) -> Result<AccessOp<Self::Op, Self>, Error> {
512        Ok(MatMul::new(left, right, dims).into())
513    }
514}
515
516impl<A: Access<T>, T: CType> LinAlgUnary<A, T> for Host {
517    type Op = MatDiag<A, T>;
518
519    fn diag(
520        self,
521        access: A,
522        batch_size: usize,
523        dim: usize,
524    ) -> Result<AccessOp<Self::Op, Self>, Error> {
525        Ok(MatDiag::new(access, batch_size, dim).into())
526    }
527}
528
529impl Random for Host {
530    type Normal = RandomNormal;
531    type Uniform = RandomUniform;
532
533    fn random_normal(self, size: usize) -> Result<AccessOp<Self::Normal, Self>, Error> {
534        Ok(RandomNormal::new(size).into())
535    }
536
537    fn random_uniform(self, size: usize) -> Result<AccessOp<Self::Uniform, Self>, Error> {
538        Ok(RandomUniform::new(size).into())
539    }
540}
541
542impl<A: Access<T>, T: CType> ReduceAll<A, T> for Host {
543    fn all(self, access: A) -> Result<bool, Error> {
544        match self {
545            Self::Heap(heap) => heap.all(access),
546            Self::Stack(stack) => stack.all(access),
547        }
548    }
549
550    fn any(self, access: A) -> Result<bool, Error> {
551        match self {
552            Self::Heap(heap) => heap.any(access),
553            Self::Stack(stack) => stack.any(access),
554        }
555    }
556
557    fn max(self, access: A) -> Result<T, Error> {
558        match self {
559            Self::Heap(heap) => heap.max(access),
560            Self::Stack(stack) => stack.max(access),
561        }
562    }
563
564    fn min(self, access: A) -> Result<T, Error> {
565        match self {
566            Self::Heap(heap) => heap.min(access),
567            Self::Stack(stack) => stack.min(access),
568        }
569    }
570
571    fn product(self, access: A) -> Result<T, Error> {
572        match self {
573            Self::Heap(heap) => heap.product(access),
574            Self::Stack(stack) => stack.product(access),
575        }
576    }
577
578    fn sum(self, access: A) -> Result<T, Error> {
579        match self {
580            Self::Heap(heap) => heap.sum(access),
581            Self::Stack(stack) => stack.sum(access),
582        }
583    }
584}
585
586impl<A: Access<T>, T: CType> ReduceAxes<A, T> for Host {
587    type Op = Reduce<A, T>;
588
589    fn max(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error> {
590        Ok(Reduce::max(access, stride).into())
591    }
592
593    fn min(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error> {
594        Ok(Reduce::min(access, stride).into())
595    }
596
597    fn product(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error> {
598        Ok(Reduce::product(access, stride).into())
599    }
600
601    fn sum(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error> {
602        Ok(Reduce::sum(access, stride).into())
603    }
604}
605
606impl<'a, A, T> Transform<A, T> for Host
607where
608    A: Access<T>,
609    T: CType,
610{
611    type Broadcast = View<A, T>;
612    type Slice = Slice<A, T>;
613    type Transpose = View<A, T>;
614
615    fn broadcast(
616        self,
617        access: A,
618        shape: Shape,
619        broadcast: Shape,
620    ) -> Result<AccessOp<Self::Broadcast, Self>, Error> {
621        Ok(View::broadcast(access, shape, broadcast).into())
622    }
623
624    fn slice(
625        self,
626        access: A,
627        shape: &[usize],
628        range: Range,
629    ) -> Result<AccessOp<Self::Slice, Self>, Error> {
630        Ok(Slice::new(access, shape, range).into())
631    }
632
633    fn transpose(
634        self,
635        access: A,
636        shape: Shape,
637        permutation: Axes,
638    ) -> Result<AccessOp<Self::Transpose, Self>, Error> {
639        Ok(View::transpose(access, shape, permutation).into())
640    }
641}