ha_ndarray/
lib.rs

1use std::cmp::Ordering;
2use std::fmt;
3use std::ops::{Add, Div, Mul, Rem, Sub};
4
5pub use smallvec::smallvec as axes;
6pub use smallvec::smallvec as range;
7pub use smallvec::smallvec as slice;
8pub use smallvec::smallvec as shape;
9pub use smallvec::smallvec as stackvec;
10use smallvec::SmallVec;
11
12pub use access::*;
13pub use array::{
14    MatrixDual, MatrixUnary, NDArray, NDArrayBoolean, NDArrayBooleanScalar, NDArrayCast,
15    NDArrayCompare, NDArrayCompareScalar, NDArrayMath, NDArrayMathScalar, NDArrayNumeric,
16    NDArrayRead, NDArrayReduce, NDArrayReduceAll, NDArrayReduceBoolean, NDArrayTransform,
17    NDArrayTrig, NDArrayUnary, NDArrayUnaryBoolean, NDArrayWhere, NDArrayWrite,
18};
19pub use buffer::{Buffer, BufferConverter, BufferInstance, BufferMut};
20pub use host::StackVec;
21pub use platform::*;
22
23mod access;
24mod array;
25mod buffer;
26pub mod host;
27#[cfg(feature = "opencl")]
28pub mod opencl;
29pub mod ops;
30mod platform;
31
32/// A numeric type supported by ha-ndarray
33#[cfg(feature = "opencl")]
34pub trait CType:
35    ocl::OclPrm + PartialEq + PartialOrd + Copy + Send + Sync + fmt::Display + fmt::Debug + 'static
36{
37    // type information
38
39    /// The C-language type name of this data type.
40    const TYPE: &'static str;
41
42    /// The maximum value of this data type.
43    const MAX: Self;
44
45    /// The minimum value of this data type.
46    const MIN: Self;
47
48    /// The zero value of this data type.
49    const ZERO: Self;
50
51    /// The one value of this data type.
52    const ONE: Self;
53
54    /// Whether this is a floating-point data type.
55    const IS_FLOAT: bool;
56
57    /// The floating-point type used to represent this type in floating-point-only operations.
58    type Float: Float;
59
60    // constructors
61
62    /// Construct an instance of this type from a [`f64`].
63    fn from_f64(float: f64) -> Self;
64
65    /// Construct an instance of this type from an instance of its floating-point type.
66    fn from_float(float: Self::Float) -> Self;
67
68    // arithmetic
69
70    /// Construct an instance of this type from a [`f64`].
71    fn abs(self) -> Self;
72
73    /// Add two instances of this type.
74    fn add(self, other: Self) -> Self;
75
76    /// Divide two instances of this type.
77    fn div(self, other: Self) -> Self;
78
79    /// Multiply two instances of this type.
80    fn mul(self, other: Self) -> Self;
81
82    /// Subtract two instances of this type.
83    fn sub(self, other: Self) -> Self;
84
85    /// Compute the remainder of `self.div(other)`.
86    fn rem(self, other: Self) -> Self;
87
88    // comparisons
89
90    /// Return the minimum of two values of this type.
91    fn min(l: Self, r: Self) -> Self;
92
93    /// Return the maximum of two values of this type.
94    fn max(l: Self, r: Self) -> Self;
95
96    // logarithms
97
98    /// Raise this value to the power of the given `exp`onent.
99    fn pow(self, exp: Self) -> Self;
100
101    // conversions
102
103    /// Round this value to the nearest integer.
104    fn round(self) -> Self;
105
106    /// Return the minimum of two values of this type.
107    fn to_f64(self) -> f64;
108
109    /// Convert this value to a floating-point value.
110    fn to_float(self) -> Self::Float;
111}
112
113/// A numeric type supported by ha-ndarray
114#[cfg(not(feature = "opencl"))]
115pub trait CType:
116    PartialEq + PartialOrd + Copy + Send + Sync + fmt::Display + fmt::Debug + 'static
117{
118    // type information
119
120    /// The C-language type name of this data type.
121    const TYPE: &'static str;
122
123    /// The maximum value of this data type.
124    const MAX: Self;
125
126    /// The minimum value of this data type.
127    const MIN: Self;
128
129    /// The zero value of this data type.
130    const ZERO: Self;
131
132    /// The one value of this data type.
133    const ONE: Self;
134
135    /// Whether this is a floating-point data type.
136    const IS_FLOAT: bool;
137
138    /// The floating-point type used to represent this type in floating-point-only operations.
139    type Float: Float;
140
141    // constructors
142
143    /// Construct an instance of this type from a [`f64`].
144    fn from_f64(float: f64) -> Self;
145
146    /// Construct an instance of this type from an instance of its floating-point type.
147    fn from_float(float: Self::Float) -> Self;
148
149    // arithmetic
150
151    /// Construct an instance of this type from a [`f64`].
152    fn abs(self) -> Self;
153
154    /// Add two instances of this type.
155    fn add(self, other: Self) -> Self;
156
157    /// Divide two instances of this type.
158    fn div(self, other: Self) -> Self;
159
160    /// Multiply two instances of this type.
161    fn mul(self, other: Self) -> Self;
162
163    /// Subtract two instances of this type.
164    fn sub(self, other: Self) -> Self;
165
166    /// Compute the remainder of `self.div(other)`.
167    fn rem(self, other: Self) -> Self;
168
169    // comparisons
170
171    /// Return the minimum of two values of this type.
172    fn min(l: Self, r: Self) -> Self;
173
174    /// Return the maximum of two values of this type.
175    fn max(l: Self, r: Self) -> Self;
176
177    // logarithms
178
179    /// Raise this value to the power of the given `exp`onent.
180    fn pow(self, exp: Self) -> Self;
181
182    // conversions
183
184    /// Round this value to the nearest integer.
185    fn round(self) -> Self;
186
187    /// Return the minimum of two values of this type.
188    fn to_f64(self) -> f64;
189
190    /// Convert this value to a floating-point value.
191    fn to_float(self) -> Self::Float;
192}
193
194macro_rules! c_type {
195    ($t:ty, $str:expr, $is_float:expr, $one:expr, $zero:expr, $float:ty, $abs:expr, $add:expr, $div:expr, $mul:expr, $sub:expr, $rem:expr, $round:expr, $pow:expr, $cmp_max:expr, $cmp_min:expr) => {
196        impl CType for $t {
197            const TYPE: &'static str = $str;
198
199            const MAX: Self = <$t>::MAX;
200
201            const MIN: Self = <$t>::MIN;
202
203            const ZERO: Self = $zero;
204
205            const ONE: Self = $one;
206
207            const IS_FLOAT: bool = $is_float;
208
209            type Float = $float;
210
211            fn from_f64(float: f64) -> Self {
212                float as $t
213            }
214
215            fn from_float(float: $float) -> Self {
216                float as $t
217            }
218
219            fn abs(self) -> Self {
220                $abs(self)
221            }
222
223            fn add(self, other: Self) -> Self {
224                $add(self, other)
225            }
226
227            fn div(self, other: Self) -> Self {
228                $div(self, other)
229            }
230
231            fn mul(self, other: Self) -> Self {
232                $mul(self, other)
233            }
234
235            fn sub(self, other: Self) -> Self {
236                $sub(self, other)
237            }
238
239            fn rem(self, other: Self) -> Self {
240                $rem(self, other)
241            }
242
243            fn min(l: Self, r: Self) -> Self {
244                $cmp_min(l, r)
245            }
246
247            fn max(l: Self, r: Self) -> Self {
248                $cmp_max(l, r)
249            }
250
251            fn pow(self, exp: Self) -> Self {
252                ($pow)(self, exp)
253            }
254
255            fn round(self) -> Self {
256                $round(self)
257            }
258
259            fn to_f64(self) -> f64 {
260                self as f64
261            }
262
263            fn to_float(self) -> $float {
264                self as $float
265            }
266        }
267    };
268}
269
270c_type!(
271    f32,
272    "float",
273    true,
274    1.,
275    0.,
276    Self,
277    f32::abs,
278    Add::add,
279    Div::div,
280    Mul::mul,
281    Sub::sub,
282    Rem::rem,
283    f32::round,
284    f32::powf,
285    max_f32,
286    min_f32
287);
288
289c_type!(
290    f64,
291    "double",
292    true,
293    1.,
294    0.,
295    Self,
296    f64::abs,
297    Add::add,
298    Div::div,
299    Mul::mul,
300    Sub::sub,
301    Rem::rem,
302    f64::round,
303    f64::powf,
304    max_f64,
305    min_f64
306);
307
308c_type!(
309    i8,
310    "char",
311    false,
312    1,
313    0,
314    f32,
315    Self::wrapping_abs,
316    Self::wrapping_add,
317    |l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
318    Self::wrapping_mul,
319    Self::wrapping_sub,
320    Self::wrapping_rem,
321    id,
322    |a, e| f32::powi(a as f32, e as i32) as i8,
323    Ord::max,
324    Ord::min
325);
326
327c_type!(
328    i16,
329    "short",
330    false,
331    1,
332    0,
333    f32,
334    Self::wrapping_abs,
335    Self::wrapping_add,
336    |l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
337    Self::wrapping_mul,
338    Self::wrapping_sub,
339    Self::wrapping_rem,
340    id,
341    |a, e| f32::powi(a as f32, e as i32) as i16,
342    Ord::max,
343    Ord::min
344);
345
346c_type!(
347    i32,
348    "int",
349    false,
350    1,
351    0,
352    f32,
353    Self::wrapping_abs,
354    Self::wrapping_add,
355    |l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
356    Self::wrapping_mul,
357    Self::wrapping_sub,
358    Self::wrapping_rem,
359    id,
360    |a, e| f32::powi(a as f32, e) as i32,
361    Ord::max,
362    Ord::min
363);
364
365c_type!(
366    i64,
367    "long",
368    false,
369    1,
370    0,
371    f64,
372    Self::wrapping_abs,
373    Self::wrapping_add,
374    |l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
375    Self::wrapping_mul,
376    Self::wrapping_sub,
377    Self::wrapping_rem,
378    id,
379    |a, e| f64::powi(
380        a as f64,
381        i32::try_from(e).unwrap_or_else(|_| if e >= 0 { i32::MAX } else { i32::MIN })
382    ) as i64,
383    Ord::max,
384    Ord::min
385);
386
387c_type!(
388    u8,
389    "uchar",
390    false,
391    1,
392    0,
393    f32,
394    id,
395    Self::wrapping_add,
396    |l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
397    Self::wrapping_mul,
398    Self::wrapping_sub,
399    Self::wrapping_rem,
400    id,
401    |a, e| u8::pow(a, e as u32),
402    Ord::max,
403    Ord::min
404);
405
406c_type!(
407    u16,
408    "ushort",
409    false,
410    1,
411    0,
412    f32,
413    id,
414    Self::wrapping_add,
415    |l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
416    Self::wrapping_mul,
417    Self::wrapping_sub,
418    Self::wrapping_rem,
419    id,
420    |a, e| u16::pow(a, e as u32),
421    Ord::max,
422    Ord::min
423);
424
425c_type!(
426    u32,
427    "uint",
428    false,
429    1,
430    0,
431    f32,
432    id,
433    Self::wrapping_add,
434    |l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
435    Self::wrapping_mul,
436    Self::wrapping_sub,
437    Self::wrapping_rem,
438    id,
439    |a, e| u32::pow(a, e),
440    Ord::max,
441    Ord::min
442);
443
444c_type!(
445    u64,
446    "ulong",
447    false,
448    1,
449    0,
450    f64,
451    id,
452    Self::wrapping_add,
453    |l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
454    Self::wrapping_mul,
455    Self::wrapping_sub,
456    Self::wrapping_rem,
457    id,
458    |a, e| u64::pow(a, u32::try_from(e).unwrap_or(u32::MAX)),
459    Ord::max,
460    Ord::min
461);
462
463fn id<T>(this: T) -> T {
464    this
465}
466
467fn max_f32(l: f32, r: f32) -> f32 {
468    match l.total_cmp(&r) {
469        Ordering::Less => r,
470        Ordering::Equal => l,
471        Ordering::Greater => l,
472    }
473}
474
475fn min_f32(l: f32, r: f32) -> f32 {
476    match l.total_cmp(&r) {
477        Ordering::Less => l,
478        Ordering::Equal => l,
479        Ordering::Greater => r,
480    }
481}
482
483fn max_f64(l: f64, r: f64) -> f64 {
484    match l.total_cmp(&r) {
485        Ordering::Less => r,
486        Ordering::Equal => l,
487        Ordering::Greater => l,
488    }
489}
490
491fn min_f64(l: f64, r: f64) -> f64 {
492    match l.total_cmp(&r) {
493        Ordering::Less => l,
494        Ordering::Equal => l,
495        Ordering::Greater => r,
496    }
497}
498
499/// A floating-point [`CType`]
500pub trait Float: CType<Float = Self> {
501    // numeric methods
502    /// Return `true` if this [`Float`] is infinite (positive or negative infinity).
503    fn is_inf(self) -> bool;
504
505    /// Return `true` if this [`Float`] is not a number (e.g. a float representation of `1.0 / 0.0`).
506    fn is_nan(self) -> bool;
507
508    // logarithms
509    /// Exponentiate this number (equivalent to `consts::E.pow(self)`).
510    fn exp(self) -> Self;
511
512    /// Return the natural logarithm of this [`Float`].
513    fn ln(self) -> Self;
514
515    /// Calculate the logarithm of this [`Float`] w/r/t the given `base`.
516    fn log(self, base: Self) -> Self;
517
518    // trigonometry
519    /// Return the sine of this [`Float`] (in radians).
520    fn sin(self) -> Self;
521
522    /// Return the arcsine of this [`Float`] (in radians).
523    fn asin(self) -> Self;
524
525    /// Return the hyperbolic sine of this [`Float`] (in radians).
526    fn sinh(self) -> Self;
527
528    /// Return the cosine of this [`Float`] (in radians).
529    fn cos(self) -> Self;
530
531    /// Return the arcsine of this [`Float`] (in radians).
532    fn acos(self) -> Self;
533
534    /// Return the hyperbolic cosine of this [`Float`] (in radians).
535    fn cosh(self) -> Self;
536
537    /// Return the tangent of this [`Float`] (in radians).
538    fn tan(self) -> Self;
539
540    /// Return the arctangent of this [`Float`] (in radians).
541    fn atan(self) -> Self;
542
543    /// Return the hyperbolic tangent of this [`Float`] (in radians).
544    fn tanh(self) -> Self;
545
546    // utility
547    /// Cast this [`Float`] to an [`f64`].
548    fn to_f64(self) -> f64;
549}
550
551macro_rules! float_type {
552    ($t:ty) => {
553        impl Float for $t {
554            fn is_inf(self) -> bool {
555                <$t>::is_infinite(self)
556            }
557
558            fn is_nan(self) -> bool {
559                <$t>::is_nan(self)
560            }
561
562            fn exp(self) -> Self {
563                <$t>::exp(self)
564            }
565
566            fn ln(self) -> Self {
567                <$t>::ln(self)
568            }
569
570            fn log(self, base: Self) -> Self {
571                <$t>::log(self, base)
572            }
573
574            fn sin(self) -> Self {
575                <$t>::sin(self)
576            }
577
578            fn asin(self) -> Self {
579                <$t>::asin(self)
580            }
581
582            fn sinh(self) -> Self {
583                <$t>::sinh(self)
584            }
585
586            fn cos(self) -> Self {
587                <$t>::cos(self)
588            }
589
590            fn acos(self) -> Self {
591                <$t>::acos(self)
592            }
593
594            fn cosh(self) -> Self {
595                <$t>::cosh(self)
596            }
597
598            fn tan(self) -> Self {
599                <$t>::tan(self)
600            }
601
602            fn atan(self) -> Self {
603                <$t>::atan(self)
604            }
605
606            fn tanh(self) -> Self {
607                <$t>::tanh(self)
608            }
609
610            fn to_f64(self) -> f64 {
611                self as f64
612            }
613        }
614    };
615}
616
617float_type!(f32);
618float_type!(f64);
619
620/// An array math error
621pub enum Error {
622    Bounds(String),
623    Interface(String),
624    Unsupported(String),
625    #[cfg(feature = "opencl")]
626    OCL(std::sync::Arc<ocl::Error>),
627}
628
629// Clone is required to support memoizing OpenCL programs
630// since constructing an [`ocl::Program`] may return an error
631impl Clone for Error {
632    fn clone(&self) -> Self {
633        match self {
634            Self::Bounds(msg) => Self::Bounds(msg.clone()),
635            Self::Interface(msg) => Self::Interface(msg.clone()),
636            Self::Unsupported(msg) => Self::Unsupported(msg.clone()),
637            #[cfg(feature = "opencl")]
638            Self::OCL(cause) => Self::OCL(cause.clone()),
639        }
640    }
641}
642
643#[cfg(feature = "opencl")]
644impl From<ocl::Error> for Error {
645    fn from(cause: ocl::Error) -> Self {
646        #[cfg(debug_assertions)]
647        panic!("OpenCL error: {:?}", cause);
648
649        #[cfg(not(debug_assertions))]
650        Self::OCL(std::sync::Arc::new(cause))
651    }
652}
653
654impl fmt::Debug for Error {
655    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
656        match self {
657            Self::Bounds(cause) => f.write_str(cause),
658            Self::Interface(cause) => f.write_str(cause),
659            Self::Unsupported(cause) => f.write_str(cause),
660            #[cfg(feature = "opencl")]
661            Self::OCL(cause) => cause.fmt(f),
662        }
663    }
664}
665
666impl fmt::Display for Error {
667    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
668        match self {
669            Self::Bounds(cause) => f.write_str(cause),
670            Self::Interface(cause) => f.write_str(cause),
671            Self::Unsupported(cause) => f.write_str(cause),
672            #[cfg(feature = "opencl")]
673            Self::OCL(cause) => cause.fmt(f),
674        }
675    }
676}
677
678impl std::error::Error for Error {}
679
680/// A list of n-dimensional array axes
681pub type Axes = SmallVec<[usize; 8]>;
682
683/// An n-dimensional selection range, used to slice an array
684pub type Range = SmallVec<[AxisRange; 8]>;
685
686/// The shape of an n-dimensional array
687pub type Shape = SmallVec<[usize; 8]>;
688
689/// The strides used to access an n-dimensional array
690pub type Strides = SmallVec<[usize; 8]>;
691
692/// An n-dimensional array on the top-level [`Platform`]
693pub type Array<T, A> = array::Array<T, A, Platform>;
694
695/// An n-dimensional array backed by a buffer on the top-level [`Platform`]
696pub type ArrayBuf<T, B> = array::Array<T, AccessBuf<B>, Platform>;
697
698/// The result of an n-dimensional array operation
699pub type ArrayOp<T, Op> = array::Array<T, AccessOp<Op>, Platform>;
700
701/// A general type of n-dimensional array used to elide recursive types
702pub type ArrayAccess<T> = array::Array<T, Accessor<T>, Platform>;
703
704/// An accessor for the result of an n-dimensional array operation on the top-level [`Platform`]
705pub type AccessOp<Op> = access::AccessOp<Op, Platform>;
706
707/// Bounds on an individual array axis
708#[derive(Clone, Eq, PartialEq, Hash)]
709pub enum AxisRange {
710    At(usize),
711    In(usize, usize, usize),
712    Of(SmallVec<[usize; 8]>),
713}
714
715impl AxisRange {
716    /// Return `true` if this is an index bound (i.e. not a slice)
717    pub fn is_index(&self) -> bool {
718        match self {
719            Self::At(_) => true,
720            _ => false,
721        }
722    }
723
724    /// Return the number of elements contained within this bound.
725    /// Returns `None` for an index bound.
726    pub fn size(&self) -> Option<usize> {
727        match self {
728            Self::At(_) => None,
729            Self::In(start, stop, step) => Some((stop - start) / step),
730            Self::Of(indices) => Some(indices.len()),
731        }
732    }
733}
734
735impl From<usize> for AxisRange {
736    fn from(i: usize) -> Self {
737        Self::At(i)
738    }
739}
740
741impl From<std::ops::Range<usize>> for AxisRange {
742    fn from(range: std::ops::Range<usize>) -> Self {
743        Self::In(range.start, range.end, 1)
744    }
745}
746
747impl From<SmallVec<[usize; 8]>> for AxisRange {
748    fn from(indices: SmallVec<[usize; 8]>) -> Self {
749        Self::Of(indices)
750    }
751}
752
753impl fmt::Debug for AxisRange {
754    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
755        match self {
756            Self::At(i) => write!(f, "{}", i),
757            Self::In(start, stop, 1) => write!(f, "{}:{}", start, stop),
758            Self::In(start, stop, step) => write!(f, "{}:{}:{}", start, stop, step),
759            Self::Of(indices) => write!(f, "{:?}", indices),
760        }
761    }
762}
763
764/// Compute the shape which results from broadcasting the `left` and `right` shapes, if possible.
765#[inline]
766pub fn broadcast_shape(left: &[usize], right: &[usize]) -> Result<Shape, Error> {
767    if left.is_empty() || right.is_empty() {
768        return Err(Error::Bounds("cannot broadcast empty shape".to_string()));
769    } else if left.len() < right.len() {
770        return broadcast_shape(right, left);
771    }
772
773    let offset = left.len() - right.len();
774
775    let mut shape = Shape::with_capacity(left.len());
776    shape.extend_from_slice(&left[..offset]);
777
778    for (l, r) in left.into_iter().copied().zip(right.into_iter().copied()) {
779        if r == 1 || r == l {
780            shape.push(l);
781        } else if l == 1 {
782            shape.push(r);
783        } else {
784            return Err(Error::Bounds(format!(
785                "cannot broadcast dimensions {l} and {r}"
786            )));
787        }
788    }
789
790    debug_assert!(!shape.iter().any(|dim| *dim == 0));
791
792    Ok(shape)
793}
794
795#[inline]
796fn range_shape(source_shape: &[usize], range: &[AxisRange]) -> Shape {
797    debug_assert_eq!(source_shape.len(), range.len());
798    range.iter().filter_map(|ar| ar.size()).collect()
799}
800
801/// Construct an iterator over the strides for the given shape and number of dimensions.
802#[inline]
803pub fn strides_for<'a>(shape: &'a [usize], ndim: usize) -> impl Iterator<Item = usize> + 'a {
804    debug_assert!(ndim >= shape.len());
805
806    let zeros = std::iter::repeat(0).take(ndim - shape.len());
807
808    let strides = shape.iter().copied().enumerate().map(|(x, dim)| {
809        if dim == 1 {
810            0
811        } else {
812            shape.iter().rev().take(shape.len() - 1 - x).product()
813        }
814    });
815
816    zeros.chain(strides)
817}