afarray/
array.rs

1use std::fmt;
2use std::iter::FromIterator;
3use std::ops::*;
4
5use arrayfire as af;
6use async_trait::async_trait;
7use destream::{de, en};
8use futures::TryFutureExt;
9use get_size::GetSize;
10use num_traits::{FromPrimitive, ToPrimitive};
11use number_general::*;
12use safecast::{as_type, CastFrom, CastInto};
13use serde::de::{Deserialize, Deserializer};
14use serde::ser::{Serialize, Serializer};
15
16use super::ext::*;
17use super::{error, Complex, Result};
18
19/// The [`NumberType`] of the product of an [`Array`] with the given `array_dtype`.
20pub fn product_dtype(array_dtype: NumberType) -> NumberType {
21    use {ComplexType as CT, FloatType as FT, IntType as IT, NumberType as NT, UIntType as UT};
22
23    match array_dtype {
24        NT::Bool => ArrayExt::<bool>::product_dtype(),
25        NT::Complex(ct) => match ct {
26            CT::C32 => ArrayExt::<Complex<f32>>::product_dtype(),
27            CT::C64 => ArrayExt::<Complex<f64>>::product_dtype(),
28            CT::Complex => ArrayExt::<Complex<f64>>::product_dtype(),
29        },
30        NT::Float(ft) => match ft {
31            FT::F32 => ArrayExt::<f32>::product_dtype(),
32            FT::F64 => ArrayExt::<f64>::product_dtype(),
33            FT::Float => ArrayExt::<f64>::product_dtype(),
34        },
35        NT::Int(it) => match it {
36            IT::I8 => ArrayExt::<i16>::product_dtype(),
37            IT::I16 => ArrayExt::<i16>::product_dtype(),
38            IT::I32 => ArrayExt::<i32>::product_dtype(),
39            IT::I64 => ArrayExt::<i64>::product_dtype(),
40            IT::Int => ArrayExt::<i64>::product_dtype(),
41        },
42        NT::UInt(ut) => match ut {
43            UT::U8 => ArrayExt::<u8>::product_dtype(),
44            UT::U16 => ArrayExt::<u16>::product_dtype(),
45            UT::U32 => ArrayExt::<u32>::product_dtype(),
46            UT::U64 => ArrayExt::<u64>::product_dtype(),
47            UT::UInt => ArrayExt::<u64>::product_dtype(),
48        },
49        NT::Number => ArrayExt::<f64>::product_dtype(),
50    }
51}
52
53/// The [`NumberType`] of the sum of an [`Array`] with the given `array_dtype`.
54pub fn sum_dtype(array_dtype: NumberType) -> NumberType {
55    use {ComplexType as CT, FloatType as FT, IntType as IT, NumberType as NT, UIntType as UT};
56
57    match array_dtype {
58        NT::Bool => ArrayExt::<bool>::sum_dtype(),
59        NT::Complex(ct) => match ct {
60            CT::C32 => ArrayExt::<Complex<f32>>::sum_dtype(),
61            CT::C64 => ArrayExt::<Complex<f64>>::sum_dtype(),
62            CT::Complex => ArrayExt::<Complex<f64>>::sum_dtype(),
63        },
64        NT::Float(ft) => match ft {
65            FT::F32 => ArrayExt::<f32>::sum_dtype(),
66            FT::F64 => ArrayExt::<f64>::sum_dtype(),
67            FT::Float => ArrayExt::<f64>::sum_dtype(),
68        },
69        NT::Int(it) => match it {
70            IT::I8 => ArrayExt::<i16>::sum_dtype(),
71            IT::I16 => ArrayExt::<i16>::sum_dtype(),
72            IT::I32 => ArrayExt::<i32>::sum_dtype(),
73            IT::I64 => ArrayExt::<i64>::sum_dtype(),
74            IT::Int => ArrayExt::<i64>::sum_dtype(),
75        },
76        NT::UInt(ut) => match ut {
77            UT::U8 => ArrayExt::<u8>::sum_dtype(),
78            UT::U16 => ArrayExt::<u16>::sum_dtype(),
79            UT::U32 => ArrayExt::<u32>::sum_dtype(),
80            UT::U64 => ArrayExt::<u64>::sum_dtype(),
81            UT::UInt => ArrayExt::<u64>::sum_dtype(),
82        },
83        NT::Number => ArrayExt::<f64>::sum_dtype(),
84    }
85}
86
87macro_rules! dispatch {
88    ($this:expr, $call:expr) => {
89        match $this {
90            Array::Bool(this) => $call(this),
91            Array::C32(this) => $call(this),
92            Array::C64(this) => $call(this),
93            Array::F32(this) => $call(this),
94            Array::F64(this) => $call(this),
95            Array::I16(this) => $call(this),
96            Array::I32(this) => $call(this),
97            Array::I64(this) => $call(this),
98            Array::U8(this) => $call(this),
99            Array::U16(this) => $call(this),
100            Array::U32(this) => $call(this),
101            Array::U64(this) => $call(this),
102        }
103    };
104}
105
106macro_rules! reduce {
107    ($this:expr, $reduce:expr, $stride:expr) => {
108        match $this {
109            Array::Bool(this) => $reduce(this, $stride),
110            Array::C32(this) => $reduce(this, $stride),
111            Array::C64(this) => $reduce(this, $stride),
112            Array::F32(this) => $reduce(this, $stride),
113            Array::F64(this) => $reduce(this, $stride),
114            Array::I16(this) => $reduce(this, $stride),
115            Array::I32(this) => $reduce(this, $stride),
116            Array::I64(this) => $reduce(this, $stride),
117            Array::U8(this) => $reduce(this, $stride),
118            Array::U16(this) => $reduce(this, $stride),
119            Array::U32(this) => $reduce(this, $stride),
120            Array::U64(this) => $reduce(this, $stride),
121        }
122    };
123}
124
125macro_rules! trig {
126    ($fun:ident) => {
127        pub fn $fun(&self) -> Array {
128            fn $fun<T>(this: &ArrayExt<T>) -> Array
129            where
130                T: af::HasAfEnum + Default,
131                ArrayExt<T>: ArrayInstanceTrig<T>,
132                Array: From<ArrayExt<T::UnaryOutType>>,
133            {
134                this.$fun().into()
135            }
136
137            dispatch!(self, $fun)
138        }
139    };
140}
141
142/// A generic one-dimensional array which supports all variants of [`NumberType`].
143#[derive(Clone)]
144pub enum Array {
145    Bool(ArrayExt<bool>),
146    C32(ArrayExt<Complex<f32>>),
147    C64(ArrayExt<Complex<f64>>),
148    F32(ArrayExt<f32>),
149    F64(ArrayExt<f64>),
150    I16(ArrayExt<i16>),
151    I32(ArrayExt<i32>),
152    I64(ArrayExt<i64>),
153    U8(ArrayExt<u8>),
154    U16(ArrayExt<u16>),
155    U32(ArrayExt<u32>),
156    U64(ArrayExt<u64>),
157}
158
159impl GetSize for Array {
160    fn get_size(&self) -> usize {
161        self.dtype().size() * self.len()
162    }
163}
164
165impl Array {
166    /// Cast the values of this array into an `ArrayExt<T>`.
167    pub fn type_cast<T: af::HasAfEnum>(&self) -> ArrayExt<T> {
168        dispatch!(self, ArrayExt::type_cast)
169    }
170
171    /// Concatenate two `Array`s.
172    pub fn concatenate(left: &Array, right: &Array) -> Array {
173        use Array::*;
174        match (left, right) {
175            (Bool(l), Bool(r)) => Bool(ArrayExt::concatenate(l, r)),
176
177            (F32(l), F32(r)) => F32(ArrayExt::concatenate(l, r)),
178            (F64(l), F64(r)) => F64(ArrayExt::concatenate(l, r)),
179
180            (C32(l), C32(r)) => C32(ArrayExt::concatenate(l, r)),
181            (C64(l), C64(r)) => C64(ArrayExt::concatenate(l, r)),
182
183            (I16(l), I16(r)) => I16(ArrayExt::concatenate(l, r)),
184            (I32(l), I32(r)) => I32(ArrayExt::concatenate(l, r)),
185            (I64(l), I64(r)) => I64(ArrayExt::concatenate(l, r)),
186
187            (U8(l), U8(r)) => U8(ArrayExt::concatenate(l, r)),
188            (U16(l), U16(r)) => U16(ArrayExt::concatenate(l, r)),
189            (U32(l), U32(r)) => U32(ArrayExt::concatenate(l, r)),
190            (U64(l), U64(r)) => U64(ArrayExt::concatenate(l, r)),
191
192            (l, r) if l.dtype() > r.dtype() => Array::concatenate(l, &r.cast_into(l.dtype())),
193            (l, r) if l.dtype() < r.dtype() => Array::concatenate(&l.cast_into(r.dtype()), r),
194
195            (l, r) => unreachable!("concatenate {}, {}", l, r),
196        }
197    }
198
199    /// Construct a new `Array` with the given constant value and length.
200    pub fn constant(value: Number, length: usize) -> Array {
201        use number_general::Complex;
202        use Array::*;
203
204        match value {
205            Number::Bool(b) => {
206                let b: bool = b.into();
207                Bool(ArrayExt::constant(b, length))
208            }
209            Number::Complex(c) => match c {
210                Complex::C32(c) => C32(ArrayExt::constant(c, length)),
211                Complex::C64(c) => C64(ArrayExt::constant(c, length)),
212            },
213            Number::Float(f) => match f {
214                Float::F32(f) => F32(ArrayExt::constant(f, length)),
215                Float::F64(f) => F64(ArrayExt::constant(f, length)),
216            },
217            Number::Int(i) => match i {
218                Int::I16(i) => I16(ArrayExt::constant(i, length)),
219                Int::I32(i) => I32(ArrayExt::constant(i, length)),
220                Int::I64(i) => I64(ArrayExt::constant(i, length)),
221                other => panic!("ArrayFire does not support {}", other),
222            },
223            Number::UInt(u) => match u {
224                UInt::U8(u) => U8(ArrayExt::constant(u, length)),
225                UInt::U16(u) => U16(ArrayExt::constant(u, length)),
226                UInt::U32(u) => U32(ArrayExt::constant(u, length)),
227                UInt::U64(u) => U64(ArrayExt::constant(u, length)),
228            },
229        }
230    }
231
232    /// Construct a new `Array` with a random normal distribution.
233    pub fn random_normal(dtype: FloatType, length: usize) -> Array {
234        match dtype {
235            FloatType::F32 => Array::F32(ArrayExt::random_normal(length)),
236            _ => Array::F64(ArrayExt::random_normal(length)),
237        }
238    }
239
240    /// Construct a new `Array` with a uniform random distribution.
241    pub fn random_uniform(dtype: FloatType, length: usize) -> Array {
242        match dtype {
243            FloatType::F32 => Array::F32(ArrayExt::random_uniform(length)),
244            _ => Array::F64(ArrayExt::random_uniform(length)),
245        }
246    }
247
248    /// The [`NumberType`] of this `Array`.
249    pub fn dtype(&self) -> NumberType {
250        use number_general::DType;
251        use Array::*;
252
253        match self {
254            Bool(_) => bool::dtype(),
255            C32(_) => Complex::<f32>::dtype(),
256            C64(_) => Complex::<f64>::dtype(),
257            F32(_) => f32::dtype(),
258            F64(_) => f64::dtype(),
259            I16(_) => i16::dtype(),
260            I32(_) => i32::dtype(),
261            I64(_) => i64::dtype(),
262            U8(_) => u8::dtype(),
263            U16(_) => u16::dtype(),
264            U32(_) => u32::dtype(),
265            U64(_) => u64::dtype(),
266        }
267    }
268
269    /// Cast into an `Array` of a different `NumberType`.
270    pub fn cast_into(&self, dtype: NumberType) -> Array {
271        use {ComplexType as CT, FloatType as FT, IntType as IT, NumberType as NT, UIntType as UT};
272
273        match dtype {
274            NT::Bool => Self::Bool(self.type_cast()),
275            NT::Complex(ct) => match ct {
276                CT::C32 => Self::C32(self.type_cast()),
277                CT::C64 => Self::C64(self.type_cast()),
278                CT::Complex => Self::C64(self.type_cast()),
279            },
280            NT::Float(ft) => match ft {
281                FT::F32 => Self::F32(self.type_cast()),
282                FT::F64 => Self::F64(self.type_cast()),
283                FT::Float => Self::F64(self.type_cast()),
284            },
285            NT::Int(it) => match it {
286                IT::I16 => Self::I16(self.type_cast()),
287                IT::I32 => Self::I32(self.type_cast()),
288                IT::I64 => Self::I64(self.type_cast()),
289                IT::Int => Self::I64(self.type_cast()),
290                other => panic!("ArrayFire does not support {}", other),
291            },
292            NT::UInt(ut) => match ut {
293                UT::U8 => Self::U8(self.type_cast()),
294                UT::U16 => Self::U16(self.type_cast()),
295                UT::U32 => Self::U32(self.type_cast()),
296                UT::U64 => Self::U64(self.type_cast()),
297                UT::UInt => Self::U64(self.type_cast()),
298            },
299            NT::Number => self.clone(),
300        }
301    }
302
303    /// Copy the contents of this `Array` into a new `Vec`.
304    pub fn to_vec(&self) -> Vec<Number> {
305        fn to_vec<T>(this: &ArrayExt<T>) -> Vec<Number>
306        where
307            T: af::HasAfEnum + Clone + Default,
308            Number: From<T>,
309        {
310            this.to_vec().into_iter().map(Number::from).collect()
311        }
312
313        dispatch!(self, to_vec)
314    }
315
316    /// Calculate the element-wise absolute value.
317    pub fn abs(&self) -> Array {
318        use Array::*;
319        match self {
320            C32(c) => F32(c.abs()),
321            C64(c) => F64(c.abs()),
322            F32(f) => F32(f.abs()),
323            F64(f) => F64(f.abs()),
324            I16(i) => I16(i.abs()),
325            I32(i) => I32(i.abs()),
326            I64(i) => I64(i.abs()),
327            other => other.clone(),
328        }
329    }
330
331    /// Returns `true` if all elements of this `Array` are nonzero.
332    pub fn all(&self) -> bool {
333        dispatch!(self, ArrayExt::all)
334    }
335
336    /// Returns `true` if any element of this `Array` is nonzero.
337    pub fn any(&self) -> bool {
338        dispatch!(self, ArrayExt::any)
339    }
340
341    /// Element-wise logical and.
342    pub fn and(&self, other: &Array) -> Array {
343        let this: ArrayExt<bool> = self.type_cast();
344        let that: ArrayExt<bool> = other.type_cast();
345        Array::Bool(this.and(&that))
346    }
347
348    /// Element-wise logical and, relative to a constant `other`.
349    pub fn and_const(&self, other: Number) -> Array {
350        let this: ArrayExt<bool> = self.type_cast();
351        let that: ArrayExt<bool> = ArrayExt::from(&[other.cast_into()][..]);
352        Array::Bool(this.and(&that))
353    }
354
355    /// Find the maximum value in this `Array` and its offset.
356    pub fn argmax(&self) -> (usize, Number) {
357        fn imax<T: af::HasAfEnum>(x: &ArrayExt<T>) -> (usize, Number)
358        where
359            ArrayExt<T>: ArrayInstanceIndex,
360            Number: From<<ArrayExt<T> as ArrayInstance>::DType>,
361        {
362            let (i, max) = x.argmax();
363            (i, max.into())
364        }
365
366        dispatch!(self, imax)
367    }
368
369    /// Element-wise equality comparison.
370    pub fn eq(&self, other: &Array) -> Array {
371        use Array::*;
372        match (self, other) {
373            (Bool(l), Bool(r)) => Bool(l.eq(r.deref())),
374            (C32(l), C32(r)) => Bool(l.eq(r.deref())),
375            (C64(l), C64(r)) => Bool(l.eq(r.deref())),
376            (F32(l), F32(r)) => Bool(l.eq(r.deref())),
377            (F64(l), F64(r)) => Bool(l.eq(r.deref())),
378            (I16(l), I16(r)) => Bool(l.eq(r.deref())),
379            (I32(l), I32(r)) => Bool(l.eq(r.deref())),
380            (I64(l), I64(r)) => Bool(l.eq(r.deref())),
381            (U8(l), U8(r)) => Bool(l.eq(r.deref())),
382            (U16(l), U16(r)) => Bool(l.eq(r.deref())),
383            (U32(l), U32(r)) => Bool(l.eq(r.deref())),
384            (U64(l), U64(r)) => Bool(l.eq(r.deref())),
385            (l, r) => match (l.dtype(), r.dtype()) {
386                (l_dtype, r_dtype) if l_dtype > r_dtype => l.eq(&r.cast_into(l_dtype)),
387                (l_dtype, r_dtype) if l_dtype < r_dtype => l.cast_into(r_dtype).eq(r),
388                (l, r) => unreachable!("{} equal to {}", l, r),
389            },
390        }
391    }
392
393    /// Element-wise equality comparison.
394    pub fn eq_const(&self, other: Number) -> Array {
395        use number_general::Complex;
396        match (self, other) {
397            (Self::Bool(l), Number::Bool(r)) => Self::Bool(l.eq(&bool::from(r))),
398
399            (Self::C32(l), Number::Complex(Complex::C32(r))) => Self::Bool(l.eq(&r)),
400            (Self::C64(l), Number::Complex(Complex::C64(r))) => Self::Bool(l.eq(&r)),
401
402            (Self::F32(l), Number::Float(Float::F32(r))) => Self::Bool(l.eq(&r)),
403            (Self::F64(l), Number::Float(Float::F64(r))) => Self::Bool(l.eq(&r)),
404
405            (Self::I16(l), Number::Int(Int::I16(r))) => Self::Bool(l.eq(&r)),
406            (Self::I32(l), Number::Int(Int::I32(r))) => Self::Bool(l.eq(&r)),
407            (Self::I64(l), Number::Int(Int::I64(r))) => Self::Bool(l.eq(&r)),
408
409            (Self::U8(l), Number::UInt(UInt::U8(r))) => Self::Bool(l.eq(&r)),
410            (Self::U16(l), Number::UInt(UInt::U16(r))) => Self::Bool(l.eq(&r)),
411            (Self::U32(l), Number::UInt(UInt::U32(r))) => Self::Bool(l.eq(&r)),
412            (Self::U64(l), Number::UInt(UInt::U64(r))) => Self::Bool(l.eq(&r)),
413
414            (l, r) => match (l.dtype(), r.class()) {
415                (l_dtype, r_dtype) if l_dtype > r_dtype => l.eq_const(r.into_type(l_dtype)),
416                (l_dtype, r_dtype) if l_dtype < r_dtype => l.cast_into(r_dtype).eq_const(r),
417                (l, r) => unreachable!("{} equal to {}", l, r),
418            },
419        }
420    }
421
422    /// Raise `e` to the power of `self`.
423    pub fn exp(&self) -> Array {
424        fn exp<T>(this: &ArrayExt<T>) -> Array
425        where
426            T: af::HasAfEnum + Default,
427            Array: From<ArrayExt<T::UnaryOutType>>,
428        {
429            this.exp().into()
430        }
431
432        dispatch!(self, exp)
433    }
434
435    /// Element-wise greater-than comparison.
436    pub fn gt(&self, other: &Array) -> Array {
437        use Array::*;
438        match (self, other) {
439            (Bool(l), Bool(r)) => Bool(l.gt(r.deref())),
440            (C32(l), C32(r)) => Bool(l.gt(r.deref())),
441            (C64(l), C64(r)) => Bool(l.gt(r.deref())),
442            (F32(l), F32(r)) => Bool(l.gt(r.deref())),
443            (F64(l), F64(r)) => Bool(l.gt(r.deref())),
444            (I16(l), I16(r)) => Bool(l.gt(r.deref())),
445            (I32(l), I32(r)) => Bool(l.gt(r.deref())),
446            (I64(l), I64(r)) => Bool(l.gt(r.deref())),
447            (U8(l), U8(r)) => Bool(l.gt(r.deref())),
448            (U16(l), U16(r)) => Bool(l.gt(r.deref())),
449            (U32(l), U32(r)) => Bool(l.gt(r.deref())),
450            (U64(l), U64(r)) => Bool(l.gt(r.deref())),
451            (l, r) => match (l.dtype(), r.dtype()) {
452                (l_dtype, r_dtype) if l_dtype > r_dtype => l.gt(&r.cast_into(l_dtype)),
453                (l_dtype, r_dtype) if l_dtype < r_dtype => l.cast_into(r_dtype).gt(r),
454                (l, r) => unreachable!("{} greater than {}", l, r),
455            },
456        }
457    }
458
459    /// Element-wise greater-than comparison.
460    pub fn gt_const(&self, other: Number) -> Array {
461        use number_general::Complex;
462        match (self, other) {
463            (Self::Bool(l), Number::Bool(r)) => Self::Bool(l.gt(&bool::from(r))),
464            (Self::C32(l), Number::Complex(Complex::C32(r))) => Self::Bool(l.gt(&r)),
465            (Self::C64(l), Number::Complex(Complex::C64(r))) => Self::Bool(l.gt(&r)),
466            (Self::F32(l), Number::Float(Float::F32(r))) => Self::Bool(l.gt(&r)),
467            (Self::F64(l), Number::Float(Float::F64(r))) => Self::Bool(l.gt(&r)),
468            (Self::I16(l), Number::Int(Int::I16(r))) => Self::Bool(l.gt(&r)),
469            (Self::I32(l), Number::Int(Int::I32(r))) => Self::Bool(l.gt(&r)),
470            (Self::I64(l), Number::Int(Int::I64(r))) => Self::Bool(l.gt(&r)),
471            (Self::U8(l), Number::UInt(UInt::U8(r))) => Self::Bool(l.gt(&r)),
472            (Self::U16(l), Number::UInt(UInt::U16(r))) => Self::Bool(l.gt(&r)),
473            (Self::U32(l), Number::UInt(UInt::U32(r))) => Self::Bool(l.gt(&r)),
474            (Self::U64(l), Number::UInt(UInt::U64(r))) => Self::Bool(l.gt(&r)),
475            (l, r) => match (l.dtype(), r.class()) {
476                (l_dtype, r_dtype) if l_dtype > r_dtype => l.gt_const(r.into_type(l_dtype)),
477                (l_dtype, r_dtype) if l_dtype < r_dtype => l.cast_into(r_dtype).gt_const(r),
478                (l, r) => unreachable!("{} greater than {}", l, r),
479            },
480        }
481    }
482
483    /// Element-wise greater-or-equal comparison.
484    pub fn gte(&self, other: &Array) -> Array {
485        use Array::*;
486        match (self, other) {
487            (Bool(l), Bool(r)) => Bool(l.gte(r.deref())),
488            (C32(l), C32(r)) => Bool(l.gte(r.deref())),
489            (C64(l), C64(r)) => Bool(l.gte(r.deref())),
490            (F32(l), F32(r)) => Bool(l.gte(r.deref())),
491            (F64(l), F64(r)) => Bool(l.gte(r.deref())),
492            (I16(l), I16(r)) => Bool(l.gte(r.deref())),
493            (I32(l), I32(r)) => Bool(l.gte(r.deref())),
494            (I64(l), I64(r)) => Bool(l.gte(r.deref())),
495            (U8(l), U8(r)) => Bool(l.gte(r.deref())),
496            (U16(l), U16(r)) => Bool(l.gte(r.deref())),
497            (U32(l), U32(r)) => Bool(l.gte(r.deref())),
498            (U64(l), U64(r)) => Bool(l.gte(r.deref())),
499            (l, r) => match (l.dtype(), r.dtype()) {
500                (l_dtype, r_dtype) if l_dtype > r_dtype => l.gte(&r.cast_into(l_dtype)),
501                (l_dtype, r_dtype) if l_dtype < r_dtype => l.cast_into(r_dtype).gte(r),
502                (l, r) => unreachable!("{} greater than or equal to {}", l, r),
503            },
504        }
505    }
506
507    /// Element-wise greater-than-or-equal comparison.
508    pub fn gte_const(&self, other: Number) -> Array {
509        use number_general::Complex;
510        match (self, other) {
511            (Self::Bool(l), Number::Bool(r)) => Self::Bool(l.gte(&bool::from(r))),
512
513            (Self::C32(l), Number::Complex(Complex::C32(r))) => Self::Bool(l.gte(&r)),
514            (Self::C64(l), Number::Complex(Complex::C64(r))) => Self::Bool(l.gte(&r)),
515
516            (Self::F32(l), Number::Float(Float::F32(r))) => Self::Bool(l.gte(&r)),
517            (Self::F64(l), Number::Float(Float::F64(r))) => Self::Bool(l.gte(&r)),
518
519            (Self::I16(l), Number::Int(Int::I16(r))) => Self::Bool(l.gte(&r)),
520            (Self::I32(l), Number::Int(Int::I32(r))) => Self::Bool(l.gte(&r)),
521            (Self::I64(l), Number::Int(Int::I64(r))) => Self::Bool(l.gte(&r)),
522
523            (Self::U8(l), Number::UInt(UInt::U8(r))) => Self::Bool(l.gte(&r)),
524            (Self::U16(l), Number::UInt(UInt::U16(r))) => Self::Bool(l.gte(&r)),
525            (Self::U32(l), Number::UInt(UInt::U32(r))) => Self::Bool(l.gte(&r)),
526            (Self::U64(l), Number::UInt(UInt::U64(r))) => Self::Bool(l.gte(&r)),
527
528            (l, r) => match (l.dtype(), r.class()) {
529                (l_dtype, r_dtype) if l_dtype > r_dtype => l.gte_const(r.into_type(l_dtype)),
530                (l_dtype, r_dtype) if l_dtype < r_dtype => l.cast_into(r_dtype).gte_const(r),
531                (l, r) => unreachable!("{} greater than or equal to {}", l, r),
532            },
533        }
534    }
535
536    /// Element-wise check for infinite values.
537    pub fn is_infinite(&self) -> Array {
538        fn is_infinite<T>(this: &ArrayExt<T>) -> Array
539        where
540            T: af::HasAfEnum + Default,
541            ArrayExt<T>: ArrayInstanceUnreal,
542        {
543            this.is_infinite().into()
544        }
545
546        dispatch!(self, is_infinite)
547    }
548
549    /// Element-wise check for non-numeric (NaN) values.
550    pub fn is_nan(&self) -> Array {
551        fn is_nan<T>(this: &ArrayExt<T>) -> Array
552        where
553            T: af::HasAfEnum + Default,
554            ArrayExt<T>: ArrayInstanceUnreal,
555        {
556            this.is_nan().into()
557        }
558
559        dispatch!(self, is_nan)
560    }
561
562    /// Compute the natural log of this `Array`.
563    pub fn ln(&self) -> Array {
564        fn ln<T>(this: &ArrayExt<T>) -> Array
565        where
566            T: af::HasAfEnum + Default,
567            Array: From<ArrayExt<T::UnaryOutType>>,
568        {
569            this.ln().into()
570        }
571
572        dispatch!(self, ln)
573    }
574
575    /// Compute the logarithm of this `Array` with respect to the given `base`.
576    pub fn log(&self, base: &Array) -> Array {
577        use Array::*;
578        match (self, base) {
579            (Bool(l), Bool(r)) => l.log(r).into(),
580            (C32(l), C32(r)) => l.log(r).into(),
581            (C64(l), C64(r)) => l.log(r).into(),
582            (F32(l), F32(r)) => l.log(r).into(),
583            (F64(l), F64(r)) => l.log(r).into(),
584            (I16(l), I16(r)) => l.log(r).into(),
585            (I32(l), I32(r)) => l.log(r).into(),
586            (I64(l), I64(r)) => l.log(r).into(),
587            (U8(l), U8(r)) => l.log(r).into(),
588            (U16(l), U16(r)) => l.log(r).into(),
589            (U32(l), U32(r)) => l.log(r).into(),
590            (U64(l), U64(r)) => l.log(r).into(),
591            (l, r) => match (l.dtype(), r.dtype()) {
592                (l_dtype, r_dtype) if l_dtype > r_dtype => l.log(&r.cast_into(l_dtype)),
593                (l_dtype, r_dtype) if l_dtype < r_dtype => l.cast_into(r_dtype).log(&r),
594                (l, r) => unreachable!("{} log {}", l, r),
595            },
596        }
597    }
598
599    /// Compute the logarithm of this `Array` with respect to the given constant `base`.
600    pub fn log_const(&self, base: Number) -> Array {
601        (&self.ln()) / base.ln()
602    }
603
604    /// Element-wise less-than comparison.
605    pub fn lt(&self, other: &Array) -> Array {
606        use Array::*;
607        match (self, other) {
608            (Bool(l), Bool(r)) => Bool(l.lt(r.deref())),
609            (C32(l), C32(r)) => Bool(l.lt(r.deref())),
610            (C64(l), C64(r)) => Bool(l.lt(r.deref())),
611            (F32(l), F32(r)) => Bool(l.lt(r.deref())),
612            (F64(l), F64(r)) => Bool(l.lt(r.deref())),
613            (I16(l), I16(r)) => Bool(l.lt(r.deref())),
614            (I32(l), I32(r)) => Bool(l.lt(r.deref())),
615            (I64(l), I64(r)) => Bool(l.lt(r.deref())),
616            (U8(l), U8(r)) => Bool(l.lt(r.deref())),
617            (U16(l), U16(r)) => Bool(l.lt(r.deref())),
618            (U32(l), U32(r)) => Bool(l.lt(r.deref())),
619            (U64(l), U64(r)) => Bool(l.lt(r.deref())),
620            (l, r) => match (l.dtype(), r.dtype()) {
621                (l_dtype, r_dtype) if l_dtype > r_dtype => l.lt(&r.cast_into(l_dtype)),
622                (l_dtype, r_dtype) if l_dtype < r_dtype => l.cast_into(r_dtype).lt(r),
623                (l, r) => unreachable!("{} less than {}", l, r),
624            },
625        }
626    }
627
628    /// Element-wise less-than comparison.
629    pub fn lt_const(&self, other: Number) -> Array {
630        use number_general::Complex;
631        match (self, other) {
632            (Self::Bool(l), Number::Bool(r)) => Self::Bool(l.lt(&bool::from(r))),
633
634            (Self::C32(l), Number::Complex(Complex::C32(r))) => Self::Bool(l.lt(&r)),
635            (Self::C64(l), Number::Complex(Complex::C64(r))) => Self::Bool(l.lt(&r)),
636
637            (Self::F32(l), Number::Float(Float::F32(r))) => Self::Bool(l.lt(&r)),
638            (Self::F64(l), Number::Float(Float::F64(r))) => Self::Bool(l.lt(&r)),
639
640            (Self::I16(l), Number::Int(Int::I16(r))) => Self::Bool(l.lt(&r)),
641            (Self::I32(l), Number::Int(Int::I32(r))) => Self::Bool(l.lt(&r)),
642            (Self::I64(l), Number::Int(Int::I64(r))) => Self::Bool(l.lt(&r)),
643
644            (Self::U8(l), Number::UInt(UInt::U8(r))) => Self::Bool(l.lt(&r)),
645            (Self::U16(l), Number::UInt(UInt::U16(r))) => Self::Bool(l.lt(&r)),
646            (Self::U32(l), Number::UInt(UInt::U32(r))) => Self::Bool(l.lt(&r)),
647            (Self::U64(l), Number::UInt(UInt::U64(r))) => Self::Bool(l.lt(&r)),
648
649            (l, r) => match (l.dtype(), r.class()) {
650                (l_dtype, r_dtype) if l_dtype > r_dtype => l.lt_const(r.into_type(l_dtype)),
651                (l_dtype, r_dtype) if l_dtype < r_dtype => l.cast_into(r_dtype).lt_const(r),
652                (l, r) => unreachable!("{} less than {}", l, r),
653            },
654        }
655    }
656
657    /// Element-wise less-or-equal comparison.
658    pub fn lte(&self, other: &Array) -> Array {
659        use Array::*;
660        match (self, other) {
661            (Bool(l), Bool(r)) => Bool(l.lte(r.deref())),
662            (C32(l), C32(r)) => Bool(l.lte(r.deref())),
663            (C64(l), C64(r)) => Bool(l.lte(r.deref())),
664            (F32(l), F32(r)) => Bool(l.lte(r.deref())),
665            (F64(l), F64(r)) => Bool(l.lte(r.deref())),
666            (I16(l), I16(r)) => Bool(l.lte(r.deref())),
667            (I32(l), I32(r)) => Bool(l.lte(r.deref())),
668            (I64(l), I64(r)) => Bool(l.lte(r.deref())),
669            (U8(l), U8(r)) => Bool(l.lte(r.deref())),
670            (U16(l), U16(r)) => Bool(l.lte(r.deref())),
671            (U32(l), U32(r)) => Bool(l.lte(r.deref())),
672            (U64(l), U64(r)) => Bool(l.lte(r.deref())),
673            (l, r) => match (l.dtype(), r.dtype()) {
674                (l_dtype, r_dtype) if l_dtype > r_dtype => l.lte(&r.cast_into(l_dtype)),
675                (l_dtype, r_dtype) if l_dtype < r_dtype => l.cast_into(r_dtype).lte(r),
676                (l, r) => unreachable!("{} less than or equal to {}", l, r),
677            },
678        }
679    }
680
681    /// Element-wise less-than-or-equal comparison.
682    pub fn lte_const(&self, other: Number) -> Array {
683        use number_general::Complex;
684        match (self, other) {
685            (Self::Bool(l), Number::Bool(r)) => Self::Bool(l.lte(&bool::from(r))),
686
687            (Self::C32(l), Number::Complex(Complex::C32(r))) => Self::Bool(l.lte(&r)),
688            (Self::C64(l), Number::Complex(Complex::C64(r))) => Self::Bool(l.lte(&r)),
689
690            (Self::F32(l), Number::Float(Float::F32(r))) => Self::Bool(l.lte(&r)),
691            (Self::F64(l), Number::Float(Float::F64(r))) => Self::Bool(l.lte(&r)),
692
693            (Self::I16(l), Number::Int(Int::I16(r))) => Self::Bool(l.lte(&r)),
694            (Self::I32(l), Number::Int(Int::I32(r))) => Self::Bool(l.lte(&r)),
695            (Self::I64(l), Number::Int(Int::I64(r))) => Self::Bool(l.lte(&r)),
696
697            (Self::U8(l), Number::UInt(UInt::U8(r))) => Self::Bool(l.lte(&r)),
698            (Self::U16(l), Number::UInt(UInt::U16(r))) => Self::Bool(l.lte(&r)),
699            (Self::U32(l), Number::UInt(UInt::U32(r))) => Self::Bool(l.lte(&r)),
700            (Self::U64(l), Number::UInt(UInt::U64(r))) => Self::Bool(l.lte(&r)),
701
702            (l, r) => match (l.dtype(), r.class()) {
703                (l_dtype, r_dtype) if l_dtype > r_dtype => l.lte_const(r.into_type(l_dtype)),
704                (l_dtype, r_dtype) if l_dtype < r_dtype => l.cast_into(r_dtype).lte_const(r),
705                (l, r) => unreachable!("{} less than or equal to {}", l, r),
706            },
707        }
708    }
709
710    /// Element-wise inequality comparison.
711    pub fn ne(&self, other: &Array) -> Array {
712        use Array::*;
713        match (self, other) {
714            (Bool(l), Bool(r)) => Bool(l.ne(r.deref())),
715            (C32(l), C32(r)) => Bool(l.ne(r.deref())),
716            (C64(l), C64(r)) => Bool(l.ne(r.deref())),
717            (F32(l), F32(r)) => Bool(l.ne(r.deref())),
718            (F64(l), F64(r)) => Bool(l.ne(r.deref())),
719            (I16(l), I16(r)) => Bool(l.ne(r.deref())),
720            (I32(l), I32(r)) => Bool(l.ne(r.deref())),
721            (I64(l), I64(r)) => Bool(l.ne(r.deref())),
722            (U8(l), U8(r)) => Bool(l.ne(r.deref())),
723            (U16(l), U16(r)) => Bool(l.ne(r.deref())),
724            (U32(l), U32(r)) => Bool(l.ne(r.deref())),
725            (U64(l), U64(r)) => Bool(l.ne(r.deref())),
726            (l, r) => match (l.dtype(), r.dtype()) {
727                (l_dtype, r_dtype) if l_dtype > r_dtype => l.ne(&r.cast_into(l_dtype)),
728                (l_dtype, r_dtype) if l_dtype < r_dtype => l.cast_into(r_dtype).ne(r),
729                (l, r) => unreachable!("{} not equal to {}", l, r),
730            },
731        }
732    }
733
734    /// Element-wise not-equal comparison.
735    pub fn ne_const(&self, other: Number) -> Array {
736        use number_general::Complex;
737        match (self, other) {
738            (Self::Bool(l), Number::Bool(r)) => Self::Bool(l.ne(&bool::from(r))),
739
740            (Self::C32(l), Number::Complex(Complex::C32(r))) => Self::Bool(l.ne(&r)),
741            (Self::C64(l), Number::Complex(Complex::C64(r))) => Self::Bool(l.ne(&r)),
742
743            (Self::F32(l), Number::Float(Float::F32(r))) => Self::Bool(l.ne(&r)),
744            (Self::F64(l), Number::Float(Float::F64(r))) => Self::Bool(l.ne(&r)),
745
746            (Self::I16(l), Number::Int(Int::I16(r))) => Self::Bool(l.ne(&r)),
747            (Self::I32(l), Number::Int(Int::I32(r))) => Self::Bool(l.ne(&r)),
748            (Self::I64(l), Number::Int(Int::I64(r))) => Self::Bool(l.ne(&r)),
749
750            (Self::U8(l), Number::UInt(UInt::U8(r))) => Self::Bool(l.ne(&r)),
751            (Self::U16(l), Number::UInt(UInt::U16(r))) => Self::Bool(l.ne(&r)),
752            (Self::U32(l), Number::UInt(UInt::U32(r))) => Self::Bool(l.ne(&r)),
753            (Self::U64(l), Number::UInt(UInt::U64(r))) => Self::Bool(l.ne(&r)),
754
755            (l, r) => match (l.dtype(), r.class()) {
756                (l_dtype, r_dtype) if l_dtype > r_dtype => l.ne_const(r.into_type(l_dtype)),
757                (l_dtype, r_dtype) if l_dtype < r_dtype => l.cast_into(r_dtype).ne_const(r),
758                (l, r) => unreachable!("{} not equal to {}", l, r),
759            },
760        }
761    }
762
763    /// Element-wise logical not.
764    pub fn not(&self) -> Array {
765        let this: ArrayExt<bool> = self.type_cast();
766        Array::Bool(this.not())
767    }
768
769    /// Element-wise logical or.
770    pub fn or(&self, other: &Array) -> Array {
771        let this: ArrayExt<bool> = self.type_cast();
772        let that: ArrayExt<bool> = other.type_cast();
773        Array::Bool(this.or(&that))
774    }
775
776    /// Element-wise logical or, relative to a constant `other`.
777    pub fn or_const(&self, other: Number) -> Array {
778        let this: ArrayExt<bool> = self.type_cast();
779        let that: ArrayExt<bool> = ArrayExt::from(&[other.cast_into()][..]);
780        Array::Bool(this.or(&that))
781    }
782
783    /// Find the maximum value of each `stride` of this `Array`
784    pub fn reduce_max(&self, stride: u64) -> Result<Array> {
785        if self.len() as u64 % stride != 0 {
786            return Err(error(format!(
787                "cannot reduce an Array of length {} with stride {}",
788                self.len(),
789                stride
790            )));
791        }
792
793        fn reduce_block_dispatch<T: af::HasAfEnum>(block: &ArrayExt<T>, stride: u64) -> Array
794        where
795            Array: From<ArrayExt<T::InType>>,
796        {
797            reduce_block(block, stride, &mut |block| af::max(&block, 0).into()).into()
798        }
799
800        Ok(reduce!(self, reduce_block_dispatch, stride))
801    }
802
803    /// Find the maximum value of each `stride` of this `Array`
804    pub fn reduce_min(&self, stride: u64) -> Result<Array> {
805        if self.len() as u64 % stride != 0 {
806            return Err(error(format!(
807                "cannot reduce an Array of length {} with stride {}",
808                self.len(),
809                stride
810            )));
811        }
812
813        fn reduce_block_dispatch<T: af::HasAfEnum>(block: &ArrayExt<T>, stride: u64) -> Array
814        where
815            Array: From<ArrayExt<T::InType>>,
816        {
817            reduce_block(block, stride, &mut |block| af::min(&block, 0).into()).into()
818        }
819
820        Ok(reduce!(self, reduce_block_dispatch, stride))
821    }
822
823    /// Compute the product of each `stride` of this `Array`
824    pub fn reduce_product(&self, stride: u64) -> Result<Array> {
825        if self.len() as u64 % stride != 0 {
826            return Err(error(format!(
827                "cannot reduce an Array of length {} with stride {}",
828                self.len(),
829                stride
830            )));
831        }
832
833        fn reduce_block_dispatch<T: af::HasAfEnum>(block: &ArrayExt<T>, stride: u64) -> Array
834        where
835            Array: From<ArrayExt<T::ProductOutType>>,
836        {
837            reduce_block(block, stride, &mut |block| af::product(&block, 0).into()).into()
838        }
839
840        Ok(reduce!(self, reduce_block_dispatch, stride))
841    }
842
843    /// Compute the sum of each `stride` of a single block
844    pub fn reduce_sum(&self, stride: u64) -> Result<Array> {
845        if self.len() as u64 % stride != 0 {
846            return Err(error(format!(
847                "cannot reduce an Array of length {} with stride {}",
848                self.len(),
849                stride
850            )));
851        }
852
853        fn reduce_block_dispatch<T: af::HasAfEnum>(block: &ArrayExt<T>, stride: u64) -> Array
854        where
855            Array: From<ArrayExt<T::AggregateOutType>>,
856        {
857            reduce_block(block, stride, &mut |block| af::sum(&block, 0).into()).into()
858        }
859
860        Ok(reduce!(self, reduce_block_dispatch, stride))
861    }
862
863    /// Find the maximum element in this `Array`.
864    pub fn max(&self) -> Number {
865        fn max<T>(this: &ArrayExt<T>) -> Number
866        where
867            T: af::HasAfEnum + Default,
868            T::AggregateOutType: number_general::DType,
869            T::ProductOutType: number_general::DType,
870            ArrayExt<T>: ArrayInstanceMinMax<T>,
871            Number: From<T>,
872        {
873            this.max().into()
874        }
875
876        dispatch!(self, max)
877    }
878
879    /// Find the maximum element in this `Array`.
880    pub fn min(&self) -> Number {
881        fn min<T>(this: &ArrayExt<T>) -> Number
882        where
883            T: af::HasAfEnum + Default,
884            T::AggregateOutType: number_general::DType,
885            T::ProductOutType: number_general::DType,
886            ArrayExt<T>: ArrayInstanceMinMax<T>,
887            Number: From<T>,
888        {
889            this.min().into()
890        }
891
892        dispatch!(self, min)
893    }
894
895    /// Calculate the cumulative product of this `Array`.
896    pub fn product(&self) -> Number {
897        fn product<T>(this: &ArrayExt<T>) -> Number
898        where
899            T: af::HasAfEnum + Default,
900            T::AggregateOutType: number_general::DType,
901            T::ProductOutType: number_general::DType,
902            ArrayExt<T>: ArrayInstanceProduct<T>,
903            Number: From<T::ProductOutType>,
904        {
905            this.product().into()
906        }
907
908        dispatch!(self, product)
909    }
910
911    /// Calculate the cumulative sum of this `Array`.
912    pub fn sum(&self) -> Number {
913        fn sum<T>(this: &ArrayExt<T>) -> Number
914        where
915            T: af::HasAfEnum + Default,
916            T::AggregateOutType: number_general::DType,
917            T::ProductOutType: number_general::DType,
918            ArrayExt<T>: ArrayInstanceSum<T>,
919            Number: From<T::AggregateOutType>,
920        {
921            this.sum().into()
922        }
923
924        dispatch!(self, sum)
925    }
926
927    /// The number of elements in this `Array`.
928    pub fn len(&self) -> usize {
929        dispatch!(self, ArrayExt::len)
930    }
931
932    /// Get the value at the specified index.
933    pub fn get_value(&self, index: usize) -> Number {
934        debug_assert!(index < self.len());
935
936        use number_general::Complex;
937        use Array::*;
938        match self {
939            Bool(b) => b.get_value(index).into(),
940            C32(c) => Complex::from(c.get_value(index)).into(),
941            C64(c) => Complex::from(c.get_value(index)).into(),
942            F32(f) => Float::from(f.get_value(index)).into(),
943            F64(f) => Float::from(f.get_value(index)).into(),
944            I16(i) => Int::from(i.get_value(index)).into(),
945            I32(i) => Int::from(i.get_value(index)).into(),
946            I64(i) => Int::from(i.get_value(index)).into(),
947            U8(u) => UInt::from(u.get_value(index)).into(),
948            U16(u) => UInt::from(u.get_value(index)).into(),
949            U32(u) => UInt::from(u.get_value(index)).into(),
950            U64(u) => UInt::from(u.get_value(index)).into(),
951        }
952    }
953
954    /// Get the values at the specified coordinates.
955    pub fn get(&self, index: &ArrayExt<u64>) -> Self {
956        let mut indexer = af::Indexer::default();
957        indexer.set_index(index.deref(), 0, None);
958        self.get_at(indexer)
959    }
960
961    fn get_at(&self, index: af::Indexer) -> Self {
962        use Array::*;
963        match self {
964            Bool(b) => Bool(b.get(index)),
965            C32(c) => C32(c.get(index)),
966            C64(c) => C64(c.get(index)),
967            F32(f) => F32(f.get(index)),
968            F64(f) => F64(f.get(index)),
969            I16(i) => I16(i.get(index)),
970            I32(i) => I32(i.get(index)),
971            I64(i) => I64(i.get(index)),
972            U8(i) => U8(i.get(index)),
973            U16(i) => U16(i.get(index)),
974            U32(i) => U32(i.get(index)),
975            U64(i) => U64(i.get(index)),
976        }
977    }
978
979    /// Return this `Array` raised to the power of `other`.
980    pub fn pow(&self, other: &Self) -> Self {
981        // af::pow only works with floating point numbers!
982        use Array::*;
983        match (self, other) {
984            (C32(l), C32(r)) => C32(l.pow(r.deref())),
985            (C64(l), C64(r)) => C64(l.pow(r.deref())),
986            (F32(l), F32(r)) => F32(l.pow(r.deref())),
987            (F64(l), F64(r)) => F64(l.pow(r.deref())),
988            (l, r) => match (l.dtype(), r.dtype()) {
989                (l_dtype, r_dtype) if l_dtype > r_dtype => l.pow(&r.cast_into(l_dtype)),
990                (l_dtype, r_dtype) if l_dtype < r_dtype => l.cast_into(r_dtype).pow(r),
991                _ => Self::F64(l.type_cast()).pow(r),
992            },
993        }
994    }
995
996    /// Return this `Array` raised to the power of `other`.
997    pub fn pow_const(&self, other: Number) -> Self {
998        // af::pow only works with floating point numbers!
999        use number_general::Complex;
1000        match (self, other) {
1001            (Self::C32(l), Number::Complex(Complex::C32(r))) => Self::C32(l.pow(&r)),
1002            (Self::C64(l), Number::Complex(Complex::C64(r))) => Self::C64(l.pow(&r)),
1003            (Self::F32(l), Number::Float(Float::F32(r))) => Self::F32(l.pow(&r)),
1004            (Self::F64(l), Number::Float(Float::F64(r))) => Self::F64(l.pow(&r)),
1005            (l, r) => match (l.dtype(), r.class()) {
1006                (l_dtype, r_dtype) if l_dtype > r_dtype => l.pow_const(r.into_type(l_dtype)),
1007                (l_dtype, r_dtype) if l_dtype < r_dtype => l.cast_into(r_dtype).pow_const(r),
1008                _ => Self::F64(l.type_cast()).pow_const(r),
1009            },
1010        }
1011    }
1012
1013    /// Round this `Array` to the nearest integer, element-wise
1014    pub fn round(&self) -> Self {
1015        fn round<T: af::HasAfEnum>(x: &ArrayExt<T>) -> Array
1016        where
1017            Array: From<ArrayExt<<ArrayExt<T> as ArrayInstanceRound>::Round>>,
1018        {
1019            x.round().into()
1020        }
1021
1022        dispatch!(self, round)
1023    }
1024
1025    /// Set the values at the specified coordinates to the corresponding values in `other`.
1026    pub fn set(&mut self, index: &ArrayExt<u64>, other: &Array) -> Result<()> {
1027        let mut indexer = af::Indexer::default();
1028        indexer.set_index(index.deref(), 0, None);
1029        self.set_at(indexer, other)
1030    }
1031
1032    /// Set the value at the specified coordinate to `value`.
1033    pub fn set_value(&mut self, offset: usize, value: Number) -> Result<()> {
1034        use Array::*;
1035        match self {
1036            Bool(b) => {
1037                let value: Boolean = value.cast_into();
1038                b.set_at(offset, value.cast_into());
1039            }
1040            C32(c) => {
1041                let value: Complex<f32> = value.cast_into();
1042                c.set_at(offset, value.cast_into())
1043            }
1044            C64(c) => {
1045                let value: Complex<f64> = value.cast_into();
1046                c.set_at(offset, value.cast_into())
1047            }
1048            F32(f) => {
1049                let value: Float = value.cast_into();
1050                f.set_at(offset, value.cast_into())
1051            }
1052            F64(f) => {
1053                let value: Float = value.cast_into();
1054                f.set_at(offset, value.cast_into())
1055            }
1056            I16(i) => {
1057                let value: Int = value.cast_into();
1058                i.set_at(offset, value.cast_into())
1059            }
1060            I32(i) => {
1061                let value: Int = value.cast_into();
1062                i.set_at(offset, value.cast_into())
1063            }
1064            I64(i) => {
1065                let value: Int = value.cast_into();
1066                i.set_at(offset, value.cast_into())
1067            }
1068            U8(u) => {
1069                let value: UInt = value.cast_into();
1070                u.set_at(offset, value.cast_into())
1071            }
1072            U16(u) => {
1073                let value: UInt = value.cast_into();
1074                u.set_at(offset, value.cast_into())
1075            }
1076            U32(u) => {
1077                let value: UInt = value.cast_into();
1078                u.set_at(offset, value.cast_into())
1079            }
1080            U64(u) => {
1081                let value: UInt = value.cast_into();
1082                u.set_at(offset, value.cast_into())
1083            }
1084        }
1085
1086        Ok(())
1087    }
1088
1089    fn set_at(&mut self, index: af::Indexer, value: &Array) -> Result<()> {
1090        use Array::*;
1091        match self {
1092            Bool(l) => l.set(&index, &value.type_cast()),
1093            C32(l) => l.set(&index, &value.type_cast()),
1094            C64(l) => l.set(&index, &value.type_cast()),
1095            F32(l) => l.set(&index, &value.type_cast()),
1096            F64(l) => l.set(&index, &value.type_cast()),
1097            I16(l) => l.set(&index, &value.type_cast()),
1098            I32(l) => l.set(&index, &value.type_cast()),
1099            I64(l) => l.set(&index, &value.type_cast()),
1100            U8(l) => l.set(&index, &value.type_cast()),
1101            U16(l) => l.set(&index, &value.type_cast()),
1102            U32(l) => l.set(&index, &value.type_cast()),
1103            U64(l) => l.set(&index, &value.type_cast()),
1104        }
1105
1106        Ok(())
1107    }
1108
1109    /// Return a slice of this `Array`.
1110    pub fn slice(&self, start: usize, end: usize) -> Result<Self> {
1111        if start > self.len() {
1112            return Err(error(format!(
1113                "invalid start index for array slice: {}",
1114                start
1115            )));
1116        }
1117
1118        if end > self.len() {
1119            return Err(error(format!(
1120                "invalid start index for array slice: {}",
1121                end
1122            )));
1123        }
1124
1125        use Array::*;
1126        let slice = match self {
1127            Bool(b) => b.slice(start, end).into(),
1128            C32(c) => c.slice(start, end).into(),
1129            C64(c) => c.slice(start, end).into(),
1130            F32(f) => f.slice(start, end).into(),
1131            F64(f) => f.slice(start, end).into(),
1132            I16(i) => i.slice(start, end).into(),
1133            I32(i) => i.slice(start, end).into(),
1134            I64(i) => i.slice(start, end).into(),
1135            U8(u) => u.slice(start, end).into(),
1136            U16(u) => u.slice(start, end).into(),
1137            U32(u) => u.slice(start, end).into(),
1138            U64(u) => u.slice(start, end).into(),
1139        };
1140
1141        Ok(slice)
1142    }
1143
1144    /// Compute the indices needed to sort this `Array`.
1145    pub fn argsort(&self, ascending: bool) -> Result<(Self, ArrayExt<u64>)> {
1146        macro_rules! argsort {
1147            ($arr:expr) => {{
1148                let (sorted, indices) = $arr.sort_index(ascending);
1149                (sorted.into(), indices.type_cast())
1150            }};
1151        }
1152
1153        use Array::*;
1154        let (sorted, indices) = match self {
1155            Bool(b) => argsort!(b),
1156            F32(f) => argsort!(f),
1157            F64(f) => argsort!(f),
1158            I16(i) => argsort!(i),
1159            I32(i) => argsort!(i),
1160            I64(i) => argsort!(i),
1161            U8(u) => argsort!(u),
1162            U16(u) => argsort!(u),
1163            U32(u) => argsort!(u),
1164            U64(u) => argsort!(u),
1165            other => {
1166                return Err(error(format!(
1167                    "{} does not support ordering",
1168                    other.dtype()
1169                )))
1170            }
1171        };
1172
1173        Ok((sorted, indices))
1174    }
1175
1176    /// Sort this `Array` in-place.
1177    pub fn sort(&mut self, ascending: bool) -> Result<()> {
1178        use Array::*;
1179        match self {
1180            Bool(b) => b.sort(ascending),
1181            F32(f) => f.sort(ascending),
1182            F64(f) => f.sort(ascending),
1183            I16(i) => i.sort(ascending),
1184            I32(i) => i.sort(ascending),
1185            I64(i) => i.sort(ascending),
1186            U8(u) => u.sort(ascending),
1187            U16(u) => u.sort(ascending),
1188            U32(u) => u.sort(ascending),
1189            U64(u) => u.sort(ascending),
1190            other => {
1191                return Err(error(format!(
1192                    "{} does not support ordering",
1193                    other.dtype()
1194                )))
1195            }
1196        }
1197
1198        Ok(())
1199    }
1200
1201    /// Split this `Array` into two new instances at the given pivot.
1202    pub fn split(&self, at: usize) -> Result<(Array, Array)> {
1203        if at > self.len() {
1204            return Err(error(format!(
1205                "Invalid pivot for Array of length {}",
1206                self.len()
1207            )));
1208        }
1209
1210        use Array::*;
1211        match self {
1212            Bool(u) => {
1213                let (l, r) = u.split(at);
1214                Ok((Bool(l), Bool(r)))
1215            }
1216            C32(u) => {
1217                let (l, r) = u.split(at);
1218                Ok((C32(l), C32(r)))
1219            }
1220            C64(u) => {
1221                let (l, r) = u.split(at);
1222                Ok((C64(l), C64(r)))
1223            }
1224            F32(u) => {
1225                let (l, r) = u.split(at);
1226                Ok((F32(l), F32(r)))
1227            }
1228            F64(u) => {
1229                let (l, r) = u.split(at);
1230                Ok((F64(l), F64(r)))
1231            }
1232            I16(u) => {
1233                let (l, r) = u.split(at);
1234                Ok((I16(l), I16(r)))
1235            }
1236            I32(u) => {
1237                let (l, r) = u.split(at);
1238                Ok((I32(l), I32(r)))
1239            }
1240            I64(u) => {
1241                let (l, r) = u.split(at);
1242                Ok((I64(l), I64(r)))
1243            }
1244            U8(u) => {
1245                let (l, r) = u.split(at);
1246                Ok((U8(l), U8(r)))
1247            }
1248            U16(u) => {
1249                let (l, r) = u.split(at);
1250                Ok((U16(l), U16(r)))
1251            }
1252            U32(u) => {
1253                let (l, r) = u.split(at);
1254                Ok((U32(l), U32(r)))
1255            }
1256            U64(u) => {
1257                let (l, r) = u.split(at);
1258                Ok((U64(l), U64(r)))
1259            }
1260        }
1261    }
1262
1263    // TODO: how to include documentation in macro invocations?
1264
1265    trig! {sin}
1266    trig! {asin}
1267    trig! {sinh}
1268    trig! {asinh}
1269    trig! {cos}
1270    trig! {acos}
1271    trig! {cosh}
1272    trig! {acosh}
1273    trig! {tan}
1274    trig! {atan}
1275    trig! {tanh}
1276    trig! {atanh}
1277
1278    /// Element-wise logical xor.
1279    pub fn xor(&self, other: &Array) -> Array {
1280        let this: ArrayExt<bool> = self.type_cast();
1281        let that: ArrayExt<bool> = other.type_cast();
1282        Array::Bool(this.xor(&that))
1283    }
1284
1285    /// Element-wise logical xor, relative to a constant `other`.
1286    pub fn xor_const(&self, other: Number) -> Array {
1287        let this: ArrayExt<bool> = self.type_cast();
1288        let that: ArrayExt<bool> = ArrayExt::from(&[other.cast_into()][..]);
1289        Array::Bool(this.xor(&that))
1290    }
1291}
1292
1293impl PartialEq for Array {
1294    fn eq(&self, other: &Array) -> bool {
1295        if self.len() != other.len() {
1296            return false;
1297        } else {
1298            Array::eq(self, other).all()
1299        }
1300    }
1301}
1302
1303impl Add for &Array {
1304    type Output = Array;
1305
1306    fn add(self, other: &Array) -> Self::Output {
1307        use Array::*;
1308        match (self, other) {
1309            (Bool(l), Bool(r)) => Bool(l + r),
1310            (C32(l), C32(r)) => C32(l + r),
1311            (C64(l), C64(r)) => C64(l + r),
1312            (F32(l), F32(r)) => F32(l + r),
1313            (F64(l), F64(r)) => F64(l + r),
1314            (I16(l), I16(r)) => I16(l + r),
1315            (I32(l), I32(r)) => I32(l + r),
1316            (I64(l), I64(r)) => I64(l + r),
1317            (U8(l), U8(r)) => U8(l + r),
1318            (U16(l), U16(r)) => U16(l + r),
1319            (U32(l), U32(r)) => U32(l + r),
1320            (U64(l), U64(r)) => U64(l + r),
1321            (l, r) => match (l.dtype(), r.dtype()) {
1322                (l_dtype, r_dtype) if l_dtype > r_dtype => l + &r.cast_into(l_dtype),
1323                (l_dtype, r_dtype) if l_dtype < r_dtype => &l.cast_into(r_dtype) + r,
1324                (l, r) => unreachable!("add {}, {}", l, r),
1325            },
1326        }
1327    }
1328}
1329
1330impl Add<Number> for &Array {
1331    type Output = Array;
1332
1333    fn add(self, rhs: Number) -> Self::Output {
1334        use number_general::Complex;
1335        match (self, rhs) {
1336            (Array::Bool(l), Number::Bool(r)) => Array::Bool((l.deref() + bool::from(r)).into()),
1337
1338            (Array::F32(l), Number::Float(Float::F32(r))) => Array::F32((l.deref() + r).into()),
1339            (Array::F64(l), Number::Float(Float::F32(r))) => Array::F64((l.deref() + r).into()),
1340            (Array::F64(l), Number::Float(Float::F64(r))) => Array::F64((l.deref() + r).into()),
1341
1342            (Array::C32(l), Number::Complex(Complex::C32(r))) => Array::C32((l.deref() + r).into()),
1343            (Array::C64(l), Number::Complex(Complex::C64(r))) => Array::C64((l.deref() + r).into()),
1344
1345            (Array::I16(l), Number::Int(Int::I16(r))) => Array::I16((l.deref() + r).into()),
1346            (Array::I32(l), Number::Int(Int::I32(r))) => Array::I32((l.deref() + r).into()),
1347            (Array::I64(l), Number::Int(Int::I64(r))) => Array::I64((l.deref() + r).into()),
1348
1349            (Array::U8(l), Number::UInt(UInt::U8(r))) => Array::U8((l.deref() + r).into()),
1350            (Array::U16(l), Number::UInt(UInt::U16(r))) => Array::U16((l.deref() + r).into()),
1351            (Array::U32(l), Number::UInt(UInt::U32(r))) => Array::U32((l.deref() + r).into()),
1352            (Array::U64(l), Number::UInt(UInt::U64(r))) => Array::U64((l.deref() + r).into()),
1353
1354            (l, r) => match (l.dtype(), r.class()) {
1355                (l_dtype, r_dtype) if l_dtype > r_dtype => l + r.into_type(l_dtype),
1356                (l_dtype, r_dtype) if l_dtype < r_dtype => &l.cast_into(r_dtype) + r,
1357                (l, r) => unreachable!("add {}, {}", l, r),
1358            },
1359        }
1360    }
1361}
1362
1363impl AddAssign<&Array> for Array {
1364    fn add_assign(&mut self, other: &Array) {
1365        let sum = &*self + other;
1366        *self = sum;
1367    }
1368}
1369
1370impl AddAssign<Number> for Array {
1371    fn add_assign(&mut self, rhs: Number) {
1372        *self = &*self + rhs;
1373    }
1374}
1375
1376impl Sub for &Array {
1377    type Output = Array;
1378
1379    fn sub(self, other: &Array) -> Self::Output {
1380        use Array::*;
1381        match (self, other) {
1382            (Bool(l), Bool(r)) => Bool(l - r),
1383            (C32(l), C32(r)) => C32(l - r),
1384            (C64(l), C64(r)) => C64(l - r),
1385            (F32(l), F32(r)) => F32(l - r),
1386            (F64(l), F64(r)) => F64(l - r),
1387            (I16(l), I16(r)) => I16(l - r),
1388            (I32(l), I32(r)) => I32(l - r),
1389            (I64(l), I64(r)) => I64(l - r),
1390            (U8(l), U8(r)) => U8(l - r),
1391            (U16(l), U16(r)) => U16(l - r),
1392            (U32(l), U32(r)) => U32(l - r),
1393            (U64(l), U64(r)) => U64(l - r),
1394            (l, r) => match (l.dtype(), r.dtype()) {
1395                (l_dtype, r_dtype) if l_dtype > r_dtype => l - &r.cast_into(l_dtype),
1396                (l_dtype, r_dtype) if l_dtype < r_dtype => &l.cast_into(r_dtype) - r,
1397                (l, r) => unreachable!("subtract {}, {}", l, r),
1398            },
1399        }
1400    }
1401}
1402
1403impl Sub<Number> for &Array {
1404    type Output = Array;
1405
1406    fn sub(self, rhs: Number) -> Self::Output {
1407        use number_general::Complex;
1408        match (self, rhs) {
1409            (Array::Bool(l), Number::Bool(r)) => Array::Bool((l.deref() - bool::from(r)).into()),
1410
1411            (Array::F32(l), Number::Float(Float::F32(r))) => Array::F32((l.deref() - r).into()),
1412            (Array::F64(l), Number::Float(Float::F64(r))) => Array::F64((l.deref() - r).into()),
1413
1414            (Array::C32(l), Number::Complex(Complex::C32(r))) => Array::C32((l.deref() - r).into()),
1415            (Array::C64(l), Number::Complex(Complex::C64(r))) => Array::C64((l.deref() - r).into()),
1416
1417            (Array::I16(l), Number::Int(Int::I16(r))) => Array::I16((l.deref() - r).into()),
1418            (Array::I32(l), Number::Int(Int::I32(r))) => Array::I32((l.deref() - r).into()),
1419            (Array::I64(l), Number::Int(Int::I64(r))) => Array::I64((l.deref() - r).into()),
1420
1421            (Array::U8(l), Number::UInt(UInt::U8(r))) => Array::U8((l.deref() - r).into()),
1422            (Array::U16(l), Number::UInt(UInt::U16(r))) => Array::U16((l.deref() - r).into()),
1423            (Array::U32(l), Number::UInt(UInt::U32(r))) => Array::U32((l.deref() - r).into()),
1424            (Array::U64(l), Number::UInt(UInt::U64(r))) => Array::U64((l.deref() - r).into()),
1425
1426            (l, r) => match (l.dtype(), r.class()) {
1427                (l_dtype, r_dtype) if l_dtype > r_dtype => l - r.into_type(l_dtype),
1428                (l_dtype, r_dtype) if l_dtype < r_dtype => &l.cast_into(r_dtype) - r,
1429                (l, r) => unreachable!("subtract {}, {}", l, r),
1430            },
1431        }
1432    }
1433}
1434
1435impl SubAssign<&Array> for Array {
1436    fn sub_assign(&mut self, other: &Array) {
1437        let diff = &*self - other;
1438        *self = diff;
1439    }
1440}
1441
1442impl SubAssign<Number> for Array {
1443    fn sub_assign(&mut self, rhs: Number) {
1444        *self = &*self - rhs;
1445    }
1446}
1447
1448impl Mul for &Array {
1449    type Output = Array;
1450
1451    fn mul(self, other: &Array) -> Self::Output {
1452        use Array::*;
1453        match (self, other) {
1454            (Bool(l), Bool(r)) => Bool(l * r),
1455            (C32(l), C32(r)) => C32(l * r),
1456            (C64(l), C64(r)) => C64(l * r),
1457            (F32(l), F32(r)) => F32(l * r),
1458            (F64(l), F64(r)) => F64(l * r),
1459            (I16(l), I16(r)) => I16(l * r),
1460            (I32(l), I32(r)) => I32(l * r),
1461            (I64(l), I64(r)) => I64(l * r),
1462            (U8(l), U8(r)) => U8(l * r),
1463            (U16(l), U16(r)) => U16(l * r),
1464            (U32(l), U32(r)) => U32(l * r),
1465            (U64(l), U64(r)) => U64(l * r),
1466            (l, r) => match (l.dtype(), r.dtype()) {
1467                (l_dtype, r_dtype) if l_dtype > r_dtype => l * &r.cast_into(l_dtype),
1468                (l_dtype, r_dtype) if l_dtype < r_dtype => &l.cast_into(r_dtype) * r,
1469                (l, r) => unreachable!("multiply {}, {}", l, r),
1470            },
1471        }
1472    }
1473}
1474
1475impl Mul<Number> for &Array {
1476    type Output = Array;
1477
1478    fn mul(self, rhs: Number) -> Self::Output {
1479        use number_general::Complex;
1480        match (self, rhs) {
1481            (Array::Bool(l), Number::Bool(r)) => Array::Bool((l.deref() * bool::from(r)).into()),
1482
1483            (Array::F32(l), Number::Float(Float::F32(r))) => Array::F32((l.deref() * r).into()),
1484            (Array::F64(l), Number::Float(Float::F64(r))) => Array::F64((l.deref() * r).into()),
1485
1486            (Array::C32(l), Number::Complex(Complex::C32(r))) => Array::C32((l.deref() * r).into()),
1487            (Array::C64(l), Number::Complex(Complex::C64(r))) => Array::C64((l.deref() * r).into()),
1488
1489            (Array::I16(l), Number::Int(Int::I16(r))) => Array::I16((l.deref() * r).into()),
1490            (Array::I32(l), Number::Int(Int::I32(r))) => Array::I32((l.deref() * r).into()),
1491            (Array::I64(l), Number::Int(Int::I64(r))) => Array::I64((l.deref() * r).into()),
1492
1493            (Array::U8(l), Number::UInt(UInt::U8(r))) => Array::U8((l.deref() * r).into()),
1494            (Array::U16(l), Number::UInt(UInt::U16(r))) => Array::U16((l.deref() * r).into()),
1495            (Array::U32(l), Number::UInt(UInt::U32(r))) => Array::U32((l.deref() * r).into()),
1496            (Array::U64(l), Number::UInt(UInt::U64(r))) => Array::U64((l.deref() * r).into()),
1497
1498            (l, r) => match (l.dtype(), r.class()) {
1499                (l_dtype, r_dtype) if l_dtype > r_dtype => l * r.into_type(l_dtype),
1500                (l_dtype, r_dtype) if l_dtype < r_dtype => &l.cast_into(r_dtype) * r,
1501                (l, r) => unreachable!("subtract {}, {}", l, r),
1502            },
1503        }
1504    }
1505}
1506
1507impl MulAssign<&Array> for Array {
1508    fn mul_assign(&mut self, other: &Array) {
1509        let product = &*self * other;
1510        *self = product;
1511    }
1512}
1513
1514impl MulAssign<Number> for Array {
1515    fn mul_assign(&mut self, rhs: Number) {
1516        *self = &*self * rhs;
1517    }
1518}
1519
1520impl Div for &Array {
1521    type Output = Array;
1522
1523    fn div(self, other: &Array) -> Self::Output {
1524        use Array::*;
1525        match (self, other) {
1526            (Bool(l), Bool(r)) => Bool(l / r),
1527            (C32(l), C32(r)) => C32(l / r),
1528            (C64(l), C64(r)) => C64(l / r),
1529            (F32(l), F32(r)) => F32(l / r),
1530            (F64(l), F64(r)) => F64(l / r),
1531            (I16(l), I16(r)) => I16(l / r),
1532            (I32(l), I32(r)) => I32(l / r),
1533            (I64(l), I64(r)) => I64(l / r),
1534            (U8(l), U8(r)) => U8(l / r),
1535            (U16(l), U16(r)) => U16(l / r),
1536            (U32(l), U32(r)) => U32(l / r),
1537            (U64(l), U64(r)) => U64(l / r),
1538            (l, r) => match (l.dtype(), r.dtype()) {
1539                (l_dtype, r_dtype) if l_dtype > r_dtype => l / &r.cast_into(l_dtype),
1540                (l_dtype, r_dtype) if l_dtype < r_dtype => &l.cast_into(r_dtype) / r,
1541                (l, r) => unreachable!("divide {}, {}", l, r),
1542            },
1543        }
1544    }
1545}
1546
1547impl Div<Number> for &Array {
1548    type Output = Array;
1549
1550    fn div(self, rhs: Number) -> Self::Output {
1551        use number_general::Complex;
1552        match (self, rhs) {
1553            (Array::Bool(l), Number::Bool(r)) => Array::Bool((l.deref() / bool::from(r)).into()),
1554
1555            (Array::F32(l), Number::Float(Float::F32(r))) => Array::F32((l.deref() / r).into()),
1556            (Array::F64(l), Number::Float(Float::F64(r))) => Array::F64((l.deref() / r).into()),
1557
1558            (Array::C32(l), Number::Complex(Complex::C32(r))) => Array::C32((l.deref() / r).into()),
1559            (Array::C64(l), Number::Complex(Complex::C64(r))) => Array::C64((l.deref() / r).into()),
1560
1561            (Array::I16(l), Number::Int(Int::I16(r))) => Array::I16((l.deref() / r).into()),
1562            (Array::I32(l), Number::Int(Int::I32(r))) => Array::I32((l.deref() / r).into()),
1563            (Array::I64(l), Number::Int(Int::I64(r))) => Array::I64((l.deref() / r).into()),
1564
1565            (Array::U8(l), Number::UInt(UInt::U8(r))) => Array::U8((l.deref() / r).into()),
1566            (Array::U16(l), Number::UInt(UInt::U16(r))) => Array::U16((l.deref() / r).into()),
1567            (Array::U32(l), Number::UInt(UInt::U32(r))) => Array::U32((l.deref() / r).into()),
1568            (Array::U64(l), Number::UInt(UInt::U64(r))) => Array::U64((l.deref() / r).into()),
1569
1570            (l, r) => match (l.dtype(), r.class()) {
1571                (l_dtype, r_dtype) if l_dtype > r_dtype => l / r.into_type(l_dtype),
1572                (l_dtype, r_dtype) if l_dtype < r_dtype => &l.cast_into(r_dtype) / r,
1573                (l, r) => unreachable!("subtract {}, {}", l, r),
1574            },
1575        }
1576    }
1577}
1578
1579impl DivAssign<&Array> for Array {
1580    fn div_assign(&mut self, other: &Array) {
1581        let div = &*self / other;
1582        *self = div;
1583    }
1584}
1585
1586impl DivAssign<Number> for Array {
1587    fn div_assign(&mut self, rhs: Number) {
1588        *self = &*self / rhs;
1589    }
1590}
1591
1592impl<T: af::HasAfEnum> CastFrom<Array> for ArrayExt<T> {
1593    fn cast_from(array: Array) -> ArrayExt<T> {
1594        use Array::*;
1595        match array {
1596            Bool(b) => b.type_cast(),
1597            C32(c) => c.type_cast(),
1598            C64(c) => c.type_cast(),
1599            F32(f) => f.type_cast(),
1600            F64(f) => f.type_cast(),
1601            I16(i) => i.type_cast(),
1602            I32(i) => i.type_cast(),
1603            I64(i) => i.type_cast(),
1604            U8(u) => u.type_cast(),
1605            U16(u) => u.type_cast(),
1606            U32(u) => u.type_cast(),
1607            U64(u) => u.type_cast(),
1608        }
1609    }
1610}
1611
1612as_type!(Array, Bool, ArrayExt<bool>);
1613as_type!(Array, C32, ArrayExt<Complex<f32>>);
1614as_type!(Array, C64, ArrayExt<Complex<f64>>);
1615as_type!(Array, F32, ArrayExt<f32>);
1616as_type!(Array, F64, ArrayExt<f64>);
1617as_type!(Array, I16, ArrayExt<i16>);
1618as_type!(Array, I32, ArrayExt<i32>);
1619as_type!(Array, I64, ArrayExt<i64>);
1620as_type!(Array, U8, ArrayExt<u8>);
1621as_type!(Array, U16, ArrayExt<u16>);
1622as_type!(Array, U32, ArrayExt<u32>);
1623as_type!(Array, U64, ArrayExt<u64>);
1624
1625impl<T: af::HasAfEnum> From<Vec<T>> for Array
1626where
1627    Array: From<ArrayExt<T>>,
1628{
1629    fn from(values: Vec<T>) -> Self {
1630        ArrayExt::from(values.as_slice()).into()
1631    }
1632}
1633
1634impl<T: af::HasAfEnum> From<&[T]> for Array
1635where
1636    Array: From<ArrayExt<T>>,
1637{
1638    fn from(values: &[T]) -> Self {
1639        ArrayExt::from(values).into()
1640    }
1641}
1642
1643impl<T: af::HasAfEnum> FromIterator<T> for Array
1644where
1645    Array: From<ArrayExt<T>>,
1646{
1647    fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
1648        ArrayExt::from_iter(iter).into()
1649    }
1650}
1651
1652impl From<Vec<Number>> for Array {
1653    fn from(elements: Vec<Number>) -> Self {
1654        use {ComplexType as CT, FloatType as FT, IntType as IT, NumberType as NT, UIntType as UT};
1655
1656        let dtype = elements.iter().map(|n| n.class()).fold(NT::Bool, Ord::max);
1657
1658        let array = match dtype {
1659            NT::Bool => Self::Bool(array_from(elements)),
1660            NT::Complex(ct) => match ct {
1661                CT::C32 => Self::C32(array_from(elements)),
1662                _ => Self::C64(array_from(elements)),
1663            },
1664            NT::Float(ft) => match ft {
1665                FT::F32 => Self::F32(array_from(elements)),
1666                _ => Self::F64(array_from(elements)),
1667            },
1668            NT::Int(it) => match it {
1669                IT::I8 => Self::I16(array_from(elements)),
1670                IT::I16 => Self::I16(array_from(elements)),
1671                IT::I32 => Self::I32(array_from(elements)),
1672                _ => Self::I64(array_from(elements)),
1673            },
1674            NT::UInt(ut) => match ut {
1675                UT::U8 => Self::U8(array_from(elements)),
1676                UT::U16 => Self::U16(array_from(elements)),
1677                UT::U32 => Self::U32(array_from(elements)),
1678                _ => Self::U64(array_from(elements)),
1679            },
1680            NT::Number => Self::F64(array_from(elements)),
1681        };
1682
1683        array
1684    }
1685}
1686
1687impl<'de> Deserialize<'de> for Array {
1688    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> std::result::Result<Self, D::Error> {
1689        Vec::<Number>::deserialize(deserializer).map(Self::from)
1690    }
1691}
1692
1693fn array_from<T: af::HasAfEnum + CastFrom<Number>>(elements: Vec<Number>) -> ArrayExt<T> {
1694    elements
1695        .into_iter()
1696        .map(|n| n.cast_into())
1697        .collect::<Vec<T>>()
1698        .as_slice()
1699        .into()
1700}
1701
1702impl Serialize for Array {
1703    fn serialize<S: Serializer>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> {
1704        self.to_vec().serialize(serializer)
1705    }
1706}
1707
1708#[async_trait]
1709impl de::FromStream for Array {
1710    type Context = ();
1711
1712    async fn from_stream<D: de::Decoder>(
1713        _: (),
1714        decoder: &mut D,
1715    ) -> std::result::Result<Self, D::Error> {
1716        decoder.decode_seq(ArrayVisitor).await
1717    }
1718}
1719
1720impl<'en> en::ToStream<'en> for Array {
1721    fn to_stream<E: en::Encoder<'en>>(
1722        &'en self,
1723        encoder: E,
1724    ) -> std::result::Result<E::Ok, E::Error> {
1725        use en::IntoStream;
1726
1727        match self {
1728            Self::Bool(array) => (DType::Bool, array).into_stream(encoder),
1729            Self::C32(array) => (DType::C32, array.re(), array.im()).into_stream(encoder),
1730            Self::C64(array) => (DType::C64, array.re(), array.im()).into_stream(encoder),
1731            Self::F32(array) => (DType::F32, array).into_stream(encoder),
1732            Self::F64(array) => (DType::F64, array).into_stream(encoder),
1733            Self::I16(array) => (DType::I16, array).into_stream(encoder),
1734            Self::I32(array) => (DType::I32, array).into_stream(encoder),
1735            Self::I64(array) => (DType::I64, array).into_stream(encoder),
1736            Self::U8(array) => (DType::U8, array).into_stream(encoder),
1737            Self::U16(array) => (DType::U16, array).into_stream(encoder),
1738            Self::U32(array) => (DType::U32, array).into_stream(encoder),
1739            Self::U64(array) => (DType::U64, array).into_stream(encoder),
1740        }
1741    }
1742}
1743
1744impl<'en> en::IntoStream<'en> for Array {
1745    fn into_stream<E: en::Encoder<'en>>(self, encoder: E) -> std::result::Result<E::Ok, E::Error> {
1746        match self {
1747            Self::Bool(array) => (DType::Bool, array).into_stream(encoder),
1748            Self::C32(array) => (DType::C32, array.re(), array.im()).into_stream(encoder),
1749            Self::C64(array) => (DType::C64, array.re(), array.im()).into_stream(encoder),
1750            Self::F32(array) => (DType::F32, array).into_stream(encoder),
1751            Self::F64(array) => (DType::F64, array).into_stream(encoder),
1752            Self::I16(array) => (DType::I16, array).into_stream(encoder),
1753            Self::I32(array) => (DType::I32, array).into_stream(encoder),
1754            Self::I64(array) => (DType::I64, array).into_stream(encoder),
1755            Self::U8(array) => (DType::U8, array).into_stream(encoder),
1756            Self::U16(array) => (DType::U16, array).into_stream(encoder),
1757            Self::U32(array) => (DType::U32, array).into_stream(encoder),
1758            Self::U64(array) => (DType::U64, array).into_stream(encoder),
1759        }
1760    }
1761}
1762
1763impl fmt::Debug for Array {
1764    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1765        match self {
1766            Self::Bool(array) => fmt::Debug::fmt(array, f),
1767            Self::C32(array) => fmt::Debug::fmt(array, f),
1768            Self::C64(array) => fmt::Debug::fmt(array, f),
1769            Self::F32(array) => fmt::Debug::fmt(array, f),
1770            Self::F64(array) => fmt::Debug::fmt(array, f),
1771            Self::I16(array) => fmt::Debug::fmt(array, f),
1772            Self::I32(array) => fmt::Debug::fmt(array, f),
1773            Self::I64(array) => fmt::Debug::fmt(array, f),
1774            Self::U8(array) => fmt::Debug::fmt(array, f),
1775            Self::U16(array) => fmt::Debug::fmt(array, f),
1776            Self::U32(array) => fmt::Debug::fmt(array, f),
1777            Self::U64(array) => fmt::Debug::fmt(array, f),
1778        }
1779    }
1780}
1781
1782impl fmt::Display for Array {
1783    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1784        match self {
1785            Self::Bool(array) => fmt::Display::fmt(array, f),
1786            Self::C32(array) => fmt::Display::fmt(array, f),
1787            Self::C64(array) => fmt::Display::fmt(array, f),
1788            Self::F32(array) => fmt::Display::fmt(array, f),
1789            Self::F64(array) => fmt::Display::fmt(array, f),
1790            Self::I16(array) => fmt::Display::fmt(array, f),
1791            Self::I32(array) => fmt::Display::fmt(array, f),
1792            Self::I64(array) => fmt::Display::fmt(array, f),
1793            Self::U8(array) => fmt::Display::fmt(array, f),
1794            Self::U16(array) => fmt::Display::fmt(array, f),
1795            Self::U32(array) => fmt::Display::fmt(array, f),
1796            Self::U64(array) => fmt::Display::fmt(array, f),
1797        }
1798    }
1799}
1800
1801struct ArrayVisitor;
1802
1803impl ArrayVisitor {
1804    async fn visit_array<A: de::SeqAccess, T: af::HasAfEnum>(
1805        seq: &mut A,
1806    ) -> std::result::Result<ArrayExt<T>, A::Error>
1807    where
1808        ArrayExt<T>: de::FromStream<Context = ()>,
1809    {
1810        seq.next_element(())
1811            .await?
1812            .ok_or_else(|| de::Error::custom("missing array"))
1813    }
1814}
1815
1816#[async_trait]
1817impl de::Visitor for ArrayVisitor {
1818    type Value = Array;
1819
1820    fn expecting() -> &'static str {
1821        "a numeric array"
1822    }
1823
1824    async fn visit_seq<A: de::SeqAccess>(
1825        self,
1826        mut seq: A,
1827    ) -> std::result::Result<Self::Value, A::Error> {
1828        let dtype = seq
1829            .next_element::<DType>(())
1830            .await?
1831            .ok_or_else(|| de::Error::custom("missing array data type"))?;
1832
1833        match dtype {
1834            DType::Bool => Self::visit_array(&mut seq).map_ok(Array::Bool).await,
1835            DType::C32 => {
1836                let re = Self::visit_array(&mut seq).await?;
1837                let im = Self::visit_array(&mut seq).await?;
1838                Ok(Array::C32(ArrayExt::from((re, im))))
1839            }
1840            DType::C64 => {
1841                let re = Self::visit_array(&mut seq).await?;
1842                let im = Self::visit_array(&mut seq).await?;
1843                Ok(Array::C64(ArrayExt::from((re, im))))
1844            }
1845            DType::F32 => Self::visit_array(&mut seq).map_ok(Array::F32).await,
1846            DType::F64 => Self::visit_array(&mut seq).map_ok(Array::F64).await,
1847            DType::I16 => Self::visit_array(&mut seq).map_ok(Array::I16).await,
1848            DType::I32 => Self::visit_array(&mut seq).map_ok(Array::I32).await,
1849            DType::I64 => Self::visit_array(&mut seq).map_ok(Array::I64).await,
1850            DType::U8 => Self::visit_array(&mut seq).map_ok(Array::U8).await,
1851            DType::U16 => Self::visit_array(&mut seq).map_ok(Array::U16).await,
1852            DType::U32 => Self::visit_array(&mut seq).map_ok(Array::U32).await,
1853            DType::U64 => Self::visit_array(&mut seq).map_ok(Array::U64).await,
1854        }
1855    }
1856}
1857
1858#[derive(Clone, Copy, Eq, PartialEq, num_derive::FromPrimitive, num_derive::ToPrimitive)]
1859enum DType {
1860    Bool,
1861    C32,
1862    C64,
1863    F32,
1864    F64,
1865    I16,
1866    I32,
1867    I64,
1868    U8,
1869    U16,
1870    U32,
1871    U64,
1872}
1873
1874#[async_trait]
1875impl de::FromStream for DType {
1876    type Context = ();
1877
1878    async fn from_stream<D: de::Decoder>(
1879        cxt: (),
1880        decoder: &mut D,
1881    ) -> std::result::Result<Self, D::Error> {
1882        let dtype = u8::from_stream(cxt, decoder).await?;
1883        Self::from_u8(dtype).ok_or_else(|| de::Error::invalid_value(dtype, "an array data type"))
1884    }
1885}
1886
1887impl<'en> en::IntoStream<'en> for DType {
1888    fn into_stream<E: en::Encoder<'en>>(self, encoder: E) -> std::result::Result<E::Ok, E::Error> {
1889        self.to_u8().into_stream(encoder)
1890    }
1891}
1892
1893pub(crate) fn reduce_block<T, B, R>(block: &ArrayExt<T>, stride: u64, reduce: &mut R) -> ArrayExt<B>
1894where
1895    T: af::HasAfEnum,
1896    B: af::HasAfEnum,
1897    R: FnMut(af::Array<T>) -> ArrayExt<B>,
1898{
1899    assert_eq!(block.len() as u64 % stride, 0);
1900    let shape = af::Dim4::new(&[stride, block.len() as u64 / stride, 1, 1]);
1901    let block = af::moddims(&block, shape);
1902    let reduced = reduce(block.into());
1903    let shape = af::Dim4::new(&[reduced.len() as u64, 1, 1, 1]);
1904    af::moddims(&reduced, shape).into()
1905}
1906
1907#[cfg(test)]
1908mod tests {
1909    use super::*;
1910
1911    #[test]
1912    fn test_get_value() {
1913        assert_eq!(Array::from(&[1, 2, 3][..]).get_value(1), Number::from(2));
1914    }
1915
1916    #[test]
1917    fn test_get() {
1918        let arr = Array::from(vec![1, 2, 3].as_slice());
1919        let actual = arr.get(&(&[1, 2][..]).into());
1920        let expected = Array::from(&[2, 3][..]);
1921        assert_eq!(actual, expected)
1922    }
1923
1924    #[test]
1925    fn test_set() {
1926        let mut actual = Array::from(&[1, 2, 3][..]);
1927        actual
1928            .set(&(&[1, 2][..]).into(), &Array::from(&[4, 5][..]))
1929            .unwrap();
1930
1931        let expected = Array::from(&[1, 4, 5][..]);
1932        assert_eq!(actual, expected)
1933    }
1934
1935    #[test]
1936    fn test_add() {
1937        let a: Array = [1, 2, 3][..].into();
1938        let b: Array = [1][..].into();
1939        assert_eq!(&a + &b, [2, 3, 4][..].into());
1940
1941        let b: Array = [3, 2, 1][..].into();
1942        assert_eq!(&a + &b, [4, 4, 4][..].into());
1943
1944        assert_eq!(&b + Number::from(1), [4, 3, 2][..].into());
1945    }
1946
1947    #[test]
1948    fn test_add_float() {
1949        let a: Array = [1, 2, 3][..].into();
1950        let b: Array = [2.0][..].into();
1951        assert_eq!(&a + &b, [3.0, 4.0, 5.0][..].into());
1952
1953        let b: Array = [-1., -4., 4.][..].into();
1954        assert_eq!(&a + &b, [0., -2., 7.][..].into());
1955
1956        assert_eq!(&b + Number::from(3), [2, -1, 7][..].into());
1957    }
1958
1959    #[test]
1960    fn test_gte() {
1961        let a: Array = [0, 1, 2][..].into();
1962        let b: Array = [1][..].into();
1963        assert_eq!(a.gte(&b), [false, true, true][..].into());
1964        assert_eq!(a.gte_const(Number::from(1)), [false, true, true][..].into());
1965    }
1966
1967    #[test]
1968    fn test_sub() {
1969        let a: Array = [1, 2, 3][..].into();
1970        let b: Array = [1][..].into();
1971        assert_eq!(&a - &b, [0, 1, 2][..].into());
1972
1973        let b: Array = [3, 2, 1][..].into();
1974        assert_eq!(&a - &b, [-2, 0, 2][..].into());
1975    }
1976
1977    #[test]
1978    fn test_sub_float() {
1979        let a: Array = [1, 2, 3][..].into();
1980        let b: Array = [2.0][..].into();
1981        assert_eq!(&a - &b, [-1.0, 0., 1.0][..].into());
1982
1983        let b: Array = [-1., -4., 4.][..].into();
1984        assert_eq!(&a - &b, [2., 6., -1.][..].into());
1985    }
1986
1987    #[test]
1988    fn test_mul() {
1989        let a: Array = [1, 2, 3][..].into();
1990        let b: Array = [2][..].into();
1991        assert_eq!(&a * &b, [2, 4, 6][..].into());
1992
1993        let b: Array = [5, 4, 3][..].into();
1994        assert_eq!(&a * &b, [5, 8, 9][..].into());
1995    }
1996
1997    #[test]
1998    fn test_mul_const() {
1999        let a: Array = [1, 2, 3][..].into();
2000        let b: Number = 2f32.into();
2001        assert_eq!(&a * b, [2.0, 4.0, 6.0][..].into());
2002    }
2003
2004    #[test]
2005    fn test_mul_float() {
2006        let a: Array = [1.0f32, 2.0f32, 3.0f32][..].into();
2007        let b: Array = [2.0f32][..].into();
2008        assert_eq!(&a * &b, [2.0, 4.0, 6.0][..].into());
2009
2010        let b: Array = [-1., -4., 4.][..].into();
2011        assert_eq!(&a * &b, [-1., -8., 12.][..].into());
2012    }
2013
2014    #[test]
2015    fn test_div() {
2016        let a: Array = [1, 2, 3][..].into();
2017        let b: Array = [2.0][..].into();
2018        assert_eq!(&a / &b, [0.5, 1.0, 1.5][..].into());
2019
2020        let b: Array = [-1., -4., 4.][..].into();
2021        assert_eq!(&a / &b, [-1., -0.5, 0.75][..].into());
2022    }
2023
2024    #[test]
2025    fn test_pow() {
2026        let a: Array = [1, 2, 3][..].into();
2027        let b: Array = [2][..].into();
2028        assert_eq!(a.pow(&b), [1.0, 4.0, 9.0][..].into());
2029
2030        let a: Array = [1, 2, 3][..].into();
2031        let b: Array = [2.0][..].into();
2032        assert_eq!(a.pow(&b), [1.0, 4.0, 9.0][..].into());
2033
2034        let a: Array = [1.0, 2.0, 3.0][..].into();
2035        let b: Array = [2][..].into();
2036        assert_eq!(a.pow(&b), [1.0, 4.0, 9.0][..].into());
2037    }
2038
2039    #[test]
2040    fn test_min_and_max() {
2041        let a: Array = [3, 1, 4, 2][..].into();
2042        assert_eq!(a.min(), 1.into());
2043        assert_eq!(a.max(), 4.into());
2044    }
2045
2046    #[test]
2047    fn test_sum() {
2048        let a: Array = [1, 2, 3, 4][..].into();
2049        assert_eq!(a.sum(), 10.into());
2050    }
2051
2052    #[test]
2053    fn test_product() {
2054        let a: Array = [1, 2, 3, 4][..].into();
2055        assert_eq!(a.product(), 24.into());
2056    }
2057
2058    #[test]
2059    fn test_argsort() {
2060        let a = Array::random_uniform(FloatType::F32, 10);
2061        let (sorted, indices) = a.argsort(true).expect("argsort");
2062        assert_eq!(sorted, a.get(&indices))
2063    }
2064
2065    #[tokio::test]
2066    async fn test_serialization() {
2067        let expected: Array = [1, 2, 3, 4][..].into();
2068        let serialized = tbon::en::encode(&expected).expect("encode");
2069        let actual = tbon::de::try_decode((), serialized).await.expect("decode");
2070        assert!(expected.eq(&actual).all());
2071    }
2072}