1use std::cmp::Ordering;
2use std::fmt;
3use std::ops::{Add, Div, Mul, Rem, Sub};
4
5pub use smallvec::smallvec as axes;
6pub use smallvec::smallvec as range;
7pub use smallvec::smallvec as slice;
8pub use smallvec::smallvec as shape;
9pub use smallvec::smallvec as stackvec;
10use smallvec::SmallVec;
11
12pub use access::*;
13pub use array::{
14 MatrixDual, MatrixUnary, NDArray, NDArrayBoolean, NDArrayBooleanScalar, NDArrayCast,
15 NDArrayCompare, NDArrayCompareScalar, NDArrayMath, NDArrayMathScalar, NDArrayNumeric,
16 NDArrayRead, NDArrayReduce, NDArrayReduceAll, NDArrayReduceBoolean, NDArrayTransform,
17 NDArrayTrig, NDArrayUnary, NDArrayUnaryBoolean, NDArrayWhere, NDArrayWrite,
18};
19pub use buffer::{Buffer, BufferConverter, BufferInstance, BufferMut};
20pub use host::StackVec;
21pub use platform::*;
22
23mod access;
24mod array;
25mod buffer;
26pub mod host;
27#[cfg(feature = "opencl")]
28pub mod opencl;
29pub mod ops;
30mod platform;
31
32#[cfg(feature = "opencl")]
34pub trait CType:
35 ocl::OclPrm + PartialEq + PartialOrd + Copy + Send + Sync + fmt::Display + fmt::Debug + 'static
36{
37 const TYPE: &'static str;
41
42 const MAX: Self;
44
45 const MIN: Self;
47
48 const ZERO: Self;
50
51 const ONE: Self;
53
54 const IS_FLOAT: bool;
56
57 type Float: Float;
59
60 fn from_f64(float: f64) -> Self;
64
65 fn from_float(float: Self::Float) -> Self;
67
68 fn abs(self) -> Self;
72
73 fn add(self, other: Self) -> Self;
75
76 fn div(self, other: Self) -> Self;
78
79 fn mul(self, other: Self) -> Self;
81
82 fn sub(self, other: Self) -> Self;
84
85 fn rem(self, other: Self) -> Self;
87
88 fn min(l: Self, r: Self) -> Self;
92
93 fn max(l: Self, r: Self) -> Self;
95
96 fn pow(self, exp: Self) -> Self;
100
101 fn round(self) -> Self;
105
106 fn to_f64(self) -> f64;
108
109 fn to_float(self) -> Self::Float;
111}
112
113#[cfg(not(feature = "opencl"))]
115pub trait CType:
116 PartialEq + PartialOrd + Copy + Send + Sync + fmt::Display + fmt::Debug + 'static
117{
118 const TYPE: &'static str;
122
123 const MAX: Self;
125
126 const MIN: Self;
128
129 const ZERO: Self;
131
132 const ONE: Self;
134
135 const IS_FLOAT: bool;
137
138 type Float: Float;
140
141 fn from_f64(float: f64) -> Self;
145
146 fn from_float(float: Self::Float) -> Self;
148
149 fn abs(self) -> Self;
153
154 fn add(self, other: Self) -> Self;
156
157 fn div(self, other: Self) -> Self;
159
160 fn mul(self, other: Self) -> Self;
162
163 fn sub(self, other: Self) -> Self;
165
166 fn rem(self, other: Self) -> Self;
168
169 fn min(l: Self, r: Self) -> Self;
173
174 fn max(l: Self, r: Self) -> Self;
176
177 fn pow(self, exp: Self) -> Self;
181
182 fn round(self) -> Self;
186
187 fn to_f64(self) -> f64;
189
190 fn to_float(self) -> Self::Float;
192}
193
194macro_rules! c_type {
195 ($t:ty, $str:expr, $is_float:expr, $one:expr, $zero:expr, $float:ty, $abs:expr, $add:expr, $div:expr, $mul:expr, $sub:expr, $rem:expr, $round:expr, $pow:expr, $cmp_max:expr, $cmp_min:expr) => {
196 impl CType for $t {
197 const TYPE: &'static str = $str;
198
199 const MAX: Self = <$t>::MAX;
200
201 const MIN: Self = <$t>::MIN;
202
203 const ZERO: Self = $zero;
204
205 const ONE: Self = $one;
206
207 const IS_FLOAT: bool = $is_float;
208
209 type Float = $float;
210
211 fn from_f64(float: f64) -> Self {
212 float as $t
213 }
214
215 fn from_float(float: $float) -> Self {
216 float as $t
217 }
218
219 fn abs(self) -> Self {
220 $abs(self)
221 }
222
223 fn add(self, other: Self) -> Self {
224 $add(self, other)
225 }
226
227 fn div(self, other: Self) -> Self {
228 $div(self, other)
229 }
230
231 fn mul(self, other: Self) -> Self {
232 $mul(self, other)
233 }
234
235 fn sub(self, other: Self) -> Self {
236 $sub(self, other)
237 }
238
239 fn rem(self, other: Self) -> Self {
240 $rem(self, other)
241 }
242
243 fn min(l: Self, r: Self) -> Self {
244 $cmp_min(l, r)
245 }
246
247 fn max(l: Self, r: Self) -> Self {
248 $cmp_max(l, r)
249 }
250
251 fn pow(self, exp: Self) -> Self {
252 ($pow)(self, exp)
253 }
254
255 fn round(self) -> Self {
256 $round(self)
257 }
258
259 fn to_f64(self) -> f64 {
260 self as f64
261 }
262
263 fn to_float(self) -> $float {
264 self as $float
265 }
266 }
267 };
268}
269
270c_type!(
271 f32,
272 "float",
273 true,
274 1.,
275 0.,
276 Self,
277 f32::abs,
278 Add::add,
279 Div::div,
280 Mul::mul,
281 Sub::sub,
282 Rem::rem,
283 f32::round,
284 f32::powf,
285 max_f32,
286 min_f32
287);
288
289c_type!(
290 f64,
291 "double",
292 true,
293 1.,
294 0.,
295 Self,
296 f64::abs,
297 Add::add,
298 Div::div,
299 Mul::mul,
300 Sub::sub,
301 Rem::rem,
302 f64::round,
303 f64::powf,
304 max_f64,
305 min_f64
306);
307
308c_type!(
309 i8,
310 "char",
311 false,
312 1,
313 0,
314 f32,
315 Self::wrapping_abs,
316 Self::wrapping_add,
317 |l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
318 Self::wrapping_mul,
319 Self::wrapping_sub,
320 Self::wrapping_rem,
321 id,
322 |a, e| f32::powi(a as f32, e as i32) as i8,
323 Ord::max,
324 Ord::min
325);
326
327c_type!(
328 i16,
329 "short",
330 false,
331 1,
332 0,
333 f32,
334 Self::wrapping_abs,
335 Self::wrapping_add,
336 |l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
337 Self::wrapping_mul,
338 Self::wrapping_sub,
339 Self::wrapping_rem,
340 id,
341 |a, e| f32::powi(a as f32, e as i32) as i16,
342 Ord::max,
343 Ord::min
344);
345
346c_type!(
347 i32,
348 "int",
349 false,
350 1,
351 0,
352 f32,
353 Self::wrapping_abs,
354 Self::wrapping_add,
355 |l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
356 Self::wrapping_mul,
357 Self::wrapping_sub,
358 Self::wrapping_rem,
359 id,
360 |a, e| f32::powi(a as f32, e) as i32,
361 Ord::max,
362 Ord::min
363);
364
365c_type!(
366 i64,
367 "long",
368 false,
369 1,
370 0,
371 f64,
372 Self::wrapping_abs,
373 Self::wrapping_add,
374 |l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
375 Self::wrapping_mul,
376 Self::wrapping_sub,
377 Self::wrapping_rem,
378 id,
379 |a, e| f64::powi(
380 a as f64,
381 i32::try_from(e).unwrap_or_else(|_| if e >= 0 { i32::MAX } else { i32::MIN })
382 ) as i64,
383 Ord::max,
384 Ord::min
385);
386
387c_type!(
388 u8,
389 "uchar",
390 false,
391 1,
392 0,
393 f32,
394 id,
395 Self::wrapping_add,
396 |l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
397 Self::wrapping_mul,
398 Self::wrapping_sub,
399 Self::wrapping_rem,
400 id,
401 |a, e| u8::pow(a, e as u32),
402 Ord::max,
403 Ord::min
404);
405
406c_type!(
407 u16,
408 "ushort",
409 false,
410 1,
411 0,
412 f32,
413 id,
414 Self::wrapping_add,
415 |l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
416 Self::wrapping_mul,
417 Self::wrapping_sub,
418 Self::wrapping_rem,
419 id,
420 |a, e| u16::pow(a, e as u32),
421 Ord::max,
422 Ord::min
423);
424
425c_type!(
426 u32,
427 "uint",
428 false,
429 1,
430 0,
431 f32,
432 id,
433 Self::wrapping_add,
434 |l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
435 Self::wrapping_mul,
436 Self::wrapping_sub,
437 Self::wrapping_rem,
438 id,
439 |a, e| u32::pow(a, e),
440 Ord::max,
441 Ord::min
442);
443
444c_type!(
445 u64,
446 "ulong",
447 false,
448 1,
449 0,
450 f64,
451 id,
452 Self::wrapping_add,
453 |l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
454 Self::wrapping_mul,
455 Self::wrapping_sub,
456 Self::wrapping_rem,
457 id,
458 |a, e| u64::pow(a, u32::try_from(e).unwrap_or(u32::MAX)),
459 Ord::max,
460 Ord::min
461);
462
463fn id<T>(this: T) -> T {
464 this
465}
466
467fn max_f32(l: f32, r: f32) -> f32 {
468 match l.total_cmp(&r) {
469 Ordering::Less => r,
470 Ordering::Equal => l,
471 Ordering::Greater => l,
472 }
473}
474
475fn min_f32(l: f32, r: f32) -> f32 {
476 match l.total_cmp(&r) {
477 Ordering::Less => l,
478 Ordering::Equal => l,
479 Ordering::Greater => r,
480 }
481}
482
483fn max_f64(l: f64, r: f64) -> f64 {
484 match l.total_cmp(&r) {
485 Ordering::Less => r,
486 Ordering::Equal => l,
487 Ordering::Greater => l,
488 }
489}
490
491fn min_f64(l: f64, r: f64) -> f64 {
492 match l.total_cmp(&r) {
493 Ordering::Less => l,
494 Ordering::Equal => l,
495 Ordering::Greater => r,
496 }
497}
498
499pub trait Float: CType<Float = Self> {
501 fn is_inf(self) -> bool;
504
505 fn is_nan(self) -> bool;
507
508 fn exp(self) -> Self;
511
512 fn ln(self) -> Self;
514
515 fn log(self, base: Self) -> Self;
517
518 fn sin(self) -> Self;
521
522 fn asin(self) -> Self;
524
525 fn sinh(self) -> Self;
527
528 fn cos(self) -> Self;
530
531 fn acos(self) -> Self;
533
534 fn cosh(self) -> Self;
536
537 fn tan(self) -> Self;
539
540 fn atan(self) -> Self;
542
543 fn tanh(self) -> Self;
545
546 fn to_f64(self) -> f64;
549}
550
551macro_rules! float_type {
552 ($t:ty) => {
553 impl Float for $t {
554 fn is_inf(self) -> bool {
555 <$t>::is_infinite(self)
556 }
557
558 fn is_nan(self) -> bool {
559 <$t>::is_nan(self)
560 }
561
562 fn exp(self) -> Self {
563 <$t>::exp(self)
564 }
565
566 fn ln(self) -> Self {
567 <$t>::ln(self)
568 }
569
570 fn log(self, base: Self) -> Self {
571 <$t>::log(self, base)
572 }
573
574 fn sin(self) -> Self {
575 <$t>::sin(self)
576 }
577
578 fn asin(self) -> Self {
579 <$t>::asin(self)
580 }
581
582 fn sinh(self) -> Self {
583 <$t>::sinh(self)
584 }
585
586 fn cos(self) -> Self {
587 <$t>::cos(self)
588 }
589
590 fn acos(self) -> Self {
591 <$t>::acos(self)
592 }
593
594 fn cosh(self) -> Self {
595 <$t>::cosh(self)
596 }
597
598 fn tan(self) -> Self {
599 <$t>::tan(self)
600 }
601
602 fn atan(self) -> Self {
603 <$t>::atan(self)
604 }
605
606 fn tanh(self) -> Self {
607 <$t>::tanh(self)
608 }
609
610 fn to_f64(self) -> f64 {
611 self as f64
612 }
613 }
614 };
615}
616
617float_type!(f32);
618float_type!(f64);
619
620pub enum Error {
622 Bounds(String),
623 Interface(String),
624 Unsupported(String),
625 #[cfg(feature = "opencl")]
626 OCL(std::sync::Arc<ocl::Error>),
627}
628
629impl Clone for Error {
632 fn clone(&self) -> Self {
633 match self {
634 Self::Bounds(msg) => Self::Bounds(msg.clone()),
635 Self::Interface(msg) => Self::Interface(msg.clone()),
636 Self::Unsupported(msg) => Self::Unsupported(msg.clone()),
637 #[cfg(feature = "opencl")]
638 Self::OCL(cause) => Self::OCL(cause.clone()),
639 }
640 }
641}
642
643#[cfg(feature = "opencl")]
644impl From<ocl::Error> for Error {
645 fn from(cause: ocl::Error) -> Self {
646 #[cfg(debug_assertions)]
647 panic!("OpenCL error: {:?}", cause);
648
649 #[cfg(not(debug_assertions))]
650 Self::OCL(std::sync::Arc::new(cause))
651 }
652}
653
654impl fmt::Debug for Error {
655 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
656 match self {
657 Self::Bounds(cause) => f.write_str(cause),
658 Self::Interface(cause) => f.write_str(cause),
659 Self::Unsupported(cause) => f.write_str(cause),
660 #[cfg(feature = "opencl")]
661 Self::OCL(cause) => cause.fmt(f),
662 }
663 }
664}
665
666impl fmt::Display for Error {
667 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
668 match self {
669 Self::Bounds(cause) => f.write_str(cause),
670 Self::Interface(cause) => f.write_str(cause),
671 Self::Unsupported(cause) => f.write_str(cause),
672 #[cfg(feature = "opencl")]
673 Self::OCL(cause) => cause.fmt(f),
674 }
675 }
676}
677
678impl std::error::Error for Error {}
679
680pub type Axes = SmallVec<[usize; 8]>;
682
683pub type Range = SmallVec<[AxisRange; 8]>;
685
686pub type Shape = SmallVec<[usize; 8]>;
688
689pub type Strides = SmallVec<[usize; 8]>;
691
692pub type Array<T, A> = array::Array<T, A, Platform>;
694
695pub type ArrayBuf<T, B> = array::Array<T, AccessBuf<B>, Platform>;
697
698pub type ArrayOp<T, Op> = array::Array<T, AccessOp<Op>, Platform>;
700
701pub type ArrayAccess<T> = array::Array<T, Accessor<T>, Platform>;
703
704pub type AccessOp<Op> = access::AccessOp<Op, Platform>;
706
707#[derive(Clone, Eq, PartialEq, Hash)]
709pub enum AxisRange {
710 At(usize),
711 In(usize, usize, usize),
712 Of(SmallVec<[usize; 8]>),
713}
714
715impl AxisRange {
716 pub fn is_index(&self) -> bool {
718 match self {
719 Self::At(_) => true,
720 _ => false,
721 }
722 }
723
724 pub fn size(&self) -> Option<usize> {
727 match self {
728 Self::At(_) => None,
729 Self::In(start, stop, step) => Some((stop - start) / step),
730 Self::Of(indices) => Some(indices.len()),
731 }
732 }
733}
734
735impl From<usize> for AxisRange {
736 fn from(i: usize) -> Self {
737 Self::At(i)
738 }
739}
740
741impl From<std::ops::Range<usize>> for AxisRange {
742 fn from(range: std::ops::Range<usize>) -> Self {
743 Self::In(range.start, range.end, 1)
744 }
745}
746
747impl From<SmallVec<[usize; 8]>> for AxisRange {
748 fn from(indices: SmallVec<[usize; 8]>) -> Self {
749 Self::Of(indices)
750 }
751}
752
753impl fmt::Debug for AxisRange {
754 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
755 match self {
756 Self::At(i) => write!(f, "{}", i),
757 Self::In(start, stop, 1) => write!(f, "{}:{}", start, stop),
758 Self::In(start, stop, step) => write!(f, "{}:{}:{}", start, stop, step),
759 Self::Of(indices) => write!(f, "{:?}", indices),
760 }
761 }
762}
763
764#[inline]
766pub fn broadcast_shape(left: &[usize], right: &[usize]) -> Result<Shape, Error> {
767 if left.is_empty() || right.is_empty() {
768 return Err(Error::Bounds("cannot broadcast empty shape".to_string()));
769 } else if left.len() < right.len() {
770 return broadcast_shape(right, left);
771 }
772
773 let offset = left.len() - right.len();
774
775 let mut shape = Shape::with_capacity(left.len());
776 shape.extend_from_slice(&left[..offset]);
777
778 for (l, r) in left.into_iter().copied().zip(right.into_iter().copied()) {
779 if r == 1 || r == l {
780 shape.push(l);
781 } else if l == 1 {
782 shape.push(r);
783 } else {
784 return Err(Error::Bounds(format!(
785 "cannot broadcast dimensions {l} and {r}"
786 )));
787 }
788 }
789
790 debug_assert!(!shape.iter().any(|dim| *dim == 0));
791
792 Ok(shape)
793}
794
795#[inline]
796fn range_shape(source_shape: &[usize], range: &[AxisRange]) -> Shape {
797 debug_assert_eq!(source_shape.len(), range.len());
798 range.iter().filter_map(|ar| ar.size()).collect()
799}
800
801#[inline]
803pub fn strides_for<'a>(shape: &'a [usize], ndim: usize) -> impl Iterator<Item = usize> + 'a {
804 debug_assert!(ndim >= shape.len());
805
806 let zeros = std::iter::repeat(0).take(ndim - shape.len());
807
808 let strides = shape.iter().copied().enumerate().map(|(x, dim)| {
809 if dim == 1 {
810 0
811 } else {
812 shape.iter().rev().take(shape.len() - 1 - x).product()
813 }
814 });
815
816 zeros.chain(strides)
817}