1use rayon::prelude::*;
2
3use crate::access::{Access, AccessOp};
4use crate::buffer::BufferConverter;
5use crate::host::StackVec;
6use crate::ops::{
7 Construct, ElementwiseBoolean, ElementwiseBooleanScalar, ElementwiseCast, ElementwiseCompare,
8 ElementwiseDual, ElementwiseNumeric, ElementwiseScalar, ElementwiseScalarCompare,
9 ElementwiseTrig, ElementwiseUnary, ElementwiseUnaryBoolean, GatherCond, LinAlgDual,
10 LinAlgUnary, Random, ReduceAll, ReduceAxes, Transform,
11};
12use crate::platform::{Convert, PlatformInstance};
13use crate::{stackvec, Axes, CType, Constant, Error, Float, Range, Shape};
14
15use super::buffer::Buffer;
16use super::ops::*;
17
18pub const VEC_MIN_SIZE: usize = 64;
19
20#[derive(Debug, Copy, Clone, Eq, PartialEq)]
21pub struct Stack;
22
23impl PlatformInstance for Stack {
24 fn select(_size_hint: usize) -> Self {
25 Self
26 }
27}
28
29impl<T: CType> Constant<T> for Stack {
30 type Buffer = StackVec<T>;
31
32 fn constant(&self, value: T, size: usize) -> Result<Self::Buffer, Error> {
33 Ok(stackvec![value; size])
34 }
35}
36
37impl<T: CType> Convert<T> for Stack {
38 type Buffer = StackVec<T>;
39
40 fn convert(&self, buffer: BufferConverter<T>) -> Result<Self::Buffer, Error> {
41 buffer.to_slice().map(|buf| buf.into_stackvec())
42 }
43}
44
45impl<A, T> ReduceAll<A, T> for Stack
46where
47 A: Access<T>,
48 T: CType,
49{
50 fn all(self, access: A) -> Result<bool, Error> {
51 access
52 .read()
53 .and_then(|buf| buf.to_slice())
54 .map(|slice| slice.iter().copied().all(|n| n != T::ZERO))
55 }
56
57 fn any(self, access: A) -> Result<bool, Error> {
58 access
59 .read()
60 .and_then(|buf| buf.to_slice())
61 .map(|slice| slice.iter().copied().any(|n| n != T::ZERO))
62 }
63
64 fn max(self, access: A) -> Result<T, Error> {
65 access
66 .read()
67 .and_then(|buf| buf.to_slice())
68 .map(|slice| slice.iter().copied().reduce(T::max).expect("max"))
69 }
70
71 fn min(self, access: A) -> Result<T, Error> {
72 access
73 .read()
74 .and_then(|buf| buf.to_slice())
75 .map(|slice| slice.iter().copied().reduce(T::min).expect("min"))
76 }
77
78 fn product(self, access: A) -> Result<T, Error> {
79 access
80 .read()
81 .and_then(|buf| buf.to_slice())
82 .map(|slice| slice.iter().copied().reduce(T::mul).expect("product"))
83 }
84
85 fn sum(self, access: A) -> Result<T, Error> {
86 access
87 .read()
88 .and_then(|buf| buf.to_slice())
89 .map(|slice| slice.iter().copied().reduce(T::add).expect("sum"))
90 }
91}
92
93#[derive(Debug, Copy, Clone, Eq, PartialEq)]
94pub struct Heap;
95
96impl PlatformInstance for Heap {
97 fn select(_size_hint: usize) -> Self {
98 Self
99 }
100}
101
102impl<T: CType> Constant<T> for Heap {
103 type Buffer = Vec<T>;
104
105 fn constant(&self, value: T, size: usize) -> Result<Self::Buffer, Error> {
106 Ok(vec![value; size])
107 }
108}
109
110impl<T: CType> Convert<T> for Heap {
111 type Buffer = Vec<T>;
112
113 fn convert(&self, buffer: BufferConverter<T>) -> Result<Self::Buffer, Error> {
114 buffer.to_slice().map(|buf| buf.into_vec())
115 }
116}
117
118impl<A, T> ReduceAll<A, T> for Heap
119where
120 A: Access<T>,
121 T: CType,
122{
123 fn all(self, access: A) -> Result<bool, Error> {
124 access
125 .read()
126 .and_then(|buf| buf.to_slice())
127 .map(|slice| slice.into_par_iter().copied().all(|n| n != T::ZERO))
128 }
129
130 fn any(self, access: A) -> Result<bool, Error> {
131 access
132 .read()
133 .and_then(|buf| buf.to_slice())
134 .map(|slice| slice.into_par_iter().copied().any(|n| n != T::ZERO))
135 }
136
137 fn max(self, access: A) -> Result<T, Error> {
138 access
139 .read()
140 .and_then(|buf| buf.to_slice())
141 .map(|slice| slice.into_par_iter().copied().reduce(|| T::MIN, T::max))
142 }
143
144 fn min(self, access: A) -> Result<T, Error> {
145 access
146 .read()
147 .and_then(|buf| buf.to_slice())
148 .map(|slice| slice.into_par_iter().copied().reduce(|| T::MAX, T::min))
149 }
150
151 fn product(self, access: A) -> Result<T, Error> {
152 access
153 .read()
154 .and_then(|buf| buf.to_slice())
155 .map(|slice| slice.into_par_iter().copied().reduce(|| T::ONE, T::mul))
156 }
157
158 fn sum(self, access: A) -> Result<T, Error> {
159 access
160 .read()
161 .and_then(|buf| buf.to_slice())
162 .map(|slice| slice.into_par_iter().copied().reduce(|| T::ZERO, T::add))
163 }
164}
165
166#[derive(Debug, Copy, Clone, Eq, PartialEq)]
167pub enum Host {
168 Stack(Stack),
169 Heap(Heap),
170}
171
172impl PlatformInstance for Host {
173 fn select(size_hint: usize) -> Self {
174 if size_hint < VEC_MIN_SIZE {
175 Self::Stack(Stack)
176 } else {
177 Self::Heap(Heap)
178 }
179 }
180}
181
182impl<T: CType> Constant<T> for Host {
183 type Buffer = Buffer<T>;
184
185 fn constant(&self, value: T, size: usize) -> Result<Self::Buffer, Error> {
186 match self {
187 Self::Heap(heap) => heap.constant(value, size).map(Buffer::Heap),
188 Self::Stack(stack) => stack.constant(value, size).map(Buffer::Stack),
189 }
190 }
191}
192
193impl<T: CType> Convert<T> for Host {
194 type Buffer = Buffer<T>;
195
196 fn convert(&self, buffer: BufferConverter<T>) -> Result<Self::Buffer, Error> {
197 match self {
198 Self::Heap(heap) => heap.convert(buffer).map(Buffer::Heap),
199 Self::Stack(stack) => stack.convert(buffer).map(Buffer::Stack),
200 }
201 }
202}
203
204impl From<Heap> for Host {
205 fn from(heap: Heap) -> Self {
206 Self::Heap(heap)
207 }
208}
209
210impl From<Stack> for Host {
211 fn from(stack: Stack) -> Self {
212 Self::Stack(stack)
213 }
214}
215
216impl<T: CType> Construct<T> for Host {
217 type Range = Linear<T>;
218
219 fn range(self, start: T, stop: T, size: usize) -> Result<AccessOp<Self::Range, Self>, Error> {
220 if start <= stop {
221 let step = T::sub(stop, start).to_f64() / size as f64;
222 Ok(Linear::new(start, step, size).into())
223 } else {
224 Err(Error::Bounds(format!("invalid range: [{start}, {stop})")))
225 }
226 }
227}
228
229impl<A: Access<IT>, IT: CType, OT: CType> ElementwiseCast<A, IT, OT> for Host {
230 type Op = Cast<A, IT, OT>;
231
232 fn cast(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
233 Ok(Cast::new(access).into())
234 }
235}
236
237impl<A, L, R, T> GatherCond<A, L, R, T> for Host
238where
239 A: Access<u8>,
240 L: Access<T>,
241 R: Access<T>,
242 T: CType,
243{
244 type Op = Cond<A, L, R, T>;
245
246 fn cond(self, cond: A, then: L, or_else: R) -> Result<AccessOp<Self::Op, Self>, Error> {
247 Ok(Cond::new(cond, then, or_else).into())
248 }
249}
250
251impl<L, R, T> ElementwiseBoolean<L, R, T> for Host
252where
253 L: Access<T>,
254 R: Access<T>,
255 T: CType,
256{
257 type Op = Dual<L, R, T, u8>;
258
259 fn and(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
260 Ok(Dual::and(left, right).into())
261 }
262
263 fn or(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
264 Ok(Dual::or(left, right).into())
265 }
266
267 fn xor(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
268 Ok(Dual::xor(left, right).into())
269 }
270}
271
272impl<A: Access<T>, T: CType> ElementwiseBooleanScalar<A, T> for Host {
273 type Op = Scalar<A, T, u8>;
274
275 fn and_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
276 Ok(Scalar::and(left, right).into())
277 }
278
279 fn or_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
280 Ok(Scalar::or(left, right).into())
281 }
282
283 fn xor_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
284 Ok(Scalar::xor(left, right).into())
285 }
286}
287
288impl<L, R, T> ElementwiseCompare<L, R, T> for Host
289where
290 L: Access<T>,
291 R: Access<T>,
292 T: CType,
293{
294 type Op = Dual<L, R, T, u8>;
295
296 fn eq(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
297 Ok(Dual::eq(left, right).into())
298 }
299
300 fn ge(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
301 Ok(Dual::ge(left, right).into())
302 }
303
304 fn gt(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
305 Ok(Dual::gt(left, right).into())
306 }
307
308 fn le(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
309 Ok(Dual::le(left, right).into())
310 }
311
312 fn lt(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
313 Ok(Dual::lt(left, right).into())
314 }
315
316 fn ne(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
317 Ok(Dual::ne(left, right).into())
318 }
319}
320
321impl<A: Access<T>, T: CType> ElementwiseScalarCompare<A, T> for Host {
322 type Op = Scalar<A, T, u8>;
323
324 fn eq_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
325 Ok(Scalar::eq(left, right).into())
326 }
327
328 fn ge_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
329 Ok(Scalar::ge(left, right).into())
330 }
331
332 fn gt_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
333 Ok(Scalar::gt(left, right).into())
334 }
335
336 fn le_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
337 Ok(Scalar::le(left, right).into())
338 }
339
340 fn lt_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
341 Ok(Scalar::lt(left, right).into())
342 }
343
344 fn ne_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
345 Ok(Scalar::ne(left, right).into())
346 }
347}
348
349impl<L, R, T> ElementwiseDual<L, R, T> for Host
350where
351 L: Access<T>,
352 R: Access<T>,
353 T: CType,
354{
355 type Op = Dual<L, R, T, T>;
356
357 fn add(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
358 Ok(Dual::add(left, right).into())
359 }
360
361 fn div(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
362 Ok(Dual::div(left, right).into())
363 }
364
365 fn log(self, arg: L, base: R) -> Result<AccessOp<Self::Op, Self>, Error> {
366 Ok(Dual::log(arg, base).into())
367 }
368
369 fn mul(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
370 Ok(Dual::mul(left, right).into())
371 }
372
373 fn pow(self, arg: L, exp: R) -> Result<AccessOp<Self::Op, Self>, Error> {
374 Ok(Dual::pow(arg, exp).into())
375 }
376
377 fn rem(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
378 Ok(Dual::rem(left, right).into())
379 }
380
381 fn sub(self, left: L, right: R) -> Result<AccessOp<Self::Op, Self>, Error> {
382 Ok(Dual::sub(left, right).into())
383 }
384}
385
386impl<A: Access<T>, T: CType> ElementwiseScalar<A, T> for Host {
387 type Op = Scalar<A, T, T>;
388
389 fn add_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
390 Ok(Scalar::add(left, right).into())
391 }
392
393 fn div_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
394 Ok(Scalar::div(left, right).into())
395 }
396
397 fn log_scalar(self, arg: A, base: T) -> Result<AccessOp<Self::Op, Self>, Error> {
398 Ok(Scalar::log(arg, base).into())
399 }
400
401 fn mul_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
402 Ok(Scalar::mul(left, right).into())
403 }
404
405 fn pow_scalar(self, arg: A, exp: T) -> Result<AccessOp<Self::Op, Self>, Error> {
406 Ok(Scalar::pow(arg, exp).into())
407 }
408
409 fn rem_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
410 Ok(Scalar::rem(left, right).into())
411 }
412
413 fn sub_scalar(self, left: A, right: T) -> Result<AccessOp<Self::Op, Self>, Error> {
414 Ok(Scalar::sub(left, right).into())
415 }
416}
417
418impl<A: Access<T>, T: Float> ElementwiseNumeric<A, T> for Host {
419 type Op = Unary<A, T, u8>;
420
421 fn is_inf(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
422 Ok(Unary::inf(access).into())
423 }
424
425 fn is_nan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
426 Ok(Unary::nan(access).into())
427 }
428}
429
430impl<A: Access<T>, T: CType> ElementwiseTrig<A, T> for Host {
431 type Op = Unary<A, T, T::Float>;
432
433 fn sin(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
434 Ok(Unary::sin(access).into())
435 }
436
437 fn asin(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
438 Ok(Unary::asin(access).into())
439 }
440
441 fn sinh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
442 Ok(Unary::sinh(access).into())
443 }
444
445 fn cos(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
446 Ok(Unary::cos(access).into())
447 }
448
449 fn acos(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
450 Ok(Unary::acos(access).into())
451 }
452
453 fn cosh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
454 Ok(Unary::cosh(access).into())
455 }
456
457 fn tan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
458 Ok(Unary::tan(access).into())
459 }
460
461 fn atan(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
462 Ok(Unary::atan(access).into())
463 }
464
465 fn tanh(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
466 Ok(Unary::tanh(access).into())
467 }
468}
469
470impl<A: Access<T>, T: CType> ElementwiseUnary<A, T> for Host {
471 type Op = Unary<A, T, T>;
472
473 fn abs(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
474 Ok(Unary::abs(access).into())
475 }
476
477 fn exp(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
478 Ok(Unary::exp(access).into())
479 }
480
481 fn ln(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
482 Ok(Unary::ln(access).into())
483 }
484
485 fn round(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
486 Ok(Unary::round(access).into())
487 }
488}
489
490impl<A: Access<T>, T: CType> ElementwiseUnaryBoolean<A, T> for Host {
491 type Op = Unary<A, T, u8>;
492
493 fn not(self, access: A) -> Result<AccessOp<Self::Op, Self>, Error> {
494 Ok(Unary::not(access).into())
495 }
496}
497
498impl<L, R, T> LinAlgDual<L, R, T> for Host
499where
500 L: Access<T>,
501 R: Access<T>,
502 T: CType,
503{
504 type Op = MatMul<L, R, T>;
505
506 fn matmul(
507 self,
508 left: L,
509 right: R,
510 dims: [usize; 4],
511 ) -> Result<AccessOp<Self::Op, Self>, Error> {
512 Ok(MatMul::new(left, right, dims).into())
513 }
514}
515
516impl<A: Access<T>, T: CType> LinAlgUnary<A, T> for Host {
517 type Op = MatDiag<A, T>;
518
519 fn diag(
520 self,
521 access: A,
522 batch_size: usize,
523 dim: usize,
524 ) -> Result<AccessOp<Self::Op, Self>, Error> {
525 Ok(MatDiag::new(access, batch_size, dim).into())
526 }
527}
528
529impl Random for Host {
530 type Normal = RandomNormal;
531 type Uniform = RandomUniform;
532
533 fn random_normal(self, size: usize) -> Result<AccessOp<Self::Normal, Self>, Error> {
534 Ok(RandomNormal::new(size).into())
535 }
536
537 fn random_uniform(self, size: usize) -> Result<AccessOp<Self::Uniform, Self>, Error> {
538 Ok(RandomUniform::new(size).into())
539 }
540}
541
542impl<A: Access<T>, T: CType> ReduceAll<A, T> for Host {
543 fn all(self, access: A) -> Result<bool, Error> {
544 match self {
545 Self::Heap(heap) => heap.all(access),
546 Self::Stack(stack) => stack.all(access),
547 }
548 }
549
550 fn any(self, access: A) -> Result<bool, Error> {
551 match self {
552 Self::Heap(heap) => heap.any(access),
553 Self::Stack(stack) => stack.any(access),
554 }
555 }
556
557 fn max(self, access: A) -> Result<T, Error> {
558 match self {
559 Self::Heap(heap) => heap.max(access),
560 Self::Stack(stack) => stack.max(access),
561 }
562 }
563
564 fn min(self, access: A) -> Result<T, Error> {
565 match self {
566 Self::Heap(heap) => heap.min(access),
567 Self::Stack(stack) => stack.min(access),
568 }
569 }
570
571 fn product(self, access: A) -> Result<T, Error> {
572 match self {
573 Self::Heap(heap) => heap.product(access),
574 Self::Stack(stack) => stack.product(access),
575 }
576 }
577
578 fn sum(self, access: A) -> Result<T, Error> {
579 match self {
580 Self::Heap(heap) => heap.sum(access),
581 Self::Stack(stack) => stack.sum(access),
582 }
583 }
584}
585
586impl<A: Access<T>, T: CType> ReduceAxes<A, T> for Host {
587 type Op = Reduce<A, T>;
588
589 fn max(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error> {
590 Ok(Reduce::max(access, stride).into())
591 }
592
593 fn min(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error> {
594 Ok(Reduce::min(access, stride).into())
595 }
596
597 fn product(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error> {
598 Ok(Reduce::product(access, stride).into())
599 }
600
601 fn sum(self, access: A, stride: usize) -> Result<AccessOp<Self::Op, Self>, Error> {
602 Ok(Reduce::sum(access, stride).into())
603 }
604}
605
606impl<'a, A, T> Transform<A, T> for Host
607where
608 A: Access<T>,
609 T: CType,
610{
611 type Broadcast = View<A, T>;
612 type Slice = Slice<A, T>;
613 type Transpose = View<A, T>;
614
615 fn broadcast(
616 self,
617 access: A,
618 shape: Shape,
619 broadcast: Shape,
620 ) -> Result<AccessOp<Self::Broadcast, Self>, Error> {
621 Ok(View::broadcast(access, shape, broadcast).into())
622 }
623
624 fn slice(
625 self,
626 access: A,
627 shape: &[usize],
628 range: Range,
629 ) -> Result<AccessOp<Self::Slice, Self>, Error> {
630 Ok(Slice::new(access, shape, range).into())
631 }
632
633 fn transpose(
634 self,
635 access: A,
636 shape: Shape,
637 permutation: Axes,
638 ) -> Result<AccessOp<Self::Transpose, Self>, Error> {
639 Ok(View::transpose(access, shape, permutation).into())
640 }
641}