1use std::cmp::Ordering;
2use std::fmt;
3use std::ops::{Add, Div, Mul, Rem, Sub};
4
5use number_general as ng;
6use safecast::CastFrom;
7
8#[cfg(feature = "complex")]
9pub use num_complex as complex;
10pub use smallvec::smallvec as axes;
11pub use smallvec::smallvec as coord;
12pub use smallvec::smallvec as range;
13pub use smallvec::smallvec as slice;
14pub use smallvec::smallvec as shape;
15pub use smallvec::smallvec as stackvec;
16use smallvec::SmallVec;
17
18pub use access::*;
19pub use array::{
20 same_shape, MatrixDual, MatrixUnary, NDArray, NDArrayAbs, NDArrayBoolean, NDArrayBooleanScalar,
21 NDArrayCast, NDArrayCompare, NDArrayCompareScalar, NDArrayMath, NDArrayMathScalar,
22 NDArrayNumeric, NDArrayRead, NDArrayReduce, NDArrayReduceAll, NDArrayReduceBoolean,
23 NDArrayTransform, NDArrayTrig, NDArrayUnary, NDArrayUnaryBoolean, NDArrayWhere, NDArrayWrite,
24};
25#[cfg(feature = "complex")]
26pub use array::{MatrixUnaryComplex, NDArrayComplex, NDArrayFourier};
27pub use buffer::{Buffer, BufferConverter, BufferInstance, BufferMut};
28pub use host::StackVec;
29pub use platform::*;
30
31mod access;
32mod array;
33mod buffer;
34#[cfg(feature = "complex")]
35pub mod fft;
36pub mod host;
37#[cfg(feature = "opencl")]
38pub mod opencl;
39pub mod ops;
40mod platform;
41
42fn id<T>(this: T) -> T {
43 this
44}
45
46#[cfg(feature = "opencl")]
47pub trait CLType:
48 opencl::CLElement + PartialEq + Copy + Send + Sync + fmt::Display + fmt::Debug + 'static
49{
50}
51
52#[cfg(not(feature = "opencl"))]
53pub trait CLType: PartialEq + Copy + Send + Sync + fmt::Display + fmt::Debug + 'static {}
54
55impl CLType for f32 {}
56impl CLType for f64 {}
57impl CLType for i8 {}
58impl CLType for i16 {}
59impl CLType for i32 {}
60impl CLType for i64 {}
61impl CLType for u8 {}
62impl CLType for u16 {}
63impl CLType for u32 {}
64impl CLType for u64 {}
65#[cfg(feature = "complex")]
66impl CLType for complex::Complex<f32> {}
67#[cfg(feature = "complex")]
68impl CLType for complex::Complex<f64> {}
69
70pub trait Number: CLType + Into<ng::Number> + CastFrom<ng::Number> + Default {
72 const ZERO: Self;
74
75 const ONE: Self;
77
78 type Abs: Number;
80
81 fn abs(self) -> Self::Abs;
85
86 fn add(self, other: Self) -> Self;
88
89 fn div(self, other: Self) -> Self;
91
92 fn mul(self, other: Self) -> Self;
94
95 fn sub(self, other: Self) -> Self;
97
98 fn pow(self, exp: Self) -> Self;
100}
101
102macro_rules! number {
103 ($t:ty, $abs_t:ty, $one:expr, $zero:expr, $abs:expr, $add:expr, $div:expr, $mul:expr, $sub:expr, $pow:expr) => {
104 impl Number for $t {
105 const ONE: Self = $one;
106
107 const ZERO: Self = $zero;
108
109 type Abs = $abs_t;
110
111 fn abs(self) -> Self::Abs {
112 $abs(self)
113 }
114
115 fn add(self, other: Self) -> Self {
116 $add(self, other)
117 }
118
119 fn div(self, other: Self) -> Self {
120 $div(self, other)
121 }
122
123 fn mul(self, other: Self) -> Self {
124 $mul(self, other)
125 }
126
127 fn sub(self, other: Self) -> Self {
128 $sub(self, other)
129 }
130
131 fn pow(self, exp: Self) -> Self {
132 ($pow)(self, exp)
133 }
134 }
135 };
136}
137
138#[cfg(feature = "complex")]
139number!(
140 complex::Complex32,
141 f32,
142 complex::Complex32::ONE,
143 complex::Complex32::ZERO,
144 complex::Complex32::norm,
145 Add::add,
146 Div::div,
147 Mul::mul,
148 Sub::sub,
149 complex::Complex32::powc
150);
151
152#[cfg(feature = "complex")]
153number!(
154 complex::Complex64,
155 f64,
156 complex::Complex64::ONE,
157 complex::Complex64::ZERO,
158 complex::Complex64::norm,
159 Add::add,
160 Div::div,
161 Mul::mul,
162 Sub::sub,
163 complex::Complex64::powc
164);
165
166number!(
167 f32,
168 Self,
169 1.,
170 0.,
171 f32::abs,
172 Add::add,
173 Div::div,
174 Mul::mul,
175 Sub::sub,
176 f32::powf
177);
178
179number!(
180 f64,
181 Self,
182 1.,
183 0.,
184 f64::abs,
185 Add::add,
186 Div::div,
187 Mul::mul,
188 Sub::sub,
189 f64::powf
190);
191
192number!(
193 i8,
194 Self,
195 1,
196 0,
197 Self::wrapping_abs,
198 Self::wrapping_add,
199 |l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
200 Self::wrapping_mul,
201 Self::wrapping_sub,
202 |a, e| f32::powi(a as f32, e as i32) as i8
203);
204
205number!(
206 i16,
207 Self,
208 1,
209 0,
210 Self::wrapping_abs,
211 Self::wrapping_add,
212 |l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
213 Self::wrapping_mul,
214 Self::wrapping_sub,
215 |a, e| f32::powi(a as f32, e as i32) as i16
216);
217
218number!(
219 i32,
220 Self,
221 1,
222 0,
223 Self::wrapping_abs,
224 Self::wrapping_add,
225 |l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
226 Self::wrapping_mul,
227 Self::wrapping_sub,
228 |a, e| f32::powi(a as f32, e) as i32
229);
230
231number!(
232 i64,
233 Self,
234 1,
235 0,
236 Self::wrapping_abs,
237 Self::wrapping_add,
238 |l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
239 Self::wrapping_mul,
240 Self::wrapping_sub,
241 |a, e| f64::powi(
242 a as f64,
243 i32::try_from(e).unwrap_or(if e >= 0 { i32::MAX } else { i32::MIN })
244 ) as i64
245);
246
247number!(
248 u8,
249 Self,
250 1,
251 0,
252 id,
253 Self::wrapping_add,
254 |l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
255 Self::wrapping_mul,
256 Self::wrapping_sub,
257 |a, e| u8::pow(a, e as u32)
258);
259
260number!(
261 u16,
262 Self,
263 1,
264 0,
265 id,
266 Self::wrapping_add,
267 |l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
268 Self::wrapping_mul,
269 Self::wrapping_sub,
270 |a, e| u16::pow(a, e as u32)
271);
272
273number!(
274 u32,
275 Self,
276 1,
277 0,
278 id,
279 Self::wrapping_add,
280 |l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
281 Self::wrapping_mul,
282 Self::wrapping_sub,
283 u32::pow
284);
285
286number!(
287 u64,
288 Self,
289 1,
290 0,
291 id,
292 Self::wrapping_add,
293 |l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
294 Self::wrapping_mul,
295 Self::wrapping_sub,
296 |a, e| u64::pow(a, u32::try_from(e).unwrap_or(u32::MAX))
297);
298
299#[cfg(not(feature = "opencl"))]
300pub trait Real: Number + PartialOrd {
302 const MAX: Self;
304
305 const MIN: Self;
307
308 fn max(l: Self, r: Self) -> Self;
310
311 fn min(l: Self, r: Self) -> Self;
313
314 fn rem(self, other: Self) -> Self;
316
317 fn round(self) -> Self;
319}
320
321#[cfg(feature = "opencl")]
322pub trait Real: Number + PartialOrd + opencl::CLElementReal {
324 const MAX: Self;
326
327 const MIN: Self;
329
330 fn max(l: Self, r: Self) -> Self;
332
333 fn min(l: Self, r: Self) -> Self;
335
336 fn rem(self, other: Self) -> Self;
338
339 fn round(self) -> Self;
341}
342
343macro_rules! real {
344 ($t:ty, $rem:expr, $ord:expr, $round:expr) => {
345 impl Real for $t {
346 const MAX: Self = <$t>::MAX;
347
348 const MIN: Self = <$t>::MIN;
349
350 fn max(l: Self, r: Self) -> $t {
351 match $ord(&l, &r) {
352 Ordering::Greater | Ordering::Equal => l,
353 Ordering::Less => r,
354 }
355 }
356
357 fn min(l: Self, r: Self) -> $t {
358 match $ord(&l, &r) {
359 Ordering::Less | Ordering::Equal => l,
360 Ordering::Greater => r,
361 }
362 }
363
364 fn rem(self, other: Self) -> Self {
365 $rem(self, other)
366 }
367
368 fn round(self) -> Self {
369 $round(self)
370 }
371 }
372 };
373}
374
375real!(f32, Rem::rem, f32::total_cmp, f32::round);
376real!(f64, Rem::rem, f64::total_cmp, f64::round);
377real!(i8, Self::wrapping_rem, Ord::cmp, id);
378real!(i16, Self::wrapping_rem, Ord::cmp, id);
379real!(i32, Self::wrapping_rem, Ord::cmp, id);
380real!(i64, Self::wrapping_rem, Ord::cmp, id);
381real!(u8, Self::wrapping_rem, Ord::cmp, id);
382real!(u16, Self::wrapping_rem, Ord::cmp, id);
383real!(u32, Self::wrapping_rem, Ord::cmp, id);
384real!(u64, Self::wrapping_rem, Ord::cmp, id);
385
386#[cfg(not(feature = "opencl"))]
387pub trait Float: Number {
389 fn is_inf(self) -> bool;
392
393 fn is_nan(self) -> bool;
395
396 fn exp(self) -> Self;
399
400 fn ln(self) -> Self;
402
403 fn log(self, base: Self) -> Self;
405
406 fn sin(self) -> Self;
409
410 fn asin(self) -> Self;
412
413 fn sinh(self) -> Self;
415
416 fn cos(self) -> Self;
418
419 fn acos(self) -> Self;
421
422 fn cosh(self) -> Self;
424
425 fn tan(self) -> Self;
427
428 fn atan(self) -> Self;
430
431 fn tanh(self) -> Self;
433}
434
435#[cfg(feature = "opencl")]
436pub trait Float: Number + opencl::CLElementTrig {
438 fn is_inf(self) -> bool;
441
442 fn is_nan(self) -> bool;
444
445 fn exp(self) -> Self;
448
449 fn ln(self) -> Self;
451
452 fn log(self, base: Self) -> Self;
454
455 fn sin(self) -> Self;
458
459 fn asin(self) -> Self;
461
462 fn sinh(self) -> Self;
464
465 fn cos(self) -> Self;
467
468 fn acos(self) -> Self;
470
471 fn cosh(self) -> Self;
473
474 fn tan(self) -> Self;
476
477 fn atan(self) -> Self;
479
480 fn tanh(self) -> Self;
482}
483
484macro_rules! float_type {
485 ($t:ty, $inf:expr, $nan:expr) => {
486 impl Float for $t {
487 fn is_inf(self) -> bool {
488 $inf(self)
489 }
490
491 fn is_nan(self) -> bool {
492 $nan(self)
493 }
494
495 fn exp(self) -> Self {
496 <$t>::exp(self)
497 }
498
499 fn ln(self) -> Self {
500 <$t>::ln(self)
501 }
502
503 fn log(self, base: Self) -> Self {
504 self.ln() / base.ln()
505 }
506
507 fn sin(self) -> Self {
508 <$t>::sin(self)
509 }
510
511 fn asin(self) -> Self {
512 <$t>::asin(self)
513 }
514
515 fn sinh(self) -> Self {
516 <$t>::sinh(self)
517 }
518
519 fn cos(self) -> Self {
520 <$t>::cos(self)
521 }
522
523 fn acos(self) -> Self {
524 <$t>::acos(self)
525 }
526
527 fn cosh(self) -> Self {
528 <$t>::cosh(self)
529 }
530
531 fn tan(self) -> Self {
532 <$t>::tan(self)
533 }
534
535 fn atan(self) -> Self {
536 <$t>::atan(self)
537 }
538
539 fn tanh(self) -> Self {
540 <$t>::tanh(self)
541 }
542 }
543 };
544}
545
546#[cfg(feature = "complex")]
547float_type!(complex::Complex32, |_| false, |_| false);
548#[cfg(feature = "complex")]
549float_type!(complex::Complex64, |_| false, |_| false);
550float_type!(f32, f32::is_infinite, f32::is_nan);
551float_type!(f64, f64::is_infinite, f64::is_nan);
552
553#[cfg(all(feature = "complex", not(feature = "opencl")))]
554pub trait Complex: Float<Abs = Self::Real> {
556 type Real: Float + Real;
557
558 fn angle(self) -> Self::Real;
559
560 fn conj(self) -> Self;
561
562 fn im(self) -> Self::Real;
563
564 fn re(self) -> Self::Real;
565}
566
567#[cfg(all(feature = "complex", feature = "opencl"))]
568pub trait Complex: Float<Abs = Self::Real> + opencl::CLElementComplex {
570 type Real: Float + Real;
571
572 fn angle(self) -> Self::Real;
573
574 fn conj(self) -> Self;
575
576 fn im(self) -> Self::Real;
577
578 fn re(self) -> Self::Real;
579}
580
581#[cfg(feature = "complex")]
582macro_rules! complex_type {
583 ($t:ty, $r:ty) => {
584 impl Complex for $t {
585 type Real = $r;
586
587 fn angle(self) -> $r {
588 Self::arg(self)
589 }
590
591 fn conj(self) -> Self {
592 complex::Complex::<$r>::conj(&self)
593 }
594
595 fn im(self) -> $r {
596 self.im
597 }
598
599 fn re(self) -> $r {
600 self.re
601 }
602 }
603 };
604}
605
606#[cfg(feature = "complex")]
607complex_type!(complex::Complex32, f32);
608#[cfg(feature = "complex")]
609complex_type!(complex::Complex64, f64);
610
611pub enum Error {
613 Bounds(String),
614 Unsupported(String),
615 #[cfg(feature = "opencl")]
616 OCL(std::sync::Arc<ocl::Error>),
617}
618
619impl Error {
620 pub fn bounds(msg: String) -> Self {
621 #[cfg(feature = "debug_crash")]
622 panic!("{}", msg);
623
624 #[cfg(not(feature = "debug_crash"))]
625 Self::Bounds(msg)
626 }
627
628 pub fn unsupported(msg: String) -> Self {
629 #[cfg(feature = "debug_crash")]
630 panic!("{}", msg);
631
632 #[cfg(not(feature = "debug_crash"))]
633 Self::Unsupported(msg)
634 }
635}
636
637impl Clone for Error {
640 fn clone(&self) -> Self {
641 match self {
642 Self::Bounds(msg) => Self::Bounds(msg.clone()),
643 Self::Unsupported(msg) => Self::Unsupported(msg.clone()),
644 #[cfg(feature = "opencl")]
645 Self::OCL(cause) => Self::OCL(cause.clone()),
646 }
647 }
648}
649
650#[cfg(feature = "opencl")]
651impl From<ocl::Error> for Error {
652 fn from(cause: ocl::Error) -> Self {
653 #[cfg(feature = "debug_crash")]
654 panic!("OpenCL error: {:?}", cause);
655
656 #[cfg(not(feature = "debug_crash"))]
657 Self::OCL(std::sync::Arc::new(cause))
658 }
659}
660
661impl fmt::Debug for Error {
662 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
663 match self {
664 Self::Bounds(cause) => f.write_str(cause),
665 Self::Unsupported(cause) => f.write_str(cause),
666 #[cfg(feature = "opencl")]
667 Self::OCL(cause) => cause.fmt(f),
668 }
669 }
670}
671
672impl fmt::Display for Error {
673 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
674 match self {
675 Self::Bounds(cause) => f.write_str(cause),
676 Self::Unsupported(cause) => f.write_str(cause),
677 #[cfg(feature = "opencl")]
678 Self::OCL(cause) => cause.fmt(f),
679 }
680 }
681}
682
683impl std::error::Error for Error {}
684
685pub type Axes = SmallVec<[usize; 8]>;
687
688pub type Range = SmallVec<[AxisRange; 8]>;
690
691pub type Shape = SmallVec<[usize; 8]>;
693
694pub type Strides = SmallVec<[usize; 8]>;
696
697pub type Array<T, A> = array::Array<T, A, Platform>;
699
700pub type ArrayBuf<T, B> = array::Array<T, AccessBuf<B>, Platform>;
702
703pub type ArrayOp<T, Op> = array::Array<T, AccessOp<Op>, Platform>;
705
706pub type ArrayAccess<'a, T> = array::Array<T, Accessor<'a, T>, Platform>;
708
709pub type AccessOp<Op> = access::AccessOp<Op, Platform>;
711
712#[derive(Clone, Eq, PartialEq, Hash)]
714pub enum AxisRange {
715 At(usize),
716 In(usize, usize, usize),
717 Of(SmallVec<[usize; 8]>),
718}
719
720impl AxisRange {
721 pub fn is_index(&self) -> bool {
723 matches!(self, Self::At(_))
724 }
725
726 pub fn size(&self) -> Option<usize> {
729 match self {
730 Self::At(_) => None,
731 Self::In(start, stop, step) => Some((stop - start) / step),
732 Self::Of(indices) => Some(indices.len()),
733 }
734 }
735}
736
737impl From<usize> for AxisRange {
738 fn from(i: usize) -> Self {
739 Self::At(i)
740 }
741}
742
743impl From<std::ops::Range<usize>> for AxisRange {
744 fn from(range: std::ops::Range<usize>) -> Self {
745 Self::In(range.start, range.end, 1)
746 }
747}
748
749impl From<SmallVec<[usize; 8]>> for AxisRange {
750 fn from(indices: SmallVec<[usize; 8]>) -> Self {
751 Self::Of(indices)
752 }
753}
754
755impl fmt::Debug for AxisRange {
756 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
757 match self {
758 Self::At(i) => write!(f, "{}", i),
759 Self::In(start, stop, 1) => write!(f, "{}:{}", start, stop),
760 Self::In(start, stop, step) => write!(f, "{}:{}:{}", start, stop, step),
761 Self::Of(indices) => write!(f, "{:?}", indices),
762 }
763 }
764}
765
766#[inline]
768pub fn broadcast_shape(left: &[usize], right: &[usize]) -> Result<Shape, Error> {
769 let ndim = usize::max(left.len(), right.len());
770 let mut shape = Shape::with_capacity(ndim);
771
772 let mut left = left.iter().rev().copied();
773 let mut right = right.iter().rev().copied();
774
775 while let Some(dim) = broadcast_dim(left.next(), right.next())? {
776 shape.push(dim)
777 }
778
779 shape.reverse();
780
781 Ok(shape)
782}
783
784#[inline]
786pub fn broadcast_matmul_shape(left: &[usize], right: &[usize]) -> Result<(Shape, Shape), Error> {
787 let (left_ndim, right_ndim) = (left.len(), right.len());
788 let ndim = usize::max(left_ndim, right_ndim);
789
790 let mut left = left.iter().rev().copied();
791 let mut right = right.iter().rev().copied();
792
793 let k = right.next().unwrap_or(1);
794 let j = match (left.next(), right.next()) {
795 (Some(jl), Some(jr)) => match (jl, jr) {
796 (jl, jr) if jl == jr => Ok(jl),
797 (1, jr) => Ok(jr),
798 (jl, 1) => Ok(jl),
799 _ => Err(Error::bounds(format!(
800 "cannot matrix-multiply shapes {left:?} and {right:?}"
801 ))),
802 },
803 (Some(jl), None) => Ok(jl),
804 (None, Some(jr)) => Ok(jr),
805 (None, None) => Ok(1),
806 }?;
807 let i = left.next().unwrap_or(1);
808
809 let mut broadcast_shape = Shape::with_capacity(ndim);
810 while let Some(dim) = broadcast_dim(left.next(), right.next())? {
811 broadcast_shape.push(dim);
812 }
813
814 broadcast_shape.reverse();
815
816 let left = broadcast_shape.iter().copied().chain([i, j]).collect();
817 let right = broadcast_shape.into_iter().chain([j, k]).collect();
818 Ok((left, right))
819}
820
821#[inline]
822fn broadcast_dim(left: Option<usize>, right: Option<usize>) -> Result<Option<usize>, Error> {
823 match (left, right) {
824 (Some(l), Some(r)) if l == r => Ok(Some(l)),
825 (Some(1), Some(r)) => Ok(Some(r)),
826 (Some(l), Some(1)) => Ok(Some(l)),
827 (None, Some(r)) => Ok(Some(r)),
828 (Some(l), None) => Ok(Some(l)),
829 (None, None) => Ok(None),
830 (l, r) => Err(Error::bounds(format!(
831 "cannot broadcast dimensions {l:?} and {r:?}"
832 ))),
833 }
834}
835
836#[inline]
837fn range_shape(source_shape: &[usize], range: &[AxisRange]) -> Shape {
838 debug_assert_eq!(source_shape.len(), range.len());
839 range.iter().filter_map(|ar| ar.size()).collect()
840}
841
842#[inline]
844pub fn strides_for<'a>(shape: &'a [usize], ndim: usize) -> impl Iterator<Item = usize> + 'a {
845 debug_assert!(ndim >= shape.len());
846
847 let zeros = std::iter::repeat_n(0, ndim - shape.len());
848
849 let strides = shape.iter().copied().enumerate().map(|(x, dim)| {
850 if dim == 1 {
851 0
852 } else {
853 shape.iter().rev().take(shape.len() - 1 - x).product()
854 }
855 });
856
857 zeros.chain(strides)
858}